import torch
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw
import numpy as np
from packaging import version as pver

from src.flux.pipeline_tools import Camera

# def custom_meshgrid(x, y):
#     return torch.meshgrid(x, y, indexing='ij')

def custom_meshgrid(*args):
    # ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid
    if pver.parse(torch.__version__) < pver.parse('1.10'):
        return torch.meshgrid(*args)
    else:
        return torch.meshgrid(*args, indexing='ij')

def encode_poses(K, c2w, H, W, device, dtype=None, flip_flag=None):
    dtype = dtype if dtype is not None else c2w.dtype

    B, V = K.shape[:2]

    j, i = custom_meshgrid(
        torch.linspace(0, H - 1, H, device=device, dtype=dtype),
        torch.linspace(0, W - 1, W, device=device, dtype=dtype),
    )
    i = i.reshape([1, 1, H * W]).expand([B, V, H * W]) + 0.5
    j = j.reshape([1, 1, H * W]).expand([B, V, H * W]) + 0.5

    n_flip = torch.sum(flip_flag).item() if flip_flag is not None else 0
    if n_flip > 0:
        j_flip, i_flip = custom_meshgrid(
            torch.linspace(0, H - 1, H, device=device, dtype=dtype),
            torch.linspace(W - 1, 0, W, device=device, dtype=dtype)
        )
        i_flip = i_flip.reshape([1, 1, H * W]).expand(B, 1, H * W) + 0.5
        j_flip = j_flip.reshape([1, 1, H * W]).expand(B, 1, H * W) + 0.5
        i[:, flip_flag, ...] = i_flip
        j[:, flip_flag, ...] = j_flip

    fx, fy, cx, cy = K.chunk(4, dim=-1)

    zs = torch.ones_like(i)
    xs = (i - cx) / fx * zs
    ys = (j - cy) / fy * zs
    zs = zs.expand_as(ys)

    directions = torch.stack((xs, ys, zs), dim=-1)
    directions = directions / directions.norm(dim=-1, keepdim=True)

    rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2)
    rays_o = c2w[..., :3, 3]
    rays_o = rays_o[:, :, None].expand_as(rays_d)
    rays_dxo = torch.linalg.cross(rays_o, rays_d)
    plucker = torch.cat([rays_dxo, rays_d], dim=-1)
    plucker = plucker.reshape(B, c2w.shape[1], H, W, 6)
    return plucker

def visualize_plucker(plucker, H, W, step=10):
    fig, ax = plt.subplots()
    for i in range(0, H, step):
        for j in range(0, W, step):
            d = plucker[0, 0, i, j, 3:].cpu().numpy()
            ax.arrow(j, i, d[0], d[1], head_width=2, head_length=2, fc='red', ec='red')
    ax.set_xlim(0, W)
    ax.set_ylim(H, 0)
    plt.show()

def visualize_plucker_pil(plucker, output_path='plucker_visualization.png', step=20, arrow_scale=10):
    """
    使用 PIL 可视化 Plücker embedding 并保存为 PNG 图片。

    参数：
    - plucker: Tensor of shape (B, V, H, W, 6)
    - output_path: 输出图片路径
    - step: 网格的步长，用于稀疏绘制箭头
    - arrow_scale: 箭头长度缩放因子
    """
    B, V, H, W, _ = plucker.shape
    plucker = plucker[0, 0]  # 选择第一个 batch 和第一个 view

    # 创建白色背景的图像
    img = Image.new('RGB', (W, H), color='white')
    draw = ImageDraw.Draw(img)

    # 使用 Plücker embedding 的方向部分 (后三个元素)
    directions = plucker[..., 3:6]  # H, W, 3

    # 获取 x 和 y 方向的分量，用于绘制2D箭头
    # 假设我们使用 x 和 y 分量进行可视化
    directions_xy = directions[..., :2]  # H, W, 2

    # 归一化方向向量以保证箭头长度一致
    norms = torch.norm(directions_xy, dim=-1, keepdim=True) + 1e-6
    directions_xy_norm = directions_xy / norms

    # 将 tensor 转换为 numpy for easier indexing
    directions_xy_norm = directions_xy_norm.cpu().numpy()

    for h in range(0, H, step):
        for w in range(0, W, step):
            dx, dy = directions_xy_norm[h, w]
            start = (w, h)
            end = (w + dx * arrow_scale, h + dy * arrow_scale)
            # 绘制箭头
            draw_arrow(draw, start, end, fill='red', arrow=True)

    # 保存图像为 PNG
    img.save(output_path)
    print(f"Plücker embedding visualization saved to {output_path}")

def draw_arrow(draw, start, end, fill='red', arrow=True, arrow_ratio=0.3, arrow_size=10):
    """
    在 PIL ImageDraw 对象上绘制带箭头的线段。

    参数：
    - draw: ImageDraw 对象
    - start: 起点 (x, y)
    - end: 终点 (x, y)
    - fill: 颜色
    - arrow: 是否绘制箭头
    - arrow_ratio: 箭头比例
    - arrow_size: 箭头大小
    """
    draw.line([start, end], fill=fill, width=1)
    if arrow:
        # 计算箭头的方向
        import math
        angle = math.atan2(end[1] - start[1], end[0] - start[0])
        # 两个箭头边
        angle1 = angle + math.pi - arrow_ratio * math.pi
        angle2 = angle + math.pi + arrow_ratio * math.pi
        x1 = end[0] + arrow_size * math.cos(angle1)
        y1 = end[1] + arrow_size * math.sin(angle1)
        x2 = end[0] + arrow_size * math.cos(angle2)
        y2 = end[1] + arrow_size * math.sin(angle2)
        draw.polygon([end, (x1, y1), (x2, y2)], fill=fill)

if __name__ == "__main__":
    device = 'cpu'
    dtype = torch.float32
    target_size = 512
    poses = [
        "0.532139961 0.946026558 0.500000000 0.500000000 0.000000000 0.000000000 0.978989959 -0.010294991 -0.203648433 -0.000762398 -0.007398812 0.996273518 -0.085932352 -0.031535059 0.203774214 0.085633665 0.975265563 -0.153683138",
        "0.591079453 1.050807901 0.500000000 0.500000000 0.000000000 0.000000000 0.999850392 0.008867561 -0.014851474 0.359474093 -0.009553785 0.998859048 -0.046790831 0.313312825 0.014419609 0.046925720 0.998794317 -0.689018860",
        "0.501289166 0.891180703 0.500000000 0.500000000 0.000000000 0.000000000 0.998422921 0.009509329 0.055328209 -0.088786468 -0.006474003 0.998477280 -0.054783192 -0.013369506 -0.055764910 0.054338600 0.996964216 -1.103852641",
    ]
    for i, pose in enumerate(poses):
        pose = pose.strip().split(' ')
        cam_param = Camera([float(x) for x in pose])
        intrinsics = torch.tensor(
            [
                cam_param.fx * target_size,
                cam_param.fy * target_size,
                cam_param.cx * target_size,
                cam_param.cy * target_size
            ], 
            device=device,
            dtype=dtype,
        )[None, None]
        
        c2w = torch.tensor(
            cam_param.c2w_mat, 
            device=device, 
            dtype=dtype,
        )[None, None]

        B, V, H, W = 1, 1, target_size, target_size  # Example dimensions

        K = intrinsics
        c2w = c2w

        plucker = encode_poses(K, c2w, H, W, device, dtype)

        # 可视化并保存为 PNG
        visualize_plucker_pil(plucker, output_path=f'plucker_visualization_{i}.png', step=30, arrow_scale=50)
