import os
import json
from collections import defaultdict

# Provided data structures (keep these as they are for parsing)
RELATIONS = {
    "3d": [
        'in front of',
        'behind of',
        'hidden by',
        'at the back left of',
        'at the front left of',
        'at the back right of',
        'at the front right of',
    ]
}

SCENES_PROMPT = [
    "on the desert",
    "in the room",
    "on the street",
    "in the jungle",
    "on the road",
    "in the studio",
    "on the beach",
    "on a snowy landscape",
    "in the apartment",
    "in the library",
]

OBJECTS_CATEGORIES = {
    "animals": ['dog', 'mouse', 'sheep', 'cat', 'cow', 'chicken', 'turtle', 'giraffe', 'pig', 'butterfly', 'horse', 'bird', 'rabbit', 'frog', 'fish'],
    "indoor": ['bed', 'desk', 'key', 'chair', 'vase', 'candle', 'cup', 'phone', 'computer', 'bowl', 'sofa', 'balloon', 'plate', 'refrigerator', 'wallet', 'bag', 'painting', 'suitcase', 'table', 'couch', 'clock', 'book', 'lamp', 'television'],
    "outdoor": ["car", "motorcycle", "backpack", "bench", 'train', 'airplane', 'bicycle'],
    "person": ['woman', 'man', 'boy', 'girl'],
}

# --- Helper to map object names to categories ---
object_to_category_map = {}
for category, objects in OBJECTS_CATEGORIES.items():
    for obj in objects:
        object_to_category_map[obj] = category

def get_object_category(obj_name: str, obj_category_map: dict) -> str:
    """Looks up the category for a given object name."""
    return obj_category_map.get(obj_name, "unknown") # Return 'unknown' if object not found in categories

def get_pair_category_combination(obj1: str, obj2: str, obj_category_map: dict) -> tuple:
    """Determines the sorted category combination for a pair of objects."""
    cat1 = get_object_category(obj1, obj_category_map)
    cat2 = get_object_category(obj2, obj_category_map)
    # Sort categories alphabetically for consistent grouping (e.g., ('animals', 'indoor') vs ('indoor', 'animals'))
    return tuple(sorted((cat1, cat2)))


def parse_question_id(question_id: str, relations: dict, scenes: list) -> tuple or None:
    """
    Parses the question_id string to extract the two objects and the relation.
    Assumes format like 'a [obj1] [relation] a [obj2] [scene]'
    Returns a tuple of (obj1, obj2, relation), or None if parsing fails.
    obj1 and obj2 are sorted alphabetically.
    """
    found_relation = None
    relation_start_index = -1
    relation_end_index = -1

    # Find the relation first
    all_relation_strings = [rel for rel_list in relations.values() for rel in rel_list]
    # Sort relations by length descending to find longer matches first
    all_relation_strings.sort(key=len, reverse=True)

    for rel in all_relation_strings:
        index = question_id.find(rel)
        if index != -1:
            found_relation = rel
            relation_start_index = index
            relation_end_index = index + len(rel)
            break # Found a relation, assume only one per question

    if not found_relation:
        # print(f"Debug: No relation found in '{question_id}'") # Optional debug
        return None # Parsing failed

    # Find the scene after the relation
    found_scene = None
    scene_start_index = -1
    part_after_relation = question_id[relation_end_index:]

    # Sort scenes by length descending
    scenes.sort(key=len, reverse=True)

    for scene in scenes:
        index = part_after_relation.find(scene)
        if index != -1:
            found_scene = scene
            scene_start_index = relation_end_index + index
            break # Found a scene, assume only one per question

    # Extract object 1: part before the relation
    part_before_relation = question_id[:relation_start_index].strip()
    # Remove "a " or "an " and trim whitespace
    obj1 = part_before_relation
    if obj1.startswith("a "):
        obj1 = obj1[2:]
    elif obj1.startswith("an "):
         obj1 = obj1[3:]
    obj1 = obj1.strip()


    # Extract object 2: part between the relation and the scene (or end of string)
    if found_scene and scene_start_index != -1:
        part_between = question_id[relation_end_index : scene_start_index].strip()
    else:
        # No scene found after relation, assume the rest of the string contains obj2
        part_between = question_id[relation_end_index:].strip()

    # Remove "a " or "an " and trim whitespace
    obj2 = part_between
    if obj2.startswith("a "):
        obj2 = obj2[2:]
    elif obj2.startswith("an "):
        obj2 = obj2[3:]
    obj2 = obj2.strip()

    # Basic validation: Ensure both objects were extracted and are not empty
    if not obj1 or not obj2:
         # print(f"Debug: Objects not found correctly in '{question_id}' -> obj1='{obj1}', obj2='{obj2}'") # Optional debug
         return None # Parsing failed if objects are empty

    # Ignore pairs where objects are the same
    if obj1 == obj2:
        # print(f"Debug: Skipping pair with identical objects in '{question_id}'") # Optional debug
        return None

    # Return the object pair (sorted for consistency) and the relation
    # if obj1 > obj2:
    #     obj1, obj2 = obj2, obj1

    return tuple((obj1, obj2, found_relation))


