import os
import json
from collections import defaultdict

# Provided data structures (keep these as they are for parsing)
RELATIONS = {
    "3d": [
        'in front of',
        'behind of',
        # 'hidden by',
        'at the back left of',
        'at the front left of',
        'at the back right of',
        'at the front right of',
    ]
}

SCENES_PROMPT = [
    "on the desert",
    "in the room",
    "on the street",
    "in the jungle",
    "on the road",
    "in the studio",
    "on the beach",
    "on a snowy landscape",
    "in the apartment",
    "in the library",
]

OBJECTS_CATEGORIES = {
    "animals": ['dog', 'mouse', 'sheep', 'cat', 'cow', 'chicken', 'turtle', 'giraffe', 'pig', 'butterfly', 'horse', 'bird', 'rabbit', 'frog', 'fish'],
    "indoor": ['bed', 'desk', 'key', 'chair', 'vase', 'candle', 'cup', 'phone', 'computer', 'bowl', 'sofa', 'balloon', 'plate', 'refrigerator', 'wallet', 'bag', 'painting', 'suitcase', 'table', 'couch', 'clock', 'book', 'lamp', 'television'],
    "outdoor": ["car", "motorcycle", "backpack", "bench", 'train', 'airplane', 'bicycle'],
    "person": ['woman', 'man', 'boy', 'girl'],
}

# --- Helper to map object names to categories ---
object_to_category_map = {}
for category, objects in OBJECTS_CATEGORIES.items():
    for obj in objects:
        object_to_category_map[obj] = category

def get_object_category(obj_name: str, obj_category_map: dict) -> str:
    """Looks up the category for a given object name."""
    return obj_category_map.get(obj_name, "unknown")

# This function is kept for potential other uses, but not for primary grouping in the modified script.
def get_pair_category_combination(obj1: str, obj2: str, obj_category_map: dict) -> tuple:
    """Determines the sorted category combination for a pair of objects."""
    cat1 = get_object_category(obj1, obj_category_map)
    cat2 = get_object_category(obj2, obj_category_map)
    return tuple(sorted((cat1, cat2)))


def parse_question_id(question_id: str, relations: dict, scenes: list) -> tuple or None:
    """
    Parses the question_id string to extract the two objects and the relation.
    Assumes format like 'a [obj1] [relation] a [obj2] [scene]'
    Returns a tuple of (obj1, obj2, relation), or None if parsing fails.
    obj1 and obj2 are in the order they appear in the question.
    """
    found_relation = None
    relation_start_index = -1
    relation_end_index = -1

    all_relation_strings = [rel for rel_list in relations.values() for rel in rel_list]
    all_relation_strings.sort(key=len, reverse=True)

    for rel in all_relation_strings:
        index = question_id.find(f" {rel} ") # Ensure relation is surrounded by spaces for more precise matching
        if index != -1:
            # Adjust index to get the actual start of the relation string
            # " a car in front of a dog" -> index points to " i" of " in front of "
            # We want the start of "in front of"
            actual_relation_start_index_in_question = index + 1 # move past leading space
            if question_id[actual_relation_start_index_in_question : actual_relation_start_index_in_question + len(rel)] == rel:
                found_relation = rel
                relation_start_index = actual_relation_start_index_in_question
                relation_end_index = relation_start_index + len(rel)
                break

    if not found_relation:
        return None

    found_scene = None
    scene_start_index = -1
    part_after_relation = question_id[relation_end_index:].strip() # Strip to handle potential leading space before scene

    # Sort scenes by length descending
    scenes_sorted = sorted(scenes, key=len, reverse=True)

    for scene in scenes_sorted:
        # Scene is expected at the end, optionally preceded by a space
        if part_after_relation.endswith(scene):
            # Ensure it's the whole scene and not a substring of obj2 if obj2 is long
            # e.g. if obj2 could be "road" and scene "on the road"
            # We find scene from the end of part_after_relation
            temp_scene_start_in_part = part_after_relation.rfind(scene)
            # check if character before scene is a space or if scene starts the part_after_relation
            if temp_scene_start_in_part == 0 or (temp_scene_start_in_part > 0 and part_after_relation[temp_scene_start_in_part-1] == ' '):
                 found_scene = scene
                 # scene_start_index is relative to the original question_id
                 scene_start_index = relation_end_index + part_after_relation.rfind(scene) # Use rfind to get last occurrence
                 break


    part_before_relation = question_id[:relation_start_index].strip()
    obj1 = part_before_relation
    if obj1.startswith("a "): obj1 = obj1[2:]
    elif obj1.startswith("an "): obj1 = obj1[3:]
    obj1 = obj1.strip()

    if found_scene and scene_start_index != -1:
        # Object 2 is between the end of the relation and the start of the scene
        part_between = question_id[relation_end_index : scene_start_index].strip()
    else:
        # No scene, or scene parsing failed to pinpoint its start relative to obj2.
        # Assume obj2 is the rest of the string after relation.
        part_between = question_id[relation_end_index:].strip()
        # If a scene was technically found but couldn't be cleanly separated,
        # try to remove it from the end of part_between if it's there.
        # This is a fallback.
        if found_scene and part_between.endswith(found_scene):
             # Check for space before scene
            if len(part_between) > len(found_scene) and part_between[-(len(found_scene)+1)] == ' ':
                part_between = part_between[:-(len(found_scene)+1)].strip()
            elif len(part_between) == len(found_scene): # part_between is just the scene
                part_between = "" # No object 2 if it's only the scene

    obj2 = part_between
    if obj2.startswith("a "): obj2 = obj2[2:]
    elif obj2.startswith("an "): obj2 = obj2[3:]
    obj2 = obj2.strip()

    if not obj1 or not obj2:
         return None

    if obj1 == obj2:
        return None

    return tuple((obj1, obj2, found_relation))


