# import lightning as L
import pytorch_lightning as L
from diffusers.pipelines import FluxPipeline
from diffusers.optimization import get_scheduler
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
from torchvision.transforms import ToPILImage, ToTensor, Normalize, Resize, Compose
from peft import LoraConfig, get_peft_model_state_dict
from einops import rearrange
from transformers import pipeline, AutoImageProcessor, AutoModelForDepthEstimation
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
import prodigyopt
from contextlib import contextmanager

from src.flux.transformer import tranformer_forward
from src.flux.condition import Condition
from src.flux.pipeline_tools import encode_images, decode_images, prepare_text_input, prepare_text_input_eligen, encode_poses
from src.flux.module import CameraAdapter, Inter_Controller, Spatial_Controller

check_loss = False
# check_loss = True

@contextmanager
def preserve_requires_grad(model):
    # 备份 requires_grad 状态
    requires_grad_backup = {name: param.requires_grad for name, param in model.named_parameters()}
    yield
    # 恢复 requires_grad 状态
    for name, param in model.named_parameters():
        param.requires_grad = requires_grad_backup[name]

class OminiModel(L.LightningModule):
    def __init__(
        self,
        flux_pipe_id: str,
        lora_path: str = None,
        lora_config: dict = None,
        condition_type: str = None,
        device: str = "cuda",
        dtype: torch.dtype = torch.bfloat16,
        model_config: dict = {},
        optimizer_config: dict = None,
        gradient_checkpointing: bool = False,
    ):
        # Initialize the LightningModule
        super().__init__()
        self.model_config = model_config
        self.optimizer_config = optimizer_config

        # Load the Flux pipeline
        print('Load the Flux pipeline')
        self.flux_pipe: FluxPipeline = (
            FluxPipeline.from_pretrained(
                flux_pipe_id, 
                # device_map="auto"
            ).to(dtype=dtype).to(device)
        )
        print('Load the Flux pipeline successful')
        self.transformer = self.flux_pipe.transformer
        self.transformer.gradient_checkpointing = gradient_checkpointing
        self.transformer.train()

        # Freeze the Flux pipeline
        self.flux_pipe.text_encoder.requires_grad_(False).eval()
        self.flux_pipe.text_encoder_2.requires_grad_(False).eval()
        self.flux_pipe.vae.requires_grad_(False).eval()

        self.condition_type = condition_type
        print(self.condition_type)

        # Initialize LoRA layers
        latent_lora = self.model_config.get('latent_lora', [])
        lora_names = self.model_config.get('lora_names', latent_lora)
        self.load_lora(lora_names)

        self.lora_path = lora_path
        self.lora_layers = self.init_lora(lora_path, lora_config)

        # self.transformer.set_adapters(lora_names + ['default'])
        # import ipdb; ipdb.set_trace() # {'default'}
        with preserve_requires_grad(self.transformer):
            self.transformer.set_adapters(lora_names + ['default'])
        # import ipdb; ipdb.set_trace() # {'default'}
        
        # self.transformer.single_transformer_blocks
        # [name for name, param in self.transformer.named_parameters() if param.requires_grad]

        self.depth_model_type = self.model_config.get('depth_model_type', 'depth_anything') # 'depth_pro'
        
        if 'camera' in self.condition_type or 'pose' in self.condition_type:
            self.init_camera()
            if self.model_config.get('orient_loss', None) is not None:
                self.init_orient()
        if 'loose' in self.condition_type:
            self.init_loose_condition()
            if 'flux' in self.condition_type:
                self.init_flux()
        if 'depth' in self.condition_type or self.condition_type == 'loose_condition' or 'reward' in self.condition_type:
            self.init_depth(dtype=dtype)

        self.to(device).to(dtype)
        self.total_steps = 0

        # 检查模型是否正确训练，梯度是否有效回传
        if check_loss:
            self.target_param_name = "transformer_blocks.7.attn.to_out.0.lora_B.default.weight"  # 示例参数名
            # self.target_param_name = "cam_embedder.camera_mlp.0.weight"  # 示例参数名
            self.initial_value = self.transformer.get_parameter(self.target_param_name).clone().detach()

    def init_depth(self, dtype=torch.bfloat16):
        if self.depth_model_type == 'depth_pro':
            import depth_pro.depth_pro as depth_pro
            self.depth_model, self.depth_transform = depth_pro.create_model_and_transforms(device='cpu', precision=dtype)
            self.image_transform = Compose([Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
        else:
            from depth_anything_v2.dpt import DepthAnythingV2

            model_configs = {
                'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
                'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
                'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
                'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
            }

            encoder = 'vits' # or 'vits', 'vitb', 'vitl', 'vitg'

            self.depth_model = DepthAnythingV2(**model_configs[encoder])
            self.depth_model.load_state_dict(torch.load(f'/mnt/workspace/workgroup/zheliu.lzy/vision_cot/3d_box/depth_anything_v2/model/depth_anything_v2_{encoder}.pth', map_location='cpu'))
            self.depth_model = self.depth_model.eval()
            self.image_transform = Compose([
                Resize((518, 518), interpolation=transforms.InterpolationMode.BILINEAR, antialias=True),
                Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])

            # model = model.to(DEVICE).eval()

            # self.depth_model_ = pipeline(task="depth-estimation", model="depth-anything/Depth-Anything-V2-Small-hf")
            # self.image_processor = AutoImageProcessor.from_pretrained("depth-anything/Depth-Anything-V2-Small-hf")
            # self.image_processor.do_rescale = False
            # self.depth_model = AutoModelForDepthEstimation.from_pretrained("depth-anything/Depth-Anything-V2-Small-hf")
        self.depth_model = self.depth_model.requires_grad_(False).eval()

    def init_flux(self):
        self.to(torch.bfloat16)
        assert 'flux' in self.condition_type

        dim = self.transformer.inner_dim
 
        self.transformer.loose_embedder = nn.Linear(dim*2, dim)
        device = self.transformer.x_embedder.weight.device
        dtype = self.transformer.x_embedder.weight.dtype
        for param in self.transformer.loose_embedder.parameters():
            param.data = param.to(device=device, dtype=dtype)

        print(hasattr(self.transformer, 'loose_embedder'))
        print('init flux module successfully')

    def save_flux(self, path: str):
        torch.save(self.transformer.loose_embedder.state_dict(), path)
        print('save flux module successfully')
    
    def init_loose_condition(self):
        self.to(torch.bfloat16)
        assert 'loose' in self.condition_type
        inter_controller_type = self.model_config.get('inter_controller_type', None)
        inter_controller_path = self.model_config.get('inter_controller_path', None)

        if inter_controller_type is not None:
            if inter_controller_type == 'spatial+layer':
                self.transformer.inter_controller = Inter_Controller(
                    dim=24*128,
                    num_attention_heads=24,
                    attention_head_dim=128,
                )
            elif inter_controller_type == 'spatial':
                self.transformer.inter_controller = Spatial_Controller(
                    dim=24*128,
                    num_attention_heads=24,
                    attention_head_dim=128,
                )
            print('init loose_condition module successfully')

            if inter_controller_path:
                state_dict = torch.load(inter_controller_path)
                self.transformer.inter_controller.load_state_dict(state_dict)
                print('load loose_condition module successfully')

            if hasattr(self.transformer, 'inter_controller'):
                device = self.transformer.x_embedder.weight.device
                dtype = self.transformer.x_embedder.weight.dtype
                for param in self.transformer.inter_controller.parameters():
                    param.data = param.to(device=device, dtype=dtype)
        else:
            print('train w/o loose_condition module')

    def save_loose_condition(self, path: str):
        torch.save(self.transformer.inter_controller.state_dict(), path)
        print('save loose_condition module successfully')

    def init_orient(self):
        import sys
        sys.path.append('/mnt/workspace/workgroup/zheliu.lzy/vision_cot/OminiControl/Orient_Anything')
        from Orient_Anything.vision_tower import DINOv2_MLP

        ckpt_path = "/mnt/workspace/workgroup/zheliu.lzy/.cache/hub/models--Viglong--Orient-Anything/snapshots/ecfa089828d9fab521d47133bedbdbfdef151cd4/croplargeEX2/dino_weight.pt"
        self.orient_model = DINOv2_MLP(
            dino_mode   = 'large',
            in_dim      = 1024,
            out_dim     = 360+180+180+2,
            evaluate    = True,
            mask_dino   = False,
            frozen_back = False
        )
        self.orient_model.load_state_dict(torch.load(ckpt_path, map_location='cpu'))
        self.orient_model = self.orient_model.requires_grad_(False).eval()

        self.orient_transform = Compose([
            Resize((224, 224), interpolation=transforms.InterpolationMode.BILINEAR, antialias=True),
            # transforms.Resize(256), 
            # transforms.CenterCrop(224),
            Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def init_camera(self):
        self.to(torch.bfloat16)
        dim = self.transformer.inner_dim
        if 'camera' in self.condition_type:
            self.transformer.cam_embedder = CameraAdapter(dim, 9)
        elif 'pose' in self.condition_type:
            self.transformer.cam_embedder = CameraAdapter(4096, 6)
            # self.transformer.cam_embedder = CameraAdapter(4096, 3)
        device = self.transformer.x_embedder.weight.device
        dtype = self.transformer.x_embedder.weight.dtype
        for param in self.transformer.cam_embedder.parameters():
            param.data = param.to(device=device, dtype=dtype)
        
    def save_camera(self, path: str):
        torch.save(self.transformer.cam_embedder.state_dict(), path)
        print('save cam_embedder module successfully')
    
    def init_grounding_dino(self):
        model_id = "IDEA-Research/grounding-dino-base"
        self.dino_processor = AutoProcessor.from_pretrained(model_id)
        self.dino_model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(self.device, self.dtype)

    def init_lora(self, lora_path: str, lora_config: dict):
        # assert lora_path or lora_config
        if lora_path:
            # from safetensors import safe_open
            # state_dict = {}
            # with safe_open(lora_path, framework="pt") as f:
            #     for k in f.keys():
            #         state_dict[k] = f.get_tensor(k)
            # self.transformer.load_lora_adapter(state_dict, prefix="transformer")
            self.transformer.load_lora_adapter(lora_path, adapter_name='default')
            lora_layers = filter(
                lambda p: p.requires_grad, self.transformer.parameters()
            )
            print(f'init lora successfully from {lora_path}')
        else:
            self.transformer.add_adapter(LoraConfig(**lora_config))
            # TODO: Check if this is correct (p.requires_grad)
            lora_layers = filter(
                lambda p: p.requires_grad, self.transformer.parameters()
            )
            print('init lora successfully')
        # lora_layers_name = [name for name, param in self.transformer.named_parameters() if param.requires_grad]
        # print(lora_layers_name)
        return list(lora_layers)

    def save_lora(self, path: str):
        FluxPipeline.save_lora_weights(
            save_directory=path,
            transformer_lora_layers=get_peft_model_state_dict(self.transformer, adapter_name='default'),
            # transformer_lora_layers=get_peft_model_state_dict(self.transformer, adapter_name='default_0' if getattr(self, 'lora_path', None) is not None else 'default'),
            safe_serialization=True,
        )
        print('save lora successfully')

    def load_lora(self, lora_names):
        for condition_type in lora_names:
            # if condition_type == 'eligen':
            #     state_dict = torch.load('checkpoints/eligen.bin')
            #     self.transformer.load_lora_adapter(state_dict, prefix="transformer", adapter_name=condition_type,)
            # elif condition_type != 'default':
            #     self.transformer.load_lora_adapter(
            #         f"checkpoints/{condition_type}.safetensors", 
            #         adapter_name=condition_type,
            #         # rank=64,
            #     )
            self.transformer.load_lora_adapter(
                f"checkpoints/{condition_type}.safetensors", 
                adapter_name=condition_type,
                # rank=64,
            )
            # self.transformer.add_adapter(f"checkpoints/{condition_type}.safetensors", adapter_name=condition_type)

    def configure_optimizers(self):        
        # Freeze the transformer
        self.transformer.requires_grad_(False)
        opt_config = self.optimizer_config

        # Set the trainable parameters
        self.trainable_params = self.lora_layers if not self.condition_type == 'loose_condition' else []

        if 'camera' in self.condition_type or 'pose' in self.condition_type:
            # self.transformer.requires_grad_(True)
            # Flatten the generator into a list of parameters
            camera_params = [param for param in self.transformer.cam_embedder.parameters()]
            self.trainable_params += camera_params
        if 'loose' in self.condition_type and hasattr(self.transformer, 'inter_controller'):
        # if 'loose' in self.condition_type and hasattr(self.transformer, 'inter_controller') and self.model_config.get('inter_controller_path', None) is None:
            # self.transformer.requires_grad_(True)
            new_params = [param for param in self.transformer.inter_controller.parameters()]
            self.trainable_params += new_params
        if 'flux' in self.condition_type and hasattr(self.transformer, 'loose_embedder'):
            # self.transformer.requires_grad_(True)
            new_params = [param for param in self.transformer.loose_embedder.parameters()]
            self.trainable_params += new_params

        # Unfreeze trainable parameters
        for p in self.trainable_params:
            p.requires_grad_(True)
        print(f"trainable_params num: {len(self.trainable_params)}")
        lora_layers_name = set([name.split('.')[-2] for name, param in self.transformer.named_parameters() if param.requires_grad])
        print('The traing lora is: ', lora_layers_name)
        # trainable_param_name = [name for name, param in self.transformer.named_parameters() if param.requires_grad]
        trainable_param_name = [name for name, param in self.transformer.named_parameters() if param.requires_grad and 'lora' not in name]
        print(trainable_param_name)
        # import ipdb; ipdb.set_trace()

        # Initialize the optimizer
        if opt_config["type"] == "AdamW":
            optimizer = torch.optim.AdamW(self.trainable_params, **opt_config["params"])
            lr_scheduler = get_scheduler(
                "cosine",
                optimizer=optimizer,
                num_warmup_steps=500,
                num_training_steps=30000,
            )
            optim_dict = {'optimizer': optimizer, 'lr_scheduler': lr_scheduler}
            return optim_dict
            # optimizer = torch.optim.AdamW(
            #     self.trainable_params,
            #     lr=3e-4,
            #     weight_decay=0.001,
            #     betas=(0.9, 0.999),
            #     eps=1e-8
            # )
            
            # scheduler = get_cosine_schedule_with_warmup(
            #     optimizer,
            #     num_warmup_steps=500,
            #     num_training_steps=30000
            # )
            
            # return [optimizer], [scheduler]
        elif opt_config["type"] == "Prodigy":
            optimizer = prodigyopt.Prodigy(
                self.trainable_params,
                **opt_config["params"],
            )
        elif opt_config["type"] == "SGD":
            optimizer = torch.optim.SGD(self.trainable_params, **opt_config["params"])
        else:
            raise NotImplementedError

        return optimizer

    def training_step(self, batch, batch_idx):
        step_loss, losses = self.step(batch)
        if not hasattr(self, "log_loss"):
            self.log_loss = losses
        else:
            for key in losses.keys():
                # self.log_loss[key] = self.log_loss[key] * 0.95 + losses[key] * 0.05
                self.log_loss[key] = losses[key]

        # if self.total_steps != 0:
        #     for name, param in self.transformer.named_parameters():
        #         if param.grad is None and param.requires_grad:
        #             print(f"未使用的参数: {name}")

        self.total_steps += 1

        return step_loss

    def step(self, batch):
        self.to(torch.bfloat16)
        
        imgs = batch["image"]
        conditions = batch.get("condition", None)
        condition_types = batch["condition_type"]
        prompts = batch["description"]
        position_delta = batch["position_delta"][0]
        
        kwargs = {}
        # Prepare inputs
        with torch.no_grad():
            # Prepare image input
            x_0, img_ids = encode_images(self.flux_pipe, imgs) # [1, 4096, 64], [4096, 3]

            # Prepare text input
            prompt_embeds, pooled_prompt_embeds, text_ids = prepare_text_input(
                self.flux_pipe, prompts
            ) # [1, 512, 4096], [1, 768], [512, 3]

            # Prepare t and x_t
            t = torch.sigmoid(torch.randn((imgs.shape[0],), device=self.device, dtype=self.dtype))
            x_1 = torch.randn_like(x_0).to(self.device)
            t_ = t.unsqueeze(1).unsqueeze(1)
            x_t = ((1 - t_) * x_0 + t_ * x_1).to(self.dtype)

            # Prepare conditions
            if conditions is not None:
                conditions = conditions.to(self.dtype)

                if 'camera' in self.condition_type:
                    kwargs["cam_entity_idx"] = [idx.item() for idx in batch["cam_entity_idx"]]
                    condition_ids = img_ids.clone()
                    condition_latents = conditions[0]
                    # print(condition_latents.shape)
                elif "depth" in self.condition_type:
                    if self.depth_model_type == 'depth_pro':
                        depth = self.depth_model.infer(conditions[0], f_px=None)["depth"][None]
                        depth_img = np.array(ToPILImage()(((depth.max()-depth[0])/(depth.max()-depth.min()))))
                        depth_img = Image.fromarray(depth_img).convert("RGB")
                        # depth_img.save('depth.png')
                        conditions = ToTensor()(depth_img)[None]

                    condition_latents, condition_ids = encode_images(self.flux_pipe, conditions[0]) # [n, 4096, 64], [4096, 3]
                elif "loose" in self.condition_type:
                    condition_latents, condition_ids = encode_images(self.flux_pipe, conditions[0]) # [n, 4096, 64], [4096, 3]
                else:
                    condition_latents, condition_ids = encode_images(self.flux_pipe, conditions)
                    
                condition_latents = condition_latents.to(self.dtype)

                # Add position delta
                condition_ids[:, 1] += position_delta[0]
                condition_ids[:, 2] += position_delta[1]

                # Prepare condition type
                condition_type_ids = torch.tensor(
                    [
                        Condition.get_type_id(condition_type)
                        for condition_type in condition_types
                    ]
                ).to(self.device)
                condition_type_ids = (
                    torch.ones_like(condition_ids[:, 0]) * condition_type_ids[0]
                ).unsqueeze(1)
            else:
                condition_latents = None
                condition_ids = None
                condition_type_ids = None
            
            # Prepare guidance
            guidance = (
                torch.ones_like(t).to(self.device)
                if self.transformer.config.guidance_embeds
                else None
            )
            
        # Prepare eligen text input
        if batch.get("eligen_entity_prompts", None) and batch.get("eligen_entity_masks", None):
            # import ipdb; ipdb.set_trace()
            eligen_entity_prompts = batch["eligen_entity_prompts"]
            eligen_entity_prompts = [eligen_entity_prompt[0] for eligen_entity_prompt in eligen_entity_prompts]
            eligen_entity_masks = batch["eligen_entity_masks"]
            eligen_kwargs = prepare_text_input_eligen(self.flux_pipe, eligen_entity_prompts, eligen_entity_masks, orients=batch.get("orient", None))
            kwargs["eligen_kwargs"] = eligen_kwargs
            # print(len(eligen_entity_prompts))

        # Forward pass
        transformer_out = tranformer_forward(
            self.transformer,
            # Model config
            model_config=self.model_config,
            # Inputs of the condition (new feature)
            condition_latents=condition_latents,
            condition_ids=condition_ids,
            condition_type_ids=condition_type_ids,
            # Inputs to the original transformer
            hidden_states=x_t,
            timestep=t,
            guidance=guidance,
            pooled_projections=pooled_prompt_embeds,
            encoder_hidden_states=prompt_embeds,
            txt_ids=text_ids,
            img_ids=img_ids,
            joint_attention_kwargs=None,
            return_dict=False,
            # return_x_embedder_output=("loose_condition" == self.condition_type),
            **kwargs,
        )
        pred = transformer_out[0]

        # Compute loss
        losses = {}
        # import ipdb; ipdb.set_trace()
        if 'weight' in self.condition_type and batch.get("eligen_entity_masks", None) is not None:
            # Compute weighted_mse_loss
            eligen_entity_masks = batch["eligen_entity_masks"]
            eligen_entity_masks = [rearrange(m, "B (H P) (W Q) C -> B (H W) (C P Q)", P=2, Q=2) for m in eligen_entity_masks]
            eligen_entity_masks = [torch.sum(m, dim=-1) > 0 for m in eligen_entity_masks]
            loss_mask = torch.zeros_like(eligen_entity_masks[0]).to(self.dtype)

            for m in eligen_entity_masks:
                foreground_area = torch.sum(m, dim=-1, keepdim=True)
                total_area = torch.numel(m[0])

                loss_mask += torch.where(m == 1, total_area / foreground_area, 1.0).to(self.dtype)
            loss_mask = loss_mask / len(eligen_entity_masks)
        else:
            loss_mask = None

        if loss_mask is not None:
            mse_loss = torch.nn.functional.mse_loss(pred, (x_1 - x_0), reduction="none")
            weighted_mse_loss = mse_loss * loss_mask[:,:,None]
            mse_loss = torch.mean(weighted_mse_loss)
        else:
            mse_loss = torch.nn.functional.mse_loss(pred, (x_1 - x_0), reduction="mean")
        loss = mse_loss
        losses['mse_loss'] = mse_loss.item()

        if False and "loose_condition" == self.condition_type:
            condition_latents = transformer_out[1]

            imgs_normal = self.image_transform(imgs).to(device=self.device, dtype=self.dtype)
            depth_gt = self.depth_model.infer(imgs_normal, f_px=None)["depth"][None]
            depth_img = np.array(ToPILImage()(((depth_gt.max()-depth_gt[0])/(depth_gt.max()-depth_gt.min()))))
            depth_img = Image.fromarray(depth_img).convert("RGB")
            # depth_img.save('depth.png')
            conditions_gt = ToTensor()(depth_img)[None].to(device=self.device, dtype=self.dtype)
            condition_latents_gt, _ = encode_images(self.flux_pipe, conditions_gt[0]) # [n, 4096, 64], [4096, 3]
            
            condition_latents_gt = self.transformer.x_embedder(condition_latents_gt)
            loss_depth = torch.nn.functional.mse_loss(condition_latents_gt, condition_latents, reduction="mean")
            # print('loss:', loss)
            # print('loss_depth:', loss_depth)
            loss += loss_depth * 0.2
        
        if 'reward' in self.condition_type:
            # if True:
            if 0<t.item()<0.4:
                grad_scale = 30
                linear_reward_scale_steps = 500

                if self.depth_model_type == 'depth_pro':
                    imgs_normal_gt = self.image_transform(imgs).to(device=self.device, dtype=self.dtype)
                    depth_gt = self.depth_model.infer(imgs_normal_gt, f_px=None)["depth"][None]
                    inverse_depth_gt = 1 / depth_gt
                    max_invdepth_vizu = min(inverse_depth_gt.max(), 1 / 0.1)
                    min_invdepth_vizu = max(1 / 250, inverse_depth_gt.min())
                    inverse_depth_normalized_gt = (inverse_depth_gt - min_invdepth_vizu) / (max_invdepth_vizu - min_invdepth_vizu)
                    # depth_img_gt = (depth_gt.max()-depth_gt)/(depth_gt.max()-depth_gt.min())
                    # depth_img_gt = (depth_gt.max()-depth_gt)/(depth_gt.max())
                    outputs_gt = inverse_depth_normalized_gt

                    latents_x0 = x_t - pred * t_
                    image_pred = decode_images(self.flux_pipe, latents_x0, 512, 512, 'pt') # TODO: height and weight
                    imgs_normal_pred = self.image_transform(image_pred).to(device=self.device, dtype=self.dtype)
                    depth_pred = self.depth_model.infer(imgs_normal_pred, f_px=None)["depth"][None]
                    inverse_depth_pred = 1 / depth_pred
                    max_invdepth_vizu = min(inverse_depth_pred.max(), 1 / 0.1)
                    min_invdepth_vizu = max(1 / 250, inverse_depth_pred.min())
                    inverse_depth_normalized_pred = (inverse_depth_pred - min_invdepth_vizu) / (max_invdepth_vizu - min_invdepth_vizu)
                    # depth_img_pred = (depth_pred.max()-depth_pred)/(depth_pred.max()-depth_pred.min())
                    # depth_img_pred = (depth_pred.max()-depth_pred)/(depth_pred.max())
                    outputs_pred = inverse_depth_normalized_pred
                else:
                    # depth_gt_ = self.depth_model_(ToPILImage()(imgs[0]))["predicted_depth"]
                    # a = self.depth_model.image2tensor(cv2.cvtColor(np.asarray(ToPILImage()(imgs[0])),cv2.COLOR_RGB2BGR))[0].to(device=self.device, dtype=self.dtype)
                    # b = self.depth_model.forward(a)
                    # b = torchvision.transforms.functional.resize(b, (imgs.shape[-2], imgs.shape[-1]), interpolation=transforms.InterpolationMode.BILINEAR, antialias=True)
                    # c = (b - b.min()) / (b.max() - b.min())

                    # inputs = self.image_processor(images=imgs, return_tensors="pt")
                    # outputs = self.depth_model(**inputs)
                    # depth_gt = self.image_processor.post_process_depth_estimation(outputs, target_sizes=[(imgs.shape[-2], imgs.shape[-1])],)[0]["predicted_depth"]

                    imgs_normal_gt = self.image_transform(imgs).to(device=self.device, dtype=self.dtype)
                    depth_gt = self.depth_model.forward(imgs_normal_gt)
                    depth_gt = torchvision.transforms.functional.resize(depth_gt, (imgs.shape[-2], imgs.shape[-1]), interpolation=transforms.InterpolationMode.BILINEAR, antialias=True)
                    outputs_gt = (depth_gt - depth_gt.min()) / (depth_gt.max() - depth_gt.min())

                    latents_x0 = x_t - pred * t_
                    image_pred = decode_images(self.flux_pipe, latents_x0, 512, 512, 'pt') # TODO: height and weight
                    imgs_normal_pred = self.image_transform(image_pred).to(device=self.device, dtype=self.dtype)
                    depth_pred = self.depth_model.forward(imgs_normal_pred)
                    depth_pred = torchvision.transforms.functional.resize(depth_pred, (imgs.shape[-2], imgs.shape[-1]), interpolation=transforms.InterpolationMode.BILINEAR, antialias=True)
                    outputs_pred = (depth_pred - depth_pred.min()) / (depth_pred.max() - depth_pred.min())
                    
                if loss_mask is not None:
                    new_shape = [outputs_gt.shape[-2] // 16, outputs_gt.shape[-1] // 16]
                    outputs_gt = torchvision.transforms.functional.resize(outputs_gt, (new_shape[0], new_shape[1]), interpolation=transforms.InterpolationMode.BILINEAR, antialias=True)
                    outputs_gt = outputs_gt.reshape(outputs_gt.shape[0], -1)
                    outputs_pred = torchvision.transforms.functional.resize(outputs_pred, (new_shape[0], new_shape[1]), interpolation=transforms.InterpolationMode.BILINEAR, antialias=True)
                    outputs_pred = outputs_pred.reshape(outputs_pred.shape[0], -1)

                    reward_loss = torch.nn.functional.l1_loss(outputs_gt, outputs_pred, reduction="none")
                    # reward_loss = torch.nn.functional.mse_loss(outputs_gt, outputs_pred, reduction="none")
                    reward_loss = reward_loss * loss_mask
                    reward_loss = torch.mean(reward_loss)
                else:
                    reward_loss = torch.nn.functional.l1_loss(outputs_gt, outputs_pred, reduction="mean")
                    # reward_loss = torch.nn.functional.mse_loss(outputs_gt, outputs_pred, reduction="mean")
                # grad_scale = min(1.0, self.total_steps / linear_reward_scale_steps) * grad_scale
                loss += reward_loss * grad_scale
                # loss = reward_loss * grad_scale

                losses['reward_loss'] = reward_loss.item()
                losses['reward_loss_scale'] = reward_loss.item() * grad_scale

                # if self.total_steps % 10 ==0 and self.total_steps > 0:
                #     print("reward_loss requires_grad:", reward_loss.requires_grad)  # 应为True
                #     print("reward_loss grad_fn:", reward_loss.grad_fn)              # 应非None
            else:
                losses['reward_loss'] = 0
                losses['reward_loss_scale'] = 0

        orient_loss = self.model_config.get('orient_loss', None)
        if False and 'camera' in self.condition_type or 'pose' in self.condition_type and orient_loss is not None:
            def P_gauss(theta, sigma, device, range=180):
                sigma = torch.tensor(sigma)
                diff = torch.arange(1, range+1, device=device, dtype=torch.bfloat16) - theta
                numerator = -(diff ** 2) / (2 * sigma ** 2)
                return numerator.to(device=device, dtype=torch.bfloat16)

            def P_gauss_circular(phi, sigma, device, range=360):
                sigma = torch.tensor(sigma)
                diff = torch.arange(1, range+1, device=device, dtype=torch.bfloat16) - phi
                diff_rad = torch.deg2rad(diff)
                numerator = torch.cos(diff_rad) / torch.deg2rad(sigma)**2
                return numerator.to(device=device, dtype=torch.bfloat16)

            if True:
                with torch.enable_grad():
                # if 0<t.item()<0.4:
                    pose_flag = False
                    pose_loss = 0
                    grad_scale = 2

                    latents_x0 = x_t - pred * t_
                    with (torch.autocast(device_type='cuda', dtype=torch.float32)):
                        image_pred = decode_images(self.flux_pipe, latents_x0, 512, 512, 'pt') # TODO: height and weight
                    image_pred = image_pred.to(self.dtype)
                    for orient_gt, orient_bbox in zip(batch["orient"], batch["orient_bboxs"]):
                        try:
                            x_min, y_min, x_max, y_max = orient_bbox
                            if (x_max-x_min) < 10 or (y_max-y_min) < 10:
                                continue
                            orient_ax_gt   = P_gauss_circular(orient_gt[0], 20, device=self.device)
                            # orient_pl_gt   = P_gauss(orient_gt[1]+90, 2, device=self.device)
                            # # orient_ro_gt   = P_gauss_circular(orient_gt[2]+90, 1, device=self.device, range=180)
                            # orient_ro_gt   = P_gauss(orient_gt[2]+90, 1, device=self.device)

                            image_pred_box = image_pred[:,:,y_min:y_max, x_min:x_max]
                            imgs_normal_pred_box = self.orient_transform(image_pred_box).to(device=self.device, dtype=self.dtype)
                            orient_pred = self.orient_model({'pixel_values': imgs_normal_pred_box})
                            orient_ax_pred   = orient_pred[:, 0:360]
                            # orient_pl_pred   = orient_pred[:, 360:360+180]
                            # orient_ro_pred   = orient_pred[:, 360+180:360+180+180]
                            confidence = torch.nn.functional.softmax(orient_pred[:, -2:], dim=-1)[0][0]
                            if confidence < 0.95:
                                continue
                            
                            # import ipdb; ipdb.set_trace()
                            epsilon = 1e-8
                            if orient_loss == "KL":
                                pose_loss_as = torch.nn.functional.kl_div((orient_ax_pred.softmax(dim=-1) + epsilon).log(), orient_ax_gt.softmax(dim=-1), reduction='sum')
                                # pose_loss_pl = torch.nn.functional.kl_div((orient_pl_pred.softmax(dim=-1) + epsilon).log(), orient_pl_gt.softmax(dim=-1), reduction='sum')
                                # pose_loss_ro = torch.nn.functional.kl_div((orient_ro_pred.softmax(dim=-1) + epsilon).log(), orient_ro_gt.softmax(dim=-1), reduction='sum')
                            elif orient_loss == "CE":
                                # pose_loss_as = torch.nn.CrossEntropyLoss(reduction='sum')(orient_ax_pred.softmax(dim=-1), orient_ax_gt.softmax(dim=-1).unsqueeze(0))
                                # pose_loss_pl = torch.nn.CrossEntropyLoss(reduction='sum')(orient_pl_pred.softmax(dim=-1), orient_pl_gt.softmax(dim=-1).unsqueeze(0))
                                # pose_loss_ro = torch.nn.CrossEntropyLoss(reduction='sum')(orient_ro_pred.softmax(dim=-1), orient_ro_gt.softmax(dim=-1).unsqueeze(0))
                                pose_loss_as = torch.nn.CrossEntropyLoss(reduction='sum')(orient_ax_pred, orient_ax_gt.softmax(dim=-1).unsqueeze(0))
                                # pose_loss_pl = torch.nn.CrossEntropyLoss(reduction='sum')(orient_pl_pred, orient_pl_gt.softmax(dim=-1).unsqueeze(0))
                                # pose_loss_ro = torch.nn.CrossEntropyLoss(reduction='sum')(orient_ro_pred, orient_ro_gt.softmax(dim=-1).unsqueeze(0))

                            pose_loss += pose_loss_as
                            # pose_loss += pose_loss_as + pose_loss_pl + pose_loss_ro

                            pose_flag=True
                        except Exception as e:
                            print(e)

                    if pose_flag:
                        pose_loss = pose_loss / len(batch["orient"]) # torch.autograd.grad(outputs=pose_loss, inputs=latents, )[0]
                        # pose_loss += self.transformer.transformer_blocks[0].attn.to_q.lora_A.default.weight.mean()
                        loss += pose_loss * grad_scale
                        # loss = pose_loss * grad_scale

                        losses['pose_loss'] = pose_loss.item()
                        losses['pose_loss_as'] = pose_loss_as.item()
                        losses['confidence'] = confidence.item()
                        # losses['pose_loss_pl'] = pose_loss_pl.item()
                        # losses['pose_loss_ro'] = pose_loss_ro.item()
                    else:
                        losses['pose_loss'] = 0
                        losses['pose_loss_as'] = 0
                        losses['confidence'] = 0
            else:
                losses['pose_loss'] = 0
                losses['pose_loss_as'] = 0
                # losses['pose_loss_pl'] = 0
                # losses['pose_loss_ro'] = 0

        if check_loss and self.total_steps % 50 ==0 and self.total_steps > 0:
            # import ipdb; ipdb.set_trace()
            grad = self.transformer.get_parameter(self.target_param_name).grad
            if grad is not None:
                print(f"{self.target_param_name} 平均梯度: {grad.abs().mean().item():.3e}")
            else:
                print(f"{self.target_param_name} 没有梯度")

            # self.target_param_name = "transformer_blocks.7.attn.to_out.0.lora_B.default.weight"
            # self.transformer.get_parameter(self.target_param_name).grad.mean()

            # name_list = []
            # for name, param in self.transformer.named_parameters():
            #     if 'lora_B' not in name and self.transformer.get_parameter(name).grad is not None and self.transformer.get_parameter(name).grad.mean()!=0:
            #         name_list.append(name)
            # print(name_list)

            # grad_norms = [p.grad.abs().mean() for p in self.transformer.cam_embedder.parameters() if p.grad is not None]
            # if grad_norms:
            #     print(f"[AdamW] 平均梯度: {torch.stack(grad_norms).mean().item():.3e}")
            
            current_value = self.transformer.get_parameter(self.target_param_name).clone().detach()
            if not torch.allclose(self.initial_value, current_value, atol=1e-6):
                print(f"参数 {self.target_param_name} 已更新.")
                self.initial_value = current_value
            else:
                print(f"参数 {self.target_param_name} 未更新.")

        # print(loss) # nan when torch==2.5.0
        self.last_t = t.mean().item()
        return loss, losses
