from diffusers.pipelines import FluxPipeline
from diffusers.utils import logging
from diffusers.pipelines.flux.pipeline_flux import logger
from diffusers.loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin

import math
import numpy as np
import torch
from torch import Tensor
from packaging import version as pver
from einops import rearrange
from PIL import Image, ImageDraw, ImageFont
import random

def encode_images(pipeline: FluxPipeline, images: Tensor):
    images = pipeline.image_processor.preprocess(images)
    images = images.to(pipeline.device).to(pipeline.dtype)
    images = pipeline.vae.encode(images).latent_dist.sample()
    images = (
        images - pipeline.vae.config.shift_factor
    ) * pipeline.vae.config.scaling_factor
    images_tokens = pipeline._pack_latents(images, *images.shape)
    images_ids = pipeline._prepare_latent_image_ids(
        images.shape[0],
        images.shape[2],
        images.shape[3],
        pipeline.device,
        pipeline.dtype,
    )
    if images_tokens.shape[1] != images_ids.shape[0]:
        images_ids = pipeline._prepare_latent_image_ids(
            images.shape[0],
            images.shape[2] // 2,
            images.shape[3] // 2,
            pipeline.device,
            pipeline.dtype,
        )
    return images_tokens, images_ids

def decode_images(pipeline: FluxPipeline, latents: Tensor, height, width, output_type):
    latents = pipeline._unpack_latents(latents, height, width, pipeline.vae_scale_factor)
    latents = (latents / pipeline.vae.config.scaling_factor) + pipeline.vae.config.shift_factor
    image = pipeline.vae.decode(latents, return_dict=False)[0]
    # image = pipeline.image_processor.postprocess(image, output_type=output_type)
    image = (image * 0.5 + 0.5)

    return image

def get_t5_prompt_embeds(
    pipe,
    prompt = None,
    orients = None,
    num_images_per_prompt: int = 1,
    max_sequence_length: int = 512,
    device = None,
    dtype = None,
):
    device = device or pipe._execution_device
    dtype = dtype or pipe.text_encoder.dtype

    prompt = [prompt] if isinstance(prompt, str) else prompt
    batch_size = len(prompt)

    if isinstance(pipe, TextualInversionLoaderMixin):
        prompt = pipe.maybe_convert_prompt(prompt, pipe.tokenizer_2)

    text_inputs = pipe.tokenizer_2(
        prompt,
        padding="max_length",
        max_length=max_sequence_length,
        truncation=True,
        return_length=False,
        return_overflowing_tokens=False,
        return_tensors="pt",
    )
    text_input_ids = text_inputs.input_ids
    untruncated_ids = pipe.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids

    if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
        removed_text = pipe.tokenizer_2.batch_decode(untruncated_ids[:, pipe.tokenizer_max_length - 1 : -1])
        logger.warning(
            "The following part of your input was truncated because `max_sequence_length` is set to "
            f" {max_sequence_length} tokens: {removed_text}"
        )

    prompt_embeds = pipe.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]

    dtype = pipe.text_encoder_2.dtype
    prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)

    # Process Compass Tokens
    compass_token_id = pipe.tokenizer_2.convert_tokens_to_ids("<extra_id_0>")
    batch_compass_positions = []
    for input_ids in text_input_ids:
        # input_ids_tensor = torch.tensor(input_ids, device=device)
        positions = (input_ids == compass_token_id).nonzero(as_tuple=True)[0].tolist()
        batch_compass_positions.append(positions)

    all_indices = []
    for i, positions in enumerate(batch_compass_positions):
        for pos in positions:
            all_indices.append((i, pos))

    orients = [[int(o) for o in orient] for orient in orients]
    # orients = [
    #     [((o0+ 180)%360-180)/360, o1/90, o2/90]
    #     for o0, o1, o2 in orients
    # ]
    orients = [
        [math.cos(math.radians(int(o0))), math.sin(math.radians(int(o0))), math.cos(math.radians(int(o1))), math.sin(math.radians(int(o1))), math.cos(math.radians(int(o2))), math.sin(math.radians(int(o2)))]
        for o0, o1, o2 in orients
    ]

    if len(orients) != len(all_indices):
        raise ValueError(
            f"Prompts has {len(all_indices)} compass tokens but {len(orients)} orientations provided."
        )

    if len(orients) > 0:
        orients = torch.tensor(orients, dtype=dtype, device=device)
        compass_embeds = pipe.transformer.cam_embedder(orients)
        for idx, (i, pos) in enumerate(all_indices):
            prompt_embeds[i, pos] = compass_embeds[idx]

    _, seq_len, _ = prompt_embeds.shape

    # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
    prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
    prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)

    return prompt_embeds

