import os
# os.environ["CUDA_VISIBLE_DEVICES"]="0"
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
os.environ["HF_CACHE_DIR"] = "/mnt/workspace/workgroup/zheliu.lzy/.cache"
os.environ["TRANSFORMERS_CACHE"] = "/mnt/workspace/workgroup/zheliu.lzy/.cache"
os.environ["HF_HOME"] = "/mnt/workspace/workgroup/zheliu.lzy/.cache"
os.environ["PYTHONWARNINGS"] = "ignore::FutureWarning"
import sys 
sys.path.append("/mnt/workspace/workgroup/zheliu.lzy/vision_cot/OminiControl/LooseControl")

import argparse
import json
from PIL import Image
from tqdm import tqdm
import random
import numpy as np
import torch
from loosecontrol import LooseControlNet

# Set seed for reproducibility
def seed_everything(seed: int):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)

def main():
    # Parse command line arguments
    parser = argparse.ArgumentParser()
    parser.add_argument('--gpu', type=int, default=0, help='GPU index to use')
    parser.add_argument('--total_gpus', type=int, default=1, help='Total number of GPUs used')
    parser.add_argument('--batch_size', type=int, default=1, help='Number of objects to process per batch')
    parser.add_argument('--seed', type=int, default=42, help='Random seed')
    parser.add_argument('--model', type=str, default='eligen_3d')
    parser.add_argument('--data_path', type=str, default='t2i_compbench/dataset/non_spatial/render')
    parser.add_argument('--json_path', type=str, default="t2i_compbench/dataset/non_spatial/json")
    parser.add_argument('--save_path', type=str, default="LooseControl/t2i_compbench/non_spatial")
    parser.add_argument('--ckpt', type=str, default="30000")
    args = parser.parse_args()

    data_path = args.data_path
    json_path = args.json_path
    save_path = args.save_path.replace(',','')
    os.makedirs(save_path, exist_ok=True)

    device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
    lcn = LooseControlNet("shariqfarooq/loose-control-3dbox")
    lcn = lcn.to(device, torch.bfloat16)

    prompt_list = os.listdir(data_path)#[:100]
    prompt_list = prompt_list[args.gpu::args.total_gpus]
    prompt_list = prompt_list[::-1]
    print(f'test_list length: {len(prompt_list)}')

    seeds = [42]
    pos_prompts="4k, high-res, realistic, "
    negative_prompt = "blurry, text, caption, lowquality,lowresolution, low res, grainy, ugly"

    os.makedirs(f"{save_path}/samples", exist_ok=True)
    for i, prompt in tqdm(enumerate(prompt_list), total=len(prompt_list), desc="🚀 Processing batches", unit="batch"):
        for seed in seeds:
            seed_everything(seed)
            file_path = f"{save_path}/samples/{prompt}_{seed}.png"
            if os.path.exists(file_path):
                continue
            try:
                boxy_depth = Image.open(f"{data_path}/{prompt}/bas_depth_2.png")
            
                gen_image = lcn(pos_prompts+prompt, negative_prompt=negative_prompt, control_image=boxy_depth)
                gen_image.save(file_path)
            except:
                pass

if __name__ == "__main__":
    main()