def process_vqa_results(file_path: str) -> tuple:
    """
    Loads VQA results, parses question IDs, and calculates average scores
    for each prompt (obj1, relation, obj2). Also collects question IDs.

    Returns:
        A tuple containing two dictionaries:
        1. average_scores_by_prompt: Keys are (obj1, relation, obj2) tuples,
           values are dicts with 'avg_score' and 'scores' list.
        2. question_ids_by_prompt: Keys are (obj1, relation, obj2) tuples,
           values are lists of original question_ids.
    """
    scores_by_prompt = defaultdict(list)
    question_ids_by_prompt = defaultdict(list)

    try:
        with open(file_path, 'r') as f:
            data = json.load(f)
    except FileNotFoundError:
        print(f"Error: File not found at {file_path}")
        return {}, {}
    except json.JSONDecodeError:
        print(f"Error: Could not decode JSON from {file_path}")
        return {}, {}

    for entry in data:
        question_id = entry.get('question_id')
        answer = entry.get('answer')

        if question_id is None or answer is None:
            continue

        parsed_info = parse_question_id(question_id, RELATIONS, SCENES_PROMPT)

        if parsed_info:
            p_obj1, p_obj2, p_relation = parsed_info
            prompt_key = (p_obj1, p_relation, p_obj2, question_id) # Key is (obj1, relation, obj2)

            try:
                score = float(answer)
                scores_by_prompt[prompt_key].append(score)
                question_ids_by_prompt[prompt_key].append(question_id)
            except (ValueError, TypeError):
                pass

    average_scores_by_prompt = {}
    for prompt_key, scores in scores_by_prompt.items():
        if scores:
            average_scores_by_prompt[prompt_key] = {
                'avg_score': sum(scores) / len(scores),
                'scores': scores,
            }

    return average_scores_by_prompt, question_ids_by_prompt


# --- Define the two sets of paths ---
gen_file_path = 'exp/gen_14000/labels/annotation_obj_detection_3d/vqa_result.json'
gen_image_path = 'exp/gen_14000/samples'
gen_file_path = 'exp/gen_30000/labels/annotation_obj_detection_3d/vqa_result.json'
gen_image_path = 'exp/gen_30000/samples'

eligen_file_path = 'exp/eligen/labels/annotation_obj_detection_3d/vqa_result.json'
eligen_image_path = 'exp/eligen/samples'

