import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
os.environ["HF_CACHE_DIR"] = "/mnt/workspace/workgroup/zheliu.lzy/.cache"
os.environ["TRANSFORMERS_CACHE"] = "/mnt/workspace/workgroup/zheliu.lzy/.cache"
os.environ["HF_HOME"] = "/mnt/workspace/workgroup/zheliu.lzy/.cache"
os.environ["PYTHONWARNINGS"] = "ignore::FutureWarning"

import json
import argparse
from tqdm import tqdm
import numpy as np
import torch
import torchvision.transforms as T
from PIL import Image
from diffusers.pipelines import FluxPipeline
from diffusers.schedulers import KarrasDiffusionSchedulers,DPMSolverMultistepScheduler

import sys
sys.path.append('/mnt/workspace/workgroup/zheliu.lzy/vision_cot/OminiControl')
sys.path.append('/mnt/workspace/workgroup/zheliu.lzy/vision_cot/RPG-DiffusionMaster')

from RegionalDiffusion_base import RegionalDiffusionPipeline
from RegionalDiffusion_xl import RegionalDiffusionXLPipeline

from src.flux.generate import generate, seed_everything
from src.flux.condition import Condition
from src.flux.module import Inter_Controller, Spatial_Controller
from src.flux.pipeline_tools import visualize_masks
from src.utils.dataset import find_applicable_scenes 
from src.utils.dataset_eligen import json_generation  

def prepare_condition_data(json_path, data_path, condition_size, device, dtype, use_depth=True):
    """Process and prepare condition data for image generation"""
    test_list = []
    json_list = os.listdir(json_path)
    for file in tqdm(json_list, desc="🚀 Loading conditions", total=len(json_list)):
        if 'right' in file or 'behind' in file:
            continue
        with open(f'{json_path}/{file}', 'r') as f:
            condition = json.load(f)
        caption = condition['caption']
        entities = condition['entities']

        # Create final condition object
        condition_data = None
        for scene in find_applicable_scenes(caption):
            test_list.append((condition_data, [0, 0], caption + f" {scene}"))
    
    return test_list

def unified_generate(
    pipe, 
    prompt, 
    condition=None, 
    eligen_entity_prompts=None,
    eligen_entity_masks=None,
    eligen_entity_masks_pil=None,
    file_path='result.png', 
    model_config={},
    condition_size=512,
    target_size=512, 
    num_inference_steps=50,
    seed=42,
):
    seed_everything(seed)
    
    # 构建generate参数
    generate_kwargs = {
        "prompt": prompt,
        "default_lora": True,
        "num_inference_steps": num_inference_steps,
    }
    
    # 处理不同模式参数
    if condition:
        generate_kwargs["conditions"] = [condition]
        generate_kwargs["height"] = condition_size
        generate_kwargs["width"] = condition_size
    
    if eligen_entity_prompts and eligen_entity_masks:
        generate_kwargs.update({
            "eligen_entity_prompts": eligen_entity_prompts,
            "eligen_entity_masks": eligen_entity_masks,
            "height": condition_size,
            "width": condition_size,
        })
    
    # 执行生成
    res = generate(
        pipe, 
        model_config=model_config,
        **generate_kwargs
    )
    
    # 保存结果图像
    save_image = res.images[0]
    save_image.resize((target_size, target_size)).save(file_path)
    
    # 处理mask可视化
    if eligen_entity_masks_pil:
        mask_path = f"{file_path[:-4].replace('samples', 'visual')}_mask.png"
        visualize_masks(save_image, eligen_entity_masks_pil, eligen_entity_prompts, mask_path)

