import os
import json
from collections import defaultdict

# Provided data structures (keep these as they are for parsing)
RELATIONS = {
    "3d": [
        'in front of',
        'behind of',
        'hidden by',
        'at the back left of',
        'at the front left of',
        'at the back right of',
        'at the front right of',
    ]
}

SCENES_PROMPT = [
    "on the desert",
    "in the room",
    "on the street",
    "in the jungle",
    "on the road",
    "in the studio",
    "on the beach",
    "on a snowy landscape",
    "in the apartment",
    "in the library",
]

OBJECTS_CATEGORIES = {
    "animals": ['dog', 'mouse', 'sheep', 'cat', 'cow', 'chicken', 'turtle', 'giraffe', 'pig', 'butterfly', 'horse', 'bird', 'rabbit', 'frog', 'fish'],
    "indoor": ['bed', 'desk', 'key', 'chair', 'vase', 'candle', 'cup', 'phone', 'computer', 'bowl', 'sofa', 'balloon', 'plate', 'refrigerator', 'wallet', 'bag', 'painting', 'suitcase', 'table', 'couch', 'clock', 'book', 'lamp', 'television'],
    "outdoor": ["car", "motorcycle", "backpack", "bench", 'train', 'airplane', 'bicycle'],
    "person": ['woman', 'man', 'boy', 'girl'],
}

# --- Helper to map object names to categories ---
object_to_category_map = {}
for category, objects in OBJECTS_CATEGORIES.items():
    for obj in objects:
        object_to_category_map[obj] = category

def get_object_category(obj_name: str, obj_category_map: dict) -> str:
    """Looks up the category for a given object name."""
    return obj_category_map.get(obj_name, "unknown") # Return 'unknown' if object not found in categories

def get_pair_category_combination(obj1: str, obj2: str, obj_category_map: dict) -> tuple:
    """Determines the sorted category combination for a pair of objects."""
    cat1 = get_object_category(obj1, obj_category_map)
    cat2 = get_object_category(obj2, obj_category_map)
    # Sort categories alphabetically for consistent grouping (e.g., ('animals', 'indoor') vs ('indoor', 'animals'))
    return tuple(sorted((cat1, cat2)))


def parse_question_id(question_id: str, relations: dict, scenes: list) -> tuple or None:
    """
    Parses the question_id string to extract the two objects and the relation.
    Assumes format like 'a [obj1] [relation] a [obj2] [scene]'
    Returns a tuple of (obj1, obj2, relation), or None if parsing fails.
    """
    found_relation = None
    relation_start_index = -1
    relation_end_index = -1

    # Find the relation first
    all_relation_strings = [rel for rel_list in relations.values() for rel in rel_list]
    # Sort relations by length descending to find longer matches first
    all_relation_strings.sort(key=len, reverse=True)

    for rel in all_relation_strings:
        index = question_id.find(rel)
        if index != -1:
            found_relation = rel
            relation_start_index = index
            relation_end_index = index + len(rel)
            break # Found a relation, assume only one per question

    if not found_relation:
        # print(f"Debug: No relation found in '{question_id}'") # Optional debug
        return None # Parsing failed

    # Find the scene after the relation
    found_scene = None
    scene_start_index = -1
    part_after_relation = question_id[relation_end_index:]

    # Sort scenes by length descending
    scenes.sort(key=len, reverse=True)

    for scene in scenes:
        index = part_after_relation.find(scene)
        if index != -1:
            found_scene = scene
            scene_start_index = relation_end_index + index
            break # Found a scene, assume only one per question

    # Extract object 1: part before the relation
    part_before_relation = question_id[:relation_start_index].strip()
    # Remove "a " or "an " and trim whitespace
    obj1 = part_before_relation
    if obj1.startswith("a "):
        obj1 = obj1[2:]
    elif obj1.startswith("an "):
         obj1 = obj1[3:]
    obj1 = obj1.strip()


    # Extract object 2: part between the relation and the scene (or end of string)
    if found_scene and scene_start_index != -1:
        part_between = question_id[relation_end_index : scene_start_index].strip()
    else:
        # No scene found after relation, assume the rest of the string contains obj2
        part_between = question_id[relation_end_index:].strip()

    # Remove "a " or "an " and trim whitespace
    obj2 = part_between
    if obj2.startswith("a "):
        obj2 = obj2[2:]
    elif obj2.startswith("an "):
        obj2 = obj2[3:]
    obj2 = obj2.strip()

    # Basic validation: Ensure both objects were extracted and are not empty
    if not obj1 or not obj2:
         # print(f"Debug: Objects not found correctly in '{question_id}' -> obj1='{obj1}', obj2='{obj2}'") # Optional debug
         return None # Parsing failed if objects are empty

    # Ignore pairs where objects are the same
    if obj1 == obj2:
        # print(f"Debug: Skipping pair with identical objects in '{question_id}'") # Optional debug
        return None

    # Return the object pair and the relation
    return tuple((obj1, obj2, found_relation))