def prepare_text_input(pipeline: FluxPipeline, prompts, orients=None, max_sequence_length=512):
    if orients is not None:
        prompt_embeds = get_t5_prompt_embeds(
            pipeline, 
            prompts, 
            orients,
            max_sequence_length=max_sequence_length,
            device=pipeline.device,
        )
    else:
        prompt_embeds = None

    # Turn off warnings (CLIP overflow)
    logger.setLevel(logging.ERROR)
    (
        prompt_embeds,
        pooled_prompt_embeds,
        text_ids,
    ) = pipeline.encode_prompt(
        prompt=prompts,
        prompt_2=None,
        prompt_embeds=prompt_embeds,
        pooled_prompt_embeds=None,
        device=pipeline.device,
        num_images_per_prompt=1,
        max_sequence_length=max_sequence_length,
        lora_scale=None,
    )
    # Turn on warnings
    logger.setLevel(logging.WARNING)
    return prompt_embeds, pooled_prompt_embeds, text_ids

def prepare_text_input_eligen(pipeline: FluxPipeline, eligen_entity_prompts, eligen_entity_masks, orients=None, max_sequence_length=512):
    if eligen_entity_masks is not None:
        entity_prompt_embs, pooled_entity_prompt_embs, _ = prepare_text_input(pipeline, eligen_entity_prompts, orients=orients, max_sequence_length=max_sequence_length) # [n_mask, 512, 4096], [n_mask, 768]
        entity_masks = []
        for mask in eligen_entity_masks:
            if mask.shape[1] != 3:
                mask = mask.permute(0,3,1,2)
            entity_mask = pipeline.image_processor.preprocess(mask).mean(dim=1, keepdim=True) > 0
            entity_mask = entity_mask.to(pipeline.device).to(pipeline.dtype)
            entity_masks.append(entity_mask)
        entity_masks = torch.cat(entity_masks, dim=0) # [n_mask, c, h, w]
    else:
        entity_prompt_embs, pooled_entity_prompt_embs, entity_masks = None, None, None
    eligen_kwargs = {"entity_prompt_emb": entity_prompt_embs, "pooled_entity_prompt_embs": pooled_entity_prompt_embs, "entity_masks": entity_masks}
    return eligen_kwargs

def process_entity_masks(pipe, hidden_states, encoder_hidden_states, eligen_kwargs, txt_ids, use_condition=False, condition_length=None, eligen_depth_attn=True):
    entity_prompt_emb = eligen_kwargs.get("entity_prompt_emb", None)
    entity_masks = eligen_kwargs.get("entity_masks", None)

    repeat_dim = hidden_states.shape[1]
    max_masks = 0
    attention_mask = None
    prompt_embs = [encoder_hidden_states]
    if entity_masks is not None:
        # entity_masks
        # batch_size, max_masks = entity_masks.shape[0], entity_masks.shape[1]
        batch_size, max_masks = 1, entity_masks.shape[0]
        entity_masks = entity_masks.repeat(1, repeat_dim, 1, 1)
        entity_masks = [entity_masks[i, None] for i in range(max_masks)]
        # global mask
        global_mask = torch.ones_like(entity_masks[0]).to(device=hidden_states.device, dtype=hidden_states.dtype)
        entity_masks = entity_masks + [global_mask] # append global to last
        # attention mask
        attention_mask = construct_mask_eligen(
            entity_masks, 
            encoder_hidden_states.shape[1], 
            hidden_states.shape[1], 
            use_condition=use_condition, 
            condition_length=condition_length,
            eligen_depth_attn=eligen_depth_attn,
        )
        attention_mask = attention_mask.to(device=hidden_states.device, dtype=hidden_states.dtype)
        attention_mask = attention_mask.unsqueeze(1)
        # embds: n_masks * b * seq * d
        local_embs = [entity_prompt_emb[i, None] for i in range(max_masks)]
        prompt_embs = local_embs + prompt_embs # append global to last
    prompt_embs = [pipe.context_embedder(prompt_emb) for prompt_emb in prompt_embs]
    prompt_emb = torch.cat(prompt_embs, dim=1)

    # positional embedding
    txt_ids = torch.cat([txt_ids] * (max_masks + 1), dim=0)
    return prompt_emb, txt_ids, attention_mask

