import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
import torch
import numpy as np
import PIL.Image as Image
import json


# PyTorch3D 的相关模块
from pytorch3d.structures import Meshes, join_meshes_as_scene
from pytorch3d.renderer import (
    FoVPerspectiveCameras,     
    look_at_view_transform,
    RasterizationSettings,
    MeshRenderer,
    MeshRasterizer,
    SoftPhongShader,           
    TexturesVertex,
)

# -----------------------
# 0. 准备工作 & 设备检查
# -----------------------
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print(f"Using device: {device}")

# ==========================
# 1. 基础形状创建函数
# ==========================
def create_cube(center=[0, 0, 0], size=1.0, color=[0.2, 0.3, 0.8]):
    """
    生成一个立方体 Mesh
    {
        "type": "cube",
        "center": [0.0, 0.0, 0.0],
        "size": 0.2,
        "color": [0.2, 0.3, 0.8]
    """
    half = size / 2.0
    vertices = torch.tensor([
        [-half, -half, -half],
        [ half, -half, -half],
        [ half,  half, -half],
        [-half,  half, -half],
        [-half, -half,  half],
        [ half, -half,  half],
        [ half,  half,  half],
        [-half,  half,  half],
    ], dtype=torch.float32)

    faces = torch.tensor([
        [0, 1, 2], [0, 2, 3],
        [4, 6, 5], [4, 7, 6],
        [0, 4, 5], [0, 5, 1],
        [1, 5, 6], [1, 6, 2],
        [2, 6, 7], [2, 7, 3],
        [3, 7, 4], [3, 4, 0],
    ], dtype=torch.int64)

    vertices += torch.tensor(center, dtype=torch.float32)

    colors = torch.tensor([color], dtype=torch.float32).repeat(len(vertices), 1)
    textures = TexturesVertex(verts_features=[colors])

    mesh = Meshes(
        verts=[vertices.to(device)],
        faces=[faces.to(device)],
        textures=textures.to(device),
    )
    return mesh


def create_sphere(center=[0, 0, 0], radius=1.0, color=[0.8, 0.3, 0.2], subdiv=10):
    """
    生成一个近似球体 Mesh
    {
        "type": "sphere",
        "center": [-0.2, 0.0, -0.3],
        "radius": 0.3,
        "color": [0.8, 0.3, 0.2],
        "subdiv": 20
    }
    """
    theta_vals = torch.linspace(0, np.pi, subdiv)
    phi_vals   = torch.linspace(0, 2 * np.pi, subdiv)

    vertices_list = []
    for theta in theta_vals:
        for phi in phi_vals:
            x = radius * torch.sin(theta) * torch.cos(phi)
            y = radius * torch.sin(theta) * torch.sin(phi)
            z = radius * torch.cos(theta)
            vertices_list.append([x, y, z])
    vertices = torch.tensor(vertices_list, dtype=torch.float32)

    # 构造面：相邻格子拆三角形
    faces_list = []
    def idx(i, j):
        return i * len(phi_vals) + j

    for i in range(subdiv - 1):
        for j in range(subdiv - 1):
            faces_list.append([idx(i, j), idx(i, j+1),   idx(i+1, j)])
            faces_list.append([idx(i+1, j), idx(i, j+1), idx(i+1, j+1)])

    faces = torch.tensor(faces_list, dtype=torch.int64)

    # 平移
    vertices += torch.tensor(center, dtype=torch.float32)

    # 颜色
    colors = torch.tensor([color], dtype=torch.float32).repeat(len(vertices), 1)
    textures = TexturesVertex(verts_features=[colors])

    mesh = Meshes(
        verts=[vertices.to(device)],
        faces=[faces.to(device)],
        textures=textures.to(device),
    )
    return mesh


def create_cuboid(center=[0, 0, 0], dims=[1.0, 0.5, 0.2], theta=0, color=[0.2, 0.8, 0.3]):
    """
    生成一个长方体 Mesh (dims = [x_len, y_len, z_len])
    {
        "type": "cuboid",
        "center": [0.2, 0.1, -0.2],
        "dims": [0.4, 0.6, 0.4],
        "color": [0.3, 0.8, 0.3]
    }
    """
    lx, ly, lz = dims
    hx, hy, hz = lx / 2.0, ly / 2.0, lz / 2.0

    # verts = torch.tensor([
    #     [-hx, -hy, -hz],
    #     [ hx, -hy, -hz],
    #     [ hx,  hy, -hz],
    #     [-hx,  hy, -hz],
    #     [-hx, -hy,  hz],
    #     [ hx, -hy,  hz],
    #     [ hx,  hy,  hz],
    #     [-hx,  hy,  hz],
    # ], dtype=torch.float32)
    verts = torch.tensor([
        [-hx, 0, -lz],
        [ hx, 0, -lz],
        [ hx,  ly, -lz],
        [-hx,  ly, -lz],
        [-hx, 0,  0],
        [ hx, 0,  0],
        [ hx,  ly,  0],
        [-hx,  ly,  0],
    ], dtype=torch.float32)

    faces = torch.tensor([
        [0, 1, 2], [0, 2, 3],
        [4, 6, 5], [4, 7, 6],
        [0, 4, 5], [0, 5, 1],
        [1, 5, 6], [1, 6, 2],
        [2, 6, 7], [2, 7, 3],
        [3, 7, 4], [3, 4, 0],
    ], dtype=torch.int64)

    # 构造绕Y轴的旋转矩阵（3D）
    theta = np.deg2rad(theta)
    R = np.array([
        [np.cos(theta), 0, np.sin(theta)],
        [0, 1, 0],
        [-np.sin(theta), 0, np.cos(theta)]
    ])

    # 缩放、旋转、平移
    verts = verts * torch.tensor(dims, dtype=torch.float32)  # 缩放
    verts = verts @ torch.tensor(R.T, dtype=torch.float32)  # 旋转
    verts += torch.tensor(center, dtype=torch.float32)  # 平移

    # 颜色
    colors = torch.tensor([color], dtype=torch.float32).repeat(len(verts), 1)
    textures = TexturesVertex(verts_features=[colors])

    mesh = Meshes(
        verts=[verts.to(device)],
        faces=[faces.to(device)],
        textures=textures.to(device),
    )
    return mesh


# ==========================
# 2. 用字典管理不同形状参数 + 统一创建函数
# ==========================
def create_shape(shape_dict):
    """
    根据字典中的 'type' 字段和其他参数，动态调用相应的形状创建函数。
    shape_dict 可能包含:
      - type: 'cube'/'sphere'/'cuboid'
      - 其余关键字: center, size, radius, dims, color, subdiv 等
    """
    shape_type = shape_dict.get("type", "cuboid").lower()
    
    if shape_type == "cube":
        return create_cube(
            center=shape_dict.get("center", [0,0,0]),
            size=shape_dict.get("size", 1.0),
            color=shape_dict.get("color", [0.2, 0.3, 0.8]),
        )
    elif shape_type == "sphere":
        return create_sphere(
            center=shape_dict.get("center", [0,0,0]),
            radius=shape_dict.get("radius", 0.5),
            color=shape_dict.get("color", [0.8, 0.3, 0.2]),
            subdiv=shape_dict.get("subdiv", 10),
        )
    elif shape_type == "cuboid":
        return create_cuboid(
            center=shape_dict.get("center", [0,0,0]),
            dims=shape_dict.get("dims", [1.0, 0.5, 0.2]),
            theta=shape_dict.get("theta", 0),
            color=shape_dict.get("color", [0.2, 0.8, 0.3]),
        )
    else:
        raise ValueError(f"Unsupported shape type: {shape_type}")


# ================================
# 3. 构建场景中的 Mesh，渲染并成像
# ================================
def render_scene_from_shapes(
    shapes,
    distance=1.0,
    elev=0.0,
    azim=0.0,
    fov=60.0,
    image_size=512,
    depth_range=[0.0, 10.0],
    out_rgb="scene.png",
    out_depth="scene_depth.png",
    save_image=True,
    save_depth=True,
    apply_blur=False,
    kernal_size=5,
    blur_sigma=1.0
):
    """
    接收多个形状字典组成的列表 shapes，构建场景并渲染RGB、Depth图。

    :param shapes:  List[Dict], 每个字典定义一个形状(类型及其参数)
    :param distance: 相机到原点的距离
    :param elev:     相机仰角
    :param azim:     相机方位角
    :param fov:      相机视角(度)
    :param image_size:  输出图像分辨率
    :param depth_range: (min_z, max_z)，用于深度图映射
    :param out_rgb:     RGB图保存路径
    :param out_depth:   深度图保存路径
    :param show_depth:  是否使用 matplotlib 显示深度图
    """
    # 3.1 构造场景 Mesh
    mesh_list = []
    for shape_info in shapes:
        mesh = create_shape(shape_info)
        mesh_list.append(mesh)
    scene_mesh = join_meshes_as_scene(mesh_list)

    # 3.2 设置相机
    # 假设我们有一个相机位姿 R, T (从世界坐标系到相机坐标系)
    # 如果你已经确定了外参，可以用下述方式自己设置
    # 例如 R: (3x3) 旋转矩阵, T: (3,) 平移向量
    # 这里只是示例创建一个 look_at_view_transform，也可以用你自己的 R, T 替换
    R, T = look_at_view_transform(dist=distance, elev=elev, azim=azim, device=device)
    # 如果你已知 R, T，可直接写成：
    # R = torch.tensor(..., dtype=torch.float32)
    # T = torch.tensor(..., dtype=torch.float32)

    # 相机内参部分（示例：FOV 相机）
    # 如果你有自己定义的 fx, fy, cx, cy 等，可以使用 PerspectiveCameras
    # 这里用一个简化的 FoVPerspectiveCameras
    cameras = FoVPerspectiveCameras(
        device=device,
        R=R,
        T=T,
        fov=fov,
    )

    # 3.3 设置渲染器
    raster_settings = RasterizationSettings(
        image_size=image_size,
        blur_radius=0.0,
        faces_per_pixel=1,
    )
    rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
    shader = SoftPhongShader(device=device, cameras=cameras)
    renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)

    # 3.4 渲染: 获取 zbuf + RGB
    fragments = rasterizer(scene_mesh)
    zbuf = fragments.zbuf[0, ..., 0]
    zbuf[zbuf == -1] = 1e10  # 对不可见像素做一个极大值

    images = renderer(scene_mesh)
    images = images[0, ..., :3].cpu().numpy()  # 取 RGB
    rgb_image = Image.fromarray((images * 255).astype(np.uint8))
    if save_image: 
        rgb_image.save(out_rgb)
        print(f"保存了彩色渲染图: {out_rgb}")

    # 3.5 生成深度图
    min_z, max_z = depth_range[0], depth_range[1]
    depth_clamped = torch.clamp(zbuf, min_z, max_z)
    depth_mapped  = 1.0 - (depth_clamped - min_z) / (max_z - min_z)  # 映射到 [1, 0]
    depth_image   = depth_mapped.cpu().numpy()

    # 3.6 对深度图进行低通滤波（可选）
    if apply_blur:
        import cv2
        depth_image_blurred = cv2.GaussianBlur(depth_image, ksize=(kernal_size, kernal_size), sigmaX=blur_sigma)
        depth_image_final = depth_image_blurred
    else:
        depth_image_final = depth_image
    print(depth_image_final.max(), depth_image_final.min())
    depth_image_pil = Image.fromarray((depth_image_final * 255).astype(np.uint8)).convert("RGB")
    
    # 3.7 保存深度图
    if save_depth: 
        depth_image_pil.save(out_depth)
        print(f"保存了深度图(低通滤波={apply_blur}): {out_depth}")
    
    return rgb_image, depth_image_pil


# ================================
# 4. Demo: 传入 shapes 列表 直接调用
# ================================

if __name__ == "__main__":
    src_dir = 'test/human_data/6'
    tar_dir = 'test/human_data/26'
    os.makedirs(tar_dir, exist_ok=True)

    mode = 'create' # ["create", "add"]

    distance, elev, azim = 1.0, 0.0, 0.0
    # 形状列表(每个元素都是一个字典)
    shapes = [
        {
            "center": [
                -0.0,
                -0.0,
                -8
            ],
            "dims": [
                1.6,
                1.6,
                1.0
            ],
            "theta": 0,
            "name": "cat"
        },
        {
            "center": [
                0.0,
                0.0,
                -1.5
            ],
            "dims": [
                1.2,
                1.35,
                1.0
            ],
            "theta": 0,
            "name": "dog"
        }
    ]

    if mode == 'add':
        with open(f"{src_dir}/data.json", "r") as f:
            data = json.load(f)
        distance, elev, azim = data['distance'], data['elev'], data['azim']
        shapes_ori = data['shapes']
        shapes = shapes_ori + shapes

    for i, shape in enumerate(shapes):
        render_scene_from_shapes(
            [shape],
            distance=distance,
            elev=elev,
            azim=azim,
            fov=60.0,
            image_size=512,
            depth_range=(1.0, 10.0),
            out_rgb=f"{tar_dir}/render_{i}.png",
            out_depth=f"{tar_dir}/render_depth_{i}.png",
            save_image=False,
            save_depth=True,
        )
    
    with open(f"{tar_dir}/data.json", "w") as f:
        json.dump({
            "distance": distance,
            "elev": elev,
            "azim": azim,
            "shapes": shapes,
        }, f, indent=4, ensure_ascii=False)

    for i, shape in enumerate(shapes):
        shape['dims'][-1] = 0.01
        render_scene_from_shapes(
            [shape],
            distance=distance,
            elev=elev,
            azim=azim,
            fov=60.0,
            image_size=512,
            depth_range=(1.0, 10.0),
            out_rgb=f"{tar_dir}/render_{i}.png",
            out_depth=f"{tar_dir}/render_bbox_{i}.png",
            save_image=False,
            save_depth=True,
        )

