import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
os.environ["HF_CACHE_DIR"] = "/mnt/workspace/workgroup/zheliu.lzy/.cache"
os.environ["TRANSFORMERS_CACHE"] = "/mnt/workspace/workgroup/zheliu.lzy/.cache"
os.environ["HF_HOME"] = "/mnt/workspace/workgroup/zheliu.lzy/.cache"
os.environ["PYTHONWARNINGS"] = "ignore::FutureWarning"

import json
import argparse
from tqdm import tqdm
import numpy as np
import torch
import torchvision.transforms as T
from PIL import Image, ImageDraw
from diffusers.pipelines import FluxPipeline
from transformers import pipeline
from transformers import SamModel, SamProcessor
from transformers import AutoImageProcessor, AutoModelForDepthEstimation

import sys
sys.path.append('/mnt/workspace/workgroup/zheliu.lzy/vision_cot/OminiControl')

from src.flux.generate import generate, seed_everything
from src.flux.condition import Condition
from src.flux.module import Inter_Controller, Spatial_Controller
from src.flux.pipeline_tools import visualize_masks
from src.utils.dataset import find_applicable_scenes 
# from src.utils.dataset_eligen import json_generation 

from src.utils.scene import DiffusionScene
from src.utils.prompt import gen_prompt, edit_prompt, identity_prompt, gen_prompt_2d
from src.utils.vlm import vlm_request, extract_and_parse_json, extract_and_parse_list

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")

image_processor = AutoImageProcessor.from_pretrained("LiheYoung/depth-anything-base-hf")
depth_model = AutoModelForDepthEstimation.from_pretrained("LiheYoung/depth-anything-base-hf", torch_dtype=torch.bfloat16)

# Paths and parameters
flux_path = "black-forest-labs/FLUX.1-dev"
flux_path = "black-forest-labs/FLUX.1-schnell"
condition_size = 512
target_size = 512
condition_type = "eligen_loose"

model_config = {
    # 'ckpt': "14000",
    'ckpt': "30000",
    'ckpt_path': "/mnt/workspace/workgroup/zheliu.lzy/vision_cot/OminiControl/runs_new/20250405-133306_eligen_loose_reward_weight_split_eligen_loose_None",
    'condition_type': "eligen_loose",
    'inter_controller_type': None,
    'eligen_depth_attn': False,
    'latent_lora': ['eligen']
}

eligen_path = '/mnt/workspace/workgroup/zheliu.lzy/vision_cot/OminiControl/checkpoints/eligen.bin'

# Load model
pipe = FluxPipeline.from_pretrained(flux_path, torch_dtype=torch.bfloat16)
print("Load Flux model successfully")

state_dict = torch.load(eligen_path)
pipe.transformer.load_lora_adapter(state_dict, prefix="transformer", adapter_name="eligen",)

print("Load Flux lora successfully")
active_adapters = pipe.get_active_adapters()
print(active_adapters)

pipe = pipe.to(device=device, dtype=torch.bfloat16)

def json_generation(caption, entities=None):
    if entities is None:
        messages=[
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": f"{identity_prompt.replace('<caption>', caption)}"},
                ]
            }
        ]
        content = vlm_request(messages)
        answer = content.split('</think>')[-1]
        entities = extract_and_parse_list(answer)

    messages=[
        {
            "role": "user",
            "content": [
                {"type": "text", "text": f"{gen_prompt.replace('<caption>', caption).replace('<entities>', json.dumps(entities))}"},
            ]
        }
    ]
    content = vlm_request(messages)
    answer = content.split('</think>')[-1]
    ans_json = extract_and_parse_json(answer)

    data = {
        'caption': caption,
        'entities': entities,
        'ans_json': ans_json,
        'content': content,
    }
    
    return data

