import contextlib import os import random import warnings from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import torch import torch.utils.checkpoint as checkpoint from diffusers import FluxTransformer2DModel from diffusers.image_processor import PipelineImageInput from diffusers.pipelines.flux.pipeline_flux_kontext import PREFERRED_KONTEXT_RESOLUTIONS, calculate_shift, retrieve_timesteps from scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler from pipeline_flux_kontext import FluxKontextPipeline from transformers.utils import is_peft_available from trl.core import randn_tensor from trl.models.sd_utils import convert_state_dict_to_diffusers if is_peft_available(): from peft import LoraConfig, get_peft_model from peft.utils import get_peft_model_state_dict PREFERRED_KONTEXT_RESOLUTIONS = [(x[0]//2,x[1]//2) for x in PREFERRED_KONTEXT_RESOLUTIONS] @dataclass class FluxPipelineOutput: """ Output class for the diffusers pipeline to be finetuned with the DDPO trainer Args: images (`torch.Tensor`): The generated images. latents (`list[torch.Tensor]`): The latents used to generate the images. log_probs (`list[torch.Tensor]`): The log probabilities of the latents. """ images: torch.Tensor latents: torch.Tensor log_probs: torch.Tensor latent_ids: torch.Tensor timesteps: torch.Tensor image_latents: torch.Tensor class DDPOFluxPipeline: """ Main class for the diffusers pipeline to be finetuned with the DDPO trainer """ def __call__(self, *args, **kwargs) -> FluxPipelineOutput: raise NotImplementedError @property def transformer(self): """ Returns the 2d U-Net model used for diffusion. """ raise NotImplementedError @property def vae(self): """ Returns the Variational Autoencoder model used from mapping images to and from the latent space """ raise NotImplementedError @property def tokenizer(self): """ Returns the tokenizer used for tokenizing text inputs """ raise NotImplementedError @property def tokenizer_2(self): """ Returns the tokenizer used for tokenizing text inputs """ raise NotImplementedError @property def scheduler(self): """ Returns the scheduler associated with the pipeline used for the diffusion process """ raise NotImplementedError @property def text_encoder(self): """ Returns the text encoder used for encoding text inputs """ raise NotImplementedError @property def text_encoder_2(self): """ Returns the text encoder used for encoding text inputs """ raise NotImplementedError @property def image_encoder(self): """ Returns the text encoder used for encoding text inputs """ raise NotImplementedError @property def feature_extractor(self): """ Returns the text encoder used for encoding text inputs """ raise NotImplementedError @property def autocast(self): """ Returns the autocast context manager """ raise NotImplementedError def set_progress_bar_config(self, *args, **kwargs): """ Sets the progress bar config for the pipeline """ raise NotImplementedError def save_pretrained(self, *args, **kwargs): """ Saves all of the model weights """ raise NotImplementedError def save_checkpoint(self, *args, **kwargs): """ Light wrapper around accelerate's register_save_state_pre_hook which is run before saving state """ raise NotImplementedError def load_checkpoint(self, *args, **kwargs): """ Light wrapper around accelerate's register_lad_state_pre_hook which is run before loading state """ raise NotImplementedError @torch.no_grad() def pipeline_step( self, image: Optional[PipelineImageInput] = None, prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, negative_prompt: Union[str, List[str]] = None, negative_prompt_2: Optional[Union[str, List[str]]] = None, true_cfg_scale: float = 1.0, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 28, sigmas: Optional[List[float]] = None, guidance_scale: float = 3.5, num_images_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, negative_ip_adapter_image: Optional[PipelineImageInput] = None, negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, joint_attention_kwargs: Optional[Dict[str, Any]] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, max_area: int = 1024**2, _auto_resize: bool = True, ): r""" Function invoked when calling the pipeline for generation. Args: image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image latents as `image`, but if passing latents directly it is not encoded again. prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is will be used instead. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is not greater than `1`). negative_prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. true_cfg_scale (`float`, *optional*, defaults to 1.0): When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance. height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. This is set to 1024 by default for the best results. width (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. This is set to 1024 by default for the best results. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. sigmas (`List[float]`, *optional*): Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. guidance_scale (`float`, *optional*, defaults to 3.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. pooled_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not provided, embeddings are computed from the `ip_adapter_image` input argument. negative_ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not provided, embeddings are computed from the `ip_adapter_image` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. joint_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). callback_on_step_end (`Callable`, *optional*): A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. max_area (`int`, defaults to `1024 ** 2`): The maximum area of the generated image in pixels. The height and width will be adjusted to fit this area while maintaining the aspect ratio. Examples: Returns: [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. """ height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor original_height, original_width = height, width aspect_ratio = width / height width = round((max_area * aspect_ratio) ** 0.5) height = round((max_area / aspect_ratio) ** 0.5) multiple_of = self.vae_scale_factor * 2 width = width // multiple_of * multiple_of height = height // multiple_of * multiple_of if height != original_height or width != original_width: logger.warning( f"Generation `height` and `width` have been adjusted to {height} and {width} to fit the model requirements." ) # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, prompt_2, height, width, negative_prompt=negative_prompt, negative_prompt_2=negative_prompt_2, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, max_sequence_length=max_sequence_length, ) self._guidance_scale = guidance_scale self._joint_attention_kwargs = joint_attention_kwargs self._current_timestep = None self._interrupt = False # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device lora_scale = ( self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None ) has_neg_prompt = negative_prompt is not None or ( negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None ) do_true_cfg = true_cfg_scale > 1 and has_neg_prompt ( prompt_embeds, pooled_prompt_embeds, text_ids, ) = self.encode_prompt( prompt=prompt, prompt_2=prompt_2, prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, device=device, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, lora_scale=lora_scale, ) if do_true_cfg: ( negative_prompt_embeds, negative_pooled_prompt_embeds, negative_text_ids, ) = self.encode_prompt( prompt=negative_prompt, prompt_2=negative_prompt_2, prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=negative_pooled_prompt_embeds, device=device, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, lora_scale=lora_scale, ) # 3. Preprocess image if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels): imgs = image if isinstance(image, list) else [image] images = [] for img in imgs: img_0 = img[0] if isinstance(img, list) else img image_height, image_width = self.image_processor.get_default_height_width(img_0) aspect_ratio = image_width / image_height if _auto_resize: _, image_width, image_height = min( (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS ) image_width = image_width // multiple_of * multiple_of image_height = image_height // multiple_of * multiple_of resized = self.image_processor.resize(img, image_height, image_width) print(image_height, image_width) processed = self.image_processor.preprocess(resized, image_height, image_width) images.append(processed) # 4. Prepare latent variables num_channels_latents = self.transformer.config.in_channels // 4 latents, image_latents, latent_ids, image_ids = self.prepare_latents( images, batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) if image_ids is not None: latent_ids = torch.cat([latent_ids, image_ids], dim=0) # dim 0 is sequence dimension # 5. Prepare timesteps sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas image_seq_len = latents.shape[1] mu = calculate_shift( image_seq_len, self.scheduler.config.get("base_image_seq_len", 256), self.scheduler.config.get("max_image_seq_len", 4096), self.scheduler.config.get("base_shift", 0.5), self.scheduler.config.get("max_shift", 1.15), ) timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, sigmas=sigmas, mu=mu, ) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) # handle guidance if self.transformer.config.guidance_embeds: guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) guidance = guidance.expand(latents.shape[0]) else: guidance = None if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and ( negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None ): negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8) negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and ( negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None ): ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8) ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters if self.joint_attention_kwargs is None: self._joint_attention_kwargs = {} image_embeds = None negative_image_embeds = None if ip_adapter_image is not None or ip_adapter_image_embeds is not None: image_embeds = self.prepare_ip_adapter_image_embeds( ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_images_per_prompt, ) if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None: negative_image_embeds = self.prepare_ip_adapter_image_embeds( negative_ip_adapter_image, negative_ip_adapter_image_embeds, device, batch_size * num_images_per_prompt, ) # 6. Denoising loop # We set the index here to remove DtoH sync, helpful especially during compilation. # Check out more details here: https://github.com/huggingface/diffusers/pull/11696 all_latents = [latents] all_log_probs = [] all_timesteps = [] self.scheduler.set_begin_index(0) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue self._current_timestep = t if image_embeds is not None: self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds latent_model_input = latents latent_model_input = latent_model_input.to(self.transformer.device) if image_latents is not None: latent_model_input = torch.cat([latents, image_latents], dim=1) timestep = t.expand(latents.shape[0]).to(torch.float32) noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep / 1000, guidance=guidance, pooled_projections=pooled_prompt_embeds, encoder_hidden_states=prompt_embeds, txt_ids=text_ids, img_ids=latent_ids, joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, )[0] noise_pred = noise_pred[:, : latents.size(1)] if do_true_cfg: if negative_image_embeds is not None: self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds neg_noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep / 1000, guidance=guidance, pooled_projections=negative_pooled_prompt_embeds, encoder_hidden_states=negative_prompt_embeds, txt_ids=negative_text_ids, img_ids=latent_ids, joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, )[0] neg_noise_pred = neg_noise_pred[:, : latents.size(1)] noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype scheduler_output = self.scheduler.step(noise_pred, t, latents, return_dict=True) latents = scheduler_output.latents log_probs = scheduler_output.log_probs all_latents.append(latents) all_log_probs.append(log_probs) all_timesteps.append(timestep) if latents.dtype != latents_dtype: latents = latents.to(latents_dtype) if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() self._current_timestep = None if output_type == "latent": image = latents else: latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor image = self.vae.decode(latents, return_dict=False)[0] image = self.image_processor.postprocess(image, output_type=output_type) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image,) return FluxPipelineOutput(image, all_latents, all_log_probs, latent_ids, all_timesteps, image_latents) def pipeline_step_with_grad( pipeline, image: Optional[PipelineImageInput] = None, prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, negative_prompt: Union[str, List[str]] = None, negative_prompt_2: Optional[Union[str, List[str]]] = None, true_cfg_scale: float = 1.0, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 28, sigmas: Optional[List[float]] = None, guidance_scale: float = 3.5, truncated_backprop: bool = True, truncated_backprop_rand: bool = True, gradient_checkpoint: bool = True, truncated_backprop_timestep: int = 49, truncated_rand_backprop_minmax: tuple = (0, 50), num_images_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, negative_ip_adapter_image: Optional[PipelineImageInput] = None, negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, joint_attention_kwargs: Optional[Dict[str, Any]] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, max_area: int = 512**2, _auto_resize: bool = True, ): height = height or pipeline.default_sample_size * pipeline.vae_scale_factor width = width or pipeline.default_sample_size * pipeline.vae_scale_factor original_height, original_width = height, width aspect_ratio = width / height width = round((max_area * aspect_ratio) ** 0.5) height = round((max_area / aspect_ratio) ** 0.5) multiple_of = pipeline.vae_scale_factor * 2 width = width // multiple_of * multiple_of height = height // multiple_of * multiple_of if height != original_height or width != original_width: logger.warning( f"Generation `height` and `width` have been adjusted to {height} and {width} to fit the model requirements." ) # 1. Check inputs. Raise error if not correct pipeline.check_inputs( prompt, prompt_2, height, width, negative_prompt=negative_prompt, negative_prompt_2=negative_prompt_2, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, max_sequence_length=max_sequence_length, ) pipeline._guidance_scale = guidance_scale pipeline._joint_attention_kwargs = joint_attention_kwargs pipeline._current_timestep = None pipeline._interrupt = False # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = pipeline._execution_device lora_scale = ( pipeline.joint_attention_kwargs.get("scale", None) if pipeline.joint_attention_kwargs is not None else None ) has_neg_prompt = negative_prompt is not None or ( negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None ) do_true_cfg = true_cfg_scale > 1 and has_neg_prompt ( prompt_embeds, pooled_prompt_embeds, text_ids, ) = pipeline.encode_prompt( prompt=prompt, prompt_2=prompt_2, prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, device=device, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, lora_scale=lora_scale, ) if do_true_cfg: ( negative_prompt_embeds, negative_pooled_prompt_embeds, negative_text_ids, ) = pipeline.encode_prompt( prompt=negative_prompt, prompt_2=negative_prompt_2, prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=negative_pooled_prompt_embeds, device=device, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, lora_scale=lora_scale, ) # 3. Preprocess image # if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == pipeline.latent_channels): # img = image[0] if isinstance(image, list) else image # image_height, image_width = pipeline.image_processor.get_default_height_width(img) # aspect_ratio = image_width / image_height # if _auto_resize: # # Kontext is trained on specific resolutions, using one of them is recommended # _, image_width, image_height = min( # (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS # ) # image_width = image_width // multiple_of * multiple_of # image_height = image_height // multiple_of * multiple_of # image = pipeline.image_processor.resize(image, image_height, image_width) # image = pipeline.image_processor.preprocess(image, image_height, image_width) if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == pipeline.latent_channels): imgs = image if isinstance(image, list) else [image] images = [] for img in imgs: img_0 = img[0] if isinstance(img, list) else img image_height, image_width = pipeline.image_processor.get_default_height_width(img_0) aspect_ratio = image_width / image_height if _auto_resize: _, image_width, image_height = min( (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS ) image_width = image_width // multiple_of * multiple_of image_height = image_height // multiple_of * multiple_of resized = pipeline.image_processor.resize(img, image_height, image_width) processed = pipeline.image_processor.preprocess(resized, image_height, image_width) images.append(processed) # 4. Prepare latent variables # num_channels_latents = pipeline.transformer.module.config.in_channels // 4 num_channels_latents = pipeline.transformer.config.in_channels // 4 latents, image_latents, latent_ids, image_ids = pipeline.prepare_latents( images, batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) if image_ids is not None: latent_ids = torch.cat([latent_ids, image_ids], dim=0) # dim 0 is sequence dimension # 5. Prepare timesteps sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas image_seq_len = latents.shape[1] mu = calculate_shift( image_seq_len, pipeline.scheduler.config.get("base_image_seq_len", 256), pipeline.scheduler.config.get("max_image_seq_len", 4096), pipeline.scheduler.config.get("base_shift", 0.5), pipeline.scheduler.config.get("max_shift", 1.15), ) timesteps, num_inference_steps = retrieve_timesteps( pipeline.scheduler, num_inference_steps, device, sigmas=sigmas, mu=mu, ) num_warmup_steps = max(len(timesteps) - num_inference_steps * pipeline.scheduler.order, 0) pipeline._num_timesteps = len(timesteps) # handle guidance if pipeline.transformer.config.guidance_embeds: guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) guidance = guidance.expand(latents.shape[0]) else: guidance = None if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and ( negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None ): negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8) negative_ip_adapter_image = [negative_ip_adapter_image] * pipeline.transformer.encoder_hid_proj.num_ip_adapters elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and ( negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None ): ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8) ip_adapter_image = [ip_adapter_image] * pipeline.transformer.encoder_hid_proj.num_ip_adapters if pipeline.joint_attention_kwargs is None: pipeline._joint_attention_kwargs = {} image_embeds = None negative_image_embeds = None if ip_adapter_image is not None or ip_adapter_image_embeds is not None: image_embeds = pipeline.prepare_ip_adapter_image_embeds( ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_images_per_prompt, ) if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None: negative_image_embeds = pipeline.prepare_ip_adapter_image_embeds( negative_ip_adapter_image, negative_ip_adapter_image_embeds, device, batch_size * num_images_per_prompt, ) all_latents = [latents] all_log_probs = [] all_timesteps = [] pipeline.scheduler.set_begin_index(0) with pipeline.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if pipeline.interrupt: continue pipeline._current_timestep = t if image_embeds is not None: pipeline._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds latent_model_input = latents if image_latents is not None: latent_model_input = torch.cat([latents, image_latents], dim=1) timestep = t.expand(latents.shape[0]).to(latents.dtype) if gradient_checkpoint: noise_pred = checkpoint.checkpoint( pipeline.transformer, hidden_states=latent_model_input, timestep=timestep / 1000, guidance=guidance, pooled_projections=pooled_prompt_embeds, encoder_hidden_states=prompt_embeds, txt_ids=text_ids, img_ids=latent_ids, joint_attention_kwargs=pipeline.joint_attention_kwargs, return_dict=False, )[0] else: noise_pred = pipeline.transformer( hidden_states=latent_model_input, timestep=timestep / 1000, guidance=guidance, pooled_projections=pooled_prompt_embeds, encoder_hidden_states=prompt_embeds, txt_ids=text_ids, img_ids=latent_ids, joint_attention_kwargs=pipeline.joint_attention_kwargs, return_dict=False, )[0] noise_pred = noise_pred[:, : latents.size(1)] if truncated_backprop: # Randomized truncation randomizes the truncation process (https://huggingface.co/papers/2310.03739) # the range of truncation is defined by truncated_rand_backprop_minmax # Setting truncated_rand_backprop_minmax[0] to be low will allow the model to update earlier timesteps in the diffusion chain, while setitng it high will reduce the memory usage. if truncated_backprop_rand: rand_timestep = random.randint( truncated_rand_backprop_minmax[0], truncated_rand_backprop_minmax[1] ) if i < rand_timestep: noise_pred = noise_pred.detach() else: # fixed truncation process if i < truncated_backprop_timestep: noise_pred = noise_pred.detach() if do_true_cfg: if negative_image_embeds is not None: pipeline._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds neg_noise_pred = pipeline.transformer( hidden_states=latent_model_input, timestep=timestep / 1000, guidance=guidance, pooled_projections=negative_pooled_prompt_embeds, encoder_hidden_states=negative_prompt_embeds, txt_ids=negative_text_ids, img_ids=latent_ids, joint_attention_kwargs=pipeline.joint_attention_kwargs, return_dict=False, )[0] neg_noise_pred = neg_noise_pred[:, : latents.size(1)] noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype scheduler_output = pipeline.scheduler.step(noise_pred, t, latents, return_dict=True) latents = scheduler_output.latents log_probs = scheduler_output.log_probs all_latents.append(latents) all_log_probs.append(log_probs) all_timesteps.append(timestep) if latents.dtype != latents_dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 latents = latents.to(latents_dtype) if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0): progress_bar.update() pipeline._current_timestep = None if output_type == "latent": image = latents else: latents = pipeline._unpack_latents(latents, height, width, pipeline.vae_scale_factor) latents = (latents / pipeline.vae.config.scaling_factor) + pipeline.vae.config.shift_factor image = pipeline.vae.decode(latents, return_dict=False)[0] image = pipeline.image_processor.postprocess(image, output_type=output_type) # Offload all models pipeline.maybe_free_model_hooks() if not return_dict: return (image,) return FluxPipelineOutput(image, all_latents, all_log_probs, latent_ids, all_timesteps, image_latents) class DefaultDDPOFluxPipeline(DDPOFluxPipeline): def __init__(self, pretrained_model_name: str, finetuned_model_path: str=''): self.flux_pipeline = FluxKontextPipeline.from_pretrained( pretrained_model_name ) self.pretrained_model = pretrained_model_name self.flux_pipeline.scheduler = FlowMatchEulerDiscreteScheduler.from_config(self.flux_pipeline.scheduler.config) self.flux_pipeline.scheduler.config.stochastic_sampling = True # memory optimization self.flux_pipeline.vae.requires_grad_(False) self.flux_pipeline.text_encoder.requires_grad_(False) self.flux_pipeline.text_encoder_2.requires_grad_(False) self.flux_pipeline.transformer.requires_grad_(False) if finetuned_model_path: print(f"load finetuned model from {finetuned_model_path}") self.flux_pipeline.transformer = FluxTransformer2DModel.from_single_file(finetuned_model_path, torch_dtype="bfloat16") def __call__(self, *args, **kwargs) -> FluxPipelineOutput: return pipeline_step(self.flux_pipeline, *args, **kwargs) def rgb_with_grad(self, *args, **kwargs) -> FluxPipelineOutput: return pipeline_step_with_grad(self.flux_pipeline, *args, **kwargs) @property def transformer(self): return self.flux_pipeline.transformer @property def vae(self): return self.flux_pipeline.vae @property def tokenizer(self): return self.flux_pipeline.tokenizer @property def tokenizer_2(self): return self.flux_pipeline.tokenizer_2 @property def scheduler(self): return self.flux_pipeline.scheduler @property def text_encoder(self): return self.flux_pipeline.text_encoder @property def text_encoder_2(self): return self.flux_pipeline.text_encoder_2 @property def image_encoder(self): return self.flux_pipeline.image_encoder @property def feature_extractor(self): return self.flux_pipeline.feature_extractor @property def autocast(self): return contextlib.nullcontext def save_pretrained(self, output_dir): state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(self.flux_pipeline.transformer)) self.flux_pipeline.transformer.save_pretrained(output_dir) def set_progress_bar_config(self, *args, **kwargs): self.flux_pipeline.set_progress_bar_config(*args, **kwargs)