output_txt_file_path = 'score_comparison_results_by_prompt.txt' # Changed filename for clarity
output_json_file_path = 'highly_scored_prompts.json' # Changed filename for clarity

# --- Process both sets of results ---
print(f"Processing results from: {gen_file_path}")
gen_avg_scores_by_prompt, gen_qids_by_prompt = process_vqa_results(gen_file_path)
print(f"Finished processing {len(gen_avg_scores_by_prompt)} unique prompts from 'gen'.")

print(f"Processing results from: {eligen_file_path}")
eligen_avg_scores_by_prompt, eligen_qids_by_prompt = process_vqa_results(eligen_file_path)
print(f"Finished processing {len(eligen_avg_scores_by_prompt)} unique prompts from 'eligen'.")

# --- Calculate Overall Averages ---
overall_gen_avg = sum(item['avg_score'] for item in gen_avg_scores_by_prompt.values()) / len(gen_avg_scores_by_prompt) if gen_avg_scores_by_prompt else 0
overall_eligen_avg = sum(item['avg_score'] for item in eligen_avg_scores_by_prompt.values()) / len(eligen_avg_scores_by_prompt) if eligen_avg_scores_by_prompt else 0

# --- Find Common Prompts and Calculate Differences ---
common_prompts = set(gen_avg_scores_by_prompt.keys()) & set(eligen_avg_scores_by_prompt.keys())

score_differences = [] # Stores (prompt_key, difference, gen_score, eligen_score)
for prompt_key in common_prompts:
    gen_score = gen_avg_scores_by_prompt[prompt_key]['avg_score']
    eligen_score = eligen_avg_scores_by_prompt[prompt_key]['avg_score']
    difference = gen_score - eligen_score
    score_differences.append((prompt_key, difference, gen_score, eligen_score))

# --- Group Filtered Prompts by Relation ---
# Structure: relation -> list of [prompt_data_dict]
grouped_prompts_by_relation = defaultdict(list)
filtered_prompts_for_json = [] # Stores (obj1, relation, obj2) tuples

for prompt_key, diff, gen_s, eligen_s in score_differences:
    obj1, relation, obj2, prompt = prompt_key

    # if diff > 0 and gen_s > 0.7: # Filter criteria
    # if diff > -0.05 and gen_s > 0.3:

    # if diff > -0.05 and gen_s > 0.45 and relation=='in front of' or \
    #    diff > -0.05 and gen_s > 0.45 and relation=='behind of' or \
    #    diff > -0.05 and gen_s > 0.4 and relation=='at the back left of' or \
    #    diff > -0.05 and gen_s > 0.4 and relation=='at the front left of' or \
    #    diff > -0.05 and gen_s > 0.4 and relation=='at the back right of' or \
    #    diff > -0.05 and gen_s > 0.4 and relation=='at the front right of':

    if diff > -0.05 and gen_s > 0.45 and relation=='in front of' or \
       diff > -0.05 and gen_s > 0.45 and relation=='behind of' or \
       diff > -0.05 and relation=='at the back left of' or \
       diff > -0.05 and relation=='at the front left of' or \
       diff > 0 and gen_s > 0.4 and relation=='at the back right of' or \
       diff > -0.05 and relation=='at the front right of':
        grouped_prompts_by_relation[relation].append({
            'prompt_tuple': prompt_key,
            'prompt': prompt,
            'obj1': obj1,
            'relation': relation, # This is the grouping key
            'obj2': obj2,
            'diff': diff,
            'gen_score': gen_s,
            'eligen_score': eligen_s,
            'obj1_category': get_object_category(obj1, object_to_category_map),
            'obj2_category': get_object_category(obj2, object_to_category_map)
        })
        filtered_prompts_for_json.append(prompt_key)