def find_nonzero_bounding_box(vector):
  """
  检测numpy向量（数组）中非零区域的边界框。

  Args:
    vector: 一个 NumPy 数组。

  Returns:
    如果向量中存在非零元素，则返回一个包含 (x_min, y_min, x_max, y_max) 的元组。
    如果向量中所有元素都为零，则返回 None。
  """
  # 检查输入是否为 NumPy 数组
  if not isinstance(vector, np.ndarray):
    raise TypeError("输入必须是 NumPy 数组")

  # 检查数组维度是否为 2
  if vector.ndim != 2:
      raise ValueError("输入数组必须是二维的")

  # 找到所有非零元素的索引
  non_zero_indices = np.nonzero(vector)

  # non_zero_indices 是一个包含两个数组的元组：
  # 第一个数组是行索引 (y 坐标)
  # 第二个数组是列索引 (x 坐标)
  y_indices = non_zero_indices[0]
  x_indices = non_zero_indices[1]

  # 检查是否存在非零元素
  if len(y_indices) == 0:
    # 如果没有非零元素，则返回 None
    return None

  # 计算 x 和 y 坐标的最小值和最大值
  y_min = np.min(y_indices)
  y_max = np.max(y_indices)
  x_min = np.min(x_indices)
  x_max = np.max(x_indices)

  return (x_min, y_min, x_max, y_max)

def calculate_mask_centroid(mask_tensor):
    """
    计算二值掩码区域的中心（质心）。

    参数:
    mask_tensor (torch.Tensor): 输入的二值掩码张量，通常形状为 (H, W) 或 (D, H, W) 等。
                                掩码区域的值应为 1 或 True，背景为 0 或 False。

    返回:
    torch.Tensor: 质心的坐标，例如对于 2D 掩码，返回一个包含 [cy, cx] 的张量。
                  如果掩码为空，则可能返回NaN或需要特殊处理。
    """
    if not mask_tensor.any():
        # 如果掩码中没有非零元素，可以返回一个特定的值，例如 (NaN, NaN)
        # 或者根据你的应用场景抛出错误或返回 None
        num_dims = mask_tensor.ndim
        return torch.full((num_dims,), float('nan'), device=mask_tensor.device)

    # 获取所有非零元素的坐标
    coordinates = torch.nonzero(mask_tensor, as_tuple=True)

    x_indices = coordinates[1]
    y_indices = coordinates[0]

    y_min = torch.min(y_indices)
    y_max = torch.max(y_indices)
    x_min = torch.min(x_indices)
    x_max = torch.max(x_indices)

    return (x_min, y_min, x_max, y_max)

def generate_scene(ans_json, total_move=None):
    scene_size = ans_json['scene_parameters']['scene_size'] / 2
    cam_pitch_angle = 90 - ans_json['scene_parameters']['camera_pitch_angle']
    # cam_pitch_angle = 90
    floor_scale_x = 1
    floor_scale_y = 1

    floor_offset = - scene_size / 2

    # Build the scene    
    scene = DiffusionScene(scene_size=scene_size, fov=(60,60))
    scene.move_camera(rotation_angle=cam_pitch_angle,rotation_axis=[1,0,0], translation=[0,0,0])# rotation_axis(x,z,y), translation(x, z, y)
    # scene.move_camera(rotation_angle=0,rotation_axis=[1,0,0], translation=[0,-2*scene_size,0])# rotation_axis(x,z,y), translation(x, z, y)
    scene.build_floor(scale_x=floor_scale_x, scale_y=floor_scale_y, floor_offset=floor_offset)

    for i, entity in enumerate(ans_json['entity_layout']):
        scene.add_box(id=f"box_{i}", size=entity['size'], origin=entity['position'], prompt=entity['entity_name'])
        # scene.box(f"box_{i}").rotate_left(entity['orient'])
        # mask_b2, latent_mask_b2, p_image_b2 = scene.get_box_masks(box_id="box_2")

    if total_move is None:
        num = 0
        total_move = 0
        depth_all = scene.render(single=True, floor=False, render_floor=False, depth_max=4*scene_size)
        x_min, y_min, x_max, y_max = find_nonzero_bounding_box(depth_all[-1])
        move = entity_center(x_min, y_min, x_max, y_max, depth_all[-1].shape)
        while move != 0 and num < 40:
            scene.move_camera(rotation_angle=0,rotation_axis=[1,0,0], translation=[0,move,0])# rotation_axis(x,z,y), translation(x, z, y)
            depth_all = scene.render(single=True, floor=False, render_floor=False, depth_max=4*scene_size)
            x_min, y_min, x_max, y_max = find_nonzero_bounding_box(depth_all[-1])
            move = entity_center(x_min, y_min, x_max, y_max, depth_all[-1].shape)
            num += 1
            total_move += move
    else:
        scene.move_camera(rotation_angle=0,rotation_axis=[1,0,0], translation=[0,total_move,0])# rotation_axis(x,z,y), translation(x, z, y)

    depth_all = scene.render_consis(depth_max=2.4, depth_min=0.4)
    return depth_all, total_move