def construct_mask_eligen(entity_masks, prompt_seq_len, image_seq_len, use_condition=False, condition_length=None, eligen_depth_attn=True):
    N = len(entity_masks)
    # batch_size = entity_masks[0].shape[0]
    batch_size = 1
    condition_length = condition_length or image_seq_len
    total_seq_len = N * prompt_seq_len + image_seq_len + condition_length if use_condition else N * prompt_seq_len + image_seq_len
    patched_masks = [patchify(entity_masks[i]) for i in range(N)]
    attention_mask = torch.ones((batch_size, total_seq_len, total_seq_len), dtype=torch.bool).to(device=entity_masks[0].device)

    image_start = N * prompt_seq_len
    image_end = N * prompt_seq_len + image_seq_len
    if use_condition:
        cond_start = N * prompt_seq_len + image_seq_len
        cond_end = N * prompt_seq_len + image_seq_len + condition_length
    # prompt-image mask
    for i in range(N):
        prompt_start = i * prompt_seq_len
        prompt_end = (i + 1) * prompt_seq_len
        image_mask = torch.sum(patched_masks[i], dim=-1) > 0
        image_mask = image_mask.unsqueeze(1).repeat(1, prompt_seq_len, 1)
        # image_mask = image_mask.unsqueeze(1).repeat(1, prompt_seq_len, 2) if use_condition else image_mask.unsqueeze(1).repeat(1, prompt_seq_len, 1)
        # prompt update with image
        attention_mask[:, prompt_start:prompt_end, image_start:image_end] = image_mask
        # image update with prompt
        attention_mask[:, image_start:image_end, prompt_start:prompt_end] = image_mask.transpose(1, 2)
        if use_condition:
            # prompt update with cond
            attention_mask[:, prompt_start:prompt_end, cond_start:cond_end] = image_mask if eligen_depth_attn else 0
            # cond update with prompt
            attention_mask[:, cond_start:cond_end, prompt_start:prompt_end] = image_mask.transpose(1, 2) if eligen_depth_attn else 0
    # prompt-prompt mask
    for i in range(N):
        for j in range(N):
            if i != j:
                prompt_start_i = i * prompt_seq_len
                prompt_end_i = (i + 1) * prompt_seq_len
                prompt_start_j = j * prompt_seq_len
                prompt_end_j = (j + 1) * prompt_seq_len
                attention_mask[:, prompt_start_i:prompt_end_i, prompt_start_j:prompt_end_j] = False

    # for i in range(N-1):
    #     prompt_start_i = i * prompt_seq_len
    #     prompt_end_i = (i + 1) * prompt_seq_len
    #     attention_mask[:, prompt_start_i+1:prompt_end_i, prompt_start_i:prompt_start_i+1] = False
    #     attention_mask[:, prompt_start_i:prompt_start_i+1, prompt_start_i+1:prompt_end_i] = False

    attention_mask = attention_mask.float()
    # Image.fromarray((attention_mask[0].cpu().numpy()*255).astype(np.uint8)).resize((1344,1344)).save('mask.png')
    attention_mask[attention_mask == 0] = float('-inf')
    attention_mask[attention_mask == 1] = 0
    return attention_mask

