Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| import inspect | |
| import math | |
| import time | |
| import warnings | |
| from typing import Any, Callable, Dict, List, Optional, Tuple, Union | |
| from dataclasses import dataclass | |
| from einops import rearrange, repeat | |
| import PIL.Image | |
| import numpy as np | |
| import torch | |
| from torch import nn | |
| import torch.nn.functional as F | |
| from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer | |
| from diffusers.pipelines.controlnet.pipeline_controlnet import ( | |
| StableDiffusionSafetyChecker, | |
| EXAMPLE_DOC_STRING, | |
| ) | |
| from diffusers.pipelines.controlnet.pipeline_controlnet_img2img import ( | |
| StableDiffusionControlNetImg2ImgPipeline as DiffusersStableDiffusionControlNetImg2ImgPipeline, | |
| ) | |
| from diffusers.configuration_utils import FrozenDict | |
| from diffusers.models import AutoencoderKL, ControlNetModel | |
| from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel | |
| from diffusers.pipelines.stable_diffusion.safety_checker import ( | |
| StableDiffusionSafetyChecker, | |
| ) | |
| from diffusers.schedulers import KarrasDiffusionSchedulers | |
| from diffusers.utils import ( | |
| deprecate, | |
| logging, | |
| BaseOutput, | |
| replace_example_docstring, | |
| ) | |
| from diffusers.utils.torch_utils import is_compiled_module | |
| from diffusers.loaders import TextualInversionLoaderMixin | |
| from diffusers.models.attention import ( | |
| BasicTransformerBlock as DiffusersBasicTransformerBlock, | |
| ) | |
| from mmcm.vision.process.correct_color import ( | |
| hist_match_color_video_batch, | |
| hist_match_video_bcthw, | |
| ) | |
| from ..models.attention import BasicTransformerBlock | |
| from ..models.unet_3d_condition import UNet3DConditionModel | |
| from ..utils.noise_util import random_noise, video_fusion_noise | |
| from ..data.data_util import ( | |
| adaptive_instance_normalization, | |
| align_repeat_tensor_single_dim, | |
| batch_adain_conditioned_tensor, | |
| batch_concat_two_tensor_with_index, | |
| batch_index_select, | |
| fuse_part_tensor, | |
| ) | |
| from ..utils.text_emb_util import encode_weighted_prompt | |
| from ..utils.tensor_util import his_match | |
| from ..utils.timesteps_util import generate_parameters_with_timesteps | |
| from .context import get_context_scheduler, prepare_global_context | |
| logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |
| class VideoPipelineOutput(BaseOutput): | |
| videos: Union[torch.Tensor, np.ndarray] | |
| latents: Union[torch.Tensor, np.ndarray] | |
| videos_mid: Union[torch.Tensor, np.ndarray] | |
| down_block_res_samples: Tuple[torch.FloatTensor] = None | |
| mid_block_res_samples: torch.FloatTensor = None | |
| up_block_res_samples: torch.FloatTensor = None | |
| mid_video_latents: List[torch.FloatTensor] = None | |
| mid_video_noises: List[torch.FloatTensor] = None | |
| def torch_dfs(model: torch.nn.Module): | |
| result = [model] | |
| for child in model.children(): | |
| result += torch_dfs(child) | |
| return result | |
| def prepare_image( | |
| image, # b c t h w | |
| batch_size, | |
| device, | |
| dtype, | |
| image_processor: Callable, | |
| num_images_per_prompt: int = 1, | |
| width=None, | |
| height=None, | |
| ): | |
| if isinstance(image, List) and isinstance(image[0], str): | |
| raise NotImplementedError | |
| if isinstance(image, List) and isinstance(image[0], np.ndarray): | |
| image = np.concatenate(image, axis=0) | |
| if isinstance(image, np.ndarray): | |
| image = torch.from_numpy(image) | |
| if image.ndim == 5: | |
| image = rearrange(image, "b c t h w-> (b t) c h w") | |
| if height is None: | |
| height = image.shape[-2] | |
| if width is None: | |
| width = image.shape[-1] | |
| width, height = (x - x % image_processor.vae_scale_factor for x in (width, height)) | |
| if height != image.shape[-2] or width != image.shape[-1]: | |
| image = torch.nn.functional.interpolate( | |
| image, size=(height, width), mode="bilinear" | |
| ) | |
| image = image.to(dtype=torch.float32) / 255.0 | |
| do_normalize = image_processor.config.do_normalize | |
| if image.min() < 0: | |
| warnings.warn( | |
| "Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] " | |
| f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]", | |
| FutureWarning, | |
| ) | |
| do_normalize = False | |
| if do_normalize: | |
| image = image_processor.normalize(image) | |
| image_batch_size = image.shape[0] | |
| if image_batch_size == 1: | |
| repeat_by = batch_size | |
| else: | |
| # image batch size is the same as prompt batch size | |
| repeat_by = num_images_per_prompt | |
| image = image.repeat_interleave(repeat_by, dim=0) | |
| image = image.to(device=device, dtype=dtype) | |
| return image | |
| class MusevControlNetPipeline( | |
| DiffusersStableDiffusionControlNetImg2ImgPipeline, TextualInversionLoaderMixin | |
| ): | |
| """ | |
| a union diffusers pipeline, support | |
| 1. text2image model only, or text2video model, by setting skip_temporal_layer | |
| 2. text2video, image2video, video2video; | |
| 3. multi controlnet | |
| 4. IPAdapter | |
| 5. referencenet | |
| 6. IPAdapterFaceID | |
| """ | |
| _optional_components = [ | |
| "safety_checker", | |
| "feature_extractor", | |
| ] | |
| print_idx = 0 | |
| def __init__( | |
| self, | |
| vae: AutoencoderKL, | |
| unet: UNet3DConditionModel, | |
| scheduler: KarrasDiffusionSchedulers, | |
| controlnet: ControlNetModel | |
| | List[ControlNetModel] | |
| | Tuple[ControlNetModel] | |
| | MultiControlNetModel, | |
| text_encoder: CLIPTextModel, | |
| tokenizer: CLIPTokenizer, | |
| safety_checker: StableDiffusionSafetyChecker, | |
| feature_extractor: CLIPImageProcessor, | |
| # | MultiControlNetModel = None, | |
| # text_encoder: CLIPTextModel = None, | |
| # tokenizer: CLIPTokenizer = None, | |
| # safety_checker: StableDiffusionSafetyChecker = None, | |
| # feature_extractor: CLIPImageProcessor = None, | |
| requires_safety_checker: bool = False, | |
| referencenet: nn.Module = None, | |
| vision_clip_extractor: nn.Module = None, | |
| ip_adapter_image_proj: nn.Module = None, | |
| face_emb_extractor: nn.Module = None, | |
| facein_image_proj: nn.Module = None, | |
| ip_adapter_face_emb_extractor: nn.Module = None, | |
| ip_adapter_face_image_proj: nn.Module = None, | |
| pose_guider: nn.Module = None, | |
| ): | |
| super().__init__( | |
| vae, | |
| text_encoder, | |
| tokenizer, | |
| unet, | |
| controlnet, | |
| scheduler, | |
| safety_checker, | |
| feature_extractor, | |
| requires_safety_checker, | |
| ) | |
| self.referencenet = referencenet | |
| # ip_adapter | |
| if isinstance(vision_clip_extractor, nn.Module): | |
| vision_clip_extractor.to(dtype=self.unet.dtype, device=self.unet.device) | |
| self.vision_clip_extractor = vision_clip_extractor | |
| if isinstance(ip_adapter_image_proj, nn.Module): | |
| ip_adapter_image_proj.to(dtype=self.unet.dtype, device=self.unet.device) | |
| self.ip_adapter_image_proj = ip_adapter_image_proj | |
| # facein | |
| if isinstance(face_emb_extractor, nn.Module): | |
| face_emb_extractor.to(dtype=self.unet.dtype, device=self.unet.device) | |
| self.face_emb_extractor = face_emb_extractor | |
| if isinstance(facein_image_proj, nn.Module): | |
| facein_image_proj.to(dtype=self.unet.dtype, device=self.unet.device) | |
| self.facein_image_proj = facein_image_proj | |
| # ip_adapter_face | |
| if isinstance(ip_adapter_face_emb_extractor, nn.Module): | |
| ip_adapter_face_emb_extractor.to( | |
| dtype=self.unet.dtype, device=self.unet.device | |
| ) | |
| self.ip_adapter_face_emb_extractor = ip_adapter_face_emb_extractor | |
| if isinstance(ip_adapter_face_image_proj, nn.Module): | |
| ip_adapter_face_image_proj.to( | |
| dtype=self.unet.dtype, device=self.unet.device | |
| ) | |
| self.ip_adapter_face_image_proj = ip_adapter_face_image_proj | |
| if isinstance(pose_guider, nn.Module): | |
| pose_guider.to(dtype=self.unet.dtype, device=self.unet.device) | |
| self.pose_guider = pose_guider | |
| def decode_latents(self, latents): | |
| batch_size = latents.shape[0] | |
| latents = rearrange(latents, "b c f h w -> (b f) c h w") | |
| video = super().decode_latents(latents=latents) | |
| video = rearrange(video, "(b f) h w c -> b c f h w", b=batch_size) | |
| return video | |
| def prepare_latents( | |
| self, | |
| batch_size: int, | |
| num_channels_latents: int, | |
| video_length: int, | |
| height: int, | |
| width: int, | |
| dtype: torch.dtype, | |
| device: torch.device, | |
| generator: torch.Generator, | |
| latents: torch.Tensor = None, | |
| w_ind_noise: float = 0.5, | |
| image: torch.Tensor = None, | |
| timestep: int = None, | |
| initial_common_latent: torch.Tensor = None, | |
| noise_type: str = "random", | |
| add_latents_noise: bool = False, | |
| need_img_based_video_noise: bool = False, | |
| condition_latents: torch.Tensor = None, | |
| img_weight=1e-3, | |
| ) -> torch.Tensor: | |
| """ | |
| 支持多种情况下的latens: | |
| img_based_latents: 当Image t=1,latents=None时,使用image赋值到shape,然后加噪;适用于text2video、middle2video。 | |
| video_based_latents:image =shape或Latents!=None时,加噪,适用于video2video; | |
| noise_latents:当image 和latents都为None时,生成随机噪声,适用于text2video | |
| support multi latents condition: | |
| img_based_latents: when Image t=1, latents=None, use image to assign to shape, then add noise; suitable for text2video, middle2video. | |
| video_based_latents: image =shape or Latents!=None, add noise, suitable for video2video; | |
| noise_laten: when image and latents are both None, generate random noise, suitable for text2video | |
| Args: | |
| batch_size (int): _description_ | |
| num_channels_latents (int): _description_ | |
| video_length (int): _description_ | |
| height (int): _description_ | |
| width (int): _description_ | |
| dtype (torch.dtype): _description_ | |
| device (torch.device): _description_ | |
| generator (torch.Generator): _description_ | |
| latents (torch.Tensor, optional): _description_. Defaults to None. | |
| w_ind_noise (float, optional): _description_. Defaults to 0.5. | |
| image (torch.Tensor, optional): _description_. Defaults to None. | |
| timestep (int, optional): _description_. Defaults to None. | |
| initial_common_latent (torch.Tensor, optional): _description_. Defaults to None. | |
| noise_type (str, optional): _description_. Defaults to "random". | |
| add_latents_noise (bool, optional): _description_. Defaults to False. | |
| need_img_based_video_noise (bool, optional): _description_. Defaults to False. | |
| condition_latents (torch.Tensor, optional): _description_. Defaults to None. | |
| img_weight (_type_, optional): _description_. Defaults to 1e-3. | |
| Raises: | |
| ValueError: _description_ | |
| ValueError: _description_ | |
| ValueError: _description_ | |
| Returns: | |
| torch.Tensor: latents | |
| """ | |
| # ref https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py#L691 | |
| # ref https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/controlnet/pipeline_controlnet.py#L659 | |
| shape = ( | |
| batch_size, | |
| num_channels_latents, | |
| video_length, | |
| height // self.vae_scale_factor, | |
| width // self.vae_scale_factor, | |
| ) | |
| if latents is None or (latents is not None and add_latents_noise): | |
| if noise_type == "random": | |
| noise = random_noise( | |
| shape=shape, dtype=dtype, device=device, generator=generator | |
| ) | |
| elif noise_type == "video_fusion": | |
| noise = video_fusion_noise( | |
| shape=shape, | |
| dtype=dtype, | |
| device=device, | |
| generator=generator, | |
| w_ind_noise=w_ind_noise, | |
| initial_common_noise=initial_common_latent, | |
| ) | |
| if ( | |
| need_img_based_video_noise | |
| and condition_latents is not None | |
| and image is None | |
| and latents is None | |
| ): | |
| if self.print_idx == 0: | |
| logger.debug( | |
| ( | |
| f"need_img_based_video_noise, condition_latents={condition_latents.shape}," | |
| f"batch_size={batch_size}, noise={noise.shape}, video_length={video_length}" | |
| ) | |
| ) | |
| condition_latents = condition_latents.mean(dim=2, keepdim=True) | |
| condition_latents = repeat( | |
| condition_latents, "b c t h w->b c (t x) h w", x=video_length | |
| ) | |
| noise = ( | |
| img_weight**0.5 * condition_latents | |
| + (1 - img_weight) ** 0.5 * noise | |
| ) | |
| if self.print_idx == 0: | |
| logger.debug(f"noise={noise.shape}") | |
| if image is not None: | |
| if image.ndim == 5: | |
| image = rearrange(image, "b c t h w->(b t) c h w") | |
| image = image.to(device=device, dtype=dtype) | |
| if isinstance(generator, list) and len(generator) != batch_size: | |
| raise ValueError( | |
| f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" | |
| f" size of {batch_size}. Make sure the batch size matches the length of the generators." | |
| ) | |
| if isinstance(generator, list): | |
| init_latents = [ | |
| # self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) | |
| self.vae.encode(image[i : i + 1]).latent_dist.mean | |
| for i in range(batch_size) | |
| ] | |
| init_latents = torch.cat(init_latents, dim=0) | |
| else: | |
| # init_latents = self.vae.encode(image).latent_dist.sample(generator) | |
| init_latents = self.vae.encode(image).latent_dist.mean | |
| init_latents = self.vae.config.scaling_factor * init_latents | |
| # scale the initial noise by the standard deviation required by the scheduler | |
| if ( | |
| batch_size > init_latents.shape[0] | |
| and batch_size % init_latents.shape[0] == 0 | |
| ): | |
| # expand init_latents for batch_size | |
| deprecation_message = ( | |
| f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" | |
| " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" | |
| " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" | |
| " your script to pass as many initial images as text prompts to suppress this warning." | |
| ) | |
| deprecate( | |
| "len(prompt) != len(image)", | |
| "1.0.0", | |
| deprecation_message, | |
| standard_warn=False, | |
| ) | |
| additional_image_per_prompt = batch_size // init_latents.shape[0] | |
| init_latents = torch.cat( | |
| [init_latents] * additional_image_per_prompt, dim=0 | |
| ) | |
| elif ( | |
| batch_size > init_latents.shape[0] | |
| and batch_size % init_latents.shape[0] != 0 | |
| ): | |
| raise ValueError( | |
| f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." | |
| ) | |
| else: | |
| init_latents = torch.cat([init_latents], dim=0) | |
| if init_latents.shape[2] != shape[3] and init_latents.shape[3] != shape[4]: | |
| init_latents = torch.nn.functional.interpolate( | |
| init_latents, | |
| size=(shape[3], shape[4]), | |
| mode="bilinear", | |
| ) | |
| init_latents = rearrange( | |
| init_latents, "(b t) c h w-> b c t h w", t=video_length | |
| ) | |
| if self.print_idx == 0: | |
| logger.debug(f"init_latensts={init_latents.shape}") | |
| if latents is None: | |
| if image is None: | |
| latents = noise * self.scheduler.init_noise_sigma | |
| else: | |
| if self.print_idx == 0: | |
| logger.debug(f"prepare latents, image is not None") | |
| latents = self.scheduler.add_noise(init_latents, noise, timestep) | |
| else: | |
| if isinstance(latents, np.ndarray): | |
| latents = torch.from_numpy(latents) | |
| latents = latents.to(device=device, dtype=dtype) | |
| if add_latents_noise: | |
| latents = self.scheduler.add_noise(latents, noise, timestep) | |
| else: | |
| latents = latents * self.scheduler.init_noise_sigma | |
| if latents.shape != shape: | |
| raise ValueError( | |
| f"Unexpected latents shape, got {latents.shape}, expected {shape}" | |
| ) | |
| latents = latents.to(device, dtype=dtype) | |
| return latents | |
| def prepare_image( | |
| self, | |
| image, # b c t h w | |
| batch_size, | |
| num_images_per_prompt, | |
| device, | |
| dtype, | |
| width=None, | |
| height=None, | |
| ): | |
| return prepare_image( | |
| image=image, | |
| batch_size=batch_size, | |
| num_images_per_prompt=num_images_per_prompt, | |
| device=device, | |
| dtype=dtype, | |
| width=width, | |
| height=height, | |
| image_processor=self.image_processor, | |
| ) | |
| def prepare_control_image( | |
| self, | |
| image, # b c t h w | |
| width, | |
| height, | |
| batch_size, | |
| num_images_per_prompt, | |
| device, | |
| dtype, | |
| do_classifier_free_guidance=False, | |
| guess_mode=False, | |
| ): | |
| image = prepare_image( | |
| image=image, | |
| batch_size=batch_size, | |
| num_images_per_prompt=num_images_per_prompt, | |
| device=device, | |
| dtype=dtype, | |
| width=width, | |
| height=height, | |
| image_processor=self.control_image_processor, | |
| ) | |
| if do_classifier_free_guidance and not guess_mode: | |
| image = torch.cat([image] * 2) | |
| return image | |
| def check_inputs( | |
| self, | |
| prompt, | |
| image, | |
| callback_steps, | |
| negative_prompt=None, | |
| prompt_embeds=None, | |
| negative_prompt_embeds=None, | |
| controlnet_conditioning_scale=1, | |
| control_guidance_start=0, | |
| control_guidance_end=1, | |
| ): | |
| # TODO: to implement | |
| if image is not None: | |
| return super().check_inputs( | |
| prompt, | |
| image, | |
| callback_steps, | |
| negative_prompt, | |
| prompt_embeds, | |
| negative_prompt_embeds, | |
| controlnet_conditioning_scale, | |
| control_guidance_start, | |
| control_guidance_end, | |
| ) | |
| def hist_match_with_vis_cond( | |
| self, video: np.ndarray, target: np.ndarray | |
| ) -> np.ndarray: | |
| """ | |
| video: b c t1 h w | |
| target: b c t2(=1) h w | |
| """ | |
| video = hist_match_video_bcthw(video, target, value=255.0) | |
| return video | |
| def get_facein_image_emb( | |
| self, refer_face_image, device, dtype, batch_size, do_classifier_free_guidance | |
| ): | |
| # refer_face_image and its face_emb | |
| if self.print_idx == 0: | |
| logger.debug( | |
| f"face_emb_extractor={type(self.face_emb_extractor)}, facein_image_proj={type(self.facein_image_proj)}, refer_face_image={type(refer_face_image)}, " | |
| ) | |
| if ( | |
| self.face_emb_extractor is not None | |
| and self.facein_image_proj is not None | |
| and refer_face_image is not None | |
| ): | |
| if self.print_idx == 0: | |
| logger.debug(f"refer_face_image={refer_face_image.shape}") | |
| if isinstance(refer_face_image, np.ndarray): | |
| refer_face_image = torch.from_numpy(refer_face_image) | |
| refer_face_image_facein = refer_face_image | |
| n_refer_face_image = refer_face_image_facein.shape[2] | |
| refer_face_image_facein = rearrange( | |
| refer_face_image, "b c t h w-> (b t) h w c" | |
| ) | |
| # refer_face_image_emb: bt d或者 bt h w d | |
| ( | |
| refer_face_image_emb, | |
| refer_align_face_image, | |
| ) = self.face_emb_extractor.extract_images( | |
| refer_face_image_facein, return_type="torch" | |
| ) | |
| refer_face_image_emb = refer_face_image_emb.to(device=device, dtype=dtype) | |
| if self.print_idx == 0: | |
| logger.debug(f"refer_face_image_emb={refer_face_image_emb.shape}") | |
| if refer_face_image_emb.shape == 2: | |
| refer_face_image_emb = rearrange(refer_face_image_emb, "bt d-> bt 1 d") | |
| elif refer_face_image_emb.shape == 4: | |
| refer_face_image_emb = rearrange( | |
| refer_face_image_emb, "bt h w d-> bt (h w) d" | |
| ) | |
| refer_face_image_emb_bk = refer_face_image_emb | |
| refer_face_image_emb = self.facein_image_proj(refer_face_image_emb) | |
| # Todo:当前不支持 IPAdapterPlus的vision_clip的输出 | |
| refer_face_image_emb = rearrange( | |
| refer_face_image_emb, | |
| "(b t) n q-> b (t n) q", | |
| t=n_refer_face_image, | |
| ) | |
| refer_face_image_emb = align_repeat_tensor_single_dim( | |
| refer_face_image_emb, target_length=batch_size, dim=0 | |
| ) | |
| if do_classifier_free_guidance: | |
| # TODO:固定特征,有优化空间 | |
| # TODO: fix the feature, there is optimization space | |
| uncond_refer_face_image_emb = self.facein_image_proj( | |
| torch.zeros_like(refer_face_image_emb_bk).to( | |
| device=device, dtype=dtype | |
| ) | |
| ) | |
| # Todo:当前可能不支持 IPAdapterPlus的vision_clip的输出 | |
| # TODO: do not support IPAdapterPlus's vision_clip's output | |
| uncond_refer_face_image_emb = rearrange( | |
| uncond_refer_face_image_emb, | |
| "(b t) n q-> b (t n) q", | |
| t=n_refer_face_image, | |
| ) | |
| uncond_refer_face_image_emb = align_repeat_tensor_single_dim( | |
| uncond_refer_face_image_emb, target_length=batch_size, dim=0 | |
| ) | |
| if self.print_idx == 0: | |
| logger.debug( | |
| f"uncond_refer_face_image_emb, {uncond_refer_face_image_emb.shape}" | |
| ) | |
| logger.debug(f"refer_face_image_emb, {refer_face_image_emb.shape}") | |
| refer_face_image_emb = torch.concat( | |
| [ | |
| uncond_refer_face_image_emb, | |
| refer_face_image_emb, | |
| ], | |
| ) | |
| else: | |
| refer_face_image_emb = None | |
| if self.print_idx == 0: | |
| logger.debug(f"refer_face_image_emb={type(refer_face_image_emb)}") | |
| return refer_face_image_emb | |
| def get_ip_adapter_face_emb( | |
| self, refer_face_image, device, dtype, batch_size, do_classifier_free_guidance | |
| ): | |
| # refer_face_image and its ip_adapter_face_emb | |
| if self.print_idx == 0: | |
| logger.debug( | |
| f"face_emb_extractor={type(self.face_emb_extractor)}, ip_adapter__image_proj={type(self.facein_image_proj)}, refer_face_image={type(refer_face_image)}, " | |
| ) | |
| if ( | |
| self.ip_adapter_face_emb_extractor is not None | |
| and self.ip_adapter_face_image_proj is not None | |
| and refer_face_image is not None | |
| ): | |
| if self.print_idx == 0: | |
| logger.debug(f"refer_face_image={refer_face_image.shape}") | |
| if isinstance(refer_face_image, np.ndarray): | |
| refer_face_image = torch.from_numpy(refer_face_image) | |
| refer_ip_adapter_face_image = refer_face_image | |
| n_refer_face_image = refer_ip_adapter_face_image.shape[2] | |
| refer_ip_adapter_face_image = rearrange( | |
| refer_ip_adapter_face_image, "b c t h w-> (b t) h w c" | |
| ) | |
| # refer_face_image_emb: bt d or bt h w d | |
| ( | |
| refer_face_image_emb, | |
| refer_align_face_image, | |
| ) = self.ip_adapter_face_emb_extractor.extract_images( | |
| refer_ip_adapter_face_image, return_type="torch" | |
| ) | |
| refer_face_image_emb = refer_face_image_emb.to(device=device, dtype=dtype) | |
| if self.print_idx == 0: | |
| logger.debug(f"refer_face_image_emb={refer_face_image_emb.shape}") | |
| if refer_face_image_emb.shape == 2: | |
| refer_face_image_emb = rearrange(refer_face_image_emb, "bt d-> bt 1 d") | |
| elif refer_face_image_emb.shape == 4: | |
| refer_face_image_emb = rearrange( | |
| refer_face_image_emb, "bt h w d-> bt (h w) d" | |
| ) | |
| refer_face_image_emb_bk = refer_face_image_emb | |
| refer_face_image_emb = self.ip_adapter_face_image_proj(refer_face_image_emb) | |
| refer_face_image_emb = rearrange( | |
| refer_face_image_emb, | |
| "(b t) n q-> b (t n) q", | |
| t=n_refer_face_image, | |
| ) | |
| refer_face_image_emb = align_repeat_tensor_single_dim( | |
| refer_face_image_emb, target_length=batch_size, dim=0 | |
| ) | |
| if do_classifier_free_guidance: | |
| # TODO:固定特征,有优化空间 | |
| # TODO: fix the feature, there is optimization space | |
| uncond_refer_face_image_emb = self.ip_adapter_face_image_proj( | |
| torch.zeros_like(refer_face_image_emb_bk).to( | |
| device=device, dtype=dtype | |
| ) | |
| ) | |
| # TODO: 当前可能不支持 IPAdapterPlus的vision_clip的输出 | |
| # TODO: do not support IPAdapterPlus's vision_clip's output | |
| uncond_refer_face_image_emb = rearrange( | |
| uncond_refer_face_image_emb, | |
| "(b t) n q-> b (t n) q", | |
| t=n_refer_face_image, | |
| ) | |
| uncond_refer_face_image_emb = align_repeat_tensor_single_dim( | |
| uncond_refer_face_image_emb, target_length=batch_size, dim=0 | |
| ) | |
| if self.print_idx == 0: | |
| logger.debug( | |
| f"uncond_refer_face_image_emb, {uncond_refer_face_image_emb.shape}" | |
| ) | |
| logger.debug(f"refer_face_image_emb, {refer_face_image_emb.shape}") | |
| refer_face_image_emb = torch.concat( | |
| [ | |
| uncond_refer_face_image_emb, | |
| refer_face_image_emb, | |
| ], | |
| ) | |
| else: | |
| refer_face_image_emb = None | |
| if self.print_idx == 0: | |
| logger.debug(f"ip_adapter_face_emb={type(refer_face_image_emb)}") | |
| return refer_face_image_emb | |
| def get_ip_adapter_image_emb( | |
| self, | |
| ip_adapter_image, | |
| device, | |
| dtype, | |
| batch_size, | |
| do_classifier_free_guidance, | |
| height, | |
| width, | |
| ): | |
| # refer_image vision_clip and its ipadapter_emb | |
| if self.print_idx == 0: | |
| logger.debug( | |
| f"vision_clip_extractor={type(self.vision_clip_extractor)}," | |
| f"ip_adapter_image_proj={type(self.ip_adapter_image_proj)}," | |
| f"ip_adapter_image={type(ip_adapter_image)}," | |
| ) | |
| if self.vision_clip_extractor is not None and ip_adapter_image is not None: | |
| if self.print_idx == 0: | |
| logger.debug(f"ip_adapter_image={ip_adapter_image.shape}") | |
| if isinstance(ip_adapter_image, np.ndarray): | |
| ip_adapter_image = torch.from_numpy(ip_adapter_image) | |
| # ip_adapter_image = ip_adapter_image.to(device=device, dtype=dtype) | |
| n_ip_adapter_image = ip_adapter_image.shape[2] | |
| ip_adapter_image = rearrange(ip_adapter_image, "b c t h w-> (b t) h w c") | |
| ip_adapter_image_emb = self.vision_clip_extractor.extract_images( | |
| ip_adapter_image, | |
| target_height=height, | |
| target_width=width, | |
| return_type="torch", | |
| ) | |
| if ip_adapter_image_emb.ndim == 2: | |
| ip_adapter_image_emb = rearrange(ip_adapter_image_emb, "b q-> b 1 q") | |
| ip_adapter_image_emb_bk = ip_adapter_image_emb | |
| # 存在只需要image_prompt、但不需要 proj的场景,如使用image_prompt替代text_prompt | |
| # There are scenarios where only image_prompt is needed, but proj is not needed, such as using image_prompt instead of text_prompt | |
| if self.ip_adapter_image_proj is not None: | |
| logger.debug(f"ip_adapter_image_proj is None, ") | |
| ip_adapter_image_emb = self.ip_adapter_image_proj(ip_adapter_image_emb) | |
| # TODO: 当前不支持 IPAdapterPlus的vision_clip的输出 | |
| # TODO: do not support IPAdapterPlus's vision_clip's output | |
| ip_adapter_image_emb = rearrange( | |
| ip_adapter_image_emb, | |
| "(b t) n q-> b (t n) q", | |
| t=n_ip_adapter_image, | |
| ) | |
| ip_adapter_image_emb = align_repeat_tensor_single_dim( | |
| ip_adapter_image_emb, target_length=batch_size, dim=0 | |
| ) | |
| if do_classifier_free_guidance: | |
| # TODO:固定特征,有优化空间 | |
| # TODO: fix the feature, there is optimization space | |
| if self.ip_adapter_image_proj is not None: | |
| uncond_ip_adapter_image_emb = self.ip_adapter_image_proj( | |
| torch.zeros_like(ip_adapter_image_emb_bk).to( | |
| device=device, dtype=dtype | |
| ) | |
| ) | |
| if self.print_idx == 0: | |
| logger.debug( | |
| f"uncond_ip_adapter_image_emb use ip_adapter_image_proj(zero_like)" | |
| ) | |
| else: | |
| uncond_ip_adapter_image_emb = torch.zeros_like(ip_adapter_image_emb) | |
| if self.print_idx == 0: | |
| logger.debug(f"uncond_ip_adapter_image_emb use zero_like") | |
| # TODO:当前可能不支持 IPAdapterPlus的vision_clip的输出 | |
| # TODO: do not support IPAdapterPlus's vision_clip's output | |
| uncond_ip_adapter_image_emb = rearrange( | |
| uncond_ip_adapter_image_emb, | |
| "(b t) n q-> b (t n) q", | |
| t=n_ip_adapter_image, | |
| ) | |
| uncond_ip_adapter_image_emb = align_repeat_tensor_single_dim( | |
| uncond_ip_adapter_image_emb, target_length=batch_size, dim=0 | |
| ) | |
| if self.print_idx == 0: | |
| logger.debug( | |
| f"uncond_ip_adapter_image_emb, {uncond_ip_adapter_image_emb.shape}" | |
| ) | |
| logger.debug(f"ip_adapter_image_emb, {ip_adapter_image_emb.shape}") | |
| # uncond_ip_adapter_image_emb = torch.zeros_like(ip_adapter_image_emb) | |
| ip_adapter_image_emb = torch.concat( | |
| [ | |
| uncond_ip_adapter_image_emb, | |
| ip_adapter_image_emb, | |
| ], | |
| ) | |
| else: | |
| ip_adapter_image_emb = None | |
| if self.print_idx == 0: | |
| logger.debug(f"ip_adapter_image_emb={type(ip_adapter_image_emb)}") | |
| return ip_adapter_image_emb | |
| def get_referencenet_image_vae_emb( | |
| self, | |
| refer_image, | |
| batch_size, | |
| num_videos_per_prompt, | |
| device, | |
| dtype, | |
| do_classifier_free_guidance, | |
| width: int = None, | |
| height: int = None, | |
| ): | |
| # prepare_referencenet_emb | |
| if self.print_idx == 0: | |
| logger.debug( | |
| f"referencenet={type(self.referencenet)}, refer_image={type(refer_image)}" | |
| ) | |
| if self.referencenet is not None and refer_image is not None: | |
| n_refer_image = refer_image.shape[2] | |
| refer_image_vae = self.prepare_image( | |
| refer_image, | |
| batch_size=batch_size * num_videos_per_prompt, | |
| num_images_per_prompt=num_videos_per_prompt, | |
| device=device, | |
| dtype=dtype, | |
| width=width, | |
| height=height, | |
| ) | |
| # ref_hidden_states = self.vae.encode(refer_image_vae).latent_dist.sample() | |
| refer_image_vae_emb = self.vae.encode(refer_image_vae).latent_dist.mean | |
| refer_image_vae_emb = self.vae.config.scaling_factor * refer_image_vae_emb | |
| logger.debug(f"refer_image_vae_emb={refer_image_vae_emb.shape}") | |
| if do_classifier_free_guidance: | |
| # 1. zeros_like image | |
| # uncond_refer_image_vae_emb = self.vae.encode( | |
| # torch.zeros_like(refer_image_vae) | |
| # ).latent_dist.mean | |
| # uncond_refer_image_vae_emb = ( | |
| # self.vae.config.scaling_factor * uncond_refer_image_vae_emb | |
| # ) | |
| # 2. zeros_like image vae emb | |
| # uncond_refer_image_vae_emb = torch.zeros_like(refer_image_vae_emb) | |
| # uncond_refer_image_vae_emb = rearrange( | |
| # uncond_refer_image_vae_emb, | |
| # "(b t) c h w-> b c t h w", | |
| # t=n_refer_image, | |
| # ) | |
| # refer_image_vae_emb = rearrange( | |
| # refer_image_vae_emb, "(b t) c h w-> b c t h w", t=n_refer_image | |
| # ) | |
| # refer_image_vae_emb = torch.concat( | |
| # [uncond_refer_image_vae_emb, refer_image_vae_emb], dim=0 | |
| # ) | |
| # refer_image_vae_emb = rearrange( | |
| # refer_image_vae_emb, "b c t h w-> (b t) c h w" | |
| # ) | |
| # logger.debug(f"refer_image_vae_emb={refer_image_vae_emb.shape}") | |
| # 3. uncond_refer_image_vae_emb = refer_image_vae_emb | |
| uncond_refer_image_vae_emb = refer_image_vae_emb | |
| uncond_refer_image_vae_emb = rearrange( | |
| uncond_refer_image_vae_emb, | |
| "(b t) c h w-> b c t h w", | |
| t=n_refer_image, | |
| ) | |
| refer_image_vae_emb = rearrange( | |
| refer_image_vae_emb, "(b t) c h w-> b c t h w", t=n_refer_image | |
| ) | |
| refer_image_vae_emb = torch.concat( | |
| [uncond_refer_image_vae_emb, refer_image_vae_emb], dim=0 | |
| ) | |
| refer_image_vae_emb = rearrange( | |
| refer_image_vae_emb, "b c t h w-> (b t) c h w" | |
| ) | |
| logger.debug(f"refer_image_vae_emb={refer_image_vae_emb.shape}") | |
| else: | |
| refer_image_vae_emb = None | |
| return refer_image_vae_emb | |
| def get_referencenet_emb( | |
| self, | |
| refer_image_vae_emb, | |
| refer_image, | |
| batch_size, | |
| num_videos_per_prompt, | |
| device, | |
| dtype, | |
| ip_adapter_image_emb, | |
| do_classifier_free_guidance, | |
| prompt_embeds, | |
| ref_timestep_int: int = 0, | |
| ): | |
| # prepare_referencenet_emb | |
| if self.print_idx == 0: | |
| logger.debug( | |
| f"referencenet={type(self.referencenet)}, refer_image={type(refer_image)}" | |
| ) | |
| if ( | |
| self.referencenet is not None | |
| and refer_image_vae_emb is not None | |
| and refer_image is not None | |
| ): | |
| n_refer_image = refer_image.shape[2] | |
| # ref_timestep = ( | |
| # torch.ones((refer_image_vae_emb.shape[0],), device=device) | |
| # * ref_timestep_int | |
| # ) | |
| ref_timestep = torch.zeros_like(ref_timestep_int) | |
| # referencenet 优先使用 ip_adapter 中图像提取到的 clip_vision_emb | |
| if ip_adapter_image_emb is not None: | |
| refer_prompt_embeds = ip_adapter_image_emb | |
| else: | |
| refer_prompt_embeds = prompt_embeds | |
| if self.print_idx == 0: | |
| logger.debug( | |
| f"use referencenet: n_refer_image={n_refer_image}, refer_image_vae_emb={refer_image_vae_emb.shape}, ref_timestep={ref_timestep.shape}" | |
| ) | |
| if prompt_embeds is not None: | |
| logger.debug(f"prompt_embeds={prompt_embeds.shape},") | |
| # refer_image_vae_emb = self.scheduler.scale_model_input( | |
| # refer_image_vae_emb, ref_timestep | |
| # ) | |
| # self.scheduler._step_index = None | |
| # self.scheduler.is_scale_input_called = False | |
| referencenet_params = { | |
| "sample": refer_image_vae_emb, | |
| "encoder_hidden_states": refer_prompt_embeds, | |
| "timestep": ref_timestep, | |
| "num_frames": n_refer_image, | |
| "return_ndim": 5, | |
| } | |
| ( | |
| down_block_refer_embs, | |
| mid_block_refer_emb, | |
| refer_self_attn_emb, | |
| ) = self.referencenet(**referencenet_params) | |
| # many ways to prepare negative referencenet emb | |
| # mode 1 | |
| # zero shape like ref_image | |
| # if do_classifier_free_guidance: | |
| # # mode 2: | |
| # # if down_block_refer_embs is not None: | |
| # # down_block_refer_embs = [ | |
| # # torch.cat([x] * 2) for x in down_block_refer_embs | |
| # # ] | |
| # # if mid_block_refer_emb is not None: | |
| # # mid_block_refer_emb = torch.cat([mid_block_refer_emb] * 2) | |
| # # if refer_self_attn_emb is not None: | |
| # # refer_self_attn_emb = [ | |
| # # torch.cat([x] * 2) for x in refer_self_attn_emb | |
| # # ] | |
| # # mode 3 | |
| # if down_block_refer_embs is not None: | |
| # down_block_refer_embs = [ | |
| # torch.cat([torch.zeros_like(x), x]) | |
| # for x in down_block_refer_embs | |
| # ] | |
| # if mid_block_refer_emb is not None: | |
| # mid_block_refer_emb = torch.cat( | |
| # [torch.zeros_like(mid_block_refer_emb), mid_block_refer_emb] * 2 | |
| # ) | |
| # if refer_self_attn_emb is not None: | |
| # refer_self_attn_emb = [ | |
| # torch.cat([torch.zeros_like(x), x]) for x in refer_self_attn_emb | |
| # ] | |
| else: | |
| down_block_refer_embs = None | |
| mid_block_refer_emb = None | |
| refer_self_attn_emb = None | |
| if self.print_idx == 0: | |
| logger.debug(f"down_block_refer_embs={type(down_block_refer_embs)}") | |
| logger.debug(f"mid_block_refer_emb={type(mid_block_refer_emb)}") | |
| logger.debug(f"refer_self_attn_emb={type(refer_self_attn_emb)}") | |
| return down_block_refer_embs, mid_block_refer_emb, refer_self_attn_emb | |
| def prepare_condition_latents_and_index( | |
| self, | |
| condition_images, | |
| condition_latents, | |
| video_length, | |
| batch_size, | |
| dtype, | |
| device, | |
| latent_index, | |
| vision_condition_latent_index, | |
| ): | |
| # prepare condition_latents | |
| if condition_images is not None and condition_latents is None: | |
| # condition_latents = self.vae.encode(condition_images).latent_dist.sample() | |
| condition_latents = self.vae.encode(condition_images).latent_dist.mean | |
| condition_latents = self.vae.config.scaling_factor * condition_latents | |
| condition_latents = rearrange( | |
| condition_latents, "(b t) c h w-> b c t h w", b=batch_size | |
| ) | |
| if self.print_idx == 0: | |
| logger.debug( | |
| f"condition_latents from condition_images, shape is condition_latents={condition_latents.shape}", | |
| ) | |
| if condition_latents is not None: | |
| total_frames = condition_latents.shape[2] + video_length | |
| if isinstance(condition_latents, np.ndarray): | |
| condition_latents = torch.from_numpy(condition_latents) | |
| condition_latents = condition_latents.to(dtype=dtype, device=device) | |
| # if condition is None, mean condition_latents head, generated video is tail | |
| if vision_condition_latent_index is not None: | |
| # vision_condition_latent_index should be list, whose length is condition_latents.shape[2] | |
| # -1 -> will be converted to condition_latents.shape[2]+video_length | |
| vision_condition_latent_index_lst = [ | |
| i_v if i_v != -1 else total_frames - 1 | |
| for i_v in vision_condition_latent_index | |
| ] | |
| vision_condition_latent_index = torch.LongTensor( | |
| vision_condition_latent_index_lst, | |
| ).to(device=device) | |
| if self.print_idx == 0: | |
| logger.debug( | |
| f"vision_condition_latent_index {type(vision_condition_latent_index)}, {vision_condition_latent_index}" | |
| ) | |
| else: | |
| # [0, condition_latents.shape[2]] | |
| vision_condition_latent_index = torch.arange( | |
| condition_latents.shape[2], dtype=torch.long, device=device | |
| ) | |
| vision_condition_latent_index_lst = ( | |
| vision_condition_latent_index.tolist() | |
| ) | |
| if latent_index is None: | |
| # [condition_latents.shape[2], condition_latents.shape[2]+video_length] | |
| latent_index_lst = sorted( | |
| list( | |
| set(range(total_frames)) | |
| - set(vision_condition_latent_index_lst) | |
| ) | |
| ) | |
| latent_index = torch.LongTensor( | |
| latent_index_lst, | |
| ).to(device=device) | |
| if vision_condition_latent_index is not None: | |
| vision_condition_latent_index = vision_condition_latent_index.to( | |
| device=device | |
| ) | |
| if self.print_idx == 0: | |
| logger.debug( | |
| f"pipeline vision_condition_latent_index ={vision_condition_latent_index.shape}, {vision_condition_latent_index}" | |
| ) | |
| if latent_index is not None: | |
| latent_index = latent_index.to(device=device) | |
| if self.print_idx == 0: | |
| logger.debug( | |
| f"pipeline latent_index ={latent_index.shape}, {latent_index}" | |
| ) | |
| logger.debug(f"condition_latents={type(condition_latents)}") | |
| logger.debug(f"latent_index={type(latent_index)}") | |
| logger.debug( | |
| f"vision_condition_latent_index={type(vision_condition_latent_index)}" | |
| ) | |
| return condition_latents, latent_index, vision_condition_latent_index | |
| def prepare_controlnet_and_guidance_parameter( | |
| self, control_guidance_start, control_guidance_end | |
| ): | |
| controlnet = ( | |
| self.controlnet._orig_mod | |
| if is_compiled_module(self.controlnet) | |
| else self.controlnet | |
| ) | |
| # align format for control guidance | |
| if not isinstance(control_guidance_start, list) and isinstance( | |
| control_guidance_end, list | |
| ): | |
| control_guidance_start = len(control_guidance_end) * [ | |
| control_guidance_start | |
| ] | |
| elif not isinstance(control_guidance_end, list) and isinstance( | |
| control_guidance_start, list | |
| ): | |
| control_guidance_end = len(control_guidance_start) * [control_guidance_end] | |
| elif not isinstance(control_guidance_start, list) and not isinstance( | |
| control_guidance_end, list | |
| ): | |
| mult = ( | |
| len(controlnet.nets) | |
| if isinstance(controlnet, MultiControlNetModel) | |
| else 1 | |
| ) | |
| control_guidance_start, control_guidance_end = mult * [ | |
| control_guidance_start | |
| ], mult * [control_guidance_end] | |
| return controlnet, control_guidance_start, control_guidance_end | |
| def prepare_controlnet_guess_mode(self, controlnet, guess_mode): | |
| global_pool_conditions = ( | |
| controlnet.config.global_pool_conditions | |
| if isinstance(controlnet, ControlNetModel) | |
| else controlnet.nets[0].config.global_pool_conditions | |
| ) | |
| guess_mode = guess_mode or global_pool_conditions | |
| return guess_mode | |
| def prepare_controlnet_image_and_latents( | |
| self, | |
| controlnet, | |
| width, | |
| height, | |
| batch_size, | |
| num_videos_per_prompt, | |
| device, | |
| dtype, | |
| controlnet_latents=None, | |
| controlnet_condition_latents=None, | |
| control_image=None, | |
| controlnet_condition_images=None, | |
| guess_mode=False, | |
| do_classifier_free_guidance=False, | |
| ): | |
| if isinstance(controlnet, ControlNetModel): | |
| if controlnet_latents is not None: | |
| if isinstance(controlnet_latents, np.ndarray): | |
| controlnet_latents = torch.from_numpy(controlnet_latents) | |
| if controlnet_condition_latents is not None: | |
| if isinstance(controlnet_condition_latents, np.ndarray): | |
| controlnet_condition_latents = torch.from_numpy( | |
| controlnet_condition_latents | |
| ) | |
| # TODO:使用index进行concat | |
| controlnet_latents = torch.concat( | |
| [controlnet_condition_latents, controlnet_latents], dim=2 | |
| ) | |
| if not guess_mode and do_classifier_free_guidance: | |
| controlnet_latents = torch.concat([controlnet_latents] * 2, dim=0) | |
| controlnet_latents = rearrange( | |
| controlnet_latents, "b c t h w->(b t) c h w" | |
| ) | |
| controlnet_latents = controlnet_latents.to(device=device, dtype=dtype) | |
| if self.print_idx == 0: | |
| logger.debug( | |
| f"call, controlnet_latents.shape, f{controlnet_latents.shape}" | |
| ) | |
| else: | |
| # TODO: concat with index | |
| if isinstance(control_image, np.ndarray): | |
| control_image = torch.from_numpy(control_image) | |
| if controlnet_condition_images is not None: | |
| if isinstance(controlnet_condition_images, np.ndarray): | |
| controlnet_condition_images = torch.from_numpy( | |
| controlnet_condition_images | |
| ) | |
| control_image = torch.concatenate( | |
| [controlnet_condition_images, control_image], dim=2 | |
| ) | |
| control_image = self.prepare_control_image( | |
| image=control_image, | |
| width=width, | |
| height=height, | |
| batch_size=batch_size * num_videos_per_prompt, | |
| num_images_per_prompt=num_videos_per_prompt, | |
| device=device, | |
| dtype=controlnet.dtype, | |
| do_classifier_free_guidance=do_classifier_free_guidance, | |
| guess_mode=guess_mode, | |
| ) | |
| height, width = control_image.shape[-2:] | |
| if self.print_idx == 0: | |
| logger.debug(f"call, control_image.shape , {control_image.shape}") | |
| elif isinstance(controlnet, MultiControlNetModel): | |
| control_images = [] | |
| # TODO: directly support contronet_latent instead of frames | |
| if ( | |
| controlnet_latents is not None | |
| and controlnet_condition_latents is not None | |
| ): | |
| raise NotImplementedError | |
| for i, control_image_ in enumerate(control_image): | |
| if controlnet_condition_images is not None and isinstance( | |
| controlnet_condition_images, list | |
| ): | |
| if isinstance(controlnet_condition_images[i], np.ndarray): | |
| control_image_ = np.concatenate( | |
| [controlnet_condition_images[i], control_image_], axis=2 | |
| ) | |
| control_image_ = self.prepare_control_image( | |
| image=control_image_, | |
| width=width, | |
| height=height, | |
| batch_size=batch_size * num_videos_per_prompt, | |
| num_images_per_prompt=num_videos_per_prompt, | |
| device=device, | |
| dtype=controlnet.dtype, | |
| do_classifier_free_guidance=do_classifier_free_guidance, | |
| guess_mode=guess_mode, | |
| ) | |
| control_images.append(control_image_) | |
| control_image = control_images | |
| height, width = control_image[0].shape[-2:] | |
| else: | |
| assert False | |
| if control_image is not None: | |
| if not isinstance(control_image, list): | |
| if self.print_idx == 0: | |
| logger.debug(f"control_image shape is {control_image.shape}") | |
| else: | |
| if self.print_idx == 0: | |
| logger.debug(f"control_image shape is {control_image[0].shape}") | |
| return control_image, controlnet_latents | |
| def get_controlnet_emb( | |
| self, | |
| run_controlnet, | |
| guess_mode, | |
| do_classifier_free_guidance, | |
| latents, | |
| prompt_embeds, | |
| latent_model_input, | |
| controlnet_keep, | |
| controlnet_conditioning_scale, | |
| control_image, | |
| controlnet_latents, | |
| i, | |
| t, | |
| ): | |
| if run_controlnet and self.pose_guider is None: | |
| # controlnet(s) inference | |
| if guess_mode and do_classifier_free_guidance: | |
| # Infer ControlNet only for the conditional batch. | |
| control_model_input = latents | |
| control_model_input = self.scheduler.scale_model_input( | |
| control_model_input, t | |
| ) | |
| controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] | |
| else: | |
| control_model_input = latent_model_input | |
| controlnet_prompt_embeds = prompt_embeds | |
| if isinstance(controlnet_keep[i], list): | |
| cond_scale = [ | |
| c * s | |
| for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i]) | |
| ] | |
| else: | |
| cond_scale = controlnet_conditioning_scale * controlnet_keep[i] | |
| control_model_input_reshape = rearrange( | |
| control_model_input, "b c t h w -> (b t) c h w" | |
| ) | |
| logger.debug( | |
| f"control_model_input_reshape={control_model_input_reshape.shape}, controlnet_prompt_embeds={controlnet_prompt_embeds.shape}" | |
| ) | |
| encoder_hidden_states_repeat = align_repeat_tensor_single_dim( | |
| controlnet_prompt_embeds, | |
| target_length=control_model_input_reshape.shape[0], | |
| dim=0, | |
| ) | |
| if self.print_idx == 0: | |
| logger.debug( | |
| f"control_model_input_reshape={control_model_input_reshape.shape}, " | |
| f"encoder_hidden_states_repeat={encoder_hidden_states_repeat.shape}, " | |
| ) | |
| down_block_res_samples, mid_block_res_sample = self.controlnet( | |
| control_model_input_reshape, | |
| t, | |
| encoder_hidden_states_repeat, | |
| controlnet_cond=control_image, | |
| controlnet_cond_latents=controlnet_latents, | |
| conditioning_scale=cond_scale, | |
| guess_mode=guess_mode, | |
| return_dict=False, | |
| ) | |
| if self.print_idx == 0: | |
| logger.debug( | |
| f"controlnet, len(down_block_res_samples, {len(down_block_res_samples)}", | |
| ) | |
| for i_tmp, tmp in enumerate(down_block_res_samples): | |
| logger.debug( | |
| f"controlnet down_block_res_samples i={i_tmp}, down_block_res_sample={tmp.shape}" | |
| ) | |
| logger.debug( | |
| f"controlnet mid_block_res_sample, {mid_block_res_sample.shape}" | |
| ) | |
| if guess_mode and do_classifier_free_guidance: | |
| # Infered ControlNet only for the conditional batch. | |
| # To apply the output of ControlNet to both the unconditional and conditional batches, | |
| # add 0 to the unconditional batch to keep it unchanged. | |
| down_block_res_samples = [ | |
| torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples | |
| ] | |
| mid_block_res_sample = torch.cat( | |
| [ | |
| torch.zeros_like(mid_block_res_sample), | |
| mid_block_res_sample, | |
| ] | |
| ) | |
| else: | |
| down_block_res_samples = None | |
| mid_block_res_sample = None | |
| return down_block_res_samples, mid_block_res_sample | |
| def __call__( | |
| self, | |
| video_length: Optional[int], | |
| prompt: Union[str, List[str]] = None, | |
| # b c t h w | |
| image: Union[ | |
| torch.FloatTensor, | |
| PIL.Image.Image, | |
| np.ndarray, | |
| List[torch.FloatTensor], | |
| List[PIL.Image.Image], | |
| List[np.ndarray], | |
| ] = None, | |
| control_image: Union[ | |
| torch.FloatTensor, | |
| PIL.Image.Image, | |
| np.ndarray, | |
| List[torch.FloatTensor], | |
| List[PIL.Image.Image], | |
| List[np.ndarray], | |
| ] = None, | |
| # b c t(1) ho wo | |
| condition_images: Optional[torch.FloatTensor] = None, | |
| condition_latents: Optional[torch.FloatTensor] = None, | |
| latents: Optional[torch.FloatTensor] = None, | |
| add_latents_noise: bool = False, | |
| height: Optional[int] = None, | |
| width: Optional[int] = None, | |
| strength: float = 0.8, | |
| num_inference_steps: int = 50, | |
| guidance_scale: float = 7.5, | |
| guidance_scale_end: float = None, | |
| guidance_scale_method: str = "linear", | |
| negative_prompt: Optional[Union[str, List[str]]] = None, | |
| num_videos_per_prompt: Optional[int] = 1, | |
| eta: float = 0.0, | |
| generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | |
| # b c t(1) hi wi | |
| controlnet_condition_images: Optional[torch.FloatTensor] = None, | |
| # b c t(1) ho wo | |
| controlnet_condition_latents: Optional[torch.FloatTensor] = None, | |
| controlnet_latents: Union[torch.FloatTensor, np.ndarray] = None, | |
| prompt_embeds: Optional[torch.FloatTensor] = None, | |
| negative_prompt_embeds: Optional[torch.FloatTensor] = None, | |
| output_type: Optional[str] = "tensor", | |
| return_dict: bool = True, | |
| callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, | |
| callback_steps: int = 1, | |
| cross_attention_kwargs: Optional[Dict[str, Any]] = None, | |
| controlnet_conditioning_scale: Union[float, List[float]] = 1.0, | |
| guess_mode: bool = False, | |
| control_guidance_start: Union[float, List[float]] = 0.0, | |
| control_guidance_end: Union[float, List[float]] = 1.0, | |
| need_middle_latents: bool = False, | |
| w_ind_noise: float = 0.5, | |
| initial_common_latent: Optional[torch.FloatTensor] = None, | |
| latent_index: torch.LongTensor = None, | |
| vision_condition_latent_index: torch.LongTensor = None, | |
| # noise parameters | |
| noise_type: str = "random", | |
| need_img_based_video_noise: bool = False, | |
| skip_temporal_layer: bool = False, | |
| img_weight: float = 1e-3, | |
| need_hist_match: bool = False, | |
| motion_speed: float = 8.0, | |
| refer_image: Optional[Tuple[torch.Tensor, np.array]] = None, | |
| ip_adapter_image: Optional[Tuple[torch.Tensor, np.array]] = None, | |
| refer_face_image: Optional[Tuple[torch.Tensor, np.array]] = None, | |
| ip_adapter_scale: float = 1.0, | |
| facein_scale: float = 1.0, | |
| ip_adapter_face_scale: float = 1.0, | |
| ip_adapter_face_image: Optional[Tuple[torch.Tensor, np.array]] = None, | |
| prompt_only_use_image_prompt: bool = False, | |
| # serial_denoise parameter start | |
| record_mid_video_noises: bool = False, | |
| last_mid_video_noises: List[torch.Tensor] = None, | |
| record_mid_video_latents: bool = False, | |
| last_mid_video_latents: List[torch.TensorType] = None, | |
| video_overlap: int = 1, | |
| # serial_denoise parameter end | |
| # parallel_denoise parameter start | |
| # refer to https://github.com/MooreThreads/Moore-AnimateAnyone/blob/master/src/pipelines/pipeline_pose2vid_long.py#L354 | |
| context_schedule="uniform", | |
| context_frames=12, | |
| context_stride=1, | |
| context_overlap=4, | |
| context_batch_size=1, | |
| interpolation_factor=1, | |
| # parallel_denoise parameter end | |
| ): | |
| r""" | |
| 旨在兼容text2video、text2image、img2img、video2video、是否有controlnet等的通用pipeline。目前仅不支持img2img、video2video。 | |
| 支持多片段同时denoise,交叉部分加权平均 | |
| 当 skip_temporal_layer 为 False 时, unet 起 video 生成作用;skip_temporal_layer为True时,unet起原image作用。 | |
| 当controlnet的所有入参为None,等价于走的是text2video pipeline; | |
| 当 condition_latents、controlnet_condition_images、controlnet_condition_latents为None时,表示不走首帧条件生成的时序condition pipeline | |
| 现在没有考虑对 `num_videos_per_prompt` 的兼容性,不是1可能报错; | |
| if skip_temporal_layer is False, unet motion layer works, else unet only run text2image layers. | |
| if parameters about controlnet are None, means text2video pipeline; | |
| if ondition_latents、controlnet_condition_images、controlnet_condition_latents are None, means only run text2video without vision condition images. | |
| By now, code works well with `num_videos_per_prpmpt=1`, !=1 may be wrong. | |
| Args: | |
| prompt (`str` or `List[str]`, *optional*): | |
| The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. | |
| instead. | |
| image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: | |
| `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): | |
| The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If | |
| the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can | |
| also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If | |
| height and/or width are passed, `image` is resized according to them. If multiple ControlNets are | |
| specified in init, images must be passed as a list such that each element of the list can be correctly | |
| batched for input to a single controlnet. | |
| condition_latents: | |
| 与latents相对应,是Latents的时序condition,一般为首帧,b c t(1) ho wo | |
| be corresponding to latents, vision condtion latents, usually first frame, should be b c t(1) ho wo. | |
| controlnet_latents: | |
| 与image二选一,image会被转化成controlnet_latents | |
| Choose either image or controlnet_latents. If image is chosen, it will be converted to controlnet_latents. | |
| controlnet_condition_images: | |
| Optional[torch.FloatTensor]# b c t(1) ho wo,与image相对应,会和image在t通道concat一起,然后转化成 controlnet_latents | |
| b c t(1) ho wo, corresponding to image, will be concatenated along the t channel with image and then converted to controlnet_latents. | |
| controlnet_condition_latents: Optional[torch.FloatTensor]:# | |
| b c t(1) ho wo,会和 controlnet_latents 在t 通道concat一起,转化成 controlnet_latents | |
| b c t(1) ho wo will be concatenated along the t channel with controlnet_latents and converted to controlnet_latents. | |
| height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): | |
| The height in pixels of the generated image. | |
| width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): | |
| The width in pixels of the generated image. | |
| num_inference_steps (`int`, *optional*, defaults to 50): | |
| The number of denoising steps. More denoising steps usually lead to a higher quality image at the | |
| expense of slower inference. | |
| guidance_scale (`float`, *optional*, defaults to 7.5): | |
| Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). | |
| `guidance_scale` is defined as `w` of equation 2. of [Imagen | |
| Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > | |
| 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, | |
| usually at the expense of lower image quality. | |
| negative_prompt (`str` or `List[str]`, *optional*): | |
| The prompt or prompts not to guide the image generation. If not defined, one has to pass | |
| `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is | |
| less than `1`). | |
| strength (`float`, *optional*, defaults to 0.8): | |
| Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a | |
| starting point and more noise is added the higher the `strength`. The number of denoising steps depends | |
| on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising | |
| process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 | |
| essentially ignores `image`. | |
| num_images_per_prompt (`int`, *optional*, defaults to 1): | |
| The number of images to generate per prompt. | |
| eta (`float`, *optional*, defaults to 0.0): | |
| Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to | |
| [`schedulers.DDIMScheduler`], will be ignored for others. | |
| generator (`torch.Generator` or `List[torch.Generator]`, *optional*): | |
| One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) | |
| to make generation deterministic. | |
| latents (`torch.FloatTensor`, *optional*): | |
| Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image | |
| generation. Can be used to tweak the same generation with different prompts. If not provided, a latents | |
| tensor will ge generated by sampling using the supplied random `generator`. | |
| prompt_embeds (`torch.FloatTensor`, *optional*): | |
| Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not | |
| provided, text embeddings will be generated from `prompt` input argument. | |
| negative_prompt_embeds (`torch.FloatTensor`, *optional*): | |
| Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt | |
| weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input | |
| argument. | |
| output_type (`str`, *optional*, defaults to `"pil"`): | |
| The output format of the generate image. Choose between | |
| [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. | |
| return_dict (`bool`, *optional*, defaults to `True`): | |
| Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a | |
| plain tuple. | |
| callback (`Callable`, *optional*): | |
| A function that will be called every `callback_steps` steps during inference. The function will be | |
| called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. | |
| callback_steps (`int`, *optional*, defaults to 1): | |
| The frequency at which the `callback` function will be called. If not specified, the callback will be | |
| called at every step. | |
| cross_attention_kwargs (`dict`, *optional*): | |
| A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under | |
| `self.processor` in | |
| [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). | |
| controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): | |
| The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added | |
| to the residual in the original unet. If multiple ControlNets are specified in init, you can set the | |
| corresponding scale as a list. | |
| guess_mode (`bool`, *optional*, defaults to `False`): | |
| In this mode, the ControlNet encoder will try best to recognize the content of the input image even if | |
| you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended. | |
| control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): | |
| The percentage of total steps at which the controlnet starts applying. | |
| control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): | |
| The percentage of total steps at which the controlnet stops applying. | |
| skip_temporal_layer (`bool`: default to False) 为False时,unet起video生成作用,会运行时序生成的block;skip_temporal_layer为True时,unet起原image作用,跳过时序生成的block。 | |
| need_img_based_video_noise: bool = False, 当只有首帧latents时,是否需要扩展为video noise; | |
| num_videos_per_prompt: now only support 1. | |
| Examples: | |
| Returns: | |
| [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: | |
| [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. | |
| When returning a tuple, the first element is a list with the generated images, and the second element is a | |
| list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" | |
| (nsfw) content, according to the `safety_checker`. | |
| """ | |
| run_controlnet = control_image is not None or controlnet_latents is not None | |
| if run_controlnet: | |
| ( | |
| controlnet, | |
| control_guidance_start, | |
| control_guidance_end, | |
| ) = self.prepare_controlnet_and_guidance_parameter( | |
| control_guidance_start=control_guidance_start, | |
| control_guidance_end=control_guidance_end, | |
| ) | |
| # 1. Check inputs. Raise error if not correct | |
| self.check_inputs( | |
| prompt, | |
| control_image, | |
| callback_steps, | |
| negative_prompt, | |
| prompt_embeds, | |
| negative_prompt_embeds, | |
| controlnet_conditioning_scale, | |
| control_guidance_start, | |
| control_guidance_end, | |
| ) | |
| # 2. Define call parameters | |
| if prompt is not None and isinstance(prompt, str): | |
| batch_size = 1 | |
| elif prompt is not None and isinstance(prompt, list): | |
| batch_size = len(prompt) | |
| else: | |
| batch_size = prompt_embeds.shape[0] | |
| device = self._execution_device | |
| dtype = self.unet.dtype | |
| # print("pipeline unet dtype", dtype) | |
| # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) | |
| # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` | |
| # corresponds to doing no classifier free guidance. | |
| do_classifier_free_guidance = guidance_scale > 1.0 | |
| if run_controlnet: | |
| if isinstance(controlnet, MultiControlNetModel) and isinstance( | |
| controlnet_conditioning_scale, float | |
| ): | |
| controlnet_conditioning_scale = [controlnet_conditioning_scale] * len( | |
| controlnet.nets | |
| ) | |
| guess_mode = self.prepare_controlnet_guess_mode( | |
| controlnet=controlnet, | |
| guess_mode=guess_mode, | |
| ) | |
| # 3. Encode input prompt | |
| text_encoder_lora_scale = ( | |
| cross_attention_kwargs.get("scale", None) | |
| if cross_attention_kwargs is not None | |
| else None | |
| ) | |
| if self.text_encoder is not None: | |
| prompt_embeds = encode_weighted_prompt( | |
| self, | |
| prompt, | |
| device, | |
| num_videos_per_prompt, | |
| do_classifier_free_guidance, | |
| negative_prompt, | |
| prompt_embeds=prompt_embeds, | |
| negative_prompt_embeds=negative_prompt_embeds, | |
| # lora_scale=text_encoder_lora_scale, | |
| ) | |
| logger.debug(f"use text_encoder prepare prompt_emb={prompt_embeds.shape}") | |
| else: | |
| prompt_embeds = None | |
| if image is not None: | |
| image = self.prepare_image( | |
| image, | |
| width=width, | |
| height=height, | |
| batch_size=batch_size * num_videos_per_prompt, | |
| num_images_per_prompt=num_videos_per_prompt, | |
| device=device, | |
| dtype=dtype, | |
| ) | |
| if self.print_idx == 0: | |
| logger.debug(f"image={image.shape}") | |
| if condition_images is not None: | |
| condition_images = self.prepare_image( | |
| condition_images, | |
| width=width, | |
| height=height, | |
| batch_size=batch_size * num_videos_per_prompt, | |
| num_images_per_prompt=num_videos_per_prompt, | |
| device=device, | |
| dtype=dtype, | |
| ) | |
| if self.print_idx == 0: | |
| logger.debug(f"condition_images={condition_images.shape}") | |
| # 4. Prepare image | |
| if run_controlnet: | |
| ( | |
| control_image, | |
| controlnet_latents, | |
| ) = self.prepare_controlnet_image_and_latents( | |
| controlnet=controlnet, | |
| width=width, | |
| height=height, | |
| batch_size=batch_size, | |
| num_videos_per_prompt=num_videos_per_prompt, | |
| device=device, | |
| dtype=dtype, | |
| controlnet_condition_latents=controlnet_condition_latents, | |
| control_image=control_image, | |
| controlnet_condition_images=controlnet_condition_images, | |
| guess_mode=guess_mode, | |
| do_classifier_free_guidance=do_classifier_free_guidance, | |
| controlnet_latents=controlnet_latents, | |
| ) | |
| # 5. Prepare timesteps | |
| self.scheduler.set_timesteps(num_inference_steps, device=device) | |
| if strength and (image is not None and latents is not None): | |
| if self.print_idx == 0: | |
| logger.debug( | |
| f"prepare timesteps, with get_timesteps strength={strength}, num_inference_steps={num_inference_steps}" | |
| ) | |
| timesteps, num_inference_steps = self.get_timesteps( | |
| num_inference_steps, strength, device | |
| ) | |
| else: | |
| if self.print_idx == 0: | |
| logger.debug(f"prepare timesteps, without get_timesteps") | |
| timesteps = self.scheduler.timesteps | |
| latent_timestep = timesteps[:1].repeat( | |
| batch_size * num_videos_per_prompt | |
| ) # 6. Prepare latent variables | |
| ( | |
| condition_latents, | |
| latent_index, | |
| vision_condition_latent_index, | |
| ) = self.prepare_condition_latents_and_index( | |
| condition_images=condition_images, | |
| condition_latents=condition_latents, | |
| video_length=video_length, | |
| batch_size=batch_size, | |
| dtype=dtype, | |
| device=device, | |
| latent_index=latent_index, | |
| vision_condition_latent_index=vision_condition_latent_index, | |
| ) | |
| if vision_condition_latent_index is None: | |
| n_vision_cond = 0 | |
| else: | |
| n_vision_cond = vision_condition_latent_index.shape[0] | |
| num_channels_latents = self.unet.config.in_channels | |
| if self.print_idx == 0: | |
| logger.debug(f"pipeline controlnet, start prepare latents") | |
| latents = self.prepare_latents( | |
| batch_size=batch_size * num_videos_per_prompt, | |
| num_channels_latents=num_channels_latents, | |
| video_length=video_length, | |
| height=height, | |
| width=width, | |
| dtype=dtype, | |
| device=device, | |
| generator=generator, | |
| latents=latents, | |
| image=image, | |
| timestep=latent_timestep, | |
| w_ind_noise=w_ind_noise, | |
| initial_common_latent=initial_common_latent, | |
| noise_type=noise_type, | |
| add_latents_noise=add_latents_noise, | |
| need_img_based_video_noise=need_img_based_video_noise, | |
| condition_latents=condition_latents, | |
| img_weight=img_weight, | |
| ) | |
| if self.print_idx == 0: | |
| logger.debug(f"pipeline controlnet, finish prepare latents={latents.shape}") | |
| # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline | |
| extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) | |
| if noise_type == "video_fusion" and "noise_type" in set( | |
| inspect.signature(self.scheduler.step).parameters.keys() | |
| ): | |
| extra_step_kwargs["w_ind_noise"] = w_ind_noise | |
| extra_step_kwargs["noise_type"] = noise_type | |
| # extra_step_kwargs["noise_offset"] = noise_offset | |
| # 7.1 Create tensor stating which controlnets to keep | |
| if run_controlnet: | |
| controlnet_keep = [] | |
| for i in range(len(timesteps)): | |
| keeps = [ | |
| 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) | |
| for s, e in zip(control_guidance_start, control_guidance_end) | |
| ] | |
| controlnet_keep.append( | |
| keeps[0] if isinstance(controlnet, ControlNetModel) else keeps | |
| ) | |
| else: | |
| controlnet_keep = None | |
| # 8. Denoising loop | |
| num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order | |
| if skip_temporal_layer: | |
| self.unet.set_skip_temporal_layers(True) | |
| n_timesteps = len(timesteps) | |
| guidance_scale_lst = generate_parameters_with_timesteps( | |
| start=guidance_scale, | |
| stop=guidance_scale_end, | |
| num=n_timesteps, | |
| method=guidance_scale_method, | |
| ) | |
| if self.print_idx == 0: | |
| logger.debug( | |
| f"guidance_scale_lst, {guidance_scale_method}, {guidance_scale}, {guidance_scale_end}, {guidance_scale_lst}" | |
| ) | |
| ip_adapter_image_emb = self.get_ip_adapter_image_emb( | |
| ip_adapter_image=ip_adapter_image, | |
| batch_size=batch_size, | |
| device=device, | |
| dtype=dtype, | |
| do_classifier_free_guidance=do_classifier_free_guidance, | |
| height=height, | |
| width=width, | |
| ) | |
| # 当前仅当没有ip_adapter时,按照参数 prompt_only_use_image_prompt 要求是否完全替换 image_prompt_emb | |
| # only if ip_adapter is None and prompt_only_use_image_prompt is True, use image_prompt_emb replace text_prompt | |
| if ( | |
| ip_adapter_image_emb is not None | |
| and prompt_only_use_image_prompt | |
| and not self.unet.ip_adapter_cross_attn | |
| ): | |
| prompt_embeds = ip_adapter_image_emb | |
| logger.debug(f"use ip_adapter_image_emb replace prompt_embeds") | |
| refer_face_image_emb = self.get_facein_image_emb( | |
| refer_face_image=refer_face_image, | |
| batch_size=batch_size, | |
| device=device, | |
| dtype=dtype, | |
| do_classifier_free_guidance=do_classifier_free_guidance, | |
| ) | |
| ip_adapter_face_emb = self.get_ip_adapter_face_emb( | |
| refer_face_image=ip_adapter_face_image, | |
| batch_size=batch_size, | |
| device=device, | |
| dtype=dtype, | |
| do_classifier_free_guidance=do_classifier_free_guidance, | |
| ) | |
| refer_image_vae_emb = self.get_referencenet_image_vae_emb( | |
| refer_image=refer_image, | |
| device=device, | |
| dtype=dtype, | |
| do_classifier_free_guidance=do_classifier_free_guidance, | |
| num_videos_per_prompt=num_videos_per_prompt, | |
| batch_size=batch_size, | |
| width=width, | |
| height=height, | |
| ) | |
| if self.pose_guider is not None and control_image is not None: | |
| if self.print_idx == 0: | |
| logger.debug(f"pose_guider, controlnet_image={control_image.shape}") | |
| control_image = rearrange( | |
| control_image, " (b t) c h w->b c t h w", t=video_length | |
| ) | |
| pose_guider_emb = self.pose_guider(control_image) | |
| pose_guider_emb = rearrange(pose_guider_emb, "b c t h w-> (b t) c h w") | |
| else: | |
| pose_guider_emb = None | |
| logger.debug(f"prompt_embeds={prompt_embeds.shape}") | |
| if control_image is not None: | |
| if isinstance(control_image, list): | |
| logger.debug(f"control_imageis list, num={len(control_image)}") | |
| control_image = [ | |
| rearrange( | |
| control_image_tmp, | |
| " (b t) c h w->b c t h w", | |
| b=(int(do_classifier_free_guidance) * 1 + 1) * batch_size, | |
| ) | |
| for control_image_tmp in control_image | |
| ] | |
| else: | |
| logger.debug(f"control_image={control_image.shape}, before") | |
| control_image = rearrange( | |
| control_image, | |
| " (b t) c h w->b c t h w", | |
| b=(int(do_classifier_free_guidance) * 1 + 1) * batch_size, | |
| ) | |
| logger.debug(f"control_image={control_image.shape}, after") | |
| if controlnet_latents is not None: | |
| if isinstance(controlnet_latents, list): | |
| logger.debug( | |
| f"controlnet_latents is list, num={len(controlnet_latents)}" | |
| ) | |
| controlnet_latents = [ | |
| rearrange( | |
| controlnet_latents_tmp, | |
| " (b t) c h w->b c t h w", | |
| b=(int(do_classifier_free_guidance) * 1 + 1) * batch_size, | |
| ) | |
| for controlnet_latents_tmp in controlnet_latents | |
| ] | |
| else: | |
| logger.debug(f"controlnet_latents={controlnet_latents.shape}, before") | |
| controlnet_latents = rearrange( | |
| controlnet_latents, | |
| " (b t) c h w->b c t h w", | |
| b=(int(do_classifier_free_guidance) * 1 + 1) * batch_size, | |
| ) | |
| logger.debug(f"controlnet_latents={controlnet_latents.shape}, after") | |
| videos_mid = [] | |
| mid_video_noises = [] if record_mid_video_noises else None | |
| mid_video_latents = [] if record_mid_video_latents else None | |
| global_context = prepare_global_context( | |
| context_schedule=context_schedule, | |
| num_inference_steps=num_inference_steps, | |
| time_size=latents.shape[2], | |
| context_frames=context_frames, | |
| context_stride=context_stride, | |
| context_overlap=context_overlap, | |
| context_batch_size=context_batch_size, | |
| ) | |
| logger.debug( | |
| f"context_schedule={context_schedule}, time_size={latents.shape[2]}, context_frames={context_frames}, context_stride={context_stride}, context_overlap={context_overlap}, context_batch_size={context_batch_size}" | |
| ) | |
| logger.debug(f"global_context={global_context}") | |
| # iterative denoise | |
| with self.progress_bar(total=num_inference_steps) as progress_bar: | |
| for i, t in enumerate(timesteps): | |
| # 使用 last_mid_video_latents 来影响初始化latent,该部分效果较差,暂留代码 | |
| # use last_mide_video_latents to affect initial latent. works bad, Temporarily reserved | |
| if i == 0: | |
| if record_mid_video_latents: | |
| mid_video_latents.append(latents[:, :, -video_overlap:]) | |
| if record_mid_video_noises: | |
| mid_video_noises.append(None) | |
| if ( | |
| last_mid_video_latents is not None | |
| and len(last_mid_video_latents) > 0 | |
| ): | |
| if self.print_idx == 1: | |
| logger.debug( | |
| f"{i}, last_mid_video_latents={last_mid_video_latents[i].shape}" | |
| ) | |
| latents = fuse_part_tensor( | |
| last_mid_video_latents[0], | |
| latents, | |
| video_overlap, | |
| weight=0.1, | |
| skip_step=0, | |
| ) | |
| noise_pred = torch.zeros( | |
| ( | |
| latents.shape[0] * (2 if do_classifier_free_guidance else 1), | |
| *latents.shape[1:], | |
| ), | |
| device=latents.device, | |
| dtype=latents.dtype, | |
| ) | |
| counter = torch.zeros( | |
| (1, 1, latents.shape[2], 1, 1), | |
| device=latents.device, | |
| dtype=latents.dtype, | |
| ) | |
| if i == 0: | |
| ( | |
| down_block_refer_embs, | |
| mid_block_refer_emb, | |
| refer_self_attn_emb, | |
| ) = self.get_referencenet_emb( | |
| refer_image_vae_emb=refer_image_vae_emb, | |
| refer_image=refer_image, | |
| device=device, | |
| dtype=dtype, | |
| do_classifier_free_guidance=do_classifier_free_guidance, | |
| num_videos_per_prompt=num_videos_per_prompt, | |
| prompt_embeds=prompt_embeds, | |
| ip_adapter_image_emb=ip_adapter_image_emb, | |
| batch_size=batch_size, | |
| ref_timestep_int=t, | |
| ) | |
| for context in global_context: | |
| # expand the latents if we are doing classifier free guidance | |
| latents_c = torch.cat([latents[:, :, c] for c in context]) | |
| latent_index_c = ( | |
| torch.cat([latent_index[c] for c in context]) | |
| if latent_index is not None | |
| else None | |
| ) | |
| latent_model_input = latents_c.to(device).repeat( | |
| 2 if do_classifier_free_guidance else 1, 1, 1, 1, 1 | |
| ) | |
| latent_model_input = self.scheduler.scale_model_input( | |
| latent_model_input, t | |
| ) | |
| sub_latent_index_c = ( | |
| torch.LongTensor( | |
| torch.arange(latent_index_c.shape[-1]) + n_vision_cond | |
| ).to(device=latents_c.device) | |
| if latent_index is not None | |
| else None | |
| ) | |
| if condition_latents is not None: | |
| latent_model_condition = ( | |
| torch.cat([condition_latents] * 2) | |
| if do_classifier_free_guidance | |
| else latents | |
| ) | |
| if self.print_idx == 0: | |
| logger.debug( | |
| f"vision_condition_latent_index, {vision_condition_latent_index.shape}, vision_condition_latent_index" | |
| ) | |
| logger.debug( | |
| f"latent_model_condition, {latent_model_condition.shape}" | |
| ) | |
| logger.debug(f"latent_index, {latent_index_c.shape}") | |
| logger.debug( | |
| f"latent_model_input, {latent_model_input.shape}" | |
| ) | |
| logger.debug(f"sub_latent_index_c, {sub_latent_index_c}") | |
| latent_model_input = batch_concat_two_tensor_with_index( | |
| data1=latent_model_condition, | |
| data1_index=vision_condition_latent_index, | |
| data2=latent_model_input, | |
| data2_index=sub_latent_index_c, | |
| dim=2, | |
| ) | |
| if control_image is not None: | |
| if vision_condition_latent_index is not None: | |
| # 获取 vision_condition 对应的 control_imgae/control_latent 部分 | |
| # generate control_image/control_latent corresponding to vision_condition | |
| controlnet_condtion_latent_index = ( | |
| vision_condition_latent_index.clone().cpu().tolist() | |
| ) | |
| if self.print_idx == 0: | |
| logger.debug( | |
| f"context={context}, controlnet_condtion_latent_index={controlnet_condtion_latent_index}" | |
| ) | |
| controlnet_context = [ | |
| controlnet_condtion_latent_index | |
| + [c_i + n_vision_cond for c_i in c] | |
| for c in context | |
| ] | |
| else: | |
| controlnet_context = context | |
| if self.print_idx == 0: | |
| logger.debug( | |
| f"controlnet_context={controlnet_context}, latent_model_input={latent_model_input.shape}" | |
| ) | |
| if isinstance(control_image, list): | |
| control_image_c = [ | |
| torch.cat( | |
| [ | |
| control_image_tmp[:, :, c] | |
| for c in controlnet_context | |
| ] | |
| ) | |
| for control_image_tmp in control_image | |
| ] | |
| control_image_c = [ | |
| rearrange(control_image_tmp, " b c t h w-> (b t) c h w") | |
| for control_image_tmp in control_image_c | |
| ] | |
| else: | |
| control_image_c = torch.cat( | |
| [control_image[:, :, c] for c in controlnet_context] | |
| ) | |
| control_image_c = rearrange( | |
| control_image_c, " b c t h w-> (b t) c h w" | |
| ) | |
| else: | |
| control_image_c = None | |
| if controlnet_latents is not None: | |
| if vision_condition_latent_index is not None: | |
| # 获取 vision_condition 对应的 control_imgae/control_latent 部分 | |
| # generate control_image/control_latent corresponding to vision_condition | |
| controlnet_condtion_latent_index = ( | |
| vision_condition_latent_index.clone().cpu().tolist() | |
| ) | |
| if self.print_idx == 0: | |
| logger.debug( | |
| f"context={context}, controlnet_condtion_latent_index={controlnet_condtion_latent_index}" | |
| ) | |
| controlnet_context = [ | |
| controlnet_condtion_latent_index | |
| + [c_i + n_vision_cond for c_i in c] | |
| for c in context | |
| ] | |
| else: | |
| controlnet_context = context | |
| if self.print_idx == 0: | |
| logger.debug( | |
| f"controlnet_context={controlnet_context}, controlnet_latents={controlnet_latents.shape}, latent_model_input={latent_model_input.shape}," | |
| ) | |
| controlnet_latents_c = torch.cat( | |
| [controlnet_latents[:, :, c] for c in controlnet_context] | |
| ) | |
| controlnet_latents_c = rearrange( | |
| controlnet_latents_c, " b c t h w-> (b t) c h w" | |
| ) | |
| else: | |
| controlnet_latents_c = None | |
| ( | |
| down_block_res_samples, | |
| mid_block_res_sample, | |
| ) = self.get_controlnet_emb( | |
| run_controlnet=run_controlnet, | |
| guess_mode=guess_mode, | |
| do_classifier_free_guidance=do_classifier_free_guidance, | |
| latents=latents_c, | |
| prompt_embeds=prompt_embeds, | |
| latent_model_input=latent_model_input, | |
| control_image=control_image_c, | |
| controlnet_latents=controlnet_latents_c, | |
| controlnet_keep=controlnet_keep, | |
| t=t, | |
| i=i, | |
| controlnet_conditioning_scale=controlnet_conditioning_scale, | |
| ) | |
| if self.print_idx == 0: | |
| logger.debug( | |
| f"{i}, latent_model_input={latent_model_input.shape}, sub_latent_index_c={sub_latent_index_c}" | |
| f"{vision_condition_latent_index}" | |
| ) | |
| # time.sleep(10) | |
| noise_pred_c = self.unet( | |
| latent_model_input, | |
| t, | |
| encoder_hidden_states=prompt_embeds, | |
| cross_attention_kwargs=cross_attention_kwargs, | |
| down_block_additional_residuals=down_block_res_samples, | |
| mid_block_additional_residual=mid_block_res_sample, | |
| return_dict=False, | |
| sample_index=sub_latent_index_c, | |
| vision_conditon_frames_sample_index=vision_condition_latent_index, | |
| sample_frame_rate=motion_speed, | |
| down_block_refer_embs=down_block_refer_embs, | |
| mid_block_refer_emb=mid_block_refer_emb, | |
| refer_self_attn_emb=refer_self_attn_emb, | |
| vision_clip_emb=ip_adapter_image_emb, | |
| face_emb=refer_face_image_emb, | |
| ip_adapter_scale=ip_adapter_scale, | |
| facein_scale=facein_scale, | |
| ip_adapter_face_emb=ip_adapter_face_emb, | |
| ip_adapter_face_scale=ip_adapter_face_scale, | |
| do_classifier_free_guidance=do_classifier_free_guidance, | |
| pose_guider_emb=pose_guider_emb, | |
| )[0] | |
| if condition_latents is not None: | |
| noise_pred_c = batch_index_select( | |
| noise_pred_c, dim=2, index=sub_latent_index_c | |
| ).contiguous() | |
| if self.print_idx == 0: | |
| logger.debug( | |
| f"{i}, latent_model_input={latent_model_input.shape}, noise_pred_c={noise_pred_c.shape}, {len(context)}, {len(context[0])}" | |
| ) | |
| for j, c in enumerate(context): | |
| noise_pred[:, :, c] = noise_pred[:, :, c] + noise_pred_c | |
| counter[:, :, c] = counter[:, :, c] + 1 | |
| noise_pred = noise_pred / counter | |
| if ( | |
| last_mid_video_noises is not None | |
| and len(last_mid_video_noises) > 0 | |
| and i <= num_inference_steps // 2 # 是个超参数 super paramter | |
| ): | |
| if self.print_idx == 1: | |
| logger.debug( | |
| f"{i}, last_mid_video_noises={last_mid_video_noises[i].shape}" | |
| ) | |
| noise_pred = fuse_part_tensor( | |
| last_mid_video_noises[i + 1], | |
| noise_pred, | |
| video_overlap, | |
| weight=0.01, | |
| skip_step=1, | |
| ) | |
| if record_mid_video_noises: | |
| mid_video_noises.append(noise_pred[:, :, -video_overlap:]) | |
| # perform guidance | |
| if do_classifier_free_guidance: | |
| noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
| noise_pred = noise_pred_uncond + guidance_scale_lst[i] * ( | |
| noise_pred_text - noise_pred_uncond | |
| ) | |
| if self.print_idx == 0: | |
| logger.debug( | |
| f"before step, noise_pred={noise_pred.shape}, {noise_pred.device}, latents={latents.shape}, {latents.device}, t={t}" | |
| ) | |
| # compute the previous noisy sample x_t -> x_t-1 | |
| latents = self.scheduler.step( | |
| noise_pred, | |
| t, | |
| latents, | |
| **extra_step_kwargs, | |
| ).prev_sample | |
| if ( | |
| last_mid_video_latents is not None | |
| and len(last_mid_video_latents) > 0 | |
| and i <= 1 # 超参数, super parameter | |
| ): | |
| if self.print_idx == 1: | |
| logger.debug( | |
| f"{i}, last_mid_video_latents={last_mid_video_latents[i].shape}" | |
| ) | |
| latents = fuse_part_tensor( | |
| last_mid_video_latents[i + 1], | |
| latents, | |
| video_overlap, | |
| weight=0.1, | |
| skip_step=0, | |
| ) | |
| if record_mid_video_latents: | |
| mid_video_latents.append(latents[:, :, -video_overlap:]) | |
| if need_middle_latents is True: | |
| videos_mid.append(self.decode_latents(latents)) | |
| # call the callback, if provided | |
| if i == len(timesteps) - 1 or ( | |
| (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 | |
| ): | |
| progress_bar.update() | |
| if callback is not None and i % callback_steps == 0: | |
| callback(i, t, latents) | |
| self.print_idx += 1 | |
| if condition_latents is not None: | |
| latents = batch_concat_two_tensor_with_index( | |
| data1=condition_latents, | |
| data1_index=vision_condition_latent_index, | |
| data2=latents, | |
| data2_index=latent_index, | |
| dim=2, | |
| ) | |
| video = self.decode_latents(latents) | |
| if skip_temporal_layer: | |
| self.unet.set_skip_temporal_layers(False) | |
| if need_hist_match: | |
| video[:, :, latent_index, :, :] = self.hist_match_with_vis_cond( | |
| batch_index_select(video, index=latent_index, dim=2), | |
| batch_index_select(video, index=vision_condition_latent_index, dim=2), | |
| ) | |
| # Convert to tensor | |
| if output_type == "tensor": | |
| videos_mid = [torch.from_numpy(x) for x in videos_mid] | |
| video = torch.from_numpy(video) | |
| else: | |
| latents = latents.cpu().numpy() | |
| if not return_dict: | |
| return ( | |
| video, | |
| latents, | |
| videos_mid, | |
| mid_video_latents, | |
| mid_video_noises, | |
| ) | |
| return VideoPipelineOutput( | |
| videos=video, | |
| latents=latents, | |
| videos_mid=videos_mid, | |
| mid_video_latents=mid_video_latents, | |
| mid_video_noises=mid_video_noises, | |
| ) | |