def main(prompt, data):
    scene_size = data['scene_parameters']['scene_size']
    depth_all, total_move = generate_scene(data, total_move=0)

    caption = prompt
    eligen_entity_prompts = [entity['entity_name'] for entity in data["entity_layout"]]
    # print(len(eligen_entity_prompts))

    condition_imgs = []
    for depth in depth_all[:len(eligen_entity_prompts)]:
        # depth = np.where(depth==depth_all[-1], depth, 0)
        condition_imgs.append(Image.fromarray(depth).convert("RGB").resize((condition_size, condition_size)))
                    
    # Process masks
    eligen_entity_masks = []
    eligen_entity_masks_pil = []
    for depth in depth_all[:len(eligen_entity_prompts)]:
        x_min,y_min,x_max,y_max = find_nonzero_bounding_box(depth)
        
        mask = Image.new("L", (condition_size, condition_size), 0)
        draw = ImageDraw.Draw(mask)
        draw.rectangle([x_min, y_min, x_max, y_max], fill=255)
        mask = mask.convert("RGB")
        eligen_entity_masks_pil.append(mask)

        mask = np.array(mask.resize((condition_size//8, condition_size//8)))
        mask = np.where(mask > 0, 1, 0).astype(np.uint8)
        mask_tensor = torch.from_numpy(mask).to(device=pipe.device, dtype=pipe.dtype)
        eligen_entity_masks.append(mask_tensor.unsqueeze(0))
        
    seed = 42
    seed_everything(seed)
    # 构建generate参数
    num_inference_steps = 4
    generate_kwargs = {
        "prompt": prompt,
        "default_lora": True,
        "num_inference_steps": num_inference_steps,
        "conditions": None,
        "height": condition_size,
        "width": condition_size,
        "eligen_entity_prompts": eligen_entity_prompts,
        "eligen_entity_masks": eligen_entity_masks,
    }

    # 执行生成
    image = generate(
        pipe, 
        model_config=model_config,
        # save_tmp_image=f'/mnt/workspace/workgroup/zheliu.lzy/vision_cot/OminiControl/show/{prompt[:50]}',
        **generate_kwargs
    ).images[0].resize((target_size, target_size))

    # 处理mask可视化
    if eligen_entity_masks_pil:
        image_mask = visualize_masks(image, eligen_entity_masks_pil, eligen_entity_prompts)

    return image, image_mask, depth_all

if __name__ == "__main__":
    prompt_list = os.listdir('consist/prompt')
    os.makedirs('consist/data/eligen', exist_ok=True)
    if os.path.exists(f"consist/eligen_2.json"):
        with open(f"consist/eligen_2.json", 'r') as f:
            all_data = json.load(f)
    else:
        all_data = {}

    for prompt in tqdm(prompt_list):
        os.makedirs(f'consist/data/eligen/{prompt}', exist_ok=True)
        os.makedirs(f'consist/data/eligen/{"reverse_"+prompt}', exist_ok=True)
        with open(f"consist/prompt/{prompt}/align.json", 'r') as f:
            json_data = json.load(f)
        data = json_data['data']
        diffs = [-0.5, -0.4, -0.3, -0.2, -0.1, 0, 0.1, 0.2, 0.3, 0.4, 0.5]
        diffs = [-0.6, -0.55, -0.5, -0.45, -0.4, -0.35, -0.3, -0.25, -0.2, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6]

        similaritys = []
        differents = []
        for diff in diffs:
            data['entity_layout'][1]['position'][2] = data['entity_layout'][0]['position'][2] + diff

            image, image_mask, depth_gt = main(prompt, data)
            image.save(f"consist/data/eligen/{prompt}/image_{diff}.png")
            Image.fromarray(depth_gt[-1]).save(f"consist/data/eligen/{prompt}/depth_{diff}.png")

            gt = []
            gen = []
            for i in range(2):
                x_min,y_min,x_max,y_max = find_nonzero_bounding_box(depth_gt[i])
                input_boxes = [[[x_min,y_min,x_max,y_max]]]
                # mean_x_gt = (x_min + x_max) / 2 / image.size[0]
                # mean_y_gt = (y_min + y_max) / 2 / image.size[0]
                # mean_depth_gt = depth_gt[i][depth_gt[i]!=0].mean() / 255
                # mean_x_gt = data['entity_layout'][i]['position'][0]
                # mean_y_gt = data['entity_layout'][i]['position'][1]
                # mean_depth_gt = data['entity_layout'][0]['position'][2]
                # gt.append([mean_x_gt, mean_y_gt, mean_depth_gt])

                inputs = processor(image, input_boxes=[input_boxes], return_tensors="pt").to(device)
                with torch.no_grad():
                    outputs = model(**inputs)

                masks = processor.image_processor.post_process_masks(
                    outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()
                )
                x_min,y_min,x_max,y_max = calculate_mask_centroid(masks[0][0,0])
                # Image.fromarray(masks[0][0,0].cpu().numpy().astype(np.uint8)*255).save('1_1.png')

                inputs = image_processor(images=image, return_tensors="pt")
                with torch.no_grad():
                    outputs = depth_model(**inputs)

                post_processed_output = image_processor.post_process_depth_estimation(
                    outputs,
                    target_sizes=[(image.height, image.width)],
                )
                predicted_depth = post_processed_output[0]["predicted_depth"]
                depth = (predicted_depth - predicted_depth.min()) / (predicted_depth.max() - predicted_depth.min())
                # depth = depth.detach().to(torch.float).cpu().numpy() * 255
                # depth = predicted_depth

                mean_x = (x_min + x_max) / 2 / image.size[0]
                mean_y = (y_min + y_max) / 2 / image.size[0]
                mean_depth = depth[masks[0][0,0]].mean()
                gen.append([mean_x, mean_y, mean_depth])
            
            # gt = torch.tensor(gt).to(torch.float32)
            # gt = gt[1] - gt[0]
            # # gt[2] = diff
            # gt = gt / torch.sqrt(torch.sum(gt*gt))
            # gen = torch.tensor(gen).to(torch.float32)
            # gen = gen[1] - gen[0]
            # gen = gen / torch.sqrt(torch.sum(gen*gen))
            # similarity = torch.sum(gt * gen)

            # gt = torch.tensor(gt).to(torch.float32)
            # gen = torch.tensor(gen).to(torch.float32)
            # dz_gt = gt[1][2] - gt[0][2]  # 这是期望的深度变化量
            # dz_gen = gen[1][2] - gen[0][2] # 这是实际生成的深度变化量
            # epsilon = 1e-8
            # similarity = 1.0 - torch.abs(dz_gt - dz_gen) / (torch.abs(dz_gt) + torch.abs(dz_gen) + epsilon)

            gen = torch.tensor(gen).to(torch.float32)
            dz_gt = torch.tensor(- diff / 2).to(torch.float32)  # 这是期望的深度变化量
            dz_gen = gen[1][2] - gen[0][2] # 这是实际生成的深度变化量
            epsilon = 1e-8
            similarity = 1.0 - torch.abs(dz_gt - dz_gen) / (torch.abs(dz_gt) + torch.abs(dz_gen) + epsilon)

            print(similarity)
            similaritys.append(similarity.item())
            differents.append([dz_gt.item(), dz_gen.item(), torch.abs(dz_gt - dz_gen).item()])
        
        all_data[prompt] = {
            'similaritys': similaritys,
            'differents': differents,
        }

        similaritys = []
        differents = []
        data['entity_layout'][0]['position'][0] = - data['entity_layout'][0]['position'][0]
        data['entity_layout'][1]['position'][0] = - data['entity_layout'][1]['position'][0]
        for diff in diffs:
            data['entity_layout'][1]['position'][2] = data['entity_layout'][0]['position'][2] + diff

            image, image_mask, depth_gt = main(prompt, data)
            image.save(f"consist/data/eligen/{'reverse_'+prompt}/image_{diff}.png")
            Image.fromarray(depth_gt[-1]).save(f"consist/data/eligen/{'reverse_'+prompt}/depth_{diff}.png")

            gt = []
            gen = []
            for i in range(2):
                x_min,y_min,x_max,y_max = find_nonzero_bounding_box(depth_gt[i])
                input_boxes = [[[x_min,y_min,x_max,y_max]]]
                # mean_x_gt = (x_min + x_max) / 2 / image.size[0]
                # mean_y_gt = (y_min + y_max) / 2 / image.size[0]
                # mean_depth_gt = depth_gt[i][depth_gt[i]!=0].mean() / 255
                # mean_x_gt = data['entity_layout'][i]['position'][0] / 4
                # mean_y_gt = data['entity_layout'][i]['position'][1] / 4
                # mean_depth_gt = diff / 2
                # gt.append([mean_x_gt, mean_y_gt, mean_depth_gt])

                inputs = processor(image, input_boxes=[input_boxes], return_tensors="pt").to(device)
                with torch.no_grad():
                    outputs = model(**inputs)

                masks = processor.image_processor.post_process_masks(
                    outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()
                )
                x_min,y_min,x_max,y_max = calculate_mask_centroid(masks[0][0,0])
                # Image.fromarray(masks[0][0,0].cpu().numpy().astype(np.uint8)*255).save('1_1.png')

                inputs = image_processor(images=image, return_tensors="pt")
                with torch.no_grad():
                    outputs = depth_model(**inputs)

                post_processed_output = image_processor.post_process_depth_estimation(
                    outputs,
                    target_sizes=[(image.height, image.width)],
                )
                predicted_depth = post_processed_output[0]["predicted_depth"]
                depth = (predicted_depth - predicted_depth.min()) / (predicted_depth.max() - predicted_depth.min())
                # depth = depth.detach().to(torch.float).cpu().numpy() * 255
                # depth = predicted_depth

                mean_x = (x_min + x_max) / 2 / image.size[0]
                mean_y = (y_min + y_max) / 2 / image.size[0]
                mean_depth = depth[masks[0][0,0]].mean()
                gen.append([mean_x, mean_y, mean_depth])
            
            # gt = torch.tensor(gt).to(torch.float32)
            # gt = gt[1] - gt[0]
            # # gt[2] = diff
            # gt = gt / torch.sqrt(torch.sum(gt*gt))
            # gen = torch.tensor(gen).to(torch.float32)
            # gen = gen[1] - gen[0]
            # gen = gen / torch.sqrt(torch.sum(gen*gen))
            # similarity = torch.sum(gt * gen)

            # gt = torch.tensor(gt).to(torch.float32)
            # gen = torch.tensor(gen).to(torch.float32)
            # dz_gt = gt[1][2] - gt[0][2]  # 这是期望的深度变化量
            # dz_gen = gen[1][2] - gen[0][2] # 这是实际生成的深度变化量
            # epsilon = 1e-8
            # similarity = 1.0 - torch.abs(dz_gt - dz_gen) / (torch.abs(dz_gt) + torch.abs(dz_gen) + epsilon)

            gen = torch.tensor(gen).to(torch.float32)
            dz_gt = torch.tensor(- diff / 2).to(torch.float32)  # 这是期望的深度变化量
            dz_gen = gen[1][2] - gen[0][2] # 这是实际生成的深度变化量
            epsilon = 1e-8
            similarity = 1.0 - torch.abs(dz_gt - dz_gen) / (torch.abs(dz_gt) + torch.abs(dz_gen) + epsilon)

            print(similarity)
            similaritys.append(similarity.item())
            differents.append([dz_gt.item(), dz_gen.item(), torch.abs(dz_gt - dz_gen).item()])

        all_data['reverse_'+prompt] = {
            'similaritys': similaritys,
            'differents': differents,
        }

    with open(f"consist/eligen_3.json", 'w') as f:
        json.dump(all_data, f, indent=4)