import torch
import yaml, os
from diffusers.pipelines import FluxPipeline
from typing import List, Union, Optional, Dict, Any, Callable
from .transformer import tranformer_forward
from .condition import Condition
from .pipeline_tools import process_entity_masks, prepare_text_input_eligen
from .pipeline_tools import encode_images, decode_images, prepare_text_input, prepare_text_input_eligen, encode_poses

from diffusers.pipelines.flux.pipeline_flux import (
    FluxPipelineOutput,
    calculate_shift,
    retrieve_timesteps,
    np,
)


def get_config(config_path: str = None):
    config_path = config_path or os.environ.get("XFL_CONFIG")
    if not config_path:
        return {}
    with open(config_path, "r") as f:
        config = yaml.safe_load(f)
    return config


def prepare_params(
    prompt: Union[str, List[str]] = None,
    prompt_2: Optional[Union[str, List[str]]] = None,
    height: Optional[int] = 512,
    width: Optional[int] = 512,
    num_inference_steps: int = 50,
    timesteps: List[int] = None,
    guidance_scale: float = 3.5,
    num_images_per_prompt: Optional[int] = 1,
    generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
    latents: Optional[torch.FloatTensor] = None,
    prompt_embeds: Optional[torch.FloatTensor] = None,
    pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
    output_type: Optional[str] = "pil",
    return_dict: bool = True,
    joint_attention_kwargs: Optional[Dict[str, Any]] = None,
    callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
    callback_on_step_end_tensor_inputs: List[str] = ["latents"],
    max_sequence_length: int = 512,
    **kwargs: dict,
):
    return (
        prompt,
        prompt_2,
        height,
        width,
        num_inference_steps,
        timesteps,
        guidance_scale,
        num_images_per_prompt,
        generator,
        latents,
        prompt_embeds,
        pooled_prompt_embeds,
        output_type,
        return_dict,
        joint_attention_kwargs,
        callback_on_step_end,
        callback_on_step_end_tensor_inputs,
        max_sequence_length,
        kwargs,
    )


def seed_everything(seed: int = 42):
    torch.backends.cudnn.deterministic = True
    torch.manual_seed(seed)
    np.random.seed(seed)