def construct_mask_camera(attention_mask, cam_entity_masks, prompt_seq_len, image_seq_len, cond_seq_len, eligen_camera_attn=True): 
    # import ipdb;ipdb.set_trace()
    attention_mask = torch.where(attention_mask == 0, 1,0)

    N = len(cam_entity_masks) - 1
    batch_size = 1
    total_seq_len = attention_mask.shape[-1]

    cond_start = total_seq_len - cond_seq_len[-1]
    image_start = cond_start - image_seq_len
    image_end = cond_start

    # patched_masks = [patchify(cam_entity_masks[i]) for i in range(N)]
    patched_masks = patchify(cam_entity_masks)

    # prompt-image mask
    for i in range(N):
        cond_seq_len_i = cond_seq_len[i+1] - cond_seq_len[i]
        cond_start_i = cond_start + cond_seq_len[i]
        cond_end_i = cond_start + cond_seq_len[i+1]

        # if eligen_camera_attn: # TODO
        #     prompt_start = i * prompt_seq_len
        #     prompt_end = (i + 1) * prompt_seq_len
        #     attention_mask[:, prompt_start:prompt_end, cond_start:cond_end] = 0

        image_mask = torch.sum(patched_masks[i], dim=-1) > 0
        image_mask = image_mask.unsqueeze(0).unsqueeze(0).repeat(1, cond_seq_len_i, 1)
        # cond update with image
        attention_mask[:, :, cond_start_i:cond_end_i, image_start:image_end] = image_mask
        # image update with cond
        attention_mask[:, :, image_start:image_end, cond_start_i:cond_end_i] = image_mask.transpose(1, 2)

    # cond-cond mask
    for i in range(N):
        for j in range(N):
            if i != j:
                cond_start_i = cond_start + cond_seq_len[i]
                cond_end_i = cond_start + cond_seq_len[i+1]
                cond_start_j = cond_start + cond_seq_len[j]
                cond_end_j = cond_start + cond_seq_len[j+1]
                attention_mask[:, :, cond_start_i:cond_end_i, cond_start_j:cond_end_j] = 0
    
    attention_mask = attention_mask.float()
    # Image.fromarray((attention_mask[0].cpu().numpy()*255).astype(np.uint8)).resize((1344,1344)).save('mask.png')
    attention_mask[attention_mask == 0] = float('-inf')
    attention_mask[attention_mask == 1] = 0
    
    return attention_mask

def patchify(hidden_states):
    hidden_states = rearrange(hidden_states, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2)
    return hidden_states

class Camera(object):
    def __init__(self, entry):
        fx, fy, cx, cy = entry[:4]
        self.fx = fx
        self.fy = fy
        self.cx = cx
        self.cy = cy
        w2c_mat = np.array(entry[6:]).reshape(3, 4)
        w2c_mat_4x4 = np.eye(4)
        w2c_mat_4x4[:3, :] = w2c_mat
        self.w2c_mat = w2c_mat_4x4
        self.c2w_mat = np.linalg.inv(w2c_mat_4x4)

def custom_meshgrid(*args):
    # ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid
    if pver.parse(torch.__version__) < pver.parse('1.10'):
        return torch.meshgrid(*args)
    else:
        return torch.meshgrid(*args, indexing='ij')