def process_vqa_results(file_path: str) -> tuple:
    """
    Loads VQA results from a JSON file, parses question IDs, and calculates
    average scores for each object pair and relation. Also collects question IDs.

    Args:
        file_path: Path to the vqa_result.json file.

    Returns:
        A tuple containing two dictionaries:
        1. avg_scores: Keys are (obj1, obj2, relation) tuples, values are
           dictionaries containing 'avg_score' and 'scores' list.
        2. question_ids_by_pair: Keys are (obj1, obj2, relation) tuples, values are
           lists of original question_ids.
    """
    scores_by_pair_relation = defaultdict(list)
    question_ids_by_pair = defaultdict(list)

    try:
        with open(file_path, 'r') as f:
            data = json.load(f)
    except FileNotFoundError:
        print(f"Error: File not found at {file_path}")
        return {}, {}
    except json.JSONDecodeError:
        print(f"Error: Could not decode JSON from {file_path}")
        return {}, {}


    for entry in data:
        question_id = entry.get('question_id')
        answer = entry.get('answer') # Use .get() for safety

        if question_id is None or answer is None:
            # Skip entries that don't have both question_id and answer
            continue

        # Parse the question_id to get the object pair and relation
        pair_relation_key = parse_question_id(question_id, RELATIONS, SCENES_PROMPT)

        if pair_relation_key:
            # Add the score and question_id to the lists for this pair and relation
            try:
                score = float(answer)
                scores_by_pair_relation[pair_relation_key].append(score)
                question_ids_by_pair[pair_relation_key].append(question_id)
            except (ValueError, TypeError):
                # print(f"Warning: Skipping non-numeric answer for {question_id}: {answer}") # Optional warning
                pass


    # Calculate average score for each pair and relation
    average_scores = {}
    for pair_relation_key, scores in scores_by_pair_relation.items():
        if scores: # Ensure there are scores to avoid division by zero
            average_scores[pair_relation_key] = {
                'avg_score': sum(scores) / len(scores),
                'scores': scores,
            }

    return average_scores, question_ids_by_pair


# --- Define the two sets of paths ---
gen_file_path = 'exp/gen_14000/labels/annotation_obj_detection_3d/vqa_result.json'
# gen_file_path = 'exp/gen/labels/annotation_obj_detection_3d_1/vqa_result.json'
gen_image_path = 'exp/gen_14000/samples' # Note: image_path is not directly used in score calculation, but kept for context

# gen_file_path = 'exp/gen_30000/labels/annotation_obj_detection_3d/vqa_result.json'
# # gen_file_path = 'exp/gen/labels/annotation_obj_detection_3d_1/vqa_result.json'
# gen_image_path = 'exp/gen_30000/samples' # Note: image_path is not directly used in score calculation, but kept for context