def process_vqa_results(file_path: str, method_name: str) -> tuple:
    """
    Loads VQA results from a JSON file, parses question IDs, and calculates
    average scores for each object pair and relation, and collects original
    question IDs.

    Args:
        file_path: Path to the vqa_result.json file.
        method_name: Name of the method being processed (e.g., 'eligen', 'rpg').

    Returns:
        A tuple containing:
        - average_scores: dict where keys are (obj1, obj2, relation) tuples, values are
           dictionaries containing 'avg_score' and 'scores' list.
        - question_ids_by_pair_relation: dict where keys are (obj1, obj2, relation) tuples,
           values are lists of original question_ids parsed from this file for that key.
    """
    scores_and_ids_by_pair_relation = defaultdict(lambda: {'scores': [], 'question_ids': []})

    if not os.path.exists(file_path):
         print(f"Error: File not found at {file_path}")
         return {}, {} # Return empty dicts

    try:
        with open(file_path, 'r') as f:
            data = json.load(f)
    except json.JSONDecodeError:
        print(f"Error: Could not decode JSON from {file_path}")
        return {}, {} # Return empty dicts

    print(f"Processing {len(data)} entries from {method_name} results...")

    for entry in data:
        question_id = entry.get('question_id')
        answer = entry.get('answer') # Use .get() for safety

        if question_id is None or answer is None:
            # Skip entries that don't have both question_id and answer
            continue

        # Parse the question_id to get the object pair and relation
        pair_relation_key = parse_question_id(question_id, RELATIONS, SCENES_PROMPT)

        if pair_relation_key:
            # Add the score and question_id for this pair and relation
            try:
                score = float(answer)
                scores_and_ids_by_pair_relation[pair_relation_key]['scores'].append(score)
                scores_and_ids_by_pair_relation[pair_relation_key]['question_ids'].append(question_id)
            except (ValueError, TypeError):
                # print(f"Warning: Skipping non-numeric answer for {question_id}: {answer}") # Optional warning
                pass


    # Calculate average score and prepare separate dict for question IDs
    average_scores = {}
    question_ids_by_pair_relation = {}
    for pair_relation_key, data in scores_and_ids_by_pair_relation.items():
        scores = data['scores']
        question_ids = data['question_ids']
        if scores: # Ensure there are scores to avoid division by zero
            average_scores[pair_relation_key] = {
                'avg_score': sum(scores) / len(scores),
                'scores': scores, # Keep scores list in case needed later
            }
            question_ids_by_pair_relation[pair_relation_key] = question_ids # Store the list of IDs

    print(f"Finished processing {method_name}. Found {len(average_scores)} unique object/relation pairs.")

    return average_scores, question_ids_by_pair_relation

gen_method_name = 'gen_14000'
eligen_method_name = 'eligen'
rpg_method_name = 'rpg'

# --- Define file paths for all three methods ---
gen_file_path = f'exp/{gen_method_name}/labels/annotation_obj_detection_3d/vqa_result.json'
eligen_file_path = 'exp/eligen/labels/annotation_obj_detection_3d/vqa_result.json'
rpg_file_path = 'exp/rpg/labels/annotation_obj_detection_3d/vqa_result.json'

output_txt_file_path = f'{gen_method_name}_{eligen_method_name}_{rpg_method_name}_detailed_comparison.txt'
output_json_file_path = f'{gen_method_name}_{eligen_method_name}_{rpg_method_name}_filtered_question_ids.json' # JSON output file path


# --- Process all three sets of results ---
gen_avg_scores, gen_question_ids = process_vqa_results(gen_file_path, gen_method_name)
eligen_avg_scores, eligen_question_ids = process_vqa_results(eligen_file_path, eligen_method_name)
rpg_avg_scores, rpg_question_ids = process_vqa_results(rpg_file_path, rpg_method_name)