@torch.no_grad()
def generate(
    pipeline: FluxPipeline,
    conditions: List[Condition] = None,
    config_path: str = None,
    model_config: Optional[Dict[str, Any]] = {},
    condition_scale: float = 1.0,
    default_lora: bool = False,
    cfg_scale=1.0,
    negative_prompt="",
    **params: dict,
):
    model_config = model_config or get_config(config_path).get("model", {})
    if condition_scale != 1:
        for name, module in pipeline.transformer.named_modules():
            if not name.endswith(".attn"):
                continue
            module.c_factor = torch.ones(1, 1) * condition_scale

    self = pipeline
    (
        prompt,
        prompt_2,
        height,
        width,
        num_inference_steps,
        timesteps,
        guidance_scale,
        num_images_per_prompt,
        generator,
        latents,
        prompt_embeds,
        pooled_prompt_embeds,
        output_type,
        return_dict,
        joint_attention_kwargs,
        callback_on_step_end,
        callback_on_step_end_tensor_inputs,
        max_sequence_length,
        kwargs,
    ) = prepare_params(**params)

    height = height or self.default_sample_size * self.vae_scale_factor
    width = width or self.default_sample_size * self.vae_scale_factor

    # 1. Check inputs. Raise error if not correct
    self.check_inputs(
        prompt,
        prompt_2,
        height,
        width,
        prompt_embeds=prompt_embeds,
        pooled_prompt_embeds=pooled_prompt_embeds,
        callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
        max_sequence_length=max_sequence_length,
    )

    self._guidance_scale = guidance_scale
    self._joint_attention_kwargs = joint_attention_kwargs
    self._interrupt = False

    # 2. Define call parameters
    if prompt is not None and isinstance(prompt, str):
        batch_size = 1
    elif prompt is not None and isinstance(prompt, list):
        batch_size = len(prompt)
    else:
        batch_size = prompt_embeds.shape[0]

    device = self._execution_device

    lora_scale = (
        self.joint_attention_kwargs.get("scale", None)
        if self.joint_attention_kwargs is not None
        else None
    )
    (
        prompt_embeds,
        pooled_prompt_embeds,
        text_ids,
    ) = self.encode_prompt(
        prompt=prompt,
        prompt_2=prompt_2,
        prompt_embeds=prompt_embeds,
        pooled_prompt_embeds=pooled_prompt_embeds,
        device=device,
        num_images_per_prompt=num_images_per_prompt,
        max_sequence_length=max_sequence_length,
        lora_scale=lora_scale,
    )
    if cfg_scale != 1.0:
        (
            negative_prompt_embeds,
            negative_pooled_prompt_embeds,
            _,
        ) = self.encode_prompt(
            prompt=negative_prompt,
            prompt_2=None,
            device=device,
            num_images_per_prompt=num_images_per_prompt,
            max_sequence_length=max_sequence_length,
            lora_scale=lora_scale,
        )

    # 4. Prepare latent variables
    num_channels_latents = self.transformer.config.in_channels // 4
    latents, latent_image_ids = self.prepare_latents(
        batch_size * num_images_per_prompt,
        num_channels_latents,
        height,
        width,
        prompt_embeds.dtype,
        device,
        generator,
        latents,
    )

    # 4.1. Prepare conditions
    condition_latents, condition_ids, condition_type_ids = ([] for _ in range(3))
    use_condition = conditions is not None or []
    if use_condition:
        assert len(conditions) <= 1, "Only one condition is supported for now."
        if not default_lora:
            pipeline.set_adapters(conditions[0].condition_type)
        for condition in conditions:
            tokens, ids, type_id = condition.encode(self, ids=latent_image_ids)
            condition_latents.append(tokens)  # [batch_size, token_n, token_dim]
            condition_ids.append(ids)  # [token_n, id_dim(3)]
            condition_type_ids.append(type_id)  # [token_n, 1]
        condition_latents = torch.cat(condition_latents, dim=1)
        condition_ids = torch.cat(condition_ids, dim=0) if ids is not None else None
        condition_type_ids = torch.cat(condition_type_ids, dim=0)

    # 5. Prepare timesteps
    sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
    image_seq_len = latents.shape[1]
    mu = calculate_shift(
        image_seq_len,
        self.scheduler.config.base_image_seq_len,
        self.scheduler.config.max_image_seq_len,
        self.scheduler.config.base_shift,
        self.scheduler.config.max_shift,
    )
    timesteps, num_inference_steps = retrieve_timesteps(
        self.scheduler,
        num_inference_steps,
        device,
        timesteps,
        sigmas,
        mu=mu,
    )
    num_warmup_steps = max(
        len(timesteps) - num_inference_steps * self.scheduler.order, 0
    )
    self._num_timesteps = len(timesteps)

    if kwargs.get("eligen_entity_prompts", None) and kwargs.get("eligen_entity_masks", None):
        eligen_entity_prompts = kwargs["eligen_entity_prompts"]
        eligen_entity_masks = kwargs["eligen_entity_masks"]
        eligen_kwargs = prepare_text_input_eligen(
            self, eligen_entity_prompts, eligen_entity_masks, orients=kwargs.get("orient", None),
        )
        kwargs["eligen_kwargs"] = eligen_kwargs

    # 6. Denoising loop
    num_steps = 5
    latent_ori = latents
    for step in range(num_steps):
        with self.progress_bar(total=num_inference_steps) as progress_bar:
            latents = latent_ori.detach().clone()
            for i, t in enumerate(timesteps):
                if self.interrupt:
                    continue

                # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
                timestep = t.expand(latents.shape[0]).to(latents.dtype)

                # handle guidance
                if self.transformer.config.guidance_embeds:
                    guidance = torch.tensor([guidance_scale], device=device)
                    guidance = guidance.expand(latents.shape[0])
                else:
                    guidance = None
                noise_pred = tranformer_forward(
                    self.transformer,
                    model_config=model_config,
                    # Inputs of the condition (new feature)
                    condition_latents=condition_latents if use_condition else None,
                    condition_ids=condition_ids if use_condition else None,
                    condition_type_ids=condition_type_ids if use_condition else None,
                    # Inputs to the original transformer
                    hidden_states=latents,
                    # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
                    timestep=timestep / 1000,
                    guidance=guidance,
                    pooled_projections=pooled_prompt_embeds,
                    encoder_hidden_states=prompt_embeds,
                    txt_ids=text_ids,
                    img_ids=latent_image_ids,
                    joint_attention_kwargs=self.joint_attention_kwargs,
                    return_dict=False,
                    **kwargs,
                )[0]

                if cfg_scale != 1.0:
                    # Negative side
                    noise_pred_nega = tranformer_forward(
                        self.transformer,
                        model_config=model_config,
                        # Inputs of the condition (new feature)
                        condition_latents=None,
                        condition_ids=None,
                        condition_type_ids=None,
                        # Inputs to the original transformer
                        hidden_states=latents,
                        # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
                        timestep=timestep / 1000,
                        guidance=guidance,
                        pooled_projections=negative_pooled_prompt_embeds,
                        encoder_hidden_states=negative_prompt_embeds,
                        txt_ids=text_ids,
                        img_ids=latent_image_ids,
                        joint_attention_kwargs=self.joint_attention_kwargs,
                        return_dict=False,
                        **kwargs,
                    )[0]
                    noise_pred = noise_pred_nega + cfg_scale * (noise_pred - noise_pred_nega)
                    
                # compute the previous noisy sample x_t -> x_t-1
                latents_dtype = latents.dtype
                latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]

                if latents.dtype != latents_dtype:
                    if torch.backends.mps.is_available():
                        # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
                        latents = latents.to(latents_dtype)

                if callback_on_step_end is not None:
                    callback_kwargs = {}
                    for k in callback_on_step_end_tensor_inputs:
                        callback_kwargs[k] = locals()[k]
                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)

                    latents = callback_outputs.pop("latents", latents)
                    prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)

                # call the callback, if provided
                if i == len(timesteps) - 1 or (
                    (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
                ):
                    progress_bar.update()
        
        if not step == num_steps - 1:
            save_path = f'Origen/tmp/step_{step}.png'
            latent_ori, reward = langevin_update(
                self, 
                latent_ori, 
                latents, 
                timesteps[0] / 1000,
                kwargs.get("orient_pose", None), 
                kwargs.get("orient_bboxs", None),
                height=height,
                width=width,
                save_path=save_path,
                save_tmp_image=True,
            )
            print(f"Step {step+1}/{num_steps} | Reward: {reward:.4f}")
            self.scheduler.set_timesteps(num_inference_steps)

    if output_type == "latent":
        image = latents        
    else:
        latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
        latents = (
            latents / self.vae.config.scaling_factor
        ) + self.vae.config.shift_factor
        image = self.vae.decode(latents, return_dict=False)[0]
        image = self.image_processor.postprocess(image, output_type=output_type)

    # Offload all models
    self.maybe_free_model_hooks()

    if condition_scale != 1:
        for name, module in pipeline.transformer.named_modules():
            if not name.endswith(".attn"):
                continue
            del module.c_factor
    
    if not return_dict:
        return (image,)

    return FluxPipelineOutput(images=image)

# Hyperparameters (论文默认参数)
s_min = 1/3
s_max = 4/3
k = 6/5
eta = 1  # Gradient scaling factor
gamma = 0.3  # Base step size

def cal_pose_loss(pipe, latents_x0, orient, orient_bboxs, dtype=torch.bfloat16):
    def P_gauss(theta, sigma, device, range=180):
        sigma = torch.tensor(sigma)
        diff = torch.arange(1, range+1, device=device, dtype=dtype) - theta
        numerator = -(diff ** 2) / (2 * sigma ** 2)
        return numerator.to(device=device, dtype=dtype)

    def P_gauss_circular(phi, sigma, device, range=360):
        sigma = torch.tensor(sigma)
        diff = torch.arange(1, range+1, device=device, dtype=dtype) - phi
        diff_rad = torch.deg2rad(diff)
        numerator = torch.cos(diff_rad) / torch.deg2rad(sigma)**2
        return numerator.to(device=device, dtype=dtype)

    pose_loss = 0
    orient_loss = "KL"
    image_pred = decode_images(pipe, latents_x0, 1024, 1024, 'pt') # TODO: height and weight

    for orient_gt, orient_bbox in zip(orient, orient_bboxs):
        x_min, y_min, x_max, y_max = orient_bbox
        orient_ax_gt   = P_gauss_circular(orient_gt[0], 20, device=pipe.device)
        orient_pl_gt   = P_gauss(orient_gt[1]+90, 2, device=pipe.device)
        # orient_ro_gt   = P_gauss_circular(orient_gt[2]+90, 1, device=pipe.device, range=180)
        orient_ro_gt   = P_gauss(orient_gt[2]+90, 1, device=pipe.device, range=180)

        image_pred_box = image_pred[:,:,y_min:y_max, x_min:x_max]
        imgs_normal_pred_box = pipe.orient_transform(image_pred_box).to(device=pipe.device, dtype=dtype)
        orient_pred = pipe.orient_model({'pixel_values': imgs_normal_pred_box})
        orient_ax_pred   = orient_pred[:, 0:360]
        orient_pl_pred   = orient_pred[:, 360:360+180]
        orient_ro_pred   = orient_pred[:, 360+180:360+180+180]
        orient_pred = [
            torch.argmax(orient_pred[:, 0:360], dim=-1).item(), 
            torch.argmax(orient_pred[:, 360:360+180], dim=-1).item()-90, 
            torch.argmax(orient_pred[:, 360+180:360+180+180], dim=-1).item()-90, 
        ]
        
        epsilon = 1e-8
        if orient_loss == "KL":
            pose_loss_as = torch.nn.functional.kl_div((orient_ax_pred.softmax(dim=-1) + epsilon).log(), orient_ax_gt.softmax(dim=-1), reduction='sum')
            pose_loss_pl = torch.nn.functional.kl_div((orient_pl_pred.softmax(dim=-1) + epsilon).log(), orient_pl_gt.softmax(dim=-1), reduction='sum')
            pose_loss_ro = torch.nn.functional.kl_div((orient_ro_pred.softmax(dim=-1) + epsilon).log(), orient_ro_gt.softmax(dim=-1), reduction='sum')
        elif orient_loss == "CE":
            # pose_loss_as = torch.nn.CrossEntropyLoss(reduction='sum')(orient_ax_pred.softmax(dim=-1), orient_ax_gt.softmax(dim=-1).unsqueeze(0))
            # pose_loss_pl = torch.nn.CrossEntropyLoss(reduction='sum')(orient_pl_pred.softmax(dim=-1), orient_pl_gt.softmax(dim=-1).unsqueeze(0))
            # pose_loss_ro = torch.nn.CrossEntropyLoss(reduction='sum')(orient_ro_pred.softmax(dim=-1), orient_ro_gt.softmax(dim=-1).unsqueeze(0))
            pose_loss_as = torch.nn.CrossEntropyLoss(reduction='sum')(orient_ax_pred, orient_ax_gt.softmax(dim=-1).unsqueeze(0))
            pose_loss_pl = torch.nn.CrossEntropyLoss(reduction='sum')(orient_pl_pred, orient_pl_gt.softmax(dim=-1).unsqueeze(0))
            pose_loss_ro = torch.nn.CrossEntropyLoss(reduction='sum')(orient_ro_pred, orient_ro_gt.softmax(dim=-1).unsqueeze(0))

        pose_loss += pose_loss_as + pose_loss_pl + pose_loss_ro

    pose_loss = pose_loss / len(orient)
    return pose_loss

def monitor_function(R: torch.Tensor):
    """
    时间重缩放监控函数G及其对数梯度
    :param R: 当前奖励值
    :return: G值, ∇logG
    """
    tanh_term = torch.tanh(k * R)
    G = s_min - tanh_term * (s_max - s_min)
    
    # 计算d(logG)/dR
    dG_dR = - (s_max - s_min) * k * (1 - tanh_term**2)
    d_logG = dG_dR / (G + 1e-6)
    
    return G, d_logG

def langevin_update(
    pipe,
    latents,
    latents_x0: torch.Tensor,
    t,
    orient,
    orient_bboxs,
    height=1024,
    width=1024,
    save_path=None,
    save_tmp_image=True,
):
    """
    单步Langevin更新
    """
    # 生成图像
    with torch.enable_grad():
        with (torch.autocast(device_type='cuda', dtype=torch.float32)):
            latents = latents.detach().to(torch.float32).requires_grad_(True)
            latents_x0 = latents_x0.detach().to(torch.float32).requires_grad_(True)
            t = t.to(torch.float32)
            
            # 计算奖励及梯度
            pose_loss = cal_pose_loss(pipe, latents_x0, orient, orient_bboxs)
            
            # 反向传播到潜空间
            latent_grad = torch.autograd.grad(
                outputs=pose_loss, 
                # inputs=latents,
                inputs=latents_x0,
            )[0]
    
    # 时间重缩放
    G, d_logG = monitor_function(-pose_loss.detach())
    gamma_i = G * gamma
    
    # 更新公式
    update_term = torch.sqrt(1 - gamma_i) * (latents + gamma_i * eta * latent_grad)
    correction_term = 0.5 * gamma_i * d_logG * latent_grad  # 链式法则
    noise_term = torch.sqrt(gamma_i) * torch.randn_like(latents)
    
    latents_new = update_term + correction_term + noise_term

    if save_tmp_image:
        image_tmp = pipe._unpack_latents(latents_x0, height, width, pipe.vae_scale_factor).to(pipe.dtype)
        image_tmp = (image_tmp / pipe.vae.config.scaling_factor) + pipe.vae.config.shift_factor
        image_tmp = pipe.vae.decode(image_tmp, return_dict=False)[0]
        image_tmp = pipe.image_processor.postprocess(image_tmp, output_type='pil')
        image_tmp[0].save(save_path)
    
    # torch.cuda.empty_cache()

    return latents_new.to(pipe.dtype), pose_loss.item()