# --- Write Results to Text File, Grouped by Relation then by Prompt ---
with open(output_txt_file_path, 'w') as f:
    f.write("--- VQA Score Comparison Results (Grouped by Relation and Prompt) ---\n\n")
    f.write(f"Gen Results File: {gen_file_path}\n")
    f.write(f"Eligen Results File: {eligen_file_path}\n")
    f.write("-" * 70 + "\n\n")

    f.write("Overall Average Scores (Across All Prompts):\n")
    f.write(f"  Average score for 'gen' ({len(gen_avg_scores_by_prompt)} prompts): {overall_gen_avg:.4f}\n")
    f.write(f"  Average score for 'eligen' ({len(eligen_avg_scores_by_prompt)} prompts): {overall_eligen_avg:.4f}\n")
    f.write("-" * 70 + "\n\n")

    f.write("Score Differences for Common Prompts, Grouped by Relation:\n")
    f.write("Format: (Object1 Relation Object2): Diff (Gen - Eligen) [Gen Score, Eligen Score] (Categories: Cat1, Cat2)\n")
    f.write("Prompts listed here satisfy: (Gen Avg Score - Eligen Avg Score) > 0 AND Gen Avg Score > 0.5\n")
    f.write("-" * 70 + "\n")

    all_relations_in_order = [rel for rel_list in RELATIONS.values() for rel in rel_list]

    for relation in all_relations_in_order:
        if relation in grouped_prompts_by_relation:
            prompts_data_list = grouped_prompts_by_relation[relation]

            if not prompts_data_list:
                continue

            f.write(f"\n>>> Relation: {relation} ({len(prompts_data_list)} filtered prompts) <<<\n")
            f.write("=" * (len(relation) + 30) + "\n")

            # Sort prompts within this relation, e.g., by difference
            prompts_data_list.sort(key=lambda item: item['diff'], reverse=True)

            # Calculate and print average scores for filtered prompts *in this relation*
            avg_diff_relation = sum(item['diff'] for item in prompts_data_list) / len(prompts_data_list)
            avg_gen_relation = sum(item['gen_score'] for item in prompts_data_list) / len(prompts_data_list)
            avg_eligen_relation = sum(item['eligen_score'] for item in prompts_data_list) / len(prompts_data_list)
            f.write(f"  Average scores for these filtered prompts: Gen={avg_gen_relation:.4f}, Eligen={avg_eligen_relation:.4f}, Avg Diff={avg_diff_relation:.4f}\n")
            f.write("  Individual prompts (sorted by difference):\n")

            for i, prompt_data in enumerate(prompts_data_list):
                prompt = prompt_data['prompt']
                p_obj1 = prompt_data['obj1']
                # p_relation = prompt_data['relation'] # is the current 'relation'
                p_obj2 = prompt_data['obj2']
                p_diff = prompt_data['diff']
                p_gen_s = prompt_data['gen_score']
                p_eligen_s = prompt_data['eligen_score']
                cat1 = prompt_data['obj1_category']
                cat2 = prompt_data['obj2_category']

                f.write(f"    {i+1}. ({prompt}): {p_diff:.4f} [{p_gen_s:.4f}, {p_eligen_s:.4f}]\n")
        else:
            # Optionally mention relations with no filtered prompts
            # f.write(f"\n>>> Relation: {relation} (0 filtered prompts meeting criteria) <<<\n")
            pass

    f.write("-" * 70 + "\n\n")
    f.write("Notes:\n")
    f.write("- A 'prompt' is defined as (Object1, Relation, Object2).\n")
    f.write("- Positive difference means 'gen' score is higher than 'eligen' score for that prompt.\n")

print(f"\nComparison results saved to '{output_txt_file_path}'")

# --- Write Filtered Prompts to JSON File ---
# filtered_prompts_for_json already contains (obj1, relation, obj2) tuples
print(f"Writing {len(filtered_prompts_for_json)} filtered prompts to '{output_json_file_path}'")
with open(output_json_file_path, 'w') as f:
    # For better readability in JSON, convert list of tuples to list of dicts
    json_output_list = []
    for p_obj1, p_relation, p_obj2, prompt in filtered_prompts_for_json:
        json_output_list.append({"object1": p_obj1, "relation": p_relation, "object2": p_obj2, "prompt": prompt})
    json.dump(json_output_list, f, indent=4)

print(f"Successfully wrote filtered prompts to '{output_json_file_path}'")