# --- Calculate Overall Averages ---
overall_gen_avg = sum(item['avg_score'] for item in gen_avg_scores.values()) / len(gen_avg_scores) if gen_avg_scores else 0
overall_eligen_avg = sum(item['avg_score'] for item in eligen_avg_scores.values()) / len(eligen_avg_scores) if eligen_avg_scores else 0
overall_rpg_avg = sum(item['avg_score'] for item in rpg_avg_scores.values()) / len(rpg_avg_scores) if rpg_avg_scores else 0


# --- Find Common Pairs Across All Three Methods ---
common_pairs_all_three = set(gen_avg_scores.keys()) & set(eligen_avg_scores.keys()) & set(rpg_avg_scores.keys())
print(f"\nFound {len(common_pairs_all_three)} common object/relation pairs across all three methods.")

# --- Prepare detailed results for common pairs and collect question IDs for output ---
detailed_common_results = []
output_question_ids_dict = {} # Dictionary to store question IDs for output

for pair_relation_key in common_pairs_all_three:
    gen_score = gen_avg_scores[pair_relation_key]['avg_score']
    eligen_score = eligen_avg_scores[pair_relation_key]['avg_score']
    rpg_score = rpg_avg_scores[pair_relation_key]['avg_score']

    # Calculate the specific differences requested
    diff_gen_eligen = gen_score - eligen_score
    diff_eligen_rpg = eligen_score - rpg_score

    # Filtering condition: gen_score > 0.5 and gen_score > eligen_score and eligen_score > rpg_score
    if gen_score > 0.5 and diff_gen_eligen > 0.05 and diff_eligen_rpg > 0:
    # if gen_score > 0.6 and diff_gen_eligen > 0.05 and diff_eligen_rpg > 0:
        detailed_common_results.append({
            'pair_relation_key': pair_relation_key,
            gen_method_name: gen_score,
            eligen_method_name: eligen_score,
            rpg_method_name: rpg_score,
            f'{gen_method_name}-{eligen_method_name}': diff_gen_eligen,
            f'{eligen_method_name}-{rpg_method_name}': diff_eligen_rpg,
        })

        # Collect question IDs for this specific pair_relation_key that met the filter criteria.
        # We collect IDs from the 'gen' results as a representative set for this common pair key.
        question_ids_for_pair = gen_question_ids.get(pair_relation_key, [])
        # Convert tuple key to string for JSON compatibility
        pair_key_str = str(pair_relation_key)
        output_question_ids_dict[pair_key_str] = question_ids_for_pair


# --- Group Detailed Results by Relation AND Category Combination ---
grouped_by_relation_then_category = defaultdict(lambda: defaultdict(list))

for item in detailed_common_results:
    pair_relation_key = item['pair_relation_key']
    obj1, obj2, relation = pair_relation_key
    category_combination = get_pair_category_combination(obj1, obj2, object_to_category_map)
    grouped_by_relation_then_category[relation][category_combination].append(item)


