import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
os.environ["HF_CACHE_DIR"] = "/mnt/workspace/workgroup/zheliu.lzy/.cache"
os.environ["TRANSFORMERS_CACHE"] = "/mnt/workspace/workgroup/zheliu.lzy/.cache"
os.environ["HF_HOME"] = "/mnt/workspace/workgroup/zheliu.lzy/.cache"
os.environ["PYTHONWARNINGS"] = "ignore::FutureWarning"

from PIL import Image
import json
import argparse
from tqdm import tqdm
import numpy as np
import torch
import torchvision.transforms as T
from diffusers.pipelines import FluxPipeline

import sys
sys.path.append('/mnt/workspace/workgroup/zheliu.lzy/vision_cot/OminiControl')

from src.flux.generate import generate, seed_everything
from src.flux.condition import Condition
from src.flux.block import Inter_Controller
from src.flux.pipeline_tools import visualize_masks

size = 512

def sort_depth(condition_imgs):
     # 创建非零掩码并计算均值
    mask = condition_imgs != 0
    sum_per_img = (condition_imgs * mask).sum(dim=(1, 2, 3))  # 各图像非零元素总和
    count_per_img = mask.sum(dim=(1, 2, 3)).float()           # 各图像非零元素数量
    means = sum_per_img / count_per_img                       # 计算非零均值

    # 处理全零图像（将NaN转为0）
    means = torch.nan_to_num(means, nan=0.0)

    # 获取降序排列索引（均值大的在前）
    sorted_indices = torch.argsort(means, descending=True).tolist()
    # if len(sorted_indices) > 2:
    #     sorted_indices = sorted_indices[:2]

    return sorted_indices

def load_inter_controller(pipe, module_path):
    state_dict = torch.load(module_path)
    new_dict = {}
    for k, v in state_dict.items():
        new_dict[k.replace('cam_embedder.', '')] = v

    pipe.transformer.inter_controller = Inter_Controller(
        dim=24*128,
        num_attention_heads=24,
        attention_head_dim=128,
    )
    device = pipe.transformer.x_embedder.weight.device
    dtype = pipe.transformer.x_embedder.weight.dtype
    for param in pipe.transformer.inter_controller.parameters():
        param.data = param.to(device=device, dtype=dtype)
    pipe.transformer.inter_controller.load_state_dict(new_dict)

def main(pipe, prompt, file_path='result.png', params=None, target_size=512, seed=42):
    seed_everything(seed)
    res = generate(
        pipe,
        prompt=prompt,
        default_lora=True,
    )
    res.images[0].save(file_path)

def flux(pipe, prompt, file_path='result.png', target_size=512, seed=42):
    seed_everything(seed)
    res = generate(
        pipe,
        prompt=prompt,
        default_lora=True,
    )
    res.images[0].save(file_path)

def eligen(pipe, prompt, eligen_entity_prompts=None, eligen_entity_masks=None, eligen_entity_masks_pil=None, file_path='result.png', target_size=512, seed=42):
    seed_everything(seed)
    res = generate(
        pipe,
        prompt=prompt,
        height=1024,
        width=1024,
        default_lora=True,
        **(
            {
                "eligen_entity_prompts": eligen_entity_prompts,
                "eligen_entity_masks": eligen_entity_masks,
            }
            if eligen_entity_prompts is not None and eligen_entity_masks is not None
            else {}
        ),
    )
    res.images[0].resize((target_size, target_size)).save(file_path)
    if eligen_entity_masks_pil:
        visualize_masks(res.images[0], eligen_entity_masks_pil, eligen_entity_prompts, f"{file_path[:-4]}_mask.png")

def loose_woeligen(pipe, prompt, condition, eligen_entity_masks_pil=None, file_path='result.png', target_size=512, seed=42):
    seed_everything(seed)
    res = generate(
        pipe,
        prompt=prompt,
        conditions=[condition],
        height=target_size,
        width=target_size,
        default_lora=True,
    )
    res.images[0].save(file_path)
    if eligen_entity_masks_pil:
        visualize_masks(res.images[0], eligen_entity_masks_pil, eligen_entity_prompts, f"{file_path[:-4]}_mask.png")

def loose(pipe, prompt, condition, eligen_entity_prompts=None, eligen_entity_masks=None, eligen_entity_masks_pil=None, file_path='result.png', target_size=512, seed=42):
    seed_everything(seed)
    res = generate(
        pipe,
        prompt=prompt,
        conditions=[condition],
        height=target_size,
        width=target_size,
        default_lora=True,
        **(
            {
                "eligen_entity_prompts": eligen_entity_prompts,
                "eligen_entity_masks": eligen_entity_masks,
            }
            if eligen_entity_prompts is not None and eligen_entity_masks is not None
            else {}
        ),
    )
    res.images[0].save(file_path)
    if eligen_entity_masks_pil:
        visualize_masks(res.images[0], eligen_entity_masks_pil, eligen_entity_prompts, f"{file_path[:-4]}_mask.png")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--gpu', type=int, default=0, help='GPU index to use')
    parser.add_argument('--total_gpus', type=int, default=1, help='Total number of GPUs used')
    parser.add_argument('--batch_size', type=int, default=1, help='Number of objects to process per batch')
    parser.add_argument('--seed', type=int, default=42, help='Random seed')
    args = parser.parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = 'flux' # ['flux', 'eligen']
    flux_flag = eligen_flag = (model == 'flux')
    loose_woeligen_flag = False
    loose_flag = True
    lora_flag = False

    flux_path = "black-forest-labs/FLUX.1-dev" if flux_flag else "entity_pretrain"

    pipe = FluxPipeline.from_pretrained(flux_path, torch_dtype=torch.bfloat16)
    print("Load Flux model successfully")        
    pipe = pipe.to(device=device, dtype=torch.bfloat16)

    condition_type = "loose_condition"
    # condition_path = 'src/evaluate/condition.json'
    condition_path = 'test/condition.json'
    save_path = f"outputs/human_data_1"
    os.makedirs(save_path, exist_ok=True)
    condition_size = 512
    target_size = 512
    ckpt = "30000"
    lora_path = f"runs/20250303-213357/ckpt/{ckpt}/pytorch_lora_weights.safetensors"
    module_path = f"runs/20250303-213357/ckpt/{ckpt}/inter_controller.pth"
    # data_path = "/mnt/workspace/workgroup/zheliu.lzy/vision_cot/3d_box/datasets_1"
    data_path = "test"

    with open(condition_path, 'r') as f:
        conditions = json.load(f)


    if flux_flag:
        pipe = pipe.to(device=device, dtype=torch.bfloat16)

        test_list = []
        for condition in tqdm(conditions, desc="🚀 Loading conditions"):
            image_id = condition['image_id']
            caption = condition['caption']
            entities = condition['entities']

            condition_imgs = [Image.open(f'{data_path}/{image_id}/render_depth_{i}.png').resize((condition_size, condition_size)).convert("RGB") for i in range(len(entities))]
            eligen_entity_prompts = [entities[i]['entity'] for i in range(len(entities))]

            eligen_entity_masks = []
            eligen_entity_masks_pil = []
            for i in range(len(condition_imgs)):
                mask = np.array(condition_imgs[i].resize((condition_size//8, condition_size//8)))
                mask = np.where(mask > 0, 1, 0).astype(np.uint8)
                mask_tensor = torch.from_numpy(mask).to(device=pipe.device, dtype=pipe.dtype)
                eligen_entity_masks.append(mask_tensor.unsqueeze(0))

                mask_pil = np.where(np.array(condition_imgs[i]) > 0, 1, 0).astype(np.uint8)
                eligen_entity_masks_pil.append(Image.fromarray(mask_pil*255))
            
            condition_imgs = torch.stack([T.ToTensor()(condition_img) for condition_img in condition_imgs])
            
            sorted_indices = sort_depth(condition_imgs)
            condition = {
                "condition": condition_imgs[sorted_indices],
                "eligen_entity_prompts": [eligen_entity_prompts[idx] for idx in sorted_indices],
                "eligen_entity_masks": [eligen_entity_masks[idx] for idx in sorted_indices],
                'eligen_entity_masks_pil': [eligen_entity_masks_pil[idx] for idx in sorted_indices],
            }
            test_list.append((condition, [0, 0], caption, image_id))

        for i, (condition, position_delta, prompt, image_id) in tqdm(enumerate(test_list), total=len(test_list), desc="🚀 Processing flux batches", unit="batch"):
            file_path = f"{save_path}/flux_{image_id}.png"
            if not os.path.exists(file_path):
                flux(pipe, prompt, file_path=file_path, target_size=target_size, seed=args.seed)
            else:
                print(f"{file_path} exists")
    
    if eligen_flag:
        eligen_path = '/mnt/workspace/workgroup/zheliu.lzy/vision_cot/DiffSynth-Studio/models/lora/entity_control/model_bf16.safetensors'
        pipe.load_lora_weights(eligen_path, weight_name="eligen_lora", adapter_name="eligen")
        pipe = pipe.to(device=device, dtype=torch.bfloat16)

        eligen_size = 1024
        test_list = []
        for condition in tqdm(conditions, desc="🚀 Loading conditions"):
            image_id = condition['image_id']
            caption = condition['caption']
            entities = condition['entities']

            condition_imgs = [Image.open(f'{data_path}/{image_id}/render_depth_{i}.png').resize((eligen_size, eligen_size)).convert("RGB") for i in range(len(entities))]
            eligen_entity_prompts = [entities[i]['entity'] for i in range(len(entities))]

            eligen_entity_masks = []
            eligen_entity_masks_pil = []
            for i in range(len(condition_imgs)):
                mask = np.array(condition_imgs[i].resize((eligen_size//8, eligen_size//8)))
                mask = np.where(mask > 0, 1, 0).astype(np.uint8)
                mask_tensor = torch.from_numpy(mask).to(device=pipe.device, dtype=pipe.dtype)
                eligen_entity_masks.append(mask_tensor.unsqueeze(0))

                mask_pil = np.where(np.array(condition_imgs[i]) > 0, 1, 0).astype(np.uint8)
                eligen_entity_masks_pil.append(Image.fromarray(mask_pil*255))
            
            condition_imgs = torch.stack([T.ToTensor()(condition_img) for condition_img in condition_imgs])
            
            sorted_indices = sort_depth(condition_imgs)
            condition = {
                "condition": condition_imgs[sorted_indices],
                "eligen_entity_prompts": [eligen_entity_prompts[idx] for idx in sorted_indices],
                "eligen_entity_masks": [eligen_entity_masks[idx] for idx in sorted_indices],
                'eligen_entity_masks_pil': [eligen_entity_masks_pil[idx] for idx in sorted_indices],
            }
            test_list.append((condition, [0, 0], caption, image_id))

        for i, (condition, position_delta, prompt, image_id) in tqdm(enumerate(test_list), total=len(test_list), desc="🚀 Processing eligen batches", unit="batch"):
            file_path = f"{save_path}/eligen_{image_id}.png"
            if not os.path.exists(file_path):
                eligen(pipe, prompt, eligen_entity_prompts=condition['eligen_entity_prompts'], eligen_entity_masks=condition['eligen_entity_masks'], eligen_entity_masks_pil=condition['eligen_entity_masks_pil'], file_path=file_path, target_size=target_size, seed=args.seed)
            else:
                print(f"{file_path} exists")

    if loose_woeligen_flag:
        pipe.load_lora_weights(lora_path)
        print("Load lora weights successfully")
        load_inter_controller(pipe, module_path)
        print("Load inter_controller successfully")
        pipe = pipe.to(device=device, dtype=torch.bfloat16)
        lora_flag = True

        test_list = []
        for condition in tqdm(conditions, desc="🚀 Loading conditions"):
            image_id = condition['image_id']
            caption = condition['caption']
            entities = condition['entities']

            condition_imgs = [Image.open(f'{data_path}/{image_id}/render_depth_{i}.png').resize((condition_size, condition_size)).convert("RGB") for i in range(len(entities))]
            eligen_entity_prompts = [entities[i]['entity'] for i in range(len(entities))]

            eligen_entity_masks = []
            eligen_entity_masks_pil = []
            for i in range(len(condition_imgs)):
                mask = np.array(condition_imgs[i].resize((condition_size//8, condition_size//8)))
                mask = np.where(mask > 0, 1, 0).astype(np.uint8)
                mask_tensor = torch.from_numpy(mask).to(device=pipe.device, dtype=pipe.dtype)
                eligen_entity_masks.append(mask_tensor.unsqueeze(0))

                mask_pil = np.where(np.array(condition_imgs[i]) > 0, 1, 0).astype(np.uint8)
                eligen_entity_masks_pil.append(Image.fromarray(mask_pil*255))
            
            condition_imgs = torch.stack([T.ToTensor()(condition_img) for condition_img in condition_imgs])
            
            sorted_indices = sort_depth(condition_imgs)
            condition = {
                "condition": condition_imgs[sorted_indices],
                "eligen_entity_prompts": [eligen_entity_prompts[idx] for idx in sorted_indices],
                "eligen_entity_masks": [eligen_entity_masks[idx] for idx in sorted_indices],
                'eligen_entity_masks_pil': [eligen_entity_masks_pil[idx] for idx in sorted_indices],
            }
            test_list.append((condition, [0, 0], caption, image_id))

        for i, (condition, position_delta, prompt, image_id) in tqdm(enumerate(test_list), total=len(test_list), desc="🚀 Processing loose_woeligen batches", unit="batch"):
            file_path = f"{save_path}/loose_woeligen_{image_id}.png"
            if not os.path.exists(file_path):
                condition_ = Condition(
                    condition_type=condition_type,
                    condition=condition.resize((condition_size, condition_size)).convert("RGB") if isinstance(condition, Image.Image) else condition,
                    position_delta=position_delta,
                )
                loose_woeligen(pipe, prompt, condition_, eligen_entity_masks_pil=condition['eligen_entity_masks_pil'], file_path=file_path, target_size=target_size, seed=args.seed)
            else:
                print(f"{file_path} exists")

    if loose_flag:
        if not lora_flag:
            pipe.load_lora_weights(lora_path, weight_name="depth_lora", adapter_name="depth")
            print("Load lora weights successfully")
            load_inter_controller(pipe, module_path)
            print("Load inter_controller successfully")
            pipe = pipe.to(device=device, dtype=torch.bfloat16)

        test_list = []
        for condition in tqdm(conditions, desc="🚀 Loading conditions"):
            image_id = condition['image_id']
            caption = condition['caption']
            entities = condition['entities']

            condition_imgs = [Image.open(f'{data_path}/{image_id}/render_depth_{i}.png').resize((condition_size, condition_size)).convert("RGB") for i in range(len(entities))]
            eligen_entity_prompts = [entities[i]['entity'] for i in range(len(entities))]

            eligen_entity_masks = []
            eligen_entity_masks_pil = []
            for i in range(len(condition_imgs)):
                mask = np.array(condition_imgs[i].resize((condition_size//8, condition_size//8)))
                mask = np.where(mask > 0, 1, 0).astype(np.uint8)
                mask_tensor = torch.from_numpy(mask).to(device=pipe.device, dtype=pipe.dtype)
                eligen_entity_masks.append(mask_tensor.unsqueeze(0))

                mask_pil = np.where(np.array(condition_imgs[i]) > 0, 1, 0).astype(np.uint8)
                eligen_entity_masks_pil.append(Image.fromarray(mask_pil*255))
            
            condition_imgs = torch.stack([T.ToTensor()(condition_img) for condition_img in condition_imgs])
            
            sorted_indices = sort_depth(condition_imgs)
            condition = {
                "condition": condition_imgs[sorted_indices],
                "eligen_entity_prompts": [eligen_entity_prompts[idx] for idx in sorted_indices],
                "eligen_entity_masks": [eligen_entity_masks[idx] for idx in sorted_indices],
                'eligen_entity_masks_pil': [eligen_entity_masks_pil[idx] for idx in sorted_indices],
            }
            test_list.append((condition, [0, 0], caption, image_id))

        for i, (condition, position_delta, prompt, image_id) in tqdm(enumerate(test_list), total=len(test_list), desc="🚀 Processing loose batches", unit="batch"):
            file_path = f"{save_path}/loose_{image_id}.png"
            if not os.path.exists(file_path):
                condition_ = Condition(
                    condition_type=condition_type,
                    condition=condition.resize((condition_size, condition_size)).convert("RGB") if isinstance(condition, Image.Image) else condition,
                    position_delta=position_delta,
                )
                loose(pipe, prompt, condition_, eligen_entity_prompts=condition['eligen_entity_prompts'], eligen_entity_masks=condition['eligen_entity_masks'], eligen_entity_masks_pil=condition['eligen_entity_masks_pil'], file_path=file_path, target_size=target_size, seed=args.seed)
            else:
                print(f"{file_path} exists")
