Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| import logging | |
| from typing import Any, Dict, Tuple, Union, Optional | |
| from einops import rearrange, repeat | |
| from torch import nn | |
| import torch | |
| from diffusers.configuration_utils import ConfigMixin, register_to_config | |
| from diffusers.models.modeling_utils import ModelMixin, load_state_dict | |
| from ..data.data_util import align_repeat_tensor_single_dim | |
| from .unet_3d_condition import UNet3DConditionModel | |
| from .referencenet import ReferenceNet2D | |
| from ip_adapter.ip_adapter import ImageProjModel | |
| logger = logging.getLogger(__name__) | |
| class SuperUNet3DConditionModel(nn.Module): | |
| """封装了各种子模型的超模型,与 diffusers 的 pipeline 很像,只不过这里是模型定义。 | |
| 主要作用 | |
| 1. 将支持controlnet、referencenet等功能的计算封装起来,简洁些; | |
| 2. 便于 accelerator 的分布式训练; | |
| wrap the sub-models, such as unet, referencenet, controlnet, vae, text_encoder, tokenizer, text_emb_extractor, clip_vision_extractor, ip_adapter_image_proj | |
| 1. support controlnet, referencenet, etc. | |
| 2. support accelerator distributed training | |
| """ | |
| _supports_gradient_checkpointing = True | |
| print_idx = 0 | |
| # @register_to_config | |
| def __init__( | |
| self, | |
| unet: nn.Module, | |
| referencenet: nn.Module = None, | |
| controlnet: nn.Module = None, | |
| vae: nn.Module = None, | |
| text_encoder: nn.Module = None, | |
| tokenizer: nn.Module = None, | |
| text_emb_extractor: nn.Module = None, | |
| clip_vision_extractor: nn.Module = None, | |
| ip_adapter_image_proj: nn.Module = None, | |
| ) -> None: | |
| """_summary_ | |
| Args: | |
| unet (nn.Module): _description_ | |
| referencenet (nn.Module, optional): _description_. Defaults to None. | |
| controlnet (nn.Module, optional): _description_. Defaults to None. | |
| vae (nn.Module, optional): _description_. Defaults to None. | |
| text_encoder (nn.Module, optional): _description_. Defaults to None. | |
| tokenizer (nn.Module, optional): _description_. Defaults to None. | |
| text_emb_extractor (nn.Module, optional): wrap text_encoder and tokenizer for str2emb. Defaults to None. | |
| clip_vision_extractor (nn.Module, optional): _description_. Defaults to None. | |
| """ | |
| super().__init__() | |
| self.unet = unet | |
| self.referencenet = referencenet | |
| self.controlnet = controlnet | |
| self.vae = vae | |
| self.text_encoder = text_encoder | |
| self.tokenizer = tokenizer | |
| self.text_emb_extractor = text_emb_extractor | |
| self.clip_vision_extractor = clip_vision_extractor | |
| self.ip_adapter_image_proj = ip_adapter_image_proj | |
| def forward( | |
| self, | |
| unet_params: Dict, | |
| encoder_hidden_states: torch.Tensor, | |
| referencenet_params: Dict = None, | |
| controlnet_params: Dict = None, | |
| controlnet_scale: float = 1.0, | |
| vision_clip_emb: Union[torch.Tensor, None] = None, | |
| prompt_only_use_image_prompt: bool = False, | |
| ): | |
| """_summary_ | |
| Args: | |
| unet_params (Dict): _description_ | |
| encoder_hidden_states (torch.Tensor): b t n d | |
| referencenet_params (Dict, optional): _description_. Defaults to None. | |
| controlnet_params (Dict, optional): _description_. Defaults to None. | |
| controlnet_scale (float, optional): _description_. Defaults to 1.0. | |
| vision_clip_emb (Union[torch.Tensor, None], optional): b t d. Defaults to None. | |
| prompt_only_use_image_prompt (bool, optional): _description_. Defaults to False. | |
| Returns: | |
| _type_: _description_ | |
| """ | |
| batch_size = unet_params["sample"].shape[0] | |
| time_size = unet_params["sample"].shape[2] | |
| # ip_adapter_cross_attn, prepare image prompt | |
| if vision_clip_emb is not None: | |
| # b t n d -> b t n d | |
| if self.print_idx == 0: | |
| logger.debug( | |
| f"vision_clip_emb, before ip_adapter_image_proj, shape={vision_clip_emb.shape} mean={torch.mean(vision_clip_emb)}" | |
| ) | |
| if vision_clip_emb.ndim == 3: | |
| vision_clip_emb = rearrange(vision_clip_emb, "b t d-> b t 1 d") | |
| if self.ip_adapter_image_proj is not None: | |
| vision_clip_emb = rearrange(vision_clip_emb, "b t n d ->(b t) n d") | |
| vision_clip_emb = self.ip_adapter_image_proj(vision_clip_emb) | |
| if self.print_idx == 0: | |
| logger.debug( | |
| f"vision_clip_emb, after ip_adapter_image_proj shape={vision_clip_emb.shape} mean={torch.mean(vision_clip_emb)}" | |
| ) | |
| if vision_clip_emb.ndim == 2: | |
| vision_clip_emb = rearrange(vision_clip_emb, "b d-> b 1 d") | |
| vision_clip_emb = rearrange( | |
| vision_clip_emb, "(b t) n d -> b t n d", b=batch_size | |
| ) | |
| vision_clip_emb = align_repeat_tensor_single_dim( | |
| vision_clip_emb, target_length=time_size, dim=1 | |
| ) | |
| if self.print_idx == 0: | |
| logger.debug( | |
| f"vision_clip_emb, after reshape shape={vision_clip_emb.shape} mean={torch.mean(vision_clip_emb)}" | |
| ) | |
| if vision_clip_emb is None and encoder_hidden_states is not None: | |
| vision_clip_emb = encoder_hidden_states | |
| if vision_clip_emb is not None and encoder_hidden_states is None: | |
| encoder_hidden_states = vision_clip_emb | |
| # 当 prompt_only_use_image_prompt 为True时, | |
| # 1. referencenet 都使用 vision_clip_emb | |
| # 2. unet 如果没有dual_cross_attn,使用vision_clip_emb,有时不更新 | |
| # 3. controlnet 当前使用 text_prompt | |
| # when prompt_only_use_image_prompt True, | |
| # 1. referencenet use vision_clip_emb | |
| # 2. unet use vision_clip_emb if no dual_cross_attn, sometimes not update | |
| # 3. controlnet use text_prompt | |
| # extract referencenet emb | |
| if self.referencenet is not None and referencenet_params is not None: | |
| referencenet_encoder_hidden_states = align_repeat_tensor_single_dim( | |
| vision_clip_emb, | |
| target_length=referencenet_params["num_frames"], | |
| dim=1, | |
| ) | |
| referencenet_params["encoder_hidden_states"] = rearrange( | |
| referencenet_encoder_hidden_states, "b t n d->(b t) n d" | |
| ) | |
| referencenet_out = self.referencenet(**referencenet_params) | |
| ( | |
| down_block_refer_embs, | |
| mid_block_refer_emb, | |
| refer_self_attn_emb, | |
| ) = referencenet_out | |
| if down_block_refer_embs is not None: | |
| if self.print_idx == 0: | |
| logger.debug( | |
| f"len(down_block_refer_embs)={len(down_block_refer_embs)}" | |
| ) | |
| for i, down_emb in enumerate(down_block_refer_embs): | |
| if self.print_idx == 0: | |
| logger.debug( | |
| f"down_emb, {i}, {down_emb.shape}, mean={down_emb.mean()}" | |
| ) | |
| else: | |
| if self.print_idx == 0: | |
| logger.debug(f"down_block_refer_embs is None") | |
| if mid_block_refer_emb is not None: | |
| if self.print_idx == 0: | |
| logger.debug( | |
| f"mid_block_refer_emb, {mid_block_refer_emb.shape}, mean={mid_block_refer_emb.mean()}" | |
| ) | |
| else: | |
| if self.print_idx == 0: | |
| logger.debug(f"mid_block_refer_emb is None") | |
| if refer_self_attn_emb is not None: | |
| if self.print_idx == 0: | |
| logger.debug(f"refer_self_attn_emb, num={len(refer_self_attn_emb)}") | |
| for i, self_attn_emb in enumerate(refer_self_attn_emb): | |
| if self.print_idx == 0: | |
| logger.debug( | |
| f"referencenet, self_attn_emb, {i}th, shape={self_attn_emb.shape}, mean={self_attn_emb.mean()}" | |
| ) | |
| else: | |
| if self.print_idx == 0: | |
| logger.debug(f"refer_self_attn_emb is None") | |
| else: | |
| down_block_refer_embs, mid_block_refer_emb, refer_self_attn_emb = ( | |
| None, | |
| None, | |
| None, | |
| ) | |
| # extract controlnet emb | |
| if self.controlnet is not None and controlnet_params is not None: | |
| controlnet_encoder_hidden_states = align_repeat_tensor_single_dim( | |
| encoder_hidden_states, | |
| target_length=unet_params["sample"].shape[2], | |
| dim=1, | |
| ) | |
| controlnet_params["encoder_hidden_states"] = rearrange( | |
| controlnet_encoder_hidden_states, " b t n d -> (b t) n d" | |
| ) | |
| ( | |
| down_block_additional_residuals, | |
| mid_block_additional_residual, | |
| ) = self.controlnet(**controlnet_params) | |
| if controlnet_scale != 1.0: | |
| down_block_additional_residuals = [ | |
| x * controlnet_scale for x in down_block_additional_residuals | |
| ] | |
| mid_block_additional_residual = ( | |
| mid_block_additional_residual * controlnet_scale | |
| ) | |
| for i, down_block_additional_residual in enumerate( | |
| down_block_additional_residuals | |
| ): | |
| if self.print_idx == 0: | |
| logger.debug( | |
| f"{i}, down_block_additional_residual mean={torch.mean(down_block_additional_residual)}" | |
| ) | |
| if self.print_idx == 0: | |
| logger.debug( | |
| f"mid_block_additional_residual mean={torch.mean(mid_block_additional_residual)}" | |
| ) | |
| else: | |
| down_block_additional_residuals = None | |
| mid_block_additional_residual = None | |
| if prompt_only_use_image_prompt and vision_clip_emb is not None: | |
| encoder_hidden_states = vision_clip_emb | |
| # run unet | |
| out = self.unet( | |
| **unet_params, | |
| down_block_refer_embs=down_block_refer_embs, | |
| mid_block_refer_emb=mid_block_refer_emb, | |
| refer_self_attn_emb=refer_self_attn_emb, | |
| down_block_additional_residuals=down_block_additional_residuals, | |
| mid_block_additional_residual=mid_block_additional_residual, | |
| encoder_hidden_states=encoder_hidden_states, | |
| vision_clip_emb=vision_clip_emb, | |
| ) | |
| self.print_idx += 1 | |
| return out | |
| def _set_gradient_checkpointing(self, module, value=False): | |
| if isinstance(module, (UNet3DConditionModel, ReferenceNet2D)): | |
| module.gradient_checkpointing = value | |