# --- Write Results to Text File, Grouped by Relation and Category ---
print(f"\nWriting detailed comparison results to '{output_txt_file_path}'...")
with open(output_txt_file_path, 'w') as f:
    f.write(f"--- Detailed VQA Score Comparison Results ({gen_method_name}, {eligen_method_name}, {rpg_method_name}) ---\n\n")
    f.write(f"{gen_method_name} Results File: {gen_file_path}\n")
    f.write(f"{eligen_method_name} Results File: {eligen_file_path}\n")
    f.write(f"{rpg_method_name} Results File: {rpg_file_path}\n")
    f.write("-" * 100 + "\n\n") # Longer separator

    f.write("Overall Average Scores (Across All Relations and Categories):\n")
    f.write(f"  Average score for '{gen_method_name}' ({len(gen_avg_scores)} pairs): {overall_gen_avg:.4f}\n")
    f.write(f"  Average score for '{eligen_method_name}' ({len(eligen_avg_scores)} pairs): {overall_eligen_avg:.4f}\n")
    f.write(f"  Average score for '{rpg_method_name}' ({len(rpg_avg_scores)} pairs): {overall_rpg_avg:.4f}\n")
    f.write("-" * 100 + "\n\n")

    f.write(f"Detailed Scores and Differences for Common Object Pairs ({len(detailed_common_results)} filtered pairs), Grouped by Relation and Object Category Combination:\n")
    f.write(f"Filtering criteria: {gen_method_name}_score > 0.5 AND {gen_method_name}_score > {eligen_method_name}_score AND {eligen_method_name}_score > {rpg_method_name}_score\n")
    f.write(f"Format: (Object1, Object2, Relation): Differences [{gen_method_name}-{eligen_method_name}, {eligen_method_name}-{rpg_method_name}] Scores [{gen_method_name}, {eligen_method_name}, {rpg_method_name}]\n")
    f.write("-" * 100 + "\n")

    # Iterate through relations (using the predefined list for consistent order)
    all_relations = [rel for rel_list in RELATIONS.values() for rel in rel_list]
    for relation in all_relations:
        relation_groups = grouped_by_relation_then_category.get(relation)

        if not relation_groups:
            continue

        f.write(f"\n>>> Relation: {relation} <<<\n")
        f.write("=" * (len(relation) + 15) + "\n")

        # Sort category combinations alphabetically
        sorted_category_combinations = sorted(relation_groups.keys())

        for category_combination in sorted_category_combinations:
            results_list = relation_groups[category_combination]

            if not results_list:
                continue # Skip empty groups

            # Format category combination string nicely
            cat_str = f"{category_combination[0]} - {category_combination[1]}"
            f.write(f"  Category Combination: {cat_str} ({len(results_list)} pairs)\n")
            f.write("    " + "-" * (len(f"Category Combination: {cat_str}") - 2) + "\n")

            # Calculate and print average scores & differences for this group
            avg_gen_group = sum(item[gen_method_name] for item in results_list) / len(results_list)
            avg_eligen_group = sum(item[eligen_method_name] for item in results_list) / len(results_list)
            avg_rpg_group = sum(item[rpg_method_name] for item in results_list) / len(results_list)
            avg_diff_gen_eligen = sum(item[f'{gen_method_name}-{eligen_method_name}'] for item in results_list) / len(results_list)
            avg_diff_eligen_rpg = sum(item[f'{eligen_method_name}-{rpg_method_name}'] for item in results_list) / len(results_list)


            f.write(f"    Average scores for this group: {gen_method_name}={avg_gen_group:.4f}, {eligen_method_name}={avg_eligen_group:.4f}, {rpg_method_name}={avg_rpg_group:.4f}\n")
            f.write(f"    Average differences for this group: {gen_method_name}-{eligen_method_name}={avg_diff_gen_eligen:.4f}, {eligen_method_name}-{rpg_method_name}={avg_diff_eligen_rpg:.4f}\n")
            f.write("    Individual pairs (sorted by Eligen - Rpg difference descending):\n")

            # Sort results within this category group by gen - eligen difference descending (as per original sort)
            results_list.sort(key=lambda item: item[f'{gen_method_name}-{eligen_method_name}'], reverse=True)

            object_set = set()
            for i, item in enumerate(results_list):
                obj1, obj2, rel = item['pair_relation_key']
                object_set.add(obj1)
                object_set.add(obj2)
                # Format the output line as requested
                f.write(f"      {i+1}. ({obj1}, {obj2}, {rel}): Differences [{item[f'{gen_method_name}-{eligen_method_name}']:.4f}, {item[f'{eligen_method_name}-{rpg_method_name}']:.4f}] Scores [{item[gen_method_name]:.4f}, {item[eligen_method_name]:.4f}, {item[rpg_method_name]:.4f}]\n")

            sorted_object_list = sorted(list(object_set))
            f.write(f"    object_list: {sorted_object_list}\n")
            f.write(f"    object_len:  {len(sorted_object_list)}\n")
            f.write("\n") # Add a blank line after each category group

    f.write("-" * 100 + "\n\n")
    f.write(f"Notes:\n")
    f.write(f"- Scores for common pairs are listed as [{gen_method_name}, {eligen_method_name}, {rpg_method_name}].\n")
    f.write(f"- Differences are listed as [{gen_method_name}-{eligen_method_name}, {eligen_method_name}-{rpg_method_name}].\n")
    f.write("- Object category combinations are sorted alphabetically (e.g., 'animals - indoor').\n")
    f.write(f"- Filtered question IDs corresponding to the pairs listed above are saved in '{output_json_file_path}'.\n")


print(f"Detailed comparison results saved to '{output_txt_file_path}'")

# --- Write Filtered Question IDs to JSON File ---
print(f"\nWriting filtered question IDs to '{output_json_file_path}'...")
try:
    with open(output_json_file_path, 'w') as json_f:
        json.dump(output_question_ids_dict, json_f, indent=4)
    print(f"Filtered question IDs saved to '{output_json_file_path}'")
except IOError as e:
    print(f"Error writing JSON file {output_json_file_path}: {e}")