eligen_file_path = 'exp/eligen/labels/annotation_obj_detection_3d/vqa_result.json'
# eligen_file_path = 'exp/eligen/labels/annotation_obj_detection_3d_1/vqa_result.json'
eligen_image_path = 'exp/eligen/samples' # Note: image_path is not directly used in score calculation, but kept for context

output_txt_file_path = 'score_comparison_results.txt'
# output_txt_file_path = 'score_comparison_results_1.txt'
output_json_file_path = 'highly_scored_prompt.json' # New output file for question IDs

# --- Process both sets of results ---
print(f"Processing results from: {gen_file_path}")
gen_avg_scores, gen_qids_by_pair = process_vqa_results(gen_file_path)
print(f"Finished processing {len(gen_avg_scores)} unique object/relation pairs from 'gen'.")

print(f"Processing results from: {eligen_file_path}")
eligen_avg_scores, eligen_qids_by_pair = process_vqa_results(eligen_file_path)
print(f"Finished processing {len(eligen_avg_scores)} unique object/relation pairs from 'eligen'.")

# --- Calculate Overall Averages (Across All Relations and Categories) ---
overall_gen_avg = sum(item['avg_score'] for item in gen_avg_scores.values()) / len(gen_avg_scores) if gen_avg_scores else 0
overall_eligen_avg = sum(item['avg_score'] for item in eligen_avg_scores.values()) / len(eligen_avg_scores) if eligen_avg_scores else 0

# --- Find Common Pairs and Calculate Differences ---
common_pairs_relation = set(gen_avg_scores.keys()) & set(eligen_avg_scores.keys())

score_differences = []
for pair_relation_key in common_pairs_relation:
    gen_score = gen_avg_scores[pair_relation_key]['avg_score']
    eligen_score = eligen_avg_scores[pair_relation_key]['avg_score']
    difference = gen_score - eligen_score # Difference: Gen - Eligen
    score_differences.append((pair_relation_key, difference, gen_score, eligen_score))

# --- Group Score Differences by Relation AND Category Combination ---
# Nested defaultdict: relation -> category_combination -> list of results
grouped_by_relation_then_category = defaultdict(lambda: defaultdict(list))

# List to store question IDs for the filtered pairs
filtered_question_ids = []

for pair_relation_key, diff, gen_s, eligen_s in score_differences:
    # Apply the same filter used for the text output file
    if diff > 0 and gen_s > 0.5:
        obj1, obj2, relation = pair_relation_key
        category_combination = get_pair_category_combination(obj1, obj2, object_to_category_map)
        grouped_by_relation_then_category[relation][category_combination].append((pair_relation_key, diff, gen_s, eligen_s))

        # Collect all question IDs associated with this pair from the 'gen' results
        if pair_relation_key in gen_qids_by_pair:
            obj1, obj2, relation = pair_relation_key
            filtered_question_ids.append((obj1, relation, obj2))
            # filtered_question_id = gen_qids_by_pair[pair_relation_key][0]
            # for scene in SCENES_PROMPT:
            #     if scene in filtered_question_id:
            #         filtered_question_id = filtered_question_id.replace(scene, '')
            # filtered_question_ids.append(filtered_question_id)
            # filtered_question_ids.extend(gen_qids_by_pair[pair_relation_key])