def encode_poses(K, c2w, H, W, device, dtype=None, flip_flag=None):
    # c2w: B, V, 4, 4
    # K: B, V, 4
    dtype = dtype if dtype is not None else c2w.dtype

    B, V = K.shape[:2]

    j, i = custom_meshgrid(
        torch.linspace(0, H - 1, H, device=device, dtype=dtype),
        torch.linspace(0, W - 1, W, device=device, dtype=dtype),
    )
    i = i.reshape([1, 1, H * W]).expand([B, V, H * W]) + 0.5          # [B, V, HxW]
    j = j.reshape([1, 1, H * W]).expand([B, V, H * W]) + 0.5          # [B, V, HxW]

    n_flip = torch.sum(flip_flag).item() if flip_flag is not None else 0
    if n_flip > 0:
        j_flip, i_flip = custom_meshgrid(
            torch.linspace(0, H - 1, H, device=device, dtype=dtype),
            torch.linspace(W - 1, 0, W, device=device, dtype=dtype)
        )
        i_flip = i_flip.reshape([1, 1, H * W]).expand(B, 1, H * W) + 0.5
        j_flip = j_flip.reshape([1, 1, H * W]).expand(B, 1, H * W) + 0.5
        i[:, flip_flag, ...] = i_flip
        j[:, flip_flag, ...] = j_flip

    fx, fy, cx, cy = K.chunk(4, dim=-1)     # B,V, 1

    zs = torch.ones_like(i)                 # [B, V, HxW]
    xs = (i - cx) / fx * zs
    ys = (j - cy) / fy * zs
    zs = zs.expand_as(ys)

    directions = torch.stack((xs, ys, zs), dim=-1)              # B, V, HW, 3
    directions = directions / directions.norm(dim=-1, keepdim=True)             # B, V, HW, 3

    rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2)        # B, V, HW, 3
    rays_o = c2w[..., :3, 3]                                        # B, V, 3
    rays_o = rays_o[:, :, None].expand_as(rays_d)                   # B, V, HW, 3
    # c2w @ dirctions
    rays_dxo = torch.cross(rays_o, rays_d)                          # B, V, HW, 3
    plucker = torch.cat([rays_dxo, rays_d], dim=-1)
    plucker = plucker.reshape(B, c2w.shape[1], H, W, 6)             # B, V, H, W, 6
    # plucker = plucker.permute(0, 1, 4, 2, 3)
    return plucker

# def visualize_masks(image, masks, mask_prompts, output_path=None, font_size=35, use_random_colors=False):
def visualize_masks(image, masks, mask_prompts, output_path=None, font_size=20, use_random_colors=False):
    # Create a blank image for overlays
    overlay = Image.new('RGBA', image.size, (0, 0, 0, 0))
    
    colors = [
        (165, 238, 173, 80),
        (76, 102, 221, 80),
        (221, 160, 77, 80),
        (204, 93, 71, 80),
        (145, 187, 149, 80),
        (134, 141, 172, 80),
        (157, 137, 109, 80),
        (153, 104, 95, 80),
        (165, 238, 173, 80),
        (76, 102, 221, 80),
        (221, 160, 77, 80),
        (204, 93, 71, 80),
        (145, 187, 149, 80),
        (134, 141, 172, 80),
        (157, 137, 109, 80),
        (153, 104, 95, 80),
    ]
    # Generate random colors for each mask
    if use_random_colors:
        colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 80) for _ in range(len(masks))]
    
    # Font settings
    try:
        font = ImageFont.truetype("arial", font_size)  # Adjust as needed
    except IOError:
        font = ImageFont.load_default(font_size)

    # Overlay each mask onto the overlay image
    for mask, mask_prompt, color in zip(masks, mask_prompts, colors):
        try:
            # Convert mask to RGBA mode
            mask_rgba = mask.convert('RGBA')
            mask_data = mask_rgba.getdata()
            new_data = [(color if item[:3] == (255, 255, 255) else (0, 0, 0, 0)) for item in mask_data]
            mask_rgba.putdata(new_data)

            # Draw the mask prompt text on the mask
            draw = ImageDraw.Draw(mask_rgba)
            mask_bbox = mask.getbbox()  # Get the bounding box of the mask
            text_position = (mask_bbox[0] + 10, mask_bbox[1] + 10)  # Adjust text position based on mask position
            draw.text(text_position, mask_prompt, fill=(255, 255, 255, 255), font=font)

            # Alpha composite the overlay with this mask
            overlay = Image.alpha_composite(overlay, mask_rgba)
        except Exception as e:
            print(e)
    
    # Composite the overlay onto the original image
    result = Image.alpha_composite(image.convert('RGBA'), overlay)
    
    # Save or display the resulting image
    if output_path is not None: result.save(output_path)

    return result
