import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
os.environ["HF_CACHE_DIR"] = "/mnt/workspace/workgroup/zheliu.lzy/.cache"
os.environ["TRANSFORMERS_CACHE"] = "/mnt/workspace/workgroup/zheliu.lzy/.cache"
os.environ["HF_HOME"] = "/mnt/workspace/workgroup/zheliu.lzy/.cache"
os.environ["PYTHONWARNINGS"] = "ignore::FutureWarning"

import json
import argparse
from tqdm import tqdm
import numpy as np
import torch
import torchvision.transforms as T
from PIL import Image
from diffusers.pipelines import FluxPipeline

import sys
sys.path.append('/mnt/workspace/workgroup/zheliu.lzy/vision_cot/OminiControl')

from src.flux.generate import generate, seed_everything
from src.flux.condition import Condition
from src.flux.module import Inter_Controller, Spatial_Controller
from src.flux.pipeline_tools import visualize_masks

def prepare_condition_data(conditions, data_path, condition_size, device, dtype, use_depth=True):
    """Process and prepare condition data for image generation"""
    test_list = []
    for idx, condition in tqdm(enumerate(conditions), desc="🚀 Loading conditions", total=len(conditions)):
        image_id = condition.get('image_id', str(idx).zfill(4))
        caption = condition['caption']
        entities = condition['entities']

        # Load and process condition images
        condition_paths = []
        for i in range(len(entities)):
            path = f'{data_path}/{image_id}/render_depth_{i}.png'
            condition_paths.append(path if os.path.exists(path) else f'{data_path}/{image_id}/render_depth_{i}.png')

        condition_imgs = [
            Image.open(condition_paths[i])
                 .resize((condition_size, condition_size))
                 .convert("RGB") 
            for i in range(len(entities))
        ]
        eligen_entity_prompts = [entities[i]['entity'] for i in range(len(entities))]

        # Process masks
        eligen_entity_masks = []
        eligen_entity_masks_pil = []
        for img in condition_imgs:
            # Create downsampled mask for model input
            mask = np.array(img.resize((condition_size//8, condition_size//8)))
            mask = np.where(mask > 0, 1, 0).astype(np.uint8)
            mask_tensor = torch.from_numpy(mask).to(device=device, dtype=dtype)
            eligen_entity_masks.append(mask_tensor.unsqueeze(0))
            
            # Create full resolution mask for visualization
            mask_pil = np.where(np.array(img) > 0, 1, 0).astype(np.uint8)
            eligen_entity_masks_pil.append(Image.fromarray(mask_pil*255))
        
        # Convert images to tensors and sort by depth
        condition_imgs = torch.stack([T.ToTensor()(img) for img in condition_imgs])
        
        # Create final condition object
        condition_data = {
            "condition": condition_imgs,
            "eligen_entity_prompts": eligen_entity_prompts,
            "eligen_entity_masks": eligen_entity_masks,
            'eligen_entity_masks_pil': eligen_entity_masks_pil,
        }
        test_list.append((condition_data, [0, 0], caption, image_id))
    
    return test_list

def unified_generate(
    pipe, 
    prompt, 
    condition=None, 
    eligen_entity_prompts=None,
    eligen_entity_masks=None,
    eligen_entity_masks_pil=None,
    file_path='result.png', 
    model_config={},
    target_size=512, 
    seed=42,
):
    seed_everything(seed)
    
    # 构建generate参数
    generate_kwargs = {
        "prompt": prompt,
        "default_lora": True,
    }
    
    # 处理不同模式参数
    if condition:
        generate_kwargs["conditions"] = [condition]
        generate_kwargs["height"] = target_size
        generate_kwargs["width"] = target_size
    
    if eligen_entity_prompts and eligen_entity_masks:
        generate_kwargs.update({
            "eligen_entity_prompts": eligen_entity_prompts,
            "eligen_entity_masks": eligen_entity_masks,
            "height": target_size,
            "width": target_size,
        })
    
    # 执行生成
    res = generate(
        pipe, 
        model_config=model_config,
        **generate_kwargs
    )
    
    # 保存结果图像
    save_image = res.images[0]
    save_image.resize((target_size, target_size)).save(file_path)
    
    # 处理mask可视化
    if eligen_entity_masks_pil:
        mask_path = f"{file_path[:-4]}_mask.png"
        visualize_masks(save_image, eligen_entity_masks_pil, eligen_entity_prompts, mask_path)

def main():
    # Parse command line arguments
    parser = argparse.ArgumentParser()
    parser.add_argument('--gpu', type=int, default=0, help='GPU index to use')
    parser.add_argument('--total_gpus', type=int, default=1, help='Total number of GPUs used')
    parser.add_argument('--batch_size', type=int, default=1, help='Number of objects to process per batch')
    parser.add_argument('--seed', type=int, default=42, help='Random seed')
    args = parser.parse_args()

    # Setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Paths and parameters
    flux_path = "black-forest-labs/FLUX.1-dev"
    condition_size = 512
    target_size = 512
    condition_type = "eligen_loose"
    data_path = 'exp/bas_data'
    condition_path = "exp/bas_data_floor.json"
    
    models = [
        {
            'ckpt': "30000",
            'ckpt_path': "runs_new/20250405-133306_eligen_loose_reward_weight_split_eligen_loose_None",
            'condition_type': "eligen_loose",
            'inter_controller_type': None,
            'eligen_depth_attn': False,
            'latent_lora': ['eligen']
        },
    ]
    
    # Load conditions
    with open(condition_path, 'r') as f:
        conditions = json.load(f)
        print(f"Load {len(conditions)} conditions successfully")
    
    for model in models:
        save_path = f"exp/bas_floor"
        os.makedirs(save_path, exist_ok=True)

        ckpt = model['ckpt']
        ckpt_path = model['ckpt_path']
        lora_path = f"{ckpt_path}/ckpt/{ckpt}/pytorch_lora_weights.safetensors"
        model_config = model

        lora_names = ['eligen']
        # eligen_path = '/mnt/workspace/workgroup/zheliu.lzy/vision_cot/DiffSynth-Studio/models/lora/entity_control/model_bf16.safetensors'
        eligen_path = 'checkpoints/eligen.bin'
        
        # Load model
        pipe = FluxPipeline.from_pretrained(flux_path, torch_dtype=torch.bfloat16)
        print("Load Flux model successfully")
        pipe = pipe.to(device=device, dtype=torch.bfloat16)

        pipe.transformer.load_lora_adapter(lora_path, adapter_name='default')

        state_dict = torch.load(eligen_path)
        pipe.transformer.load_lora_adapter(state_dict, prefix="transformer", adapter_name="eligen",)
        pipe.transformer.set_adapters(lora_names + ['default'])
        print("Load Flux lora successfully")
        active_adapters = pipe.get_active_adapters()
        print(active_adapters)
        pipe = pipe.to(device=device, dtype=torch.bfloat16)
        
        test_list = prepare_condition_data(conditions, data_path, condition_size, pipe.device, pipe.dtype)
        
        for i, (condition, position_delta, prompt, image_id) in tqdm(enumerate(test_list), total=len(test_list), desc="🚀 Processing loose batches", unit="batch"):
            os.makedirs(f"{save_path}/{image_id}", exist_ok=True)
            seeds = conditions[i]['seeds']
            for seed in seeds[:1]:
                file_path = f"{save_path}/{image_id}/{image_id}_{seed}.png"
                # if True:
                if not os.path.exists(file_path):
                    condition_ = Condition(
                        condition_type=condition_type,
                        condition=condition.resize((condition_size, condition_size)).convert("RGB") 
                            if isinstance(condition, Image.Image) else condition,
                        position_delta=position_delta,
                    )
                    unified_generate(
                        pipe, prompt, condition_, 
                        eligen_entity_prompts=condition['eligen_entity_prompts'], 
                        eligen_entity_masks=condition['eligen_entity_masks'], 
                        eligen_entity_masks_pil=condition['eligen_entity_masks_pil'], 
                        file_path=file_path, target_size=target_size, seed=seed,
                        model_config=model_config,
                    )
                else:
                    print(f"{file_path} exists")

if __name__ == "__main__":
    main()

