Spaces:
Paused
Paused
| import json | |
| import os | |
| import random | |
| from dataclasses import dataclass, field | |
| import cv2 | |
| import numpy as np | |
| import pytorch_lightning as pl | |
| import torch | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| from torch.utils.data import DataLoader, Dataset | |
| from ..utils.config import parse_structured | |
| from ..utils.geometry import ( | |
| get_plucker_embeds_from_cameras, | |
| get_plucker_embeds_from_cameras_ortho, | |
| get_position_map_from_depth, | |
| get_position_map_from_depth_ortho, | |
| ) | |
| from ..utils.typing import * | |
| os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" | |
| def _parse_scene_list_single(scene_list_path: str, root_data_dir: str): | |
| all_scenes = [] | |
| if scene_list_path.endswith(".json"): | |
| with open(scene_list_path) as f: | |
| for p in json.loads(f.read()): | |
| if "/" in p: | |
| all_scenes.append(os.path.join(root_data_dir, p)) | |
| else: | |
| all_scenes.append(os.path.join(root_data_dir, p[:2], p)) | |
| elif scene_list_path.endswith(".txt"): | |
| with open(scene_list_path) as f: | |
| for p in f.readlines(): | |
| p = p.strip() | |
| if "/" in p: | |
| all_scenes.append(os.path.join(root_data_dir, p)) | |
| else: | |
| all_scenes.append(os.path.join(root_data_dir, p[:2], p)) | |
| else: | |
| raise NotImplementedError | |
| return all_scenes | |
| def _parse_scene_list( | |
| scene_list_path: Union[str, List[str]], root_data_dir: Union[str, List[str]] | |
| ): | |
| all_scenes = [] | |
| if isinstance(scene_list_path, str): | |
| scene_list_path = [scene_list_path] | |
| if isinstance(root_data_dir, str): | |
| root_data_dir = [root_data_dir] | |
| for scene_list_path_, root_data_dir_ in zip(scene_list_path, root_data_dir): | |
| all_scenes += _parse_scene_list_single(scene_list_path_, root_data_dir_) | |
| return all_scenes | |
| def _parse_reference_scene_list(reference_scenes: List[str], all_scenes: List[str]): | |
| all_ids = set(scene.split("/")[-1] for scene in all_scenes) | |
| ref_ids = set(scene.split("/")[-1] for scene in reference_scenes) | |
| common_ids = ref_ids.intersection(all_ids) | |
| all_scenes = [scene for scene in all_scenes if scene.split("/")[-1] in common_ids] | |
| all_ids = {scene.split("/")[-1]: idx for idx, scene in enumerate(all_scenes)} | |
| ref_scenes = [ | |
| scene for scene in reference_scenes if scene.split("/")[-1] in all_ids | |
| ] | |
| sorted_ref_scenes = sorted(ref_scenes, key=lambda x: all_ids[x.split("/")[-1]]) | |
| scene2ref = { | |
| scene: ref_scene for scene, ref_scene in zip(all_scenes, sorted_ref_scenes) | |
| } | |
| return all_scenes, scene2ref | |
| class MultiviewDataModuleConfig: | |
| root_dir: Any = "" | |
| scene_list: Any = "" | |
| image_suffix: str = "webp" | |
| background_color: Union[str, float] = "gray" | |
| image_names: List[str] = field(default_factory=lambda: []) | |
| image_modality: str = "render" | |
| num_views: int = 1 | |
| random_view_list: Optional[List[List[int]]] = None | |
| prompt_db_path: Optional[str] = None | |
| return_prompt: bool = False | |
| use_empty_prompt: bool = False | |
| prompt_prefix: Optional[Any] = None | |
| return_one_prompt: bool = True | |
| projection_type: str = "ORTHO" | |
| # source conditions | |
| source_image_modality: Any = "position" | |
| use_camera_space_normal: bool = False | |
| position_offset: float = 0.5 | |
| position_scale: float = 1.0 | |
| plucker_offset: float = 1.0 | |
| plucker_scale: float = 2.0 | |
| # reference image | |
| reference_root_dir: Optional[Any] = None | |
| reference_scene_list: Optional[Any] = None | |
| reference_image_modality: str = "render" | |
| reference_image_names: List[str] = field(default_factory=lambda: []) | |
| reference_augment_resolutions: Optional[List[int]] = None | |
| reference_mask_aug: bool = False | |
| repeat: int = 1 # for debugging purpose | |
| train_indices: Optional[Tuple[Any, Any]] = None | |
| val_indices: Optional[Tuple[Any, Any]] = None | |
| test_indices: Optional[Tuple[Any, Any]] = None | |
| height: int = 768 | |
| width: int = 768 | |
| batch_size: int = 1 | |
| eval_batch_size: int = 1 | |
| num_workers: int = 16 | |
| class MultiviewDataset(Dataset): | |
| def __init__(self, cfg: Any, split: str = "train") -> None: | |
| super().__init__() | |
| assert split in ["train", "val", "test"] | |
| self.cfg: MultiviewDataModuleConfig = cfg | |
| self.all_scenes = _parse_scene_list(self.cfg.scene_list, self.cfg.root_dir) | |
| if ( | |
| self.cfg.reference_root_dir is not None | |
| and self.cfg.reference_scene_list is not None | |
| ): | |
| reference_scenes = _parse_scene_list( | |
| self.cfg.reference_scene_list, self.cfg.reference_root_dir | |
| ) | |
| self.all_scenes, self.reference_scenes = _parse_reference_scene_list( | |
| reference_scenes, self.all_scenes | |
| ) | |
| else: | |
| self.reference_scenes = None | |
| self.split = split | |
| if self.split == "train" and self.cfg.train_indices is not None: | |
| self.all_scenes = self.all_scenes[ | |
| self.cfg.train_indices[0] : self.cfg.train_indices[1] | |
| ] | |
| self.all_scenes = self.all_scenes * self.cfg.repeat | |
| elif self.split == "val" and self.cfg.val_indices is not None: | |
| self.all_scenes = self.all_scenes[ | |
| self.cfg.val_indices[0] : self.cfg.val_indices[1] | |
| ] | |
| elif self.split == "test" and self.cfg.test_indices is not None: | |
| self.all_scenes = self.all_scenes[ | |
| self.cfg.test_indices[0] : self.cfg.test_indices[1] | |
| ] | |
| if self.cfg.prompt_db_path is not None: | |
| self.prompt_db = json.load(open(self.cfg.prompt_db_path)) | |
| else: | |
| self.prompt_db = None | |
| def __len__(self): | |
| return len(self.all_scenes) | |
| def get_bg_color(self, bg_color): | |
| if bg_color == "white": | |
| bg_color = np.array([1.0, 1.0, 1.0], dtype=np.float32) | |
| elif bg_color == "black": | |
| bg_color = np.array([0.0, 0.0, 0.0], dtype=np.float32) | |
| elif bg_color == "gray": | |
| bg_color = np.array([0.5, 0.5, 0.5], dtype=np.float32) | |
| elif bg_color == "random": | |
| bg_color = np.random.rand(3) | |
| elif bg_color == "random_gray": | |
| bg_color = random.uniform(0.3, 0.7) | |
| bg_color = np.array([bg_color] * 3, dtype=np.float32) | |
| elif isinstance(bg_color, float): | |
| bg_color = np.array([bg_color] * 3, dtype=np.float32) | |
| elif isinstance(bg_color, list) or isinstance(bg_color, tuple): | |
| bg_color = np.array(bg_color, dtype=np.float32) | |
| else: | |
| raise NotImplementedError | |
| return bg_color | |
| def load_image( | |
| self, | |
| image: Union[str, Image.Image], | |
| height: int, | |
| width: int, | |
| background_color: torch.Tensor, | |
| rescale: bool = False, | |
| mask_aug: bool = False, | |
| ): | |
| if isinstance(image, str): | |
| image = Image.open(image) | |
| image = image.resize((width, height)) | |
| image = torch.from_numpy(np.array(image)).float() / 255.0 | |
| if mask_aug: | |
| alpha = image[:, :, 3] # Extract alpha channel | |
| h, w = alpha.shape | |
| y_indices, x_indices = torch.where(alpha > 0.5) | |
| if len(y_indices) > 0 and len(x_indices) > 0: | |
| idx = torch.randint(len(y_indices), (1,)).item() | |
| y_center = y_indices[idx].item() | |
| x_center = x_indices[idx].item() | |
| mask_h = random.randint(h // 8, h // 4) | |
| mask_w = random.randint(w // 8, w // 4) | |
| y1 = max(0, y_center - mask_h // 2) | |
| y2 = min(h, y_center + mask_h // 2) | |
| x1 = max(0, x_center - mask_w // 2) | |
| x2 = min(w, x_center + mask_w // 2) | |
| alpha[y1:y2, x1:x2] = 0.0 | |
| image[:, :, 3] = alpha | |
| image = image[:, :, :3] * image[:, :, 3:4] + background_color * ( | |
| 1 - image[:, :, 3:4] | |
| ) | |
| if rescale: | |
| image = image * 2.0 - 1.0 | |
| return image | |
| def load_normal_image( | |
| self, | |
| path, | |
| height, | |
| width, | |
| background_color, | |
| camera_space: bool = False, | |
| c2w: Optional[torch.FloatTensor] = None, | |
| ): | |
| image = Image.open(path).resize((width, height), resample=Image.NEAREST) | |
| image = torch.from_numpy(np.array(image)).float() / 255.0 | |
| alpha = image[:, :, 3:4] | |
| image = image[:, :, :3] | |
| if camera_space: | |
| w2c = torch.linalg.inv(c2w)[:3, :3] | |
| image = ( | |
| F.normalize(((image * 2 - 1)[:, :, None, :] * w2c).sum(-1), dim=-1) | |
| * 0.5 | |
| + 0.5 | |
| ) | |
| image = image * alpha + background_color * (1 - alpha) | |
| return image | |
| def load_depth(self, path, height, width): | |
| depth = cv2.imread(path, cv2.IMREAD_UNCHANGED) | |
| depth = cv2.resize(depth, (width, height), interpolation=cv2.INTER_NEAREST) | |
| depth = torch.from_numpy(depth[..., 0:1]).float() | |
| mask = torch.ones_like(depth) | |
| mask[depth > 1000.0] = 0.0 # depth = 65535 is the invalid value | |
| depth[~(mask > 0.5)] = 0.0 | |
| return depth, mask | |
| def retrieve_prompt(self, scene_dir): | |
| assert self.prompt_db is not None | |
| source_id = os.path.basename(scene_dir) | |
| return self.prompt_db.get(source_id, "") | |
| def __getitem__(self, index): | |
| background_color = torch.as_tensor(self.get_bg_color(self.cfg.background_color)) | |
| scene_dir = self.all_scenes[index] | |
| with open(os.path.join(scene_dir, "meta.json")) as f: | |
| meta = json.load(f) | |
| name2loc = {loc["index"]: loc for loc in meta["locations"]} | |
| # target multi-view images | |
| image_paths = [ | |
| os.path.join( | |
| scene_dir, f"{self.cfg.image_modality}_{f}.{self.cfg.image_suffix}" | |
| ) | |
| for f in self.cfg.image_names | |
| ] | |
| images = [ | |
| self.load_image( | |
| p, | |
| height=self.cfg.height, | |
| width=self.cfg.width, | |
| background_color=background_color, | |
| ) | |
| for p in image_paths | |
| ] | |
| images = torch.stack(images, dim=0).permute(0, 3, 1, 2) | |
| # camera | |
| c2w = [ | |
| torch.as_tensor(name2loc[name]["transform_matrix"]) | |
| for name in self.cfg.image_names | |
| ] | |
| c2w = torch.stack(c2w, dim=0) | |
| if self.cfg.projection_type == "PERSP": | |
| camera_angle_x = ( | |
| meta.get("camera_angle_x", None) | |
| or meta["locations"][0]["camera_angle_x"] | |
| ) | |
| focal_length = 0.5 * self.cfg.width / np.tan(0.5 * camera_angle_x) | |
| intrinsics = ( | |
| torch.as_tensor( | |
| [ | |
| [focal_length, 0.0, 0.5 * self.cfg.width], | |
| [0.0, focal_length, 0.5 * self.cfg.height], | |
| [0.0, 0.0, 1.0], | |
| ] | |
| ) | |
| .unsqueeze(0) | |
| .float() | |
| .repeat(len(self.cfg.image_names), 1, 1) | |
| ) | |
| elif self.cfg.projection_type == "ORTHO": | |
| ortho_scale = ( | |
| meta.get("ortho_scale", None) or meta["locations"][0]["ortho_scale"] | |
| ) | |
| # source conditions | |
| source_image_modality = self.cfg.source_image_modality | |
| if isinstance(source_image_modality, str): | |
| source_image_modality = [source_image_modality] | |
| source_images = [] | |
| for modality in source_image_modality: | |
| if modality == "position": | |
| depth_masks = [ | |
| self.load_depth( | |
| os.path.join(scene_dir, f"depth_{f}.exr"), | |
| self.cfg.height, | |
| self.cfg.width, | |
| ) | |
| for f in self.cfg.image_names | |
| ] | |
| depths = torch.stack([d for d, _ in depth_masks]) | |
| masks = torch.stack([m for _, m in depth_masks]) | |
| c2w_ = c2w.clone() | |
| c2w_[:, :, 1:3] *= -1 | |
| if self.cfg.projection_type == "PERSP": | |
| position_maps = get_position_map_from_depth( | |
| depths, | |
| masks, | |
| intrinsics, | |
| c2w_, | |
| image_wh=(self.cfg.width, self.cfg.height), | |
| ) | |
| elif self.cfg.projection_type == "ORTHO": | |
| position_maps = get_position_map_from_depth_ortho( | |
| depths, | |
| masks, | |
| c2w_, | |
| ortho_scale, | |
| image_wh=(self.cfg.width, self.cfg.height), | |
| ) | |
| position_maps = ( | |
| (position_maps + self.cfg.position_offset) / self.cfg.position_scale | |
| ).clamp(0.0, 1.0) | |
| source_images.append(position_maps) | |
| elif modality == "normal": | |
| normal_maps = [ | |
| self.load_normal_image( | |
| os.path.join( | |
| scene_dir, f"{modality}_{f}.{self.cfg.image_suffix}" | |
| ), | |
| height=self.cfg.height, | |
| width=self.cfg.width, | |
| background_color=background_color, | |
| camera_space=self.cfg.use_camera_space_normal, | |
| c2w=c, | |
| ) | |
| for c, f in zip(c2w, self.cfg.image_names) | |
| ] | |
| source_images.append(torch.stack(normal_maps, dim=0)) | |
| elif modality == "plucker": | |
| if self.cfg.projection_type == "ORTHO": | |
| plucker_embed = get_plucker_embeds_from_cameras_ortho( | |
| c2w, [ortho_scale] * len(c2w), self.cfg.width | |
| ) | |
| elif self.cfg.projection_type == "PERSP": | |
| plucker_embed = get_plucker_embeds_from_cameras( | |
| c2w, [camera_angle_x] * len(c2w), self.cfg.width | |
| ) | |
| else: | |
| raise NotImplementedError | |
| plucker_embed = plucker_embed.permute(0, 2, 3, 1) | |
| plucker_embed = ( | |
| (plucker_embed + self.cfg.plucker_offset) / self.cfg.plucker_scale | |
| ).clamp(0.0, 1.0) | |
| source_images.append(plucker_embed) | |
| else: | |
| raise NotImplementedError | |
| source_images = torch.cat(source_images, dim=-1).permute(0, 3, 1, 2) | |
| rv = {"rgb": images, "c2w": c2w, "source_rgb": source_images} | |
| num_images = len(self.cfg.image_names) | |
| # prompt | |
| if self.cfg.return_prompt: | |
| if self.cfg.use_empty_prompt: | |
| prompt = "" | |
| else: | |
| prompt = self.retrieve_prompt(scene_dir) | |
| prompts = [prompt] * num_images | |
| if self.cfg.prompt_prefix is not None: | |
| prompt_prefix = self.cfg.prompt_prefix | |
| if isinstance(prompt_prefix, str): | |
| prompt_prefix = [prompt_prefix] * num_images | |
| for i, prompt in enumerate(prompts): | |
| prompts[i] = f"{prompt_prefix[i]} {prompt}" | |
| if self.cfg.return_one_prompt: | |
| rv.update({"prompts": prompts[0]}) | |
| else: | |
| rv.update({"prompts": prompts}) | |
| # reference image | |
| if self.reference_scenes is not None: | |
| reference_scene_dir = self.reference_scenes[scene_dir] | |
| reference_image_paths = [ | |
| os.path.join( | |
| reference_scene_dir, | |
| f"{self.cfg.reference_image_modality}_{f}.{self.cfg.image_suffix}", | |
| ) | |
| for f in self.cfg.reference_image_names | |
| ] | |
| reference_image_path = random.choice(reference_image_paths) | |
| if self.cfg.reference_augment_resolutions is None: | |
| reference_image = self.load_image( | |
| reference_image_path, | |
| height=self.cfg.height, | |
| width=self.cfg.width, | |
| background_color=background_color, | |
| mask_aug=self.cfg.reference_mask_aug, | |
| ).permute(2, 0, 1) | |
| rv.update({"reference_rgb": reference_image}) | |
| else: | |
| random_resolution = random.choice( | |
| self.cfg.reference_augment_resolutions | |
| ) | |
| reference_image_ = Image.open(reference_image_path).resize( | |
| (random_resolution, random_resolution) | |
| ) | |
| reference_image = self.load_image( | |
| reference_image_, | |
| height=self.cfg.height, | |
| width=self.cfg.width, | |
| background_color=background_color, | |
| mask_aug=self.cfg.reference_mask_aug, | |
| ).permute(2, 0, 1) | |
| rv.update({"reference_rgb": reference_image}) | |
| return rv | |
| def collate(self, batch): | |
| batch = torch.utils.data.default_collate(batch) | |
| pack = lambda t: t.view(-1, *t.shape[2:]) | |
| if self.cfg.random_view_list is not None: | |
| indices = random.choice(self.cfg.random_view_list) | |
| else: | |
| indices = list(range(self.cfg.num_views)) | |
| num_views = len(indices) | |
| for k in batch.keys(): | |
| if k in ["rgb", "source_rgb", "c2w"]: | |
| batch[k] = batch[k][:, indices] | |
| batch[k] = pack(batch[k]) | |
| for k in ["prompts"]: | |
| if not self.cfg.return_one_prompt: | |
| batch[k] = [item for pair in zip(*batch[k]) for item in pair] | |
| batch.update( | |
| { | |
| "num_views": num_views, | |
| # For SDXL | |
| "original_size": (self.cfg.height, self.cfg.width), | |
| "target_size": (self.cfg.height, self.cfg.width), | |
| "crops_coords_top_left": (0, 0), | |
| } | |
| ) | |
| return batch | |
| class MultiviewDataModule(pl.LightningDataModule): | |
| cfg: MultiviewDataModuleConfig | |
| def __init__(self, cfg: Optional[Union[dict, DictConfig]] = None) -> None: | |
| super().__init__() | |
| self.cfg = parse_structured(MultiviewDataModuleConfig, cfg) | |
| def setup(self, stage=None) -> None: | |
| if stage in [None, "fit"]: | |
| self.train_dataset = MultiviewDataset(self.cfg, "train") | |
| if stage in [None, "fit", "validate"]: | |
| self.val_dataset = MultiviewDataset(self.cfg, "val") | |
| if stage in [None, "test", "predict"]: | |
| self.test_dataset = MultiviewDataset(self.cfg, "test") | |
| def prepare_data(self): | |
| pass | |
| def train_dataloader(self) -> DataLoader: | |
| return DataLoader( | |
| self.train_dataset, | |
| batch_size=self.cfg.batch_size, | |
| num_workers=self.cfg.num_workers, | |
| shuffle=True, | |
| collate_fn=self.train_dataset.collate, | |
| ) | |
| def val_dataloader(self) -> DataLoader: | |
| return DataLoader( | |
| self.val_dataset, | |
| batch_size=self.cfg.eval_batch_size, | |
| num_workers=self.cfg.num_workers, | |
| shuffle=False, | |
| collate_fn=self.val_dataset.collate, | |
| ) | |
| def test_dataloader(self) -> DataLoader: | |
| return DataLoader( | |
| self.test_dataset, | |
| batch_size=self.cfg.eval_batch_size, | |
| num_workers=self.cfg.num_workers, | |
| shuffle=False, | |
| collate_fn=self.test_dataset.collate, | |
| ) | |
| def predict_dataloader(self) -> DataLoader: | |
| return self.test_dataloader() | |