def main():
    # Parse command line arguments
    parser = argparse.ArgumentParser()
    parser.add_argument('--gpu', type=int, default=0, help='GPU index to use')
    parser.add_argument('--total_gpus', type=int, default=1, help='Total number of GPUs used')
    parser.add_argument('--batch_size', type=int, default=1, help='Number of objects to process per batch')
    parser.add_argument('--seed', type=int, default=42, help='Random seed')
    args = parser.parse_args()

    # Setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Paths and parameters
    condition_size = 1024
    target_size = 512
    condition_type = "eligen_loose"
    json_path = "data/json_rpg"

    json_list = os.listdir(json_path)#[:100]
    json_list = json_list[args.gpu::args.total_gpus]
    print(f'test_list length: {len(json_list)}')
    
    pipe = RegionalDiffusionXLPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0",torch_dtype=torch.float16, use_safetensors=True)
    pipe.to("cuda")
    pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config,use_karras_sigmas=True)
    pipe.enable_xformers_memory_efficient_attention()
    
    save_path = f"exp/rpg"
    os.makedirs(save_path, exist_ok=True)
    os.makedirs(f"{save_path}/samples", exist_ok=True)

    seeds = [args.seed]
    for seed in seeds[:1]:
        for i, json_file in tqdm(enumerate(json_list), total=len(json_list), desc="🚀 Processing batches", unit="batch"):
            with open(f'{json_path}/{json_file}', 'r') as f:
                data = json.load(f)
            for scene in find_applicable_scenes(data['caption']):
                prompt = data['caption'] + f' {scene}'
                file_path = f"{save_path}/samples/{prompt}_{seed}.png"
                # if True:
                if not os.path.exists(file_path):
                    try:
                        split_ratio = data['ans_json']['Final_split_ratio']
                        if isinstance(split_ratio, list):
                            new_split_ratio = ''
                            for ratio in split_ratio[:-1]:
                                new_split_ratio = new_split_ratio + f'{ratio},'
                            new_split_ratio = new_split_ratio + f'{split_ratio[-1]}'
                            split_ratio = new_split_ratio
                        regional_prompt = data['ans_json']['Regional_Prompt']
                        negative_prompt = ""
                        images = pipe(
                            prompt = regional_prompt,
                            split_ratio = split_ratio, # The ratio of the regional prompt, the number of prompts is the same as the number of regions, and the number of prompts is the same as the number of regions
                            batch_size = 1, #batch size
                            base_ratio = 0.3, # The ratio of the base prompt    
                            base_prompt= prompt,       
                            num_inference_steps=20, # sampling step
                            negative_prompt=negative_prompt, # negative prompt
                            height = condition_size, 
                            width = condition_size, 
                            seed = seed,# random seed
                            guidance_scale = 7.0
                        ).images[0]
                        images.resize((target_size,target_size)).save(file_path)
                    except Exception as e:
                        print(e)
                else:
                    print(f"{file_path} exists")

if __name__ == "__main__":
    main()
    pipe = RegionalDiffusionXLPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0",torch_dtype=torch.float16, use_safetensors=True)
    pipe.to("cuda")
    pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config,use_karras_sigmas=True)
    pipe.enable_xformers_memory_efficient_attention()
    ## User input
    prompt = "A beautiful landscape with a river in the middle the left of the river is in the evening and in the winter with a big iceberg and a small village while some people are skating on the river and some people are skiing, the right of the river is in the summer with a volcano in the morning and a small village while some people are playing."
    # prompt= 'A green twintail girl in orange dress is sitting on the sofa while a messy desk under a big window on the left, a lively aquarium is on the top right of the sofa, realistic style'
    para_dict = GPT4(prompt,key='sk-pfMo7kK0sRRr38wz2f1851C04d1b42F6Be2080D38b20EdCa')
    split_ratio = para_dict['Final split ratio']
    regional_prompt = para_dict['Regional Prompt']
    negative_prompt = ""
    images = pipe(
        prompt = regional_prompt,
        split_ratio = split_ratio, # The ratio of the regional prompt, the number of prompts is the same as the number of regions, and the number of prompts is the same as the number of regions
        batch_size = 1, #batch size
        base_ratio = 0.3, # The ratio of the base prompt    
        base_prompt= prompt,       
        num_inference_steps=20, # sampling step
        height = 1024, 
        negative_prompt=negative_prompt, # negative prompt
        width = 2048, 
        seed = 2468,# random seed
        guidance_scale = 7.0
    ).images[0]
    images.save("test.png")

