import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
os.environ["HF_CACHE_DIR"] = "/mnt/workspace/workgroup/zheliu.lzy/.cache"
os.environ["TRANSFORMERS_CACHE"] = "/mnt/workspace/workgroup/zheliu.lzy/.cache"
os.environ["HF_HOME"] = "/mnt/workspace/workgroup/zheliu.lzy/.cache"
os.environ["PYTHONWARNINGS"] = "ignore::FutureWarning"

import huggingface_hub
huggingface_hub.login("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")

from diffusers.pipelines import FluxPipeline
import torch
from PIL import Image

# def get_state_dict(pipeline):
#     state = pipeline.transformer.state_dict()
#     for name in list(state.keys()):
#         if "lora" not in name:  # <-- adapt the condition to your use case
#             state.pop(name)

#     return state

def get_state_dict(pipeline):
    state = pipeline.transformer.state_dict()
    new_state = {}

    for key in list(state.keys()):
        if "lora" in key:  # 仅保留包含 lora 的键
            new_key = f"transformer.{key}"  # 添加前缀
            new_key = new_key.replace('default_0.', '')
            new_state[new_key] = state[key]
    
    return new_state

pretrained_model_name_or_path = "black-forest-labs/FLUX.1-dev"
# pretrained_model_name_or_path = "/mnt/sh_nas/zheliu.lzy/.cache/entity_pretrain"
pipeline = FluxPipeline.from_pretrained(
    pretrained_model_name_or_path,
    torch_dtype=torch.bfloat16,
)

# import ipdb;ipdb.set_trace()
pipeline.load_lora_weights("/mnt/workspace/workgroup/zheliu.lzy/vision_cot/DiffSynth-Studio/models/lora/entity_control/model_bf16.safetensors")

# print(pipeline.transformer.x_embedder)
state_dict = get_state_dict(pipeline)
torch.save(state_dict, "checkpoints/eligen.bin")

# pipeline.fuse_lora()
# pipeline.unload_lora_weights()
# pipeline.save_pretrained("entity_pretrain")

# condition_type = "depth"
# pipeline.load_lora_weights(
#     "Yuanshi/OminiControl",
#     weight_name=f"experimental/{condition_type}.safetensors",
#     adapter_name=condition_type,
# )
# pipeline.fuse_lora()
# pipeline.unload_lora_weights()
# pipeline.save_pretrained("depth_pretrain")