# --- Write Results to Text File, Grouped by Relation and Category ---
with open(output_txt_file_path, 'w') as f:
    f.write("--- VQA Score Comparison Results (Grouped by Relation and Object Category Combination) ---\n\n")
    f.write(f"Gen Results File: {gen_file_path}\n")
    f.write(f"Eligen Results File: {eligen_file_path}\n")
    f.write("-" * 60 + "\n\n")

    f.write("Overall Average Scores (Across All Relations and Categories):\n")
    f.write(f"  Average score for 'gen' ({len(gen_avg_scores)} pairs): {overall_gen_avg:.4f}\n")
    f.write(f"  Average score for 'eligen' ({len(eligen_avg_scores)} pairs): {overall_eligen_avg:.4f}\n")
    f.write("-" * 60 + "\n\n")

    f.write("Score Differences for Common Object Pairs, Grouped by Relation and Object Category Combination:\n")
    f.write("Format: (Object1, Object2): Diff (Gen - Eligen) [Gen Score, Eligen Score]\n")
    f.write("Pairs listed here satisfy: (Gen Avg Score - Eligen Avg Score) > 0 AND Gen Avg Score > 0.5\n") # Clarify filter
    f.write("-" * 60 + "\n")

    # Iterate through relations (using the predefined list for consistent order)
    all_relations = [rel for rel_list in RELATIONS.values() for rel in rel_list]
    for relation in all_relations:
        relation_groups = grouped_by_relation_then_category.get(relation)

        if not relation_groups:
            # Skip relations with no common pairs satisfying the filter
            continue

        f.write(f"\n>>> Relation: {relation} <<<\n")
        f.write("=" * (len(relation) + 15) + "\n") # Use different separator for main grouping

        # Sort category combinations alphabetically for consistent output order within a relation
        sorted_category_combinations = sorted(relation_groups.keys())

        for category_combination in sorted_category_combinations:
            results_list = relation_groups[category_combination]

            # Format category combination string nicely
            cat_str = f"{category_combination[0]} - {category_combination[1]}"
            f.write(f"  Category Combination: {cat_str} ({len(results_list)} pairs)\n")
            f.write("    " + "-" * (len(f"Category Combination: {cat_str}") - 2) + "\n") # dynamic underline

            # Sort results within this category group by difference (Gen - Eligen) descending
            results_list.sort(key=lambda item: item[1], reverse=True)

            # Calculate and print average scores for *this category combination within this relation*
            avg_diff_group = sum(item[1] for item in results_list) / len(results_list)
            avg_gen_group = sum(item[2] for item in results_list) / len(results_list)
            avg_eligen_group = sum(item[3] for item in results_list) / len(results_list)

            f.write(f"    Average scores for this group: Gen={avg_gen_group:.4f}, Eligen={avg_eligen_group:.4f}, Avg Diff={avg_diff_group:.4f}\n")
            f.write("    Individual pairs (sorted by difference):\n")

            object_set = set()
            for i, (pair_relation_key, diff, gen_s, eligen_s) in enumerate(results_list):
                object_set.add(pair_relation_key[0])
                object_set.add(pair_relation_key[1])
                 # pair_relation_key is (obj1, obj2, relation), we only need obj1 and obj2 here for display
                f.write(f"      {i+1}. ({pair_relation_key[0]}, {pair_relation_key[1]}): {diff:.4f} [{gen_s:.4f}, {eligen_s:.4f}]\n")
            sorted_object_list = sorted(list(object_set))
            f.write(f"    object_list: {sorted_object_list}\n")
            f.write(f"    object_len:  {len(sorted_object_list)}\n")
            f.write("\n") # Add a blank line after each category group for readability

    f.write("-" * 60 + "\n\n")
    f.write("Notes:\n")
    f.write("- Positive difference means 'gen' score is higher than 'eligen' score.\n")
    f.write("- Negative difference means 'eligen' score is higher than 'gen' score.\n")
    f.write("- Object category combinations are sorted alphabetically (e.g., 'animals - indoor').\n")


print(f"\nComparison results saved to '{output_txt_file_path}'")

# --- Write Filtered Question IDs to JSON File ---
print(f"Writing filtered question IDs to '{output_json_file_path}'")
with open(output_json_file_path, 'w') as f:
    json.dump(filtered_question_ids, f, indent=4)

print(f"Successfully wrote {len(filtered_question_ids)} question IDs to '{output_json_file_path}'")