Spaces:
Runtime error
Runtime error
| from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union | |
| import warnings | |
| import os | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from diffusers.models.modeling_utils import ModelMixin | |
| import PIL | |
| from einops import rearrange, repeat | |
| import numpy as np | |
| import torch | |
| import torch.nn.init as init | |
| from diffusers.models.controlnet import ControlNetModel | |
| from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel | |
| from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers | |
| from diffusers.utils.torch_utils import is_compiled_module | |
| class ControlnetPredictor(object): | |
| def __init__(self, controlnet_model_path: str, *args, **kwargs): | |
| """Controlnet 推断函数,用于提取 controlnet backbone的emb,避免训练时重复抽取 | |
| Controlnet inference predictor, used to extract the emb of the controlnet backbone to avoid repeated extraction during training | |
| Args: | |
| controlnet_model_path (str): controlnet 模型路径. controlnet model path. | |
| """ | |
| super(ControlnetPredictor, self).__init__(*args, **kwargs) | |
| self.controlnet = ControlNetModel.from_pretrained( | |
| controlnet_model_path, | |
| ) | |
| def prepare_image( | |
| self, | |
| image, # b c t h w | |
| width, | |
| height, | |
| batch_size, | |
| num_images_per_prompt, | |
| device, | |
| dtype, | |
| do_classifier_free_guidance=False, | |
| guess_mode=False, | |
| ): | |
| if height is None: | |
| height = image.shape[-2] | |
| if width is None: | |
| width = image.shape[-1] | |
| width, height = ( | |
| x - x % self.control_image_processor.vae_scale_factor | |
| for x in (width, height) | |
| ) | |
| image = rearrange(image, "b c t h w-> (b t) c h w") | |
| image = torch.from_numpy(image).to(dtype=torch.float32) / 255.0 | |
| image = ( | |
| torch.nn.functional.interpolate( | |
| image, | |
| size=(height, width), | |
| mode="bilinear", | |
| ), | |
| ) | |
| do_normalize = self.control_image_processor.config.do_normalize | |
| if image.min() < 0: | |
| warnings.warn( | |
| "Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] " | |
| f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]", | |
| FutureWarning, | |
| ) | |
| do_normalize = False | |
| if do_normalize: | |
| image = self.control_image_processor.normalize(image) | |
| image_batch_size = image.shape[0] | |
| if image_batch_size == 1: | |
| repeat_by = batch_size | |
| else: | |
| # image batch size is the same as prompt batch size | |
| repeat_by = num_images_per_prompt | |
| image = image.repeat_interleave(repeat_by, dim=0) | |
| image = image.to(device=device, dtype=dtype) | |
| if do_classifier_free_guidance and not guess_mode: | |
| image = torch.cat([image] * 2) | |
| return image | |
| def __call__( | |
| self, | |
| batch_size: int, | |
| device: str, | |
| dtype: torch.dtype, | |
| timesteps: List[float], | |
| i: int, | |
| scheduler: KarrasDiffusionSchedulers, | |
| prompt_embeds: torch.Tensor, | |
| do_classifier_free_guidance: bool = False, | |
| # 2b co t ho wo | |
| latent_model_input: torch.Tensor = None, | |
| # b co t ho wo | |
| latents: torch.Tensor = None, | |
| # b c t h w | |
| image: Union[ | |
| torch.FloatTensor, | |
| PIL.Image.Image, | |
| np.ndarray, | |
| List[torch.FloatTensor], | |
| List[PIL.Image.Image], | |
| List[np.ndarray], | |
| ] = None, | |
| # b c t(1) hi wi | |
| controlnet_condition_frames: Optional[torch.FloatTensor] = None, | |
| # b c t ho wo | |
| controlnet_latents: Union[torch.FloatTensor, np.ndarray] = None, | |
| # b c t(1) ho wo | |
| controlnet_condition_latents: Optional[torch.FloatTensor] = None, | |
| height: Optional[int] = None, | |
| width: Optional[int] = None, | |
| num_videos_per_prompt: Optional[int] = 1, | |
| return_dict: bool = True, | |
| controlnet_conditioning_scale: Union[float, List[float]] = 1.0, | |
| guess_mode: bool = False, | |
| control_guidance_start: Union[float, List[float]] = 0.0, | |
| control_guidance_end: Union[float, List[float]] = 1.0, | |
| latent_index: torch.LongTensor = None, | |
| vision_condition_latent_index: torch.LongTensor = None, | |
| **kwargs, | |
| ): | |
| assert ( | |
| image is None and controlnet_latents is None | |
| ), "should set one of image and controlnet_latents" | |
| controlnet = ( | |
| self.controlnet._orig_mod | |
| if is_compiled_module(self.controlnet) | |
| else self.controlnet | |
| ) | |
| # align format for control guidance | |
| if not isinstance(control_guidance_start, list) and isinstance( | |
| control_guidance_end, list | |
| ): | |
| control_guidance_start = len(control_guidance_end) * [ | |
| control_guidance_start | |
| ] | |
| elif not isinstance(control_guidance_end, list) and isinstance( | |
| control_guidance_start, list | |
| ): | |
| control_guidance_end = len(control_guidance_start) * [control_guidance_end] | |
| elif not isinstance(control_guidance_start, list) and not isinstance( | |
| control_guidance_end, list | |
| ): | |
| mult = ( | |
| len(controlnet.nets) | |
| if isinstance(controlnet, MultiControlNetModel) | |
| else 1 | |
| ) | |
| control_guidance_start, control_guidance_end = mult * [ | |
| control_guidance_start | |
| ], mult * [control_guidance_end] | |
| if isinstance(controlnet, MultiControlNetModel) and isinstance( | |
| controlnet_conditioning_scale, float | |
| ): | |
| controlnet_conditioning_scale = [controlnet_conditioning_scale] * len( | |
| controlnet.nets | |
| ) | |
| global_pool_conditions = ( | |
| controlnet.config.global_pool_conditions | |
| if isinstance(controlnet, ControlNetModel) | |
| else controlnet.nets[0].config.global_pool_conditions | |
| ) | |
| guess_mode = guess_mode or global_pool_conditions | |
| # 4. Prepare image | |
| if isinstance(controlnet, ControlNetModel): | |
| if ( | |
| controlnet_latents is not None | |
| and controlnet_condition_latents is not None | |
| ): | |
| if isinstance(controlnet_latents, np.ndarray): | |
| controlnet_latents = torch.from_numpy(controlnet_latents) | |
| if isinstance(controlnet_condition_latents, np.ndarray): | |
| controlnet_condition_latents = torch.from_numpy( | |
| controlnet_condition_latents | |
| ) | |
| # TODO:使用index进行concat | |
| controlnet_latents = torch.concat( | |
| [controlnet_condition_latents, controlnet_latents], dim=2 | |
| ) | |
| if not guess_mode and do_classifier_free_guidance: | |
| controlnet_latents = torch.concat([controlnet_latents] * 2, dim=0) | |
| controlnet_latents = rearrange( | |
| controlnet_latents, "b c t h w->(b t) c h w" | |
| ) | |
| controlnet_latents = controlnet_latents.to(device=device, dtype=dtype) | |
| else: | |
| # TODO:使用index进行concat | |
| # TODO: concat with index | |
| if controlnet_condition_frames is not None: | |
| if isinstance(controlnet_condition_frames, np.ndarray): | |
| image = np.concatenate( | |
| [controlnet_condition_frames, image], axis=2 | |
| ) | |
| image = self.prepare_image( | |
| image=image, | |
| width=width, | |
| height=height, | |
| batch_size=batch_size * num_videos_per_prompt, | |
| num_images_per_prompt=num_videos_per_prompt, | |
| device=device, | |
| dtype=controlnet.dtype, | |
| do_classifier_free_guidance=do_classifier_free_guidance, | |
| guess_mode=guess_mode, | |
| ) | |
| height, width = image.shape[-2:] | |
| elif isinstance(controlnet, MultiControlNetModel): | |
| images = [] | |
| # TODO: 支持直接使用controlnet_latent而不是frames | |
| # TODO: support using controlnet_latent directly instead of frames | |
| if controlnet_latents is not None: | |
| raise NotImplementedError | |
| else: | |
| for i, image_ in enumerate(image): | |
| if controlnet_condition_frames is not None and isinstance( | |
| controlnet_condition_frames, list | |
| ): | |
| if isinstance(controlnet_condition_frames[i], np.ndarray): | |
| image_ = np.concatenate( | |
| [controlnet_condition_frames[i], image_], axis=2 | |
| ) | |
| image_ = self.prepare_image( | |
| image=image_, | |
| width=width, | |
| height=height, | |
| batch_size=batch_size * num_videos_per_prompt, | |
| num_images_per_prompt=num_videos_per_prompt, | |
| device=device, | |
| dtype=controlnet.dtype, | |
| do_classifier_free_guidance=do_classifier_free_guidance, | |
| guess_mode=guess_mode, | |
| ) | |
| images.append(image_) | |
| image = images | |
| height, width = image[0].shape[-2:] | |
| else: | |
| assert False | |
| # 7.1 Create tensor stating which controlnets to keep | |
| controlnet_keep = [] | |
| for i in range(len(timesteps)): | |
| keeps = [ | |
| 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) | |
| for s, e in zip(control_guidance_start, control_guidance_end) | |
| ] | |
| controlnet_keep.append( | |
| keeps[0] if isinstance(controlnet, ControlNetModel) else keeps | |
| ) | |
| t = timesteps[i] | |
| # controlnet(s) inference | |
| if guess_mode and do_classifier_free_guidance: | |
| # Infer ControlNet only for the conditional batch. | |
| control_model_input = latents | |
| control_model_input = scheduler.scale_model_input(control_model_input, t) | |
| controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] | |
| else: | |
| control_model_input = latent_model_input | |
| controlnet_prompt_embeds = prompt_embeds | |
| if isinstance(controlnet_keep[i], list): | |
| cond_scale = [ | |
| c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i]) | |
| ] | |
| else: | |
| cond_scale = controlnet_conditioning_scale * controlnet_keep[i] | |
| control_model_input_reshape = rearrange( | |
| control_model_input, "b c t h w -> (b t) c h w" | |
| ) | |
| encoder_hidden_states_repeat = repeat( | |
| controlnet_prompt_embeds, | |
| "b n q->(b t) n q", | |
| t=control_model_input.shape[2], | |
| ) | |
| down_block_res_samples, mid_block_res_sample = self.controlnet( | |
| control_model_input_reshape, | |
| t, | |
| encoder_hidden_states_repeat, | |
| controlnet_cond=image, | |
| controlnet_cond_latents=controlnet_latents, | |
| conditioning_scale=cond_scale, | |
| guess_mode=guess_mode, | |
| return_dict=False, | |
| ) | |
| return down_block_res_samples, mid_block_res_sample | |
| class InflatedConv3d(nn.Conv2d): | |
| def forward(self, x): | |
| video_length = x.shape[2] | |
| x = rearrange(x, "b c f h w -> (b f) c h w") | |
| x = super().forward(x) | |
| x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) | |
| return x | |
| def zero_module(module): | |
| # Zero out the parameters of a module and return it. | |
| for p in module.parameters(): | |
| p.detach().zero_() | |
| return module | |
| class PoseGuider(ModelMixin): | |
| def __init__( | |
| self, | |
| conditioning_embedding_channels: int, | |
| conditioning_channels: int = 3, | |
| block_out_channels: Tuple[int] = (16, 32, 64, 128), | |
| ): | |
| super().__init__() | |
| self.conv_in = InflatedConv3d( | |
| conditioning_channels, block_out_channels[0], kernel_size=3, padding=1 | |
| ) | |
| self.blocks = nn.ModuleList([]) | |
| for i in range(len(block_out_channels) - 1): | |
| channel_in = block_out_channels[i] | |
| channel_out = block_out_channels[i + 1] | |
| self.blocks.append( | |
| InflatedConv3d(channel_in, channel_in, kernel_size=3, padding=1) | |
| ) | |
| self.blocks.append( | |
| InflatedConv3d( | |
| channel_in, channel_out, kernel_size=3, padding=1, stride=2 | |
| ) | |
| ) | |
| self.conv_out = zero_module( | |
| InflatedConv3d( | |
| block_out_channels[-1], | |
| conditioning_embedding_channels, | |
| kernel_size=3, | |
| padding=1, | |
| ) | |
| ) | |
| def forward(self, conditioning): | |
| embedding = self.conv_in(conditioning) | |
| embedding = F.silu(embedding) | |
| for block in self.blocks: | |
| embedding = block(embedding) | |
| embedding = F.silu(embedding) | |
| embedding = self.conv_out(embedding) | |
| return embedding | |
| def from_pretrained( | |
| cls, | |
| pretrained_model_path, | |
| conditioning_embedding_channels: int, | |
| conditioning_channels: int = 3, | |
| block_out_channels: Tuple[int] = (16, 32, 64, 128), | |
| ): | |
| if not os.path.exists(pretrained_model_path): | |
| print(f"There is no model file in {pretrained_model_path}") | |
| print( | |
| f"loaded PoseGuider's pretrained weights from {pretrained_model_path} ..." | |
| ) | |
| state_dict = torch.load(pretrained_model_path, map_location="cpu") | |
| model = PoseGuider( | |
| conditioning_embedding_channels=conditioning_embedding_channels, | |
| conditioning_channels=conditioning_channels, | |
| block_out_channels=block_out_channels, | |
| ) | |
| m, u = model.load_state_dict(state_dict, strict=False) | |
| # print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};") | |
| params = [p.numel() for n, p in model.named_parameters()] | |
| print(f"### PoseGuider's Parameters: {sum(params) / 1e6} M") | |
| return model | |