Multi_Ref_Edit / kontext /modeling_flux_base.py
bo.l
pipe
fed9294
import contextlib
import os
import random
import warnings
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import torch
import torch.utils.checkpoint as checkpoint
from diffusers import FluxTransformer2DModel
from diffusers.image_processor import PipelineImageInput
from diffusers.pipelines.flux.pipeline_flux_kontext import PREFERRED_KONTEXT_RESOLUTIONS, calculate_shift, retrieve_timesteps
from scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
from pipeline_flux_kontext import FluxKontextPipeline
from transformers.utils import is_peft_available
from trl.core import randn_tensor
from trl.models.sd_utils import convert_state_dict_to_diffusers
if is_peft_available():
from peft import LoraConfig, get_peft_model
from peft.utils import get_peft_model_state_dict
PREFERRED_KONTEXT_RESOLUTIONS = [(x[0]//2,x[1]//2) for x in PREFERRED_KONTEXT_RESOLUTIONS]
@dataclass
class FluxPipelineOutput:
"""
Output class for the diffusers pipeline to be finetuned with the DDPO trainer
Args:
images (`torch.Tensor`):
The generated images.
latents (`list[torch.Tensor]`):
The latents used to generate the images.
log_probs (`list[torch.Tensor]`):
The log probabilities of the latents.
"""
images: torch.Tensor
latents: torch.Tensor
log_probs: torch.Tensor
latent_ids: torch.Tensor
timesteps: torch.Tensor
image_latents: torch.Tensor
class DDPOFluxPipeline:
"""
Main class for the diffusers pipeline to be finetuned with the DDPO trainer
"""
def __call__(self, *args, **kwargs) -> FluxPipelineOutput:
raise NotImplementedError
@property
def transformer(self):
"""
Returns the 2d U-Net model used for diffusion.
"""
raise NotImplementedError
@property
def vae(self):
"""
Returns the Variational Autoencoder model used from mapping images to and from the latent space
"""
raise NotImplementedError
@property
def tokenizer(self):
"""
Returns the tokenizer used for tokenizing text inputs
"""
raise NotImplementedError
@property
def tokenizer_2(self):
"""
Returns the tokenizer used for tokenizing text inputs
"""
raise NotImplementedError
@property
def scheduler(self):
"""
Returns the scheduler associated with the pipeline used for the diffusion process
"""
raise NotImplementedError
@property
def text_encoder(self):
"""
Returns the text encoder used for encoding text inputs
"""
raise NotImplementedError
@property
def text_encoder_2(self):
"""
Returns the text encoder used for encoding text inputs
"""
raise NotImplementedError
@property
def image_encoder(self):
"""
Returns the text encoder used for encoding text inputs
"""
raise NotImplementedError
@property
def feature_extractor(self):
"""
Returns the text encoder used for encoding text inputs
"""
raise NotImplementedError
@property
def autocast(self):
"""
Returns the autocast context manager
"""
raise NotImplementedError
def set_progress_bar_config(self, *args, **kwargs):
"""
Sets the progress bar config for the pipeline
"""
raise NotImplementedError
def save_pretrained(self, *args, **kwargs):
"""
Saves all of the model weights
"""
raise NotImplementedError
def save_checkpoint(self, *args, **kwargs):
"""
Light wrapper around accelerate's register_save_state_pre_hook which is run before saving state
"""
raise NotImplementedError
def load_checkpoint(self, *args, **kwargs):
"""
Light wrapper around accelerate's register_lad_state_pre_hook which is run before loading state
"""
raise NotImplementedError
@torch.no_grad()
def pipeline_step(
self,
image: Optional[PipelineImageInput] = None,
prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
negative_prompt: Union[str, List[str]] = None,
negative_prompt_2: Optional[Union[str, List[str]]] = None,
true_cfg_scale: float = 1.0,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 28,
sigmas: Optional[List[float]] = None,
guidance_scale: float = 3.5,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None,
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
negative_ip_adapter_image: Optional[PipelineImageInput] = None,
negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 512,
max_area: int = 1024**2,
_auto_resize: bool = True,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
latents as `image`, but if passing latents directly it is not encoded again.
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
will be used instead.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
not greater than `1`).
negative_prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
`text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
true_cfg_scale (`float`, *optional*, defaults to 1.0):
When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image. This is set to 1024 by default for the best results.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 3.5):
Guidance scale as defined in [Classifier-Free Diffusion
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
the text `prompt`, usually at the expense of lower image quality.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*):
Optional image input to work with IP Adapters.
ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
provided, embeddings are computed from the `ip_adapter_image` input argument.
negative_ip_adapter_image:
(`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
provided, embeddings are computed from the `ip_adapter_image` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int` defaults to 512):
Maximum sequence length to use with the `prompt`.
max_area (`int`, defaults to `1024 ** 2`):
The maximum area of the generated image in pixels. The height and width will be adjusted to fit this
area while maintaining the aspect ratio.
Examples:
Returns:
[`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
images.
"""
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
original_height, original_width = height, width
aspect_ratio = width / height
width = round((max_area * aspect_ratio) ** 0.5)
height = round((max_area / aspect_ratio) ** 0.5)
multiple_of = self.vae_scale_factor * 2
width = width // multiple_of * multiple_of
height = height // multiple_of * multiple_of
if height != original_height or width != original_width:
logger.warning(
f"Generation `height` and `width` have been adjusted to {height} and {width} to fit the model requirements."
)
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
prompt_2,
height,
width,
negative_prompt=negative_prompt,
negative_prompt_2=negative_prompt_2,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
max_sequence_length=max_sequence_length,
)
self._guidance_scale = guidance_scale
self._joint_attention_kwargs = joint_attention_kwargs
self._current_timestep = None
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
lora_scale = (
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
)
has_neg_prompt = negative_prompt is not None or (
negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
)
do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
(
prompt_embeds,
pooled_prompt_embeds,
text_ids,
) = self.encode_prompt(
prompt=prompt,
prompt_2=prompt_2,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
device=device,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
lora_scale=lora_scale,
)
if do_true_cfg:
(
negative_prompt_embeds,
negative_pooled_prompt_embeds,
negative_text_ids,
) = self.encode_prompt(
prompt=negative_prompt,
prompt_2=negative_prompt_2,
prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=negative_pooled_prompt_embeds,
device=device,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
lora_scale=lora_scale,
)
# 3. Preprocess image
if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
imgs = image if isinstance(image, list) else [image]
images = []
for img in imgs:
img_0 = img[0] if isinstance(img, list) else img
image_height, image_width = self.image_processor.get_default_height_width(img_0)
aspect_ratio = image_width / image_height
if _auto_resize:
_, image_width, image_height = min(
(abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS
)
image_width = image_width // multiple_of * multiple_of
image_height = image_height // multiple_of * multiple_of
resized = self.image_processor.resize(img, image_height, image_width)
print(image_height, image_width)
processed = self.image_processor.preprocess(resized, image_height, image_width)
images.append(processed)
# 4. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels // 4
latents, image_latents, latent_ids, image_ids = self.prepare_latents(
images,
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
if image_ids is not None:
latent_ids = torch.cat([latent_ids, image_ids], dim=0) # dim 0 is sequence dimension
# 5. Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
image_seq_len = latents.shape[1]
mu = calculate_shift(
image_seq_len,
self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.15),
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
sigmas=sigmas,
mu=mu,
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
# handle guidance
if self.transformer.config.guidance_embeds:
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
guidance = guidance.expand(latents.shape[0])
else:
guidance = None
if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
):
negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
):
ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
if self.joint_attention_kwargs is None:
self._joint_attention_kwargs = {}
image_embeds = None
negative_image_embeds = None
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
image_embeds = self.prepare_ip_adapter_image_embeds(
ip_adapter_image,
ip_adapter_image_embeds,
device,
batch_size * num_images_per_prompt,
)
if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
negative_image_embeds = self.prepare_ip_adapter_image_embeds(
negative_ip_adapter_image,
negative_ip_adapter_image_embeds,
device,
batch_size * num_images_per_prompt,
)
# 6. Denoising loop
# We set the index here to remove DtoH sync, helpful especially during compilation.
# Check out more details here: https://github.com/huggingface/diffusers/pull/11696
all_latents = [latents]
all_log_probs = []
all_timesteps = []
self.scheduler.set_begin_index(0)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
self._current_timestep = t
if image_embeds is not None:
self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
latent_model_input = latents
latent_model_input = latent_model_input.to(self.transformer.device)
if image_latents is not None:
latent_model_input = torch.cat([latents, image_latents], dim=1)
timestep = t.expand(latents.shape[0]).to(torch.float32)
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_pred[:, : latents.size(1)]
if do_true_cfg:
if negative_image_embeds is not None:
self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
neg_noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=negative_pooled_prompt_embeds,
encoder_hidden_states=negative_prompt_embeds,
txt_ids=negative_text_ids,
img_ids=latent_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
scheduler_output = self.scheduler.step(noise_pred, t, latents, return_dict=True)
latents = scheduler_output.latents
log_probs = scheduler_output.log_probs
all_latents.append(latents)
all_log_probs.append(log_probs)
all_timesteps.append(timestep)
if latents.dtype != latents_dtype:
latents = latents.to(latents_dtype)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
self._current_timestep = None
if output_type == "latent":
image = latents
else:
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
image = self.vae.decode(latents, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return FluxPipelineOutput(image, all_latents, all_log_probs, latent_ids, all_timesteps, image_latents)
def pipeline_step_with_grad(
pipeline,
image: Optional[PipelineImageInput] = None,
prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
negative_prompt: Union[str, List[str]] = None,
negative_prompt_2: Optional[Union[str, List[str]]] = None,
true_cfg_scale: float = 1.0,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 28,
sigmas: Optional[List[float]] = None,
guidance_scale: float = 3.5,
truncated_backprop: bool = True,
truncated_backprop_rand: bool = True,
gradient_checkpoint: bool = True,
truncated_backprop_timestep: int = 49,
truncated_rand_backprop_minmax: tuple = (0, 50),
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None,
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
negative_ip_adapter_image: Optional[PipelineImageInput] = None,
negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 512,
max_area: int = 512**2,
_auto_resize: bool = True,
):
height = height or pipeline.default_sample_size * pipeline.vae_scale_factor
width = width or pipeline.default_sample_size * pipeline.vae_scale_factor
original_height, original_width = height, width
aspect_ratio = width / height
width = round((max_area * aspect_ratio) ** 0.5)
height = round((max_area / aspect_ratio) ** 0.5)
multiple_of = pipeline.vae_scale_factor * 2
width = width // multiple_of * multiple_of
height = height // multiple_of * multiple_of
if height != original_height or width != original_width:
logger.warning(
f"Generation `height` and `width` have been adjusted to {height} and {width} to fit the model requirements."
)
# 1. Check inputs. Raise error if not correct
pipeline.check_inputs(
prompt,
prompt_2,
height,
width,
negative_prompt=negative_prompt,
negative_prompt_2=negative_prompt_2,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
max_sequence_length=max_sequence_length,
)
pipeline._guidance_scale = guidance_scale
pipeline._joint_attention_kwargs = joint_attention_kwargs
pipeline._current_timestep = None
pipeline._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = pipeline._execution_device
lora_scale = (
pipeline.joint_attention_kwargs.get("scale", None) if pipeline.joint_attention_kwargs is not None else None
)
has_neg_prompt = negative_prompt is not None or (
negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
)
do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
(
prompt_embeds,
pooled_prompt_embeds,
text_ids,
) = pipeline.encode_prompt(
prompt=prompt,
prompt_2=prompt_2,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
device=device,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
lora_scale=lora_scale,
)
if do_true_cfg:
(
negative_prompt_embeds,
negative_pooled_prompt_embeds,
negative_text_ids,
) = pipeline.encode_prompt(
prompt=negative_prompt,
prompt_2=negative_prompt_2,
prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=negative_pooled_prompt_embeds,
device=device,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
lora_scale=lora_scale,
)
# 3. Preprocess image
# if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == pipeline.latent_channels):
# img = image[0] if isinstance(image, list) else image
# image_height, image_width = pipeline.image_processor.get_default_height_width(img)
# aspect_ratio = image_width / image_height
# if _auto_resize:
# # Kontext is trained on specific resolutions, using one of them is recommended
# _, image_width, image_height = min(
# (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS
# )
# image_width = image_width // multiple_of * multiple_of
# image_height = image_height // multiple_of * multiple_of
# image = pipeline.image_processor.resize(image, image_height, image_width)
# image = pipeline.image_processor.preprocess(image, image_height, image_width)
if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == pipeline.latent_channels):
imgs = image if isinstance(image, list) else [image]
images = []
for img in imgs:
img_0 = img[0] if isinstance(img, list) else img
image_height, image_width = pipeline.image_processor.get_default_height_width(img_0)
aspect_ratio = image_width / image_height
if _auto_resize:
_, image_width, image_height = min(
(abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS
)
image_width = image_width // multiple_of * multiple_of
image_height = image_height // multiple_of * multiple_of
resized = pipeline.image_processor.resize(img, image_height, image_width)
processed = pipeline.image_processor.preprocess(resized, image_height, image_width)
images.append(processed)
# 4. Prepare latent variables
# num_channels_latents = pipeline.transformer.module.config.in_channels // 4
num_channels_latents = pipeline.transformer.config.in_channels // 4
latents, image_latents, latent_ids, image_ids = pipeline.prepare_latents(
images,
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
if image_ids is not None:
latent_ids = torch.cat([latent_ids, image_ids], dim=0) # dim 0 is sequence dimension
# 5. Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
image_seq_len = latents.shape[1]
mu = calculate_shift(
image_seq_len,
pipeline.scheduler.config.get("base_image_seq_len", 256),
pipeline.scheduler.config.get("max_image_seq_len", 4096),
pipeline.scheduler.config.get("base_shift", 0.5),
pipeline.scheduler.config.get("max_shift", 1.15),
)
timesteps, num_inference_steps = retrieve_timesteps(
pipeline.scheduler,
num_inference_steps,
device,
sigmas=sigmas,
mu=mu,
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * pipeline.scheduler.order, 0)
pipeline._num_timesteps = len(timesteps)
# handle guidance
if pipeline.transformer.config.guidance_embeds:
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
guidance = guidance.expand(latents.shape[0])
else:
guidance = None
if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
):
negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
negative_ip_adapter_image = [negative_ip_adapter_image] * pipeline.transformer.encoder_hid_proj.num_ip_adapters
elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
):
ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
ip_adapter_image = [ip_adapter_image] * pipeline.transformer.encoder_hid_proj.num_ip_adapters
if pipeline.joint_attention_kwargs is None:
pipeline._joint_attention_kwargs = {}
image_embeds = None
negative_image_embeds = None
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
image_embeds = pipeline.prepare_ip_adapter_image_embeds(
ip_adapter_image,
ip_adapter_image_embeds,
device,
batch_size * num_images_per_prompt,
)
if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
negative_image_embeds = pipeline.prepare_ip_adapter_image_embeds(
negative_ip_adapter_image,
negative_ip_adapter_image_embeds,
device,
batch_size * num_images_per_prompt,
)
all_latents = [latents]
all_log_probs = []
all_timesteps = []
pipeline.scheduler.set_begin_index(0)
with pipeline.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if pipeline.interrupt:
continue
pipeline._current_timestep = t
if image_embeds is not None:
pipeline._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
latent_model_input = latents
if image_latents is not None:
latent_model_input = torch.cat([latents, image_latents], dim=1)
timestep = t.expand(latents.shape[0]).to(latents.dtype)
if gradient_checkpoint:
noise_pred = checkpoint.checkpoint(
pipeline.transformer,
hidden_states=latent_model_input,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_ids,
joint_attention_kwargs=pipeline.joint_attention_kwargs,
return_dict=False,
)[0]
else:
noise_pred = pipeline.transformer(
hidden_states=latent_model_input,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_ids,
joint_attention_kwargs=pipeline.joint_attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_pred[:, : latents.size(1)]
if truncated_backprop:
# Randomized truncation randomizes the truncation process (https://huggingface.co/papers/2310.03739)
# the range of truncation is defined by truncated_rand_backprop_minmax
# Setting truncated_rand_backprop_minmax[0] to be low will allow the model to update earlier timesteps in the diffusion chain, while setitng it high will reduce the memory usage.
if truncated_backprop_rand:
rand_timestep = random.randint(
truncated_rand_backprop_minmax[0], truncated_rand_backprop_minmax[1]
)
if i < rand_timestep:
noise_pred = noise_pred.detach()
else:
# fixed truncation process
if i < truncated_backprop_timestep:
noise_pred = noise_pred.detach()
if do_true_cfg:
if negative_image_embeds is not None:
pipeline._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
neg_noise_pred = pipeline.transformer(
hidden_states=latent_model_input,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=negative_pooled_prompt_embeds,
encoder_hidden_states=negative_prompt_embeds,
txt_ids=negative_text_ids,
img_ids=latent_ids,
joint_attention_kwargs=pipeline.joint_attention_kwargs,
return_dict=False,
)[0]
neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
scheduler_output = pipeline.scheduler.step(noise_pred, t, latents, return_dict=True)
latents = scheduler_output.latents
log_probs = scheduler_output.log_probs
all_latents.append(latents)
all_log_probs.append(log_probs)
all_timesteps.append(timestep)
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0):
progress_bar.update()
pipeline._current_timestep = None
if output_type == "latent":
image = latents
else:
latents = pipeline._unpack_latents(latents, height, width, pipeline.vae_scale_factor)
latents = (latents / pipeline.vae.config.scaling_factor) + pipeline.vae.config.shift_factor
image = pipeline.vae.decode(latents, return_dict=False)[0]
image = pipeline.image_processor.postprocess(image, output_type=output_type)
# Offload all models
pipeline.maybe_free_model_hooks()
if not return_dict:
return (image,)
return FluxPipelineOutput(image, all_latents, all_log_probs, latent_ids, all_timesteps, image_latents)
class DefaultDDPOFluxPipeline(DDPOFluxPipeline):
def __init__(self, pretrained_model_name: str, finetuned_model_path: str=''):
self.flux_pipeline = FluxKontextPipeline.from_pretrained(
pretrained_model_name
)
self.pretrained_model = pretrained_model_name
self.flux_pipeline.scheduler = FlowMatchEulerDiscreteScheduler.from_config(self.flux_pipeline.scheduler.config)
self.flux_pipeline.scheduler.config.stochastic_sampling = True
# memory optimization
self.flux_pipeline.vae.requires_grad_(False)
self.flux_pipeline.text_encoder.requires_grad_(False)
self.flux_pipeline.text_encoder_2.requires_grad_(False)
self.flux_pipeline.transformer.requires_grad_(False)
if finetuned_model_path:
print(f"load finetuned model from {finetuned_model_path}")
self.flux_pipeline.transformer = FluxTransformer2DModel.from_single_file(finetuned_model_path, torch_dtype="bfloat16")
def __call__(self, *args, **kwargs) -> FluxPipelineOutput:
return pipeline_step(self.flux_pipeline, *args, **kwargs)
def rgb_with_grad(self, *args, **kwargs) -> FluxPipelineOutput:
return pipeline_step_with_grad(self.flux_pipeline, *args, **kwargs)
@property
def transformer(self):
return self.flux_pipeline.transformer
@property
def vae(self):
return self.flux_pipeline.vae
@property
def tokenizer(self):
return self.flux_pipeline.tokenizer
@property
def tokenizer_2(self):
return self.flux_pipeline.tokenizer_2
@property
def scheduler(self):
return self.flux_pipeline.scheduler
@property
def text_encoder(self):
return self.flux_pipeline.text_encoder
@property
def text_encoder_2(self):
return self.flux_pipeline.text_encoder_2
@property
def image_encoder(self):
return self.flux_pipeline.image_encoder
@property
def feature_extractor(self):
return self.flux_pipeline.feature_extractor
@property
def autocast(self):
return contextlib.nullcontext
def save_pretrained(self, output_dir):
state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(self.flux_pipeline.transformer))
self.flux_pipeline.transformer.save_pretrained(output_dir)
def set_progress_bar_config(self, *args, **kwargs):
self.flux_pipeline.set_progress_bar_config(*args, **kwargs)