from re import A
import torch
import torch.nn as nn

class CameraAdapter(nn.Module):
    def __init__(self, input_channels, camera_dim=9):
        super(CameraAdapter, self).__init__()
        self.camera_mlp = nn.Sequential(
            nn.Linear(camera_dim, 1024),
            nn.ReLU(),
            nn.Linear(1024, input_channels),
            nn.ReLU(),
            nn.Linear(input_channels, input_channels)
        )

    def forward(self, x):
        x_proj = self.camera_mlp(x)

        return x_proj

module_path_0 = "runs_pose/20250413-131142_eligen_pose_split_eligen_pose_None/ckpt/500/cam_embedder.pth"
state_0 = torch.load(module_path_0)

# module_path_1 = "runs_pose/20250413-131142_eligen_pose_split_eligen_pose_None/ckpt/1000/cam_embedder.pth"
# state_1 = torch.load(module_path_1)

# for k in state_0.keys():
#     print(f'{k}: ', (state_0[k]==state_1[k]).all())

cam_embedder = CameraAdapter(4096,3)
cam_embedder.load_state_dict(state_0)

a = torch.zeros((5,3))
a[:,0] = torch.tensor(list(range(0, 361, 90)))
a = [[int(o) for o in orient] for orient in a]
a = [[((o0+ 180)%360-180)/360, o1/90, o2/90] for o0, o1, o2 in a]
a = cam_embedder(torch.tensor(a))

print(1)
