import os
import json
from collections import defaultdict

# Provided data structures
RELATIONS = {
    "3d": [
        'in front of', 
        'at the back left of', 
        'at the front left of', 
        'behind of', 
        'at the back right of', 
        'at the front right of', 
        'hidden by'
    ]
}

# We only need the list of scene strings for parsing
SCENES_PROMPT = [
    "on the desert",
    "in the room",
    "on the street",
    "in the jungle",
    "on the road",
    "in the studio",
    "on the beach",
    "on a snowy landscape",
    "in the apartment",
    "in the library",
]

OBJECTS_CATEGORIES = {
    "animals": ['dog', 'mouse', 'sheep', 'cat', 'cow', 'chicken', 'turtle', 'giraffe', 'pig', 'butterfly', 'horse', 'bird', 'rabbit', 'frog', 'fish'],
    "indoor": ['bed', 'desk', 'key', 'chair', 'vase', 'candle', 'cup', 'phone', 'computer', 'bowl', 'sofa', 'balloon', 'plate', 'refrigerator', 'wallet', 'bag', 'painting', 'suitcase', 'table', 'couch', 'clock', 'book', 'lamp', 'television'],
    "outdoor": ["car", "motorcycle", "backpack", "bench", 'train', 'airplane', 'bicycle'],
    "person": ['woman', 'man', 'boy', 'girl'],
}


def parse_question_id(question_id: str, relations: dict, scenes: list) -> tuple or None:
    """
    Parses the question_id string to extract the two objects.
    Assumes format like 'a [obj1] [relation] a [obj2] [scene]'
    """
    found_relation = None
    relation_start_index = -1
    relation_end_index = -1

    # Find the relation first
    all_relation_strings = [rel for rel_list in relations.values() for rel in rel_list]
    # Sort relations by length descending to find longer matches first (e.g., 'at the back left of' before 'of')
    all_relation_strings.sort(key=len, reverse=True)

    for rel in all_relation_strings:
        index = question_id.find(rel)
        if index != -1:
            found_relation = rel
            relation_start_index = index
            relation_end_index = index + len(rel)
            break # Found a relation, assume only one per question

    if not found_relation:
        # print(f"Warning: No known relation found in '{question_id}'") # Optional debugging
        return None # Parsing failed

    # Find the scene after the relation
    found_scene = None
    scene_start_index = -1
    part_after_relation = question_id[relation_end_index:]

    # Sort scenes by length descending
    scenes.sort(key=len, reverse=True)

    for scene in scenes:
        index = part_after_relation.find(scene)
        if index != -1:
            found_scene = scene
            scene_start_index = relation_end_index + index
            break # Found a scene, assume only one per question

    # Extract object 1: part before the relation
    part_before_relation = question_id[:relation_start_index].strip()
    obj1 = part_before_relation.replace("a ", "", 1).strip() # Remove leading "a " if present

    # Extract object 2: part between the relation and the scene (or end of string)
    if found_scene and scene_start_index != -1:
        part_between = question_id[relation_end_index : scene_start_index].strip()
    else:
        # No scene found after relation, assume the rest of the string contains obj2
        part_between = question_id[relation_end_index:].strip()

    obj2 = part_between.replace("a ", "", 1).strip() # Remove leading "a " if present

    # Basic validation: Ensure both objects were extracted and are not empty
    if not obj1 or not obj2:
         # print(f"Warning: Could not extract valid objects from '{question_id}'") # Optional debugging
         return None # Parsing failed if objects are empty

    # Return the object pair as a sorted tuple for consistent keying
    # if not (obj1 in OBJECTS_CATEGORIES['indoor'] or obj2 in OBJECTS_CATEGORIES['indoor']):
    #     return None
    return tuple(sorted((obj1, obj2)))

# --- Main Program ---

file_path = 'exp/gen/labels/annotation_obj_detection_3d/vqa_result.json'
image_path = 'exp/gen/samples'
file_path = 'exp/eligen/labels/annotation_obj_detection_3d/vqa_result.json'
image_path = 'exp/eligen/samples'
image_list = os.listdir(image_path)

# Dictionary to store scores for each object pair
# Key: tuple of sorted object names (e.g., ('cat', 'dog'))
# Value: list of scores
scores_by_pair = defaultdict(list)

with open(file_path, 'r') as f:
    data = json.load(f)

# Process each entry in the JSON data
for entry in data:
    question_id = entry['question_id']
    answer = entry['answer']

    # Parse the question_id to get the object pair
    pair_key = parse_question_id(question_id, RELATIONS, SCENES_PROMPT)

    if pair_key:
        # Add the score to the list for this pair
        scores_by_pair[pair_key].append(answer)
    # else:
        # Optional: print(f"Skipped entry due to parsing failure: {question_id}")

# Calculate average score for each pair
average_scores = {}
for pair_key, scores in scores_by_pair.items():
    if scores: # Ensure there are scores to avoid division by zero
        average_scores[pair_key] = {
            'avg_score': sum(scores) / len(scores),
            'scores': scores,
        }

# Sort the pairs by their average score in descending order
sorted_avg_scores = sorted(average_scores.items(), key=lambda item: item[1]['avg_score'], reverse=True)

# Print the results
print("\nObject pair average scores (sorted highest to lowest):")
if not sorted_avg_scores:
    print("No valid object pairs found or processed from the data.")
else:
    total_score = 0
    for i, ((obj1, obj2), avg_score) in enumerate(sorted_avg_scores[:100]):
        print(f"  Pair{i}: ({obj1}, {obj2}) - Average Score: {avg_score['avg_score']:.4f}")
        # print(f"  Pair{i}: ({obj1}, {obj2}) - Average Score: {avg_score['avg_score']:.4f} - {avg_score['scores']}")
        total_score += avg_score['avg_score']
    
    print(f'total_score: {total_score}')