import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
os.environ["HF_CACHE_DIR"] = "/mnt/workspace/workgroup/zheliu.lzy/.cache"
os.environ["TRANSFORMERS_CACHE"] = "/mnt/workspace/workgroup/zheliu.lzy/.cache"
os.environ["HF_HOME"] = "/mnt/workspace/workgroup/zheliu.lzy/.cache"
os.environ["PYTHONWARNINGS"] = "ignore::FutureWarning"

import json
import argparse
from tqdm import tqdm
import numpy as np
import torch
import torchvision.transforms as T
from PIL import Image
from diffusers.pipelines import FluxPipeline

import sys
sys.path.append('/mnt/workspace/workgroup/zheliu.lzy/vision_cot/OminiControl')

from src.flux.generate import generate, seed_everything
from src.flux.condition import Condition
from src.flux.block import Inter_Controller
from src.flux.pipeline_tools import visualize_masks

def sort_depth(condition_imgs):
    """Sort images by depth (non-zero pixel mean values in descending order)"""
    mask = condition_imgs != 0
    sum_per_img = (condition_imgs * mask).sum(dim=(1, 2, 3))
    count_per_img = mask.sum(dim=(1, 2, 3)).float()
    means = sum_per_img / count_per_img
    means = torch.nan_to_num(means, nan=0.0)
    return torch.argsort(means, descending=True).tolist()

def load_inter_controller(pipe, module_path):
    """Load and initialize the inter controller module"""
    state_dict = torch.load(module_path)
    new_dict = {k.replace('cam_embedder.', ''): v for k, v in state_dict.items()}
    
    pipe.transformer.inter_controller = Inter_Controller(
        dim=24*128,
        num_attention_heads=24,
        attention_head_dim=128,
    )
    device = pipe.transformer.x_embedder.weight.device
    dtype = pipe.transformer.x_embedder.weight.dtype
    for param in pipe.transformer.inter_controller.parameters():
        param.data = param.to(device=device, dtype=dtype)
    pipe.transformer.inter_controller.load_state_dict(new_dict)

def prepare_condition_data(conditions, data_path, condition_size, device, dtype, use_depth=True):
    """Process and prepare condition data for image generation"""
    test_list = []
    for idx, condition in tqdm(enumerate(conditions), desc="🚀 Loading conditions", total=len(conditions)):
        # image_id = condition['image_id']
        image_id = str(idx)
        caption = condition['caption']
        entities = condition['entities']

        # Load and process condition images
        condition_paths = []
        for i in range(len(entities)):
            path = f'{data_path}/{image_id}/render_depth_{i}.png' if use_depth else f'{data_path}/{image_id}/render_bbox_{i}.png'
            condition_paths.append(path if os.path.exists(path) else f'{data_path}/{image_id}/render_depth_{i}.png')

        condition_imgs = [
            Image.open(condition_paths[i])
                 .resize((condition_size, condition_size))
                 .convert("RGB") 
            for i in range(len(entities))
        ]
        eligen_entity_prompts = [entities[i]['entity'] for i in range(len(entities))]

        # Process masks
        eligen_entity_masks = []
        eligen_entity_masks_pil = []
        for img in condition_imgs:
            # Create downsampled mask for model input
            mask = np.array(img.resize((condition_size//8, condition_size//8)))
            mask = np.where(mask > 0, 1, 0).astype(np.uint8)
            mask_tensor = torch.from_numpy(mask).to(device=device, dtype=dtype)
            eligen_entity_masks.append(mask_tensor.unsqueeze(0))
            
            # Create full resolution mask for visualization
            mask_pil = np.where(np.array(img) > 0, 1, 0).astype(np.uint8)
            eligen_entity_masks_pil.append(Image.fromarray(mask_pil*255))
        
        # Convert images to tensors and sort by depth
        condition_imgs = torch.stack([T.ToTensor()(img) for img in condition_imgs])
        sorted_indices = sort_depth(condition_imgs)
        sorted_indices = sorted_indices[::-1]
        
        # Create final condition object
        condition_data = {
            "condition": condition_imgs[sorted_indices],
            "eligen_entity_prompts": [eligen_entity_prompts[idx] for idx in sorted_indices],
            "eligen_entity_masks": [eligen_entity_masks[idx] for idx in sorted_indices],
            'eligen_entity_masks_pil': [eligen_entity_masks_pil[idx] for idx in sorted_indices],
        }
        test_list.append((condition_data, [0, 0], caption, image_id))
    
    return test_list

def unified_generate(
    pipe, 
    prompt, 
    condition=None, 
    eligen_entity_prompts=None,
    eligen_entity_masks=None,
    eligen_entity_masks_pil=None,
    file_path='result.png', 
    target_size=512, 
    seed=42,
):
    seed_everything(seed)
    
    # 构建generate参数
    generate_kwargs = {
        "prompt": prompt,
        "default_lora": True,
    }
    
    # 处理不同模式参数
    if condition:
        generate_kwargs["conditions"] = [condition]
        generate_kwargs["height"] = target_size
        generate_kwargs["width"] = target_size
    
    if eligen_entity_prompts and eligen_entity_masks:
        generate_kwargs.update({
            "eligen_entity_prompts": eligen_entity_prompts,
            "eligen_entity_masks": eligen_entity_masks,
            "height": target_size,
            "width": target_size,
        })
    
    # 执行生成
    res = generate(
        pipe, 
        **generate_kwargs
    )
    
    # 保存结果图像
    save_image = res.images[0]
    save_image.resize((target_size, target_size)).save(file_path)
    
    # 处理mask可视化
    if eligen_entity_masks_pil:
        mask_path = f"{file_path[:-4]}_mask.png"
        visualize_masks(save_image, eligen_entity_masks_pil, eligen_entity_prompts, mask_path)

def main():
    # Parse command line arguments
    parser = argparse.ArgumentParser()
    parser.add_argument('--gpu', type=int, default=0, help='GPU index to use')
    parser.add_argument('--total_gpus', type=int, default=1, help='Total number of GPUs used')
    parser.add_argument('--batch_size', type=int, default=1, help='Number of objects to process per batch')
    parser.add_argument('--seed', type=int, default=42, help='Random seed')
    args = parser.parse_args()

    # Setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Configuration
    model = 'eligen'  # ['flux', 'eligen']
    flux_flag = eligen_flag = (model == 'flux')
    loose_woeligen_flag = False and (model == 'eligen')
    loose_flag = (model == 'eligen')
    lora_flag = False
    
    # Paths and parameters
    flux_path = "black-forest-labs/FLUX.1-dev" if flux_flag else "entity_pretrain"
    condition_type = "loose_condition"
    condition_path = 'test/condition.json'
    save_path = "outputs/human_data"
    os.makedirs(save_path, exist_ok=True)
    condition_size = 512
    target_size = 512
    ckpt = "9000"
    ckpt_path = "runs/20250303-213357"
    ckpt_path = "runs/20250310-161942_entity_pretrain_loose_condition_None"
    lora_path = f"{ckpt_path}/ckpt/{ckpt}/pytorch_lora_weights.safetensors"
    module_path = f"{ckpt_path}/ckpt/{ckpt}/inter_controller.pth"
    data_path = "test/human_data"
    
    # Load conditions
    with open(condition_path, 'r') as f:
        conditions = json.load(f)
        print(f"Load {len(conditions)} conditions successfully")
    
    # Load model
    pipe = FluxPipeline.from_pretrained(flux_path, torch_dtype=torch.bfloat16)
    print("Load Flux model successfully")
    pipe = pipe.to(device=device, dtype=torch.bfloat16)

    # FLUX generation
    if flux_flag:
        test_list = prepare_condition_data(conditions, data_path, condition_size, pipe.device, pipe.dtype)
        
        for i, (condition, position_delta, prompt, image_id) in tqdm(enumerate(test_list), total=len(test_list), desc="🚀 Processing flux batches", unit="batch"):
            file_path = f"{save_path}/flux_{image_id}.png"
            if not os.path.exists(file_path):
                unified_generate(pipe, prompt, file_path=file_path, target_size=target_size, seed=args.seed)
            else:
                print(f"{file_path} exists")
    
    # ELIGEN generation
    if eligen_flag:
        eligen_path = '/mnt/workspace/workgroup/zheliu.lzy/vision_cot/DiffSynth-Studio/models/lora/entity_control/model_bf16.safetensors'
        pipe.load_lora_weights(eligen_path, weight_name="eligen_lora", adapter_name="eligen")
        pipe = pipe.to(device=device, dtype=torch.bfloat16)
        
        eligen_size = 1024
        test_list = prepare_condition_data(conditions, data_path, eligen_size, pipe.device, pipe.dtype)
        
        for i, (condition, position_delta, prompt, image_id) in tqdm(enumerate(test_list), total=len(test_list), desc="🚀 Processing eligen batches", unit="batch"):
            file_path = f"{save_path}/eligen_{image_id}.png"
            if not os.path.exists(file_path):
                unified_generate(
                    pipe, prompt, 
                    eligen_entity_prompts=condition['eligen_entity_prompts'], 
                    eligen_entity_masks=condition['eligen_entity_masks'], 
                    eligen_entity_masks_pil=condition['eligen_entity_masks_pil'], 
                    file_path=file_path, target_size=eligen_size, seed=args.seed
                )
            else:
                print(f"{file_path} exists")

    # LOOSE_WOELIGEN generation
    if loose_woeligen_flag:
        pipe.load_lora_weights(lora_path)
        print("Load lora weights successfully")
        # load_inter_controller(pipe, module_path)
        # print("Load inter_controller successfully")
        pipe = pipe.to(device=device, dtype=torch.bfloat16)
        lora_flag = True
        
        test_list = prepare_condition_data(conditions, data_path, condition_size, pipe.device, pipe.dtype)
        
        for i, (condition, position_delta, prompt, image_id) in tqdm(enumerate(test_list), total=len(test_list), desc="🚀 Processing loose_woeligen batches", unit="batch"):
            file_path = f"{save_path}/loose_woeligen_{image_id}.png"
            if not os.path.exists(file_path):
                condition_ = Condition(
                    condition_type=condition_type,
                    condition=condition.resize((condition_size, condition_size)).convert("RGB") 
                        if isinstance(condition, Image.Image) else condition,
                    position_delta=position_delta,
                )
                unified_generate(
                    pipe, prompt, condition_, 
                    eligen_entity_masks_pil=condition['eligen_entity_masks_pil'], 
                    file_path=file_path, target_size=target_size, seed=args.seed
                )
            else:
                print(f"{file_path} exists")

    # LOOSE generation
    if loose_flag:
        if not lora_flag:
            pipe.load_lora_weights(lora_path, weight_name="depth_lora", adapter_name="depth")
            print("Load lora weights successfully")
            # load_inter_controller(pipe, module_path)
            # print("Load inter_controller successfully")
            pipe = pipe.to(device=device, dtype=torch.bfloat16)
        
        test_list = prepare_condition_data(conditions, data_path, condition_size, pipe.device, pipe.dtype)
        
        for i, (condition, position_delta, prompt, image_id) in tqdm(enumerate(test_list), total=len(test_list), desc="🚀 Processing loose batches", unit="batch"):
            file_path = f"{save_path}/loose_{image_id}.png"
            if not os.path.exists(file_path):
                condition_ = Condition(
                    condition_type=condition_type,
                    condition=condition.resize((condition_size, condition_size)).convert("RGB") 
                        if isinstance(condition, Image.Image) else condition,
                    position_delta=position_delta,
                )
                unified_generate(
                    pipe, prompt, condition_, 
                    eligen_entity_prompts=condition['eligen_entity_prompts'], 
                    eligen_entity_masks=condition['eligen_entity_masks'], 
                    eligen_entity_masks_pil=condition['eligen_entity_masks_pil'], 
                    file_path=file_path, target_size=target_size, seed=args.seed
                )
            else:
                print(f"{file_path} exists")

if __name__ == "__main__":
    main()

