bo.l commited on
Commit
fed9294
·
1 Parent(s): 72f5ee7
app.py CHANGED
@@ -1,13 +1,46 @@
1
  import gradio as gr
2
  import numpy as np
3
  import random
4
-
5
- # import spaces #[uncomment to use ZeroGPU]
6
- from diffusers import DiffusionPipeline
7
  import torch
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
- model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
11
 
12
  if torch.cuda.is_available():
13
  torch_dtype = torch.float16
@@ -17,14 +50,24 @@ else:
17
  pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
18
  pipe = pipe.to(device)
19
 
 
 
 
 
 
 
 
 
 
 
20
  MAX_SEED = np.iinfo(np.int32).max
21
- MAX_IMAGE_SIZE = 1024
22
 
23
 
24
- # @spaces.GPU #[uncomment to use ZeroGPU]
25
  def infer(
26
  prompt,
27
- negative_prompt,
28
  seed,
29
  randomize_seed,
30
  width,
@@ -38,15 +81,16 @@ def infer(
38
 
39
  generator = torch.Generator().manual_seed(seed)
40
 
41
- image = pipe(
42
- prompt=prompt,
43
- negative_prompt=negative_prompt,
44
- guidance_scale=guidance_scale,
45
- num_inference_steps=num_inference_steps,
46
- width=width,
47
- height=height,
48
- generator=generator,
49
- ).images[0]
 
50
 
51
  return image, seed
52
 
@@ -66,7 +110,7 @@ css = """
66
 
67
  with gr.Blocks(css=css) as demo:
68
  with gr.Column(elem_id="col-container"):
69
- gr.Markdown(" # Text-to-Image Gradio Template")
70
 
71
  with gr.Row():
72
  prompt = gr.Text(
@@ -76,19 +120,16 @@ with gr.Blocks(css=css) as demo:
76
  placeholder="Enter your prompt",
77
  container=False,
78
  )
79
-
80
  run_button = gr.Button("Run", scale=0, variant="primary")
81
 
 
 
 
 
 
82
  result = gr.Image(label="Result", show_label=False)
83
 
84
  with gr.Accordion("Advanced Settings", open=False):
85
- negative_prompt = gr.Text(
86
- label="Negative prompt",
87
- max_lines=1,
88
- placeholder="Enter a negative prompt",
89
- visible=False,
90
- )
91
-
92
  seed = gr.Slider(
93
  label="Seed",
94
  minimum=0,
@@ -96,7 +137,6 @@ with gr.Blocks(css=css) as demo:
96
  step=1,
97
  value=0,
98
  )
99
-
100
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
101
 
102
  with gr.Row():
@@ -105,15 +145,14 @@ with gr.Blocks(css=css) as demo:
105
  minimum=256,
106
  maximum=MAX_IMAGE_SIZE,
107
  step=32,
108
- value=1024, # Replace with defaults that work for your model
109
  )
110
-
111
  height = gr.Slider(
112
  label="Height",
113
  minimum=256,
114
  maximum=MAX_IMAGE_SIZE,
115
  step=32,
116
- value=1024, # Replace with defaults that work for your model
117
  )
118
 
119
  with gr.Row():
@@ -122,24 +161,30 @@ with gr.Blocks(css=css) as demo:
122
  minimum=0.0,
123
  maximum=10.0,
124
  step=0.1,
125
- value=0.0, # Replace with defaults that work for your model
126
  )
127
-
128
  num_inference_steps = gr.Slider(
129
  label="Number of inference steps",
130
  minimum=1,
131
  maximum=50,
132
  step=1,
133
- value=2, # Replace with defaults that work for your model
134
  )
135
 
 
 
 
 
 
136
  gr.Examples(examples=examples, inputs=[prompt])
 
 
137
  gr.on(
138
  triggers=[run_button.click, prompt.submit],
139
  fn=infer,
140
  inputs=[
141
  prompt,
142
- negative_prompt,
143
  seed,
144
  randomize_seed,
145
  width,
@@ -151,4 +196,4 @@ with gr.Blocks(css=css) as demo:
151
  )
152
 
153
  if __name__ == "__main__":
154
- demo.launch()
 
1
  import gradio as gr
2
  import numpy as np
3
  import random
4
+ import spaces #[uncomment to use ZeroGPU]
5
+ from kontext.pipeline_flux_kontext import FluxKontextPipeline
6
+ from kontext.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
7
  import torch
8
 
9
+ def resize_by_bucket(images_pil, resolution=512):
10
+ assert len(images_pil) > 0, "images_pil 不能为空"
11
+ bucket_override = [
12
+ (336, 784), (344, 752), (360, 728), (376, 696),
13
+ (400, 664), (416, 624), (440, 592), (472, 552),
14
+ (512, 512),
15
+ (552, 472), (592, 440), (624, 416), (664, 400),
16
+ (696, 376), (728, 360), (752, 344), (784, 336),
17
+ ]
18
+ bucket_override = [
19
+ (int(h / 512 * resolution), int(w / 512 * resolution))
20
+ for h, w in bucket_override
21
+ ]
22
+ bucket_override = [
23
+ (h // 16 * 16, w // 16 * 16)
24
+ for h, w in bucket_override
25
+ ]
26
+
27
+ aspect_ratios = [img.height / img.width for img in images_pil]
28
+ mean_aspect_ratio = float(np.mean(aspect_ratios))
29
+
30
+ new_h, new_w = bucket_override[0]
31
+ min_aspect_diff = abs(new_h / new_w - mean_aspect_ratio)
32
+ for h, w in bucket_override:
33
+ aspect_diff = abs(h / w - mean_aspect_ratio)
34
+ if aspect_diff < min_aspect_diff:
35
+ min_aspect_diff = aspect_diff
36
+ new_h, new_w = h, w
37
+
38
+ resized_images = [
39
+ img.resize((new_w, new_h), resample=Image.BICUBIC) for img in images_pil
40
+ ]
41
+ return resized_images
42
+
43
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
44
 
45
  if torch.cuda.is_available():
46
  torch_dtype = torch.float16
 
50
  pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
51
  pipe = pipe.to(device)
52
 
53
+ flux_pipeline = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev")
54
+ flux_pipeline.scheduler = FlowMatchEulerDiscreteScheduler.from_config(flux_pipeline.scheduler.config)
55
+ flux_pipeline.vae.to(device).to(torch.bfloat16)
56
+ flux_pipeline.text_encoder.to(device).to(torch.bfloat16)
57
+ flux_pipeline.text_encoder_2.to(device).to(torch.bfloat16)
58
+ flux_pipeline.scheduler.config.stochastic_sampling = False
59
+ finetuned_path = "NoobDoge/Multi_Ref"
60
+ flux_pipeline.transformer = FluxTransformer2DModel.from_pretrained(finetuned_path,subfolder='transformer', torch_dtype=torch.bfloat16)
61
+ flux_pipeline.transformer.to(device).to(torch.bfloat16)
62
+
63
  MAX_SEED = np.iinfo(np.int32).max
64
+ MAX_IMAGE_SIZE = 512
65
 
66
 
67
+ @spaces.GPU #[uncomment to use ZeroGPU]
68
  def infer(
69
  prompt,
70
+ raw_images,
71
  seed,
72
  randomize_seed,
73
  width,
 
81
 
82
  generator = torch.Generator().manual_seed(seed)
83
 
84
+ with torch.no_grad():
85
+ output_img = flux_pipeline(
86
+ image = raw_images,
87
+ prompt = prompts,
88
+ height = height,
89
+ width = width,
90
+ num_inference_steps = num_inference_steps,
91
+ max_area=MAX_IMAGE_SIZE**2,
92
+ generator=generator,
93
+ ).images[0]
94
 
95
  return image, seed
96
 
 
110
 
111
  with gr.Blocks(css=css) as demo:
112
  with gr.Column(elem_id="col-container"):
113
+ gr.Markdown("# Text-to-Image Gradio Template")
114
 
115
  with gr.Row():
116
  prompt = gr.Text(
 
120
  placeholder="Enter your prompt",
121
  container=False,
122
  )
 
123
  run_button = gr.Button("Run", scale=0, variant="primary")
124
 
125
+ # 新增:两张输入图片
126
+ with gr.Row():
127
+ ref1 = gr.Image(label="Input Image 1", type="pil")
128
+ ref2 = gr.Image(label="Input Image 2", type="pil")
129
+
130
  result = gr.Image(label="Result", show_label=False)
131
 
132
  with gr.Accordion("Advanced Settings", open=False):
 
 
 
 
 
 
 
133
  seed = gr.Slider(
134
  label="Seed",
135
  minimum=0,
 
137
  step=1,
138
  value=0,
139
  )
 
140
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
141
 
142
  with gr.Row():
 
145
  minimum=256,
146
  maximum=MAX_IMAGE_SIZE,
147
  step=32,
148
+ value=1024,
149
  )
 
150
  height = gr.Slider(
151
  label="Height",
152
  minimum=256,
153
  maximum=MAX_IMAGE_SIZE,
154
  step=32,
155
+ value=1024,
156
  )
157
 
158
  with gr.Row():
 
161
  minimum=0.0,
162
  maximum=10.0,
163
  step=0.1,
164
+ value=0.0,
165
  )
 
166
  num_inference_steps = gr.Slider(
167
  label="Number of inference steps",
168
  minimum=1,
169
  maximum=50,
170
  step=1,
171
+ value=2,
172
  )
173
 
174
+ # 如果 examples 只包含文本 prompt,保持如下即可
175
+ examples = [
176
+ ["a cute corgi in a wizard hat"],
177
+ ["a watercolor painting of yosemite valley at sunrise"],
178
+ ]
179
  gr.Examples(examples=examples, inputs=[prompt])
180
+ raw_images=[ref1, ref2]
181
+ raw_images = [x for resize_by_bucket(x) in raw_images]
182
  gr.on(
183
  triggers=[run_button.click, prompt.submit],
184
  fn=infer,
185
  inputs=[
186
  prompt,
187
+ raw_images, # 新增:两张图
188
  seed,
189
  randomize_seed,
190
  width,
 
196
  )
197
 
198
  if __name__ == "__main__":
199
+ demo.launch()
kontext/__pycache__/pipeline_flux_kontext.cpython-311.pyc ADDED
Binary file (57.4 kB). View file
 
kontext/__pycache__/scheduling_flow_match_euler_discrete.cpython-311.pyc ADDED
Binary file (28.6 kB). View file
 
kontext/ddpo_edit_trainer.py ADDED
@@ -0,0 +1,601 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, pickle, random, json, os, base64, io
2
+ import torch
3
+ import torch.nn as nn
4
+ import matplotlib.pyplot as plt
5
+ from PIL import Image
6
+ import numpy as np
7
+ from glob import glob
8
+ from tqdm import tqdm, trange
9
+ from concurrent.futures import ThreadPoolExecutor, as_completed
10
+
11
+ from collections import defaultdict
12
+ from concurrent import futures
13
+ from pathlib import Path
14
+ from accelerate import Accelerator
15
+ from typing import Any, Callable, Optional, Union
16
+ from warnings import warn
17
+ from peft import LoraConfig, get_peft_model
18
+ from accelerate.logging import get_logger
19
+ from accelerate.utils import ProjectConfiguration, set_seed
20
+ from huggingface_hub import PyTorchModelHubMixin
21
+
22
+ from modeling_flux_base import DefaultDDPOFluxPipeline
23
+ from ddpo_flux_config import DDPOFluxConfig
24
+
25
+ from transformers import is_wandb_available
26
+ if is_wandb_available():
27
+ import wandb
28
+
29
+ logger = get_logger(__name__)
30
+
31
+ class DDPOTrainer_edit(PyTorchModelHubMixin):
32
+
33
+ def __init__(
34
+ self,
35
+ config: DDPOFluxConfig,
36
+ reward_function: Callable[[], tuple[str, Any]],
37
+ prompt_function: Callable[[], tuple[str, Any]],
38
+ edit_pipeline: DefaultDDPOFluxPipeline,
39
+ image_samples_hook: Optional[Callable[[Any, Any, Any], Any]] = None,
40
+ ):
41
+ if image_samples_hook is None:
42
+ warn("No image_samples_hook provided; no images will be logged")
43
+
44
+ self.prompt_fn = prompt_function
45
+ self.reward_fn = reward_function
46
+ self.config = config
47
+ self.image_samples_callback = image_samples_hook
48
+ accelerator_project_config = ProjectConfiguration(**self.config.project_kwargs)
49
+ self.project_dir = accelerator_project_config.project_dir
50
+
51
+ if self.config.resume_from:
52
+ if self.config.resume_from == "latest":
53
+ dirs = os.listdir(self.project_dir)
54
+ dirs = [d for d in dirs if d.startswith("checkpoint_lora")]
55
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
56
+ if len(dirs) == 0:
57
+ print(f"Checkpoint '{self.config.resume_from}' does not exist. Starting a new training run.")
58
+ self.config.resume_from = ""
59
+ path = dirs[-1]
60
+ else:
61
+ path = os.path.basename(self.config.resume_from)
62
+ self.config.resume_from = os.path.join(self.project_dir, path)
63
+ accelerator_project_config.iteration = int(path.split("-")[1])+1
64
+
65
+ # number of timesteps within each trajectory to train on
66
+ self.num_train_timesteps = int(self.config.sample_num_steps * self.config.train_timestep_fraction - 1)
67
+
68
+ self.accelerator = Accelerator(
69
+ log_with=self.config.log_with,
70
+ mixed_precision=self.config.mixed_precision,
71
+ project_config=accelerator_project_config,
72
+ # we always accumulate gradients across timesteps; we want config.train.gradient_accumulation_steps to be the
73
+ # number of *samples* we accumulate across, so we need to multiply by the number of training timesteps to get
74
+ # the total number of optimizer steps to accumulate across.
75
+ gradient_accumulation_steps=self.config.train_gradient_accumulation_steps * self.num_train_timesteps,
76
+ **self.config.accelerator_kwargs,
77
+ )
78
+ is_using_tensorboard = config.log_with is not None and config.log_with == "tensorboard"
79
+
80
+ if self.accelerator.is_main_process:
81
+ self.accelerator.init_trackers(
82
+ self.config.tracker_project_name,
83
+ config=dict(ddpo_trainer_config=config.to_dict()) if not is_using_tensorboard else config.to_dict(),
84
+ init_kwargs=self.config.tracker_kwargs,
85
+ )
86
+
87
+ is_okay, message = self._config_check()
88
+ if not is_okay:
89
+ raise ValueError(message)
90
+
91
+ logger.info(f"\n{config}")
92
+
93
+ set_seed(self.config.seed, device_specific=True)
94
+
95
+ self.edit_pipeline = edit_pipeline
96
+
97
+ self.edit_pipeline.set_progress_bar_config(
98
+ position=1,
99
+ disable=not self.accelerator.is_local_main_process,
100
+ leave=False,
101
+ desc="Timestep",
102
+ dynamic_ncols=True,
103
+ )
104
+
105
+ # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora transformer) to half-precision
106
+ # as these weights are only used for inference, keeping weights in full precision is not required.
107
+ if self.accelerator.mixed_precision == "fp16":
108
+ inference_dtype = torch.float16
109
+ elif self.accelerator.mixed_precision == "bf16":
110
+ inference_dtype = torch.bfloat16
111
+ else:
112
+ inference_dtype = torch.float32
113
+
114
+ self.edit_pipeline.vae.to(self.accelerator.device, dtype=inference_dtype)
115
+ self.edit_pipeline.text_encoder.to(self.accelerator.device, dtype=inference_dtype)
116
+ self.edit_pipeline.text_encoder_2.to(self.accelerator.device, dtype=inference_dtype)
117
+
118
+ lora_config = LoraConfig(
119
+ r=self.config.lora_rank,
120
+ lora_alpha=self.config.lora_alpha,
121
+ init_lora_weights="gaussian",
122
+ target_modules=["to_k", "to_q", "to_v", "to_out.0"],
123
+ )
124
+ self.edit_pipeline.flux_pipeline.transformer.requires_grad_(False)
125
+ self.edit_pipeline.flux_pipeline.transformer = get_peft_model(self.edit_pipeline.flux_pipeline.transformer, lora_config)
126
+ trainable_params = [p for p in list(self.edit_pipeline.flux_pipeline.transformer.parameters()) if p.requires_grad]
127
+ total_params = sum(p.numel() for p in trainable_params)
128
+
129
+ self.optimizer = torch.optim.AdamW(
130
+ trainable_params,
131
+ lr=self.config.train_learning_rate,
132
+ betas=(self.config.train_adam_beta1, self.config.train_adam_beta2),
133
+ weight_decay=self.config.train_adam_weight_decay,
134
+ eps=self.config.train_adam_epsilon,
135
+ )
136
+
137
+ (
138
+ self.negative_prompt_embeds,
139
+ self.negative_pooled_prompt_embeds,
140
+ self.negative_text_ids,
141
+ ) = self.edit_pipeline.flux_pipeline.encode_prompt(
142
+ prompt=[""] if self.config.negative_prompts is None else self.config.negative_prompts,
143
+ prompt_2=[""],
144
+ device=self.accelerator.device,
145
+ )
146
+
147
+
148
+ # NOTE: for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses
149
+ # more memory
150
+ self.autocast = self.edit_pipeline.autocast or self.accelerator.autocast
151
+
152
+ if self.config.resume_from:
153
+ print(f"Resuming from {self.config.resume_from}")
154
+ logger.info(f"Resuming from {self.config.resume_from}")
155
+ self.edit_pipeline.flux_pipeline.transformer.load_adapter(self.config.resume_from, adapter_name="default", is_trainable=True)
156
+ self.edit_pipeline.flux_pipeline.transformer.train()
157
+ self.first_epoch = accelerator_project_config.iteration
158
+ else:
159
+ self.first_epoch = 0
160
+
161
+ self.edit_pipeline.flux_pipeline.transformer, self.optimizer = self.accelerator.prepare(self.edit_pipeline.flux_pipeline.transformer, self.optimizer)
162
+
163
+ self.trainable_layers = list(filter(lambda p: p.requires_grad, self.edit_pipeline.flux_pipeline.transformer.parameters()))
164
+
165
+ self.executor = futures.ThreadPoolExecutor(max_workers=self.config.max_workers)#config.max_workers
166
+
167
+
168
+
169
+ def compute_rewards(self, prompt_image_pairs):
170
+ all_rewards = []
171
+ all_meta_data = []
172
+ for img, prompt, raw_img, img_path in prompt_image_pairs:
173
+
174
+ data_pair_vllm = []
175
+ for idx in range(len(img)):
176
+ data_pair_vllm.append((raw_img[idx][0],raw_img[idx][1], prompt[idx], img[idx]))
177
+ # rewards = self.executor.map(lambda x: self.reward_fn(*x), data_pair_vllm)
178
+ # -------- submit + as_completed --------
179
+ fut_to_idx = {
180
+ self.executor.submit(self.reward_fn, *triple): idx
181
+ for idx, triple in enumerate(data_pair_vllm)
182
+ }
183
+
184
+ # Collect results in original order
185
+ rewards = [None] * len(data_pair_vllm)
186
+ for fut in futures.as_completed(fut_to_idx):
187
+ idx = fut_to_idx[fut]
188
+ rewards[idx] = fut.result()
189
+
190
+ rewards_ = [torch.as_tensor(reward, device=self.accelerator.device) for reward, reward_metadata in rewards]
191
+ rewards_ = torch.stack(rewards_)
192
+ all_rewards.append(rewards_)
193
+ all_meta_data.append(img_path)
194
+ return all_rewards, all_meta_data
195
+
196
+ def step(self, epoch: int, global_step: int):
197
+
198
+ """
199
+ Perform a single step of training.
200
+
201
+ Args:
202
+ epoch (int): The current epoch.
203
+ global_step (int): The current global step.
204
+
205
+ Side Effects:
206
+ - Model weights are updated
207
+ - Logs the statistics to the accelerator trackers.
208
+ - If `self.image_samples_callback` is not None, it will be called with the prompt_image_pairs, global_step,
209
+ and the accelerator tracker.
210
+
211
+ Returns:
212
+ global_step (int): The updated global step.
213
+
214
+ """
215
+ samples, prompt_image_data = self._generate_samples(
216
+ iterations=self.config.sample_num_batches_per_epoch,
217
+ batch_size=self.config.sample_batch_size,
218
+ )
219
+ # collate samples into dict where each entry has shape (num_batches_per_epoch * sample.batch_size, ...)
220
+ local_rank = self.accelerator.local_process_index
221
+ samples = {k: torch.cat([s[k] for s in samples]) for k in samples[0].keys()}
222
+ rewards, rewards_metadata = self.compute_rewards(prompt_image_data)
223
+ for i, image_data in enumerate(prompt_image_data):
224
+ image_data.extend([rewards[i], rewards_metadata[i]])
225
+
226
+ if self.image_samples_callback is not None and self.accelerator.is_main_process:
227
+ self.image_samples_callback(prompt_image_data, global_step, self.accelerator.trackers[0])
228
+ rewards = torch.cat(rewards)
229
+ rewards = self.accelerator.gather(rewards).cpu().numpy()
230
+ if self.accelerator.is_main_process:
231
+ print(rewards.mean())
232
+
233
+ self.accelerator.log(
234
+ {
235
+ "reward": rewards,
236
+ "epoch": epoch,
237
+ "reward_mean": rewards.mean(),
238
+ "reward_std": rewards.std(),
239
+ },
240
+ step=global_step,
241
+ )
242
+ advantages = (rewards - rewards.mean()) / (rewards.std() + 1e-8)
243
+ # ungather advantages; keep the entries corresponding to the samples on this process
244
+ samples["advantages"] = (
245
+ torch.as_tensor(advantages)
246
+ .reshape(self.accelerator.num_processes, -1)[self.accelerator.process_index]
247
+ .to(self.accelerator.device)
248
+ )
249
+
250
+ del samples["prompt_ids"]
251
+ del samples["text_ids"]
252
+ del samples["latent_ids"]
253
+ del samples["negative_text_ids"]
254
+
255
+ total_batch_size, num_timesteps = samples["timesteps"].shape
256
+ self.accelerator.wait_for_everyone()
257
+ for inner_epoch in range(self.config.train_num_inner_epochs):
258
+ # shuffle samples along batch dimension
259
+ perm = torch.randperm(total_batch_size, device=self.accelerator.device)
260
+ samples = {k: v[perm] for k, v in samples.items()}
261
+
262
+ # shuffle along time dimension independently for each sample
263
+ # still trying to understand the code below
264
+ perms = torch.stack(
265
+ [torch.randperm(num_timesteps, device=self.accelerator.device) for _ in range(total_batch_size)]
266
+ )
267
+
268
+ for key in ["timesteps", "latents", "next_latents", "log_probs"]:
269
+ samples[key] = samples[key][
270
+ torch.arange(total_batch_size, device=self.accelerator.device)[:, None],
271
+ perms,
272
+ ]
273
+
274
+ original_keys = samples.keys()
275
+ original_values = samples.values()
276
+ # rebatch them as user defined train_batch_size is different from sample_batch_size
277
+ reshaped_values = [v.reshape(-1, self.config.train_batch_size, *v.shape[1:]) for v in original_values]
278
+
279
+ # Transpose the list of original values
280
+ transposed_values = zip(*reshaped_values)
281
+ # Create new dictionaries for each row of transposed values
282
+ samples_batched = [dict(zip(original_keys, row_values)) for row_values in transposed_values]
283
+
284
+ self.edit_pipeline.transformer.train()
285
+ global_step = self._train_batched_samples(inner_epoch, epoch, global_step, samples_batched)
286
+ # ensure optimization step at the end of the inner epoch
287
+ if not self.accelerator.sync_gradients:
288
+ raise ValueError(
289
+ "Optimization step should have been performed by this point. Please check calculated gradient accumulation settings."
290
+ )
291
+
292
+ if self.accelerator.sync_gradients:
293
+ if self.accelerator.is_main_process:
294
+ print("Save checkpoint on epoch", epoch)
295
+ save_model = self.edit_pipeline.flux_pipeline.transformer
296
+ unwrapped_model = self.accelerator.unwrap_model(save_model)
297
+ unwrapped_model.save_pretrained(
298
+ f"{self.project_dir}/checkpoint_lora-{epoch}",
299
+ is_main_process=self.accelerator.is_main_process,
300
+ save_function=self.accelerator.save,
301
+ state_dict=self.accelerator.get_state_dict(save_model),
302
+ )
303
+
304
+ self.accelerator.wait_for_everyone()
305
+
306
+ return global_step, rewards.mean()
307
+
308
+ def calculate_loss(self, latents, image_latents, timestep, next_latents, log_probs, advantages, pooled_prompt_embeds, prompt_embeds, negative_pooled_prompt_embeds, negative_prompt_embeds):
309
+ """
310
+ Calculate the loss for a batch of an unpacked sample
311
+
312
+ Args:
313
+ latents (torch.Tensor):
314
+ The latents sampled from the diffusion model, shape: [batch_size, num_channels_latents, height, width]
315
+ timesteps (torch.Tensor):
316
+ The timesteps sampled from the diffusion model, shape: [batch_size]
317
+ next_latents (torch.Tensor):
318
+ The next latents sampled from the diffusion model, shape: [batch_size, num_channels_latents, height,
319
+ width]
320
+ log_probs (torch.Tensor):
321
+ The log probabilities of the latents, shape: [batch_size]
322
+ advantages (torch.Tensor):
323
+ The advantages of the latents, shape: [batch_size]
324
+ embeds (torch.Tensor):
325
+ The embeddings of the prompts, shape: [2*batch_size or batch_size, ...] Note: the "or" is because if
326
+ train_cfg is True, the expectation is that negative prompts are concatenated to the embeds
327
+
328
+ Returns:
329
+ loss (torch.Tensor), approx_kl (torch.Tensor), clipfrac (torch.Tensor) (all of these are of shape (1,))
330
+ """
331
+ torch.autograd.set_detect_anomaly(True)
332
+ with self.autocast():
333
+
334
+ latent_model_input = torch.cat([latents, image_latents], dim=1)
335
+ latent_model_input = latent_model_input.detach()
336
+ pooled_prompt_embeds = pooled_prompt_embeds.detach()
337
+ prompt_embeds = prompt_embeds.detach()
338
+ guidance = torch.full([1], self.config.sample_guidance, device=self.edit_pipeline.transformer.device, dtype=torch.bfloat16)
339
+ guidance = guidance.expand(latent_model_input.shape[0])
340
+ noise_pred = self.edit_pipeline.transformer(
341
+ hidden_states=latent_model_input,
342
+ timestep=timestep.detach() / 1000,
343
+ guidance=guidance.detach(),
344
+ pooled_projections=pooled_prompt_embeds,
345
+ encoder_hidden_states=prompt_embeds,
346
+ txt_ids=self.text_ids.detach(),
347
+ img_ids=self.latent_ids.detach(),
348
+ return_dict=False,
349
+ )[0]
350
+ noise_pred = noise_pred[:, : latents.size(1)]
351
+ if self.config.train_cfg:
352
+ neg_noise_pred = self.edit_pipeline.transformer(
353
+ hidden_states=latent_model_input,
354
+ timestep=timestep / 1000,
355
+ guidance=self.config.sample_guidance,
356
+ pooled_projections=negative_pooled_prompt_embeds,
357
+ encoder_hidden_states=negative_prompt_embeds,
358
+ txt_ids=self.negative_text_ids,
359
+ img_ids=self.latent_ids,
360
+ return_dict=False,
361
+ )[0]
362
+ neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
363
+ noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
364
+
365
+ # compute the log prob of next_latents given latents under the current model
366
+ scheduler_step_output = self.edit_pipeline.scheduler.step(
367
+ noise_pred,
368
+ timestep.detach(),
369
+ latents.detach(),
370
+ prev_sample=next_latents.detach(),
371
+ return_dict=True,
372
+ init_step=True,
373
+ )
374
+
375
+ log_prob = scheduler_step_output.log_probs
376
+ advantages = torch.clamp(
377
+ advantages,
378
+ -self.config.train_adv_clip_max,
379
+ self.config.train_adv_clip_max,
380
+ )
381
+
382
+ ratio = torch.exp(log_prob - log_probs)
383
+
384
+ loss = self.loss(advantages, self.config.train_clip_range, ratio)
385
+
386
+ approx_kl = 0.5 * torch.mean((log_prob - log_probs) ** 2)
387
+
388
+ clipfrac = torch.mean((torch.abs(ratio - 1.0) > self.config.train_clip_range).float())
389
+
390
+ return loss, approx_kl, clipfrac
391
+
392
+ def loss(
393
+ self,
394
+ advantages: torch.Tensor,
395
+ clip_range: float,
396
+ ratio: torch.Tensor,
397
+ ):
398
+ unclipped_loss = -advantages * ratio
399
+ clipped_loss = -advantages * torch.clamp(
400
+ ratio,
401
+ 1.0 - clip_range,
402
+ 1.0 + clip_range,
403
+ )
404
+ return torch.mean(torch.maximum(unclipped_loss, clipped_loss))
405
+
406
+ def _generate_samples(self, iterations, batch_size):
407
+ """
408
+ Generate samples from the model
409
+
410
+ Args:
411
+ iterations (int): Number of iterations to generate samples for
412
+ batch_size (int): Batch size to use for sampling
413
+
414
+ Returns:
415
+ samples (list[dict[str, torch.Tensor]]), prompt_image_pairs (list[list[Any]])
416
+ """
417
+ samples = []
418
+ prompt_image_pairs = []
419
+ self.edit_pipeline.transformer.eval()
420
+
421
+ sample_neg_prompt_embeds = self.negative_prompt_embeds.repeat(batch_size, 1, 1)
422
+ sample_neg_pooled_prompt_embeds = self.negative_pooled_prompt_embeds.repeat(batch_size, 1)
423
+ sample_neg_text_ids = self.negative_text_ids
424
+
425
+ for iters in range(iterations):
426
+ prompts, raw_images, img_paths = map(list, zip(*[self.prompt_fn('multi') for _ in range(batch_size)]))
427
+ if len(raw_images) == batch_size:
428
+ raw_images = list(map(list, zip(*raw_images)))
429
+
430
+ (
431
+ prompt_embeds,
432
+ pooled_prompt_embeds,
433
+ text_ids,
434
+ ) = self.edit_pipeline.flux_pipeline.encode_prompt(
435
+ prompt=prompts,
436
+ prompt_2=prompts,
437
+ device=self.accelerator.device,
438
+ )
439
+
440
+ prompt_ids = self.edit_pipeline.tokenizer(
441
+ prompts,
442
+ padding="max_length",
443
+ max_length=self.edit_pipeline.flux_pipeline.tokenizer_max_length,
444
+ truncation=True,
445
+ return_tensors="pt",
446
+ ).input_ids.to(self.accelerator.device)
447
+ generator = torch.Generator(device='cuda')
448
+ generator.seed()
449
+ with self.autocast():
450
+ with torch.no_grad():
451
+ edit_output = self.edit_pipeline(
452
+ image=raw_images,
453
+ height=self.config.height,
454
+ width=self.config.width,
455
+ prompt_embeds=prompt_embeds,
456
+ pooled_prompt_embeds=pooled_prompt_embeds,
457
+ negative_prompt_embeds=sample_neg_prompt_embeds,
458
+ negative_pooled_prompt_embeds=sample_neg_pooled_prompt_embeds,
459
+ num_inference_steps=self.config.sample_num_steps,
460
+ guidance_scale=self.config.sample_guidance,
461
+ generator=generator,
462
+ output_type="pt",
463
+ max_area=self.config.max_size**2,
464
+ )
465
+
466
+ images = edit_output.images
467
+ latents = edit_output.latents
468
+ log_probs = edit_output.log_probs
469
+ timesteps = edit_output.timesteps
470
+ latent_ids = edit_output.latent_ids
471
+ image_latents = edit_output.image_latents
472
+
473
+ latents = torch.stack(latents, dim=1) # (batch_size, num_steps + 1, ...)
474
+ log_probs = torch.stack(log_probs, dim=1) # (batch_size, num_steps, 1)
475
+ timesteps = torch.stack(timesteps, dim=1)
476
+
477
+ samples.append(
478
+ {
479
+ "prompt_ids": prompt_ids.float(),
480
+ "timesteps": timesteps[:, :-1],
481
+ "latents": latents[:, :-2], # each entry is the latent before timestep t
482
+ "next_latents": latents[:, 1:-1], # each entry is the latent after timestep t
483
+ "log_probs": log_probs[:, :-1],
484
+ "pooled_prompt_embeds":pooled_prompt_embeds,
485
+ "prompt_embeds":prompt_embeds,
486
+ "negative_prompt_embeds":sample_neg_prompt_embeds,
487
+ "negative_pooled_prompt_embeds":sample_neg_pooled_prompt_embeds,
488
+ "text_ids":text_ids,
489
+ "latent_ids":latent_ids,
490
+ "negative_text_ids":sample_neg_text_ids,
491
+ "image_latents":image_latents,
492
+ }
493
+ )
494
+ raw_images = [list(x) for x in zip(*raw_images)]
495
+ prompt_image_pairs.append([images, prompts, raw_images, img_paths])
496
+ local_rank = self.accelerator.local_process_index
497
+ self.text_ids = samples[0]['text_ids']
498
+ self.latent_ids = samples[0]['latent_ids']
499
+ self.negative_text_ids = samples[0]['negative_text_ids']
500
+ return samples, prompt_image_pairs
501
+
502
+ def _train_batched_samples(self, inner_epoch, epoch, global_step, batched_samples):
503
+ """
504
+ Train on a batch of samples. Main training segment
505
+
506
+ Args:
507
+ inner_epoch (int): The current inner epoch
508
+ epoch (int): The current epoch
509
+ global_step (int): The current global step
510
+ batched_samples (list[dict[str, torch.Tensor]]): The batched samples to train on
511
+
512
+ Side Effects:
513
+ - Model weights are updated
514
+ - Logs the statistics to the accelerator trackers.
515
+
516
+ Returns:
517
+ global_step (int): The updated global step
518
+ """
519
+ info = defaultdict(list)
520
+ for _i, sample in enumerate(batched_samples):
521
+
522
+ for j in trange(self.num_train_timesteps):
523
+ with self.accelerator.accumulate(self.edit_pipeline.transformer):
524
+ loss, approx_kl, clipfrac = self.calculate_loss(
525
+ sample["latents"][:, j],
526
+ sample["image_latents"],
527
+ sample["timesteps"][:, j],
528
+ sample["next_latents"][:, j],
529
+ sample["log_probs"][:, j],
530
+ sample["advantages"],
531
+ sample["pooled_prompt_embeds"],
532
+ sample["prompt_embeds"],
533
+ sample["negative_pooled_prompt_embeds"],
534
+ sample["negative_prompt_embeds"],
535
+ )
536
+ info["approx_kl"].append(approx_kl)
537
+ info["clipfrac"].append(clipfrac)
538
+ info["loss"].append(loss)
539
+
540
+ self.accelerator.backward(loss)
541
+ if self.accelerator.sync_gradients:
542
+ self.accelerator.clip_grad_norm_(
543
+ self.trainable_layers.parameters()
544
+ if not isinstance(self.trainable_layers, list)
545
+ else self.trainable_layers,
546
+ self.config.train_max_grad_norm,
547
+ )
548
+ self.optimizer.step()
549
+ self.optimizer.zero_grad()
550
+
551
+ # Checks if the accelerator has performed an optimization step behind the scenes
552
+ if self.accelerator.sync_gradients:
553
+ # log training-related stuff
554
+ info = {k: torch.mean(torch.stack(v)) for k, v in info.items()}
555
+ info = self.accelerator.reduce(info, reduction="mean")
556
+ info.update({"epoch": epoch, "inner_epoch": inner_epoch})
557
+ self.accelerator.log(info, step=global_step)
558
+ global_step += 1
559
+ info = defaultdict(list)
560
+ return global_step
561
+
562
+ def _config_check(self) -> tuple[bool, str]:
563
+ samples_per_epoch = (
564
+ self.config.sample_batch_size * self.accelerator.num_processes * self.config.sample_num_batches_per_epoch
565
+ )
566
+ total_train_batch_size = (
567
+ self.config.train_batch_size
568
+ * self.accelerator.num_processes
569
+ * self.config.train_gradient_accumulation_steps
570
+ )
571
+
572
+ if not self.config.sample_batch_size >= self.config.train_batch_size:
573
+ return (
574
+ False,
575
+ f"Sample batch size ({self.config.sample_batch_size}) must be greater than or equal to the train batch size ({self.config.train_batch_size})",
576
+ )
577
+ if not self.config.sample_batch_size % self.config.train_batch_size == 0:
578
+ return (
579
+ False,
580
+ f"Sample batch size ({self.config.sample_batch_size}) must be divisible by the train batch size ({self.config.train_batch_size})",
581
+ )
582
+ if not samples_per_epoch % total_train_batch_size == 0:
583
+ return (
584
+ False,
585
+ f"Number of samples per epoch ({samples_per_epoch}) must be divisible by the total train batch size ({total_train_batch_size})",
586
+ )
587
+ return True, ""
588
+
589
+
590
+ def train(self, epochs: Optional[int] = None):
591
+ """
592
+ Train the model for a given number of epochs
593
+ """
594
+ global_step = 0
595
+ rewards_curve = []
596
+ if epochs is None:
597
+ epochs = self.config.num_epochs
598
+ for epoch in range(self.first_epoch, epochs):
599
+ global_step, reward_mean = self.step(epoch, global_step)
600
+ rewards_curve.append(reward_mean)
601
+ return rewards_curve
kontext/ddpo_flux_config.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import sys
17
+ from dataclasses import dataclass, field
18
+ from typing import Optional
19
+
20
+ from transformers import is_bitsandbytes_available
21
+
22
+ from trl.core import flatten_dict
23
+
24
+
25
+ @dataclass
26
+ class DDPOFluxConfig:
27
+ r"""
28
+ Configuration class for the [`DDPOTrainer`].
29
+
30
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
31
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
32
+ command line.
33
+
34
+ Parameters:
35
+ exp_name (`str`, *optional*, defaults to `os.path.basename(sys.argv[0])[: -len(".py")]`):
36
+ Name of this experiment (by default is the file name without the extension name).
37
+ run_name (`str`, *optional*, defaults to `""`):
38
+ Name of this run.
39
+ seed (`int`, *optional*, defaults to `0`):
40
+ Random seed.
41
+ log_with (`Literal["wandb", "tensorboard"]]` or `None`, *optional*, defaults to `None`):
42
+ Log with either 'wandb' or 'tensorboard', check
43
+ https://huggingface.co/docs/accelerate/usage_guides/tracking for more details.
44
+ tracker_kwargs (`Dict`, *optional*, defaults to `{}`):
45
+ Keyword arguments for the tracker (e.g. wandb_project).
46
+ accelerator_kwargs (`Dict`, *optional*, defaults to `{}`):
47
+ Keyword arguments for the accelerator.
48
+ project_kwargs (`Dict`, *optional*, defaults to `{}`):
49
+ Keyword arguments for the accelerator project config (e.g. `logging_dir`).
50
+ tracker_project_name (`str`, *optional*, defaults to `"trl"`):
51
+ Name of project to use for tracking.
52
+ logdir (`str`, *optional*, defaults to `"logs"`):
53
+ Top-level logging directory for checkpoint saving.
54
+ num_epochs (`int`, *optional*, defaults to `100`):
55
+ Number of epochs to train.
56
+ save_freq (`int`, *optional*, defaults to `1`):
57
+ Number of epochs between saving model checkpoints.
58
+ num_checkpoint_limit (`int`, *optional*, defaults to `5`):
59
+ Number of checkpoints to keep before overwriting old ones.
60
+ mixed_precision (`str`, *optional*, defaults to `"fp16"`):
61
+ Mixed precision training.
62
+ allow_tf32 (`bool`, *optional*, defaults to `True`):
63
+ Allow `tf32` on Ampere GPUs.
64
+ resume_from (`str`, *optional*, defaults to `""`):
65
+ Resume training from a checkpoint.
66
+ sample_num_steps (`int`, *optional*, defaults to `50`):
67
+ Number of sampler inference steps.
68
+ sample_eta (`float`, *optional*, defaults to `1.0`):
69
+ Eta parameter for the DDIM sampler.
70
+ sample_guidance_scale (`float`, *optional*, defaults to `5.0`):
71
+ Classifier-free guidance weight.
72
+ sample_batch_size (`int`, *optional*, defaults to `1`):
73
+ Batch size (per GPU) to use for sampling.
74
+ sample_num_batches_per_epoch (`int`, *optional*, defaults to `2`):
75
+ Number of batches to sample per epoch.
76
+ train_batch_size (`int`, *optional*, defaults to `1`):
77
+ Batch size (per GPU) to use for training.
78
+ train_use_8bit_adam (`bool`, *optional*, defaults to `False`):
79
+ Use 8bit Adam optimizer from bitsandbytes.
80
+ train_learning_rate (`float`, *optional*, defaults to `3e-4`):
81
+ Learning rate.
82
+ train_adam_beta1 (`float`, *optional*, defaults to `0.9`):
83
+ Adam beta1.
84
+ train_adam_beta2 (`float`, *optional*, defaults to `0.999`):
85
+ Adam beta2.
86
+ train_adam_weight_decay (`float`, *optional*, defaults to `1e-4`):
87
+ Adam weight decay.
88
+ train_adam_epsilon (`float`, *optional*, defaults to `1e-8`):
89
+ Adam epsilon.
90
+ train_gradient_accumulation_steps (`int`, *optional*, defaults to `1`):
91
+ Number of gradient accumulation steps.
92
+ train_max_grad_norm (`float`, *optional*, defaults to `1.0`):
93
+ Maximum gradient norm for gradient clipping.
94
+ train_num_inner_epochs (`int`, *optional*, defaults to `1`):
95
+ Number of inner epochs per outer epoch.
96
+ train_cfg (`bool`, *optional*, defaults to `True`):
97
+ Whether to use classifier-free guidance during training.
98
+ train_adv_clip_max (`float`, *optional*, defaults to `5.0`):
99
+ Clip advantages to the range.
100
+ train_clip_range (`float`, *optional*, defaults to `1e-4`):
101
+ PPO clip range.
102
+ train_timestep_fraction (`float`, *optional*, defaults to `1.0`):
103
+ Fraction of timesteps to train on.
104
+ per_prompt_stat_tracking (`bool`, *optional*, defaults to `False`):
105
+ Whether to track statistics for each prompt separately.
106
+ per_prompt_stat_tracking_buffer_size (`int`, *optional*, defaults to `16`):
107
+ Number of reward values to store in the buffer for each prompt.
108
+ per_prompt_stat_tracking_min_count (`int`, *optional*, defaults to `16`):
109
+ Minimum number of reward values to store in the buffer.
110
+ async_reward_computation (`bool`, *optional*, defaults to `False`):
111
+ Whether to compute rewards asynchronously.
112
+ max_workers (`int`, *optional*, defaults to `2`):
113
+ Maximum number of workers to use for async reward computation.
114
+ negative_prompts (`str`, *optional*, defaults to `""`):
115
+ Comma-separated list of prompts to use as negative examples.
116
+ push_to_hub (`bool`, *optional*, defaults to `False`):
117
+ Whether to push the final model checkpoint to the Hub.
118
+ """
119
+
120
+ exp_name: str = field(
121
+ default=os.path.basename(sys.argv[0])[: -len(".py")],
122
+ metadata={"help": "Name of this experiment (by default is the file name without the extension name)."},
123
+ )
124
+ run_name: str = field(
125
+ default="",
126
+ metadata={"help": "Name of this run."},
127
+ )
128
+ seed: int = field(
129
+ default=0,
130
+ metadata={"help": "Random seed."},
131
+ )
132
+ log_with: Optional[str] = field(
133
+ default="wandb",
134
+ metadata={
135
+ "help": "Log with either 'wandb' or 'tensorboard'.",
136
+ "choices": ["wandb", "tensorboard"],
137
+ },
138
+ )
139
+ tracker_kwargs: dict = field(
140
+ default_factory=dict,
141
+ metadata={"help": "Keyword arguments for the tracker (e.g. wandb_project)."},
142
+ )
143
+ accelerator_kwargs: dict = field(
144
+ default_factory=dict,
145
+ metadata={"help": "Keyword arguments for the accelerator."},
146
+ )
147
+ project_kwargs: dict = field(
148
+ default_factory=dict,
149
+ metadata={"help": "Keyword arguments for the accelerator project config (e.g. `logging_dir`)."},
150
+ )
151
+ tracker_project_name: str = field(
152
+ default="trl_flux_ddpo",
153
+ metadata={"help": "Name of project to use for tracking."},
154
+ )
155
+ logdir: str = field(
156
+ default="logs",
157
+ metadata={"help": "Top-level logging directory for checkpoint saving."},
158
+ )
159
+ num_epochs: int = field(
160
+ default=100,
161
+ metadata={"help": "Number of epochs to train."},
162
+ )
163
+ save_freq: int = field(
164
+ default=1,
165
+ metadata={"help": "Number of epochs between saving model checkpoints."},
166
+ )
167
+ num_checkpoint_limit: int = field(
168
+ default=5,
169
+ metadata={"help": "Number of checkpoints to keep before overwriting old ones."},
170
+ )
171
+ mixed_precision: str = field(
172
+ default="bf16",
173
+ metadata={"help": "Mixed precision training."},
174
+ )
175
+ allow_tf32: bool = field(
176
+ default=True,
177
+ metadata={"help": "Allow `tf32` on Ampere GPUs."},
178
+ )
179
+ resume_from: str = field(
180
+ default="",
181
+ metadata={"help": "Resume training from a checkpoint."},
182
+ )
183
+ sample_num_steps: int = field(
184
+ default=10,
185
+ metadata={"help": "Number of sampler inference steps."},
186
+ )
187
+ sample_guidance: float = field(
188
+ default=3.5,
189
+ metadata={"help": "Classifier-free guidance weight."},
190
+ )
191
+ sample_batch_size: int = field(
192
+ default=1,
193
+ metadata={"help": "Batch size (per GPU) to use for sampling."},
194
+ )
195
+ sample_num_batches_per_epoch: int = field(
196
+ default=2,
197
+ metadata={"help": "Number of batches to sample per epoch."},
198
+ )
199
+ train_batch_size: int = field(
200
+ default=1,
201
+ metadata={"help": "Batch size (per GPU) to use for training. Only support 1 now"},
202
+ )
203
+ lora_rank: int = field(
204
+ default=4,
205
+ metadata={"help": "Lora rank for training."},
206
+ )
207
+ lora_alpha: int = field(
208
+ default=4,
209
+ metadata={"help": "Lora alpha for training."},
210
+ )
211
+ train_use_8bit_adam: bool = field(
212
+ default=False,
213
+ metadata={"help": "Use 8bit Adam optimizer from bitsandbytes."},
214
+ )
215
+ train_learning_rate: float = field(
216
+ default=1e-4,
217
+ metadata={"help": "Learning rate."},
218
+ )
219
+ train_adam_beta1: float = field(
220
+ default=0.9,
221
+ metadata={"help": "Adam beta1."},
222
+ )
223
+ train_adam_beta2: float = field(
224
+ default=0.999,
225
+ metadata={"help": "Adam beta2."},
226
+ )
227
+ train_adam_weight_decay: float = field(
228
+ default=1e-4,
229
+ metadata={"help": "Adam weight decay."},
230
+ )
231
+ train_adam_epsilon: float = field(
232
+ default=1e-8,
233
+ metadata={"help": "Adam epsilon."},
234
+ )
235
+ train_gradient_accumulation_steps: int = field(
236
+ default=1,
237
+ metadata={"help": "Number of gradient accumulation steps."},
238
+ )
239
+ train_max_grad_norm: float = field(
240
+ default=1.0,
241
+ metadata={"help": "Maximum gradient norm for gradient clipping."},
242
+ )
243
+ train_num_inner_epochs: int = field(
244
+ default=1,
245
+ metadata={"help": "Number of inner epochs per outer epoch."},
246
+ )
247
+ train_cfg: bool = field(
248
+ default=False,
249
+ metadata={"help": "Whether to use classifier-free guidance during training."},
250
+ )
251
+ train_adv_clip_max: float = field(
252
+ default=5.0,
253
+ metadata={"help": "Clip advantages to the range."},
254
+ )
255
+ train_clip_range: float = field(
256
+ default=1e-4,
257
+ metadata={"help": "PPO clip range."},
258
+ )
259
+ train_timestep_fraction: float = field(
260
+ default=1.0,
261
+ metadata={"help": "Fraction of timesteps to train on."},
262
+ )
263
+ per_prompt_stat_tracking: bool = field(
264
+ default=False,
265
+ metadata={"help": "Whether to track statistics for each prompt separately."},
266
+ )
267
+ per_prompt_stat_tracking_buffer_size: int = field(
268
+ default=16,
269
+ metadata={"help": "Number of reward values to store in the buffer for each prompt."},
270
+ )
271
+ per_prompt_stat_tracking_min_count: int = field(
272
+ default=16,
273
+ metadata={"help": "Minimum number of reward values to store in the buffer."},
274
+ )
275
+ height: int = field(
276
+ default=512,
277
+ metadata={"help": "Height of gene image."},
278
+ )
279
+ width: int = field(
280
+ default=512,
281
+ metadata={"help": "Width of gene image."},
282
+ )
283
+ max_size: int = field(
284
+ default=512,
285
+ metadata={"help": "Max size of gene image."},
286
+ )
287
+ max_workers: int = field(
288
+ default=8,
289
+ metadata={"help": "Maximum number of workers to use for async reward computation."},
290
+ )
291
+ negative_prompts: str = field(
292
+ default="",
293
+ metadata={"help": "Comma-separated list of prompts to use as negative examples."},
294
+ )
295
+ push_to_hub: bool = field(
296
+ default=False,
297
+ metadata={"help": "Whether to push the final model checkpoint to the Hub."},
298
+ )
299
+
300
+ def to_dict(self):
301
+ output_dict = {}
302
+ for key, value in self.__dict__.items():
303
+ output_dict[key] = value
304
+ return flatten_dict(output_dict)
305
+
306
+ def __post_init__(self):
307
+ if self.train_use_8bit_adam and not is_bitsandbytes_available():
308
+ raise ImportError(
309
+ "You need to install bitsandbytes to use 8bit Adam. "
310
+ "You can install it with `pip install bitsandbytes`."
311
+ )
kontext/modeling_flux_base.py ADDED
@@ -0,0 +1,997 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import os
3
+ import random
4
+ import warnings
5
+ from dataclasses import dataclass
6
+ from typing import Any, Callable, Dict, List, Optional, Union
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.utils.checkpoint as checkpoint
11
+ from diffusers import FluxTransformer2DModel
12
+ from diffusers.image_processor import PipelineImageInput
13
+ from diffusers.pipelines.flux.pipeline_flux_kontext import PREFERRED_KONTEXT_RESOLUTIONS, calculate_shift, retrieve_timesteps
14
+ from scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
15
+ from pipeline_flux_kontext import FluxKontextPipeline
16
+ from transformers.utils import is_peft_available
17
+
18
+ from trl.core import randn_tensor
19
+ from trl.models.sd_utils import convert_state_dict_to_diffusers
20
+
21
+ if is_peft_available():
22
+ from peft import LoraConfig, get_peft_model
23
+ from peft.utils import get_peft_model_state_dict
24
+
25
+ PREFERRED_KONTEXT_RESOLUTIONS = [(x[0]//2,x[1]//2) for x in PREFERRED_KONTEXT_RESOLUTIONS]
26
+ @dataclass
27
+ class FluxPipelineOutput:
28
+ """
29
+ Output class for the diffusers pipeline to be finetuned with the DDPO trainer
30
+
31
+ Args:
32
+ images (`torch.Tensor`):
33
+ The generated images.
34
+ latents (`list[torch.Tensor]`):
35
+ The latents used to generate the images.
36
+ log_probs (`list[torch.Tensor]`):
37
+ The log probabilities of the latents.
38
+
39
+ """
40
+
41
+ images: torch.Tensor
42
+ latents: torch.Tensor
43
+ log_probs: torch.Tensor
44
+ latent_ids: torch.Tensor
45
+ timesteps: torch.Tensor
46
+ image_latents: torch.Tensor
47
+
48
+ class DDPOFluxPipeline:
49
+ """
50
+ Main class for the diffusers pipeline to be finetuned with the DDPO trainer
51
+ """
52
+
53
+ def __call__(self, *args, **kwargs) -> FluxPipelineOutput:
54
+ raise NotImplementedError
55
+
56
+ @property
57
+ def transformer(self):
58
+ """
59
+ Returns the 2d U-Net model used for diffusion.
60
+ """
61
+ raise NotImplementedError
62
+
63
+ @property
64
+ def vae(self):
65
+ """
66
+ Returns the Variational Autoencoder model used from mapping images to and from the latent space
67
+ """
68
+ raise NotImplementedError
69
+
70
+ @property
71
+ def tokenizer(self):
72
+ """
73
+ Returns the tokenizer used for tokenizing text inputs
74
+ """
75
+ raise NotImplementedError
76
+
77
+ @property
78
+ def tokenizer_2(self):
79
+ """
80
+ Returns the tokenizer used for tokenizing text inputs
81
+ """
82
+ raise NotImplementedError
83
+
84
+ @property
85
+ def scheduler(self):
86
+ """
87
+ Returns the scheduler associated with the pipeline used for the diffusion process
88
+ """
89
+ raise NotImplementedError
90
+
91
+ @property
92
+ def text_encoder(self):
93
+ """
94
+ Returns the text encoder used for encoding text inputs
95
+ """
96
+ raise NotImplementedError
97
+
98
+ @property
99
+ def text_encoder_2(self):
100
+ """
101
+ Returns the text encoder used for encoding text inputs
102
+ """
103
+ raise NotImplementedError
104
+
105
+ @property
106
+ def image_encoder(self):
107
+ """
108
+ Returns the text encoder used for encoding text inputs
109
+ """
110
+ raise NotImplementedError
111
+
112
+ @property
113
+ def feature_extractor(self):
114
+ """
115
+ Returns the text encoder used for encoding text inputs
116
+ """
117
+ raise NotImplementedError
118
+
119
+ @property
120
+ def autocast(self):
121
+ """
122
+ Returns the autocast context manager
123
+ """
124
+ raise NotImplementedError
125
+
126
+ def set_progress_bar_config(self, *args, **kwargs):
127
+ """
128
+ Sets the progress bar config for the pipeline
129
+ """
130
+ raise NotImplementedError
131
+
132
+ def save_pretrained(self, *args, **kwargs):
133
+ """
134
+ Saves all of the model weights
135
+ """
136
+ raise NotImplementedError
137
+
138
+ def save_checkpoint(self, *args, **kwargs):
139
+ """
140
+ Light wrapper around accelerate's register_save_state_pre_hook which is run before saving state
141
+ """
142
+ raise NotImplementedError
143
+
144
+ def load_checkpoint(self, *args, **kwargs):
145
+ """
146
+ Light wrapper around accelerate's register_lad_state_pre_hook which is run before loading state
147
+ """
148
+ raise NotImplementedError
149
+
150
+ @torch.no_grad()
151
+ def pipeline_step(
152
+ self,
153
+ image: Optional[PipelineImageInput] = None,
154
+ prompt: Union[str, List[str]] = None,
155
+ prompt_2: Optional[Union[str, List[str]]] = None,
156
+ negative_prompt: Union[str, List[str]] = None,
157
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
158
+ true_cfg_scale: float = 1.0,
159
+ height: Optional[int] = None,
160
+ width: Optional[int] = None,
161
+ num_inference_steps: int = 28,
162
+ sigmas: Optional[List[float]] = None,
163
+ guidance_scale: float = 3.5,
164
+ num_images_per_prompt: Optional[int] = 1,
165
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
166
+ latents: Optional[torch.FloatTensor] = None,
167
+ prompt_embeds: Optional[torch.FloatTensor] = None,
168
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
169
+ ip_adapter_image: Optional[PipelineImageInput] = None,
170
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
171
+ negative_ip_adapter_image: Optional[PipelineImageInput] = None,
172
+ negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
173
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
174
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
175
+ output_type: Optional[str] = "pil",
176
+ return_dict: bool = True,
177
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
178
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
179
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
180
+ max_sequence_length: int = 512,
181
+ max_area: int = 1024**2,
182
+ _auto_resize: bool = True,
183
+ ):
184
+ r"""
185
+ Function invoked when calling the pipeline for generation.
186
+
187
+ Args:
188
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
189
+ `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
190
+ numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
191
+ or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
192
+ list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
193
+ latents as `image`, but if passing latents directly it is not encoded again.
194
+ prompt (`str` or `List[str]`, *optional*):
195
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
196
+ instead.
197
+ prompt_2 (`str` or `List[str]`, *optional*):
198
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
199
+ will be used instead.
200
+ negative_prompt (`str` or `List[str]`, *optional*):
201
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
202
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
203
+ not greater than `1`).
204
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
205
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
206
+ `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
207
+ true_cfg_scale (`float`, *optional*, defaults to 1.0):
208
+ When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
209
+ height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
210
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
211
+ width (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
212
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
213
+ num_inference_steps (`int`, *optional*, defaults to 50):
214
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
215
+ expense of slower inference.
216
+ sigmas (`List[float]`, *optional*):
217
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
218
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
219
+ will be used.
220
+ guidance_scale (`float`, *optional*, defaults to 3.5):
221
+ Guidance scale as defined in [Classifier-Free Diffusion
222
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
223
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
224
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
225
+ the text `prompt`, usually at the expense of lower image quality.
226
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
227
+ The number of images to generate per prompt.
228
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
229
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
230
+ to make generation deterministic.
231
+ latents (`torch.FloatTensor`, *optional*):
232
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
233
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
234
+ tensor will ge generated by sampling using the supplied random `generator`.
235
+ prompt_embeds (`torch.FloatTensor`, *optional*):
236
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
237
+ provided, text embeddings will be generated from `prompt` input argument.
238
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
239
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
240
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
241
+ ip_adapter_image: (`PipelineImageInput`, *optional*):
242
+ Optional image input to work with IP Adapters.
243
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
244
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
245
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
246
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
247
+ negative_ip_adapter_image:
248
+ (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
249
+ negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
250
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
251
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
252
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
253
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
254
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
255
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
256
+ argument.
257
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
258
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
259
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
260
+ input argument.
261
+ output_type (`str`, *optional*, defaults to `"pil"`):
262
+ The output format of the generate image. Choose between
263
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
264
+ return_dict (`bool`, *optional*, defaults to `True`):
265
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
266
+ joint_attention_kwargs (`dict`, *optional*):
267
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
268
+ `self.processor` in
269
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
270
+ callback_on_step_end (`Callable`, *optional*):
271
+ A function that calls at the end of each denoising steps during the inference. The function is called
272
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
273
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
274
+ `callback_on_step_end_tensor_inputs`.
275
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
276
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
277
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
278
+ `._callback_tensor_inputs` attribute of your pipeline class.
279
+ max_sequence_length (`int` defaults to 512):
280
+ Maximum sequence length to use with the `prompt`.
281
+ max_area (`int`, defaults to `1024 ** 2`):
282
+ The maximum area of the generated image in pixels. The height and width will be adjusted to fit this
283
+ area while maintaining the aspect ratio.
284
+
285
+ Examples:
286
+
287
+ Returns:
288
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
289
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
290
+ images.
291
+ """
292
+
293
+ height = height or self.default_sample_size * self.vae_scale_factor
294
+ width = width or self.default_sample_size * self.vae_scale_factor
295
+
296
+ original_height, original_width = height, width
297
+ aspect_ratio = width / height
298
+ width = round((max_area * aspect_ratio) ** 0.5)
299
+ height = round((max_area / aspect_ratio) ** 0.5)
300
+
301
+ multiple_of = self.vae_scale_factor * 2
302
+ width = width // multiple_of * multiple_of
303
+ height = height // multiple_of * multiple_of
304
+
305
+ if height != original_height or width != original_width:
306
+ logger.warning(
307
+ f"Generation `height` and `width` have been adjusted to {height} and {width} to fit the model requirements."
308
+ )
309
+
310
+ # 1. Check inputs. Raise error if not correct
311
+ self.check_inputs(
312
+ prompt,
313
+ prompt_2,
314
+ height,
315
+ width,
316
+ negative_prompt=negative_prompt,
317
+ negative_prompt_2=negative_prompt_2,
318
+ prompt_embeds=prompt_embeds,
319
+ negative_prompt_embeds=negative_prompt_embeds,
320
+ pooled_prompt_embeds=pooled_prompt_embeds,
321
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
322
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
323
+ max_sequence_length=max_sequence_length,
324
+ )
325
+
326
+ self._guidance_scale = guidance_scale
327
+ self._joint_attention_kwargs = joint_attention_kwargs
328
+ self._current_timestep = None
329
+ self._interrupt = False
330
+
331
+ # 2. Define call parameters
332
+ if prompt is not None and isinstance(prompt, str):
333
+ batch_size = 1
334
+ elif prompt is not None and isinstance(prompt, list):
335
+ batch_size = len(prompt)
336
+ else:
337
+ batch_size = prompt_embeds.shape[0]
338
+
339
+ device = self._execution_device
340
+
341
+ lora_scale = (
342
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
343
+ )
344
+ has_neg_prompt = negative_prompt is not None or (
345
+ negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
346
+ )
347
+ do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
348
+ (
349
+ prompt_embeds,
350
+ pooled_prompt_embeds,
351
+ text_ids,
352
+ ) = self.encode_prompt(
353
+ prompt=prompt,
354
+ prompt_2=prompt_2,
355
+ prompt_embeds=prompt_embeds,
356
+ pooled_prompt_embeds=pooled_prompt_embeds,
357
+ device=device,
358
+ num_images_per_prompt=num_images_per_prompt,
359
+ max_sequence_length=max_sequence_length,
360
+ lora_scale=lora_scale,
361
+ )
362
+ if do_true_cfg:
363
+ (
364
+ negative_prompt_embeds,
365
+ negative_pooled_prompt_embeds,
366
+ negative_text_ids,
367
+ ) = self.encode_prompt(
368
+ prompt=negative_prompt,
369
+ prompt_2=negative_prompt_2,
370
+ prompt_embeds=negative_prompt_embeds,
371
+ pooled_prompt_embeds=negative_pooled_prompt_embeds,
372
+ device=device,
373
+ num_images_per_prompt=num_images_per_prompt,
374
+ max_sequence_length=max_sequence_length,
375
+ lora_scale=lora_scale,
376
+ )
377
+
378
+ # 3. Preprocess image
379
+ if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
380
+ imgs = image if isinstance(image, list) else [image]
381
+
382
+ images = []
383
+ for img in imgs:
384
+ img_0 = img[0] if isinstance(img, list) else img
385
+ image_height, image_width = self.image_processor.get_default_height_width(img_0)
386
+ aspect_ratio = image_width / image_height
387
+
388
+ if _auto_resize:
389
+ _, image_width, image_height = min(
390
+ (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS
391
+ )
392
+
393
+ image_width = image_width // multiple_of * multiple_of
394
+ image_height = image_height // multiple_of * multiple_of
395
+
396
+ resized = self.image_processor.resize(img, image_height, image_width)
397
+ print(image_height, image_width)
398
+ processed = self.image_processor.preprocess(resized, image_height, image_width)
399
+ images.append(processed)
400
+ # 4. Prepare latent variables
401
+ num_channels_latents = self.transformer.config.in_channels // 4
402
+ latents, image_latents, latent_ids, image_ids = self.prepare_latents(
403
+ images,
404
+ batch_size * num_images_per_prompt,
405
+ num_channels_latents,
406
+ height,
407
+ width,
408
+ prompt_embeds.dtype,
409
+ device,
410
+ generator,
411
+ latents,
412
+ )
413
+ if image_ids is not None:
414
+ latent_ids = torch.cat([latent_ids, image_ids], dim=0) # dim 0 is sequence dimension
415
+ # 5. Prepare timesteps
416
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
417
+ image_seq_len = latents.shape[1]
418
+ mu = calculate_shift(
419
+ image_seq_len,
420
+ self.scheduler.config.get("base_image_seq_len", 256),
421
+ self.scheduler.config.get("max_image_seq_len", 4096),
422
+ self.scheduler.config.get("base_shift", 0.5),
423
+ self.scheduler.config.get("max_shift", 1.15),
424
+ )
425
+ timesteps, num_inference_steps = retrieve_timesteps(
426
+ self.scheduler,
427
+ num_inference_steps,
428
+ device,
429
+ sigmas=sigmas,
430
+ mu=mu,
431
+ )
432
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
433
+ self._num_timesteps = len(timesteps)
434
+
435
+ # handle guidance
436
+ if self.transformer.config.guidance_embeds:
437
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
438
+ guidance = guidance.expand(latents.shape[0])
439
+ else:
440
+ guidance = None
441
+
442
+ if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
443
+ negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
444
+ ):
445
+ negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
446
+ negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
447
+
448
+ elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
449
+ negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
450
+ ):
451
+ ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
452
+ ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
453
+
454
+ if self.joint_attention_kwargs is None:
455
+ self._joint_attention_kwargs = {}
456
+
457
+ image_embeds = None
458
+ negative_image_embeds = None
459
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
460
+ image_embeds = self.prepare_ip_adapter_image_embeds(
461
+ ip_adapter_image,
462
+ ip_adapter_image_embeds,
463
+ device,
464
+ batch_size * num_images_per_prompt,
465
+ )
466
+ if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
467
+ negative_image_embeds = self.prepare_ip_adapter_image_embeds(
468
+ negative_ip_adapter_image,
469
+ negative_ip_adapter_image_embeds,
470
+ device,
471
+ batch_size * num_images_per_prompt,
472
+ )
473
+
474
+ # 6. Denoising loop
475
+ # We set the index here to remove DtoH sync, helpful especially during compilation.
476
+ # Check out more details here: https://github.com/huggingface/diffusers/pull/11696
477
+ all_latents = [latents]
478
+ all_log_probs = []
479
+ all_timesteps = []
480
+ self.scheduler.set_begin_index(0)
481
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
482
+ for i, t in enumerate(timesteps):
483
+ if self.interrupt:
484
+ continue
485
+
486
+ self._current_timestep = t
487
+ if image_embeds is not None:
488
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
489
+
490
+ latent_model_input = latents
491
+ latent_model_input = latent_model_input.to(self.transformer.device)
492
+ if image_latents is not None:
493
+ latent_model_input = torch.cat([latents, image_latents], dim=1)
494
+ timestep = t.expand(latents.shape[0]).to(torch.float32)
495
+ noise_pred = self.transformer(
496
+ hidden_states=latent_model_input,
497
+ timestep=timestep / 1000,
498
+ guidance=guidance,
499
+ pooled_projections=pooled_prompt_embeds,
500
+ encoder_hidden_states=prompt_embeds,
501
+ txt_ids=text_ids,
502
+ img_ids=latent_ids,
503
+ joint_attention_kwargs=self.joint_attention_kwargs,
504
+ return_dict=False,
505
+ )[0]
506
+ noise_pred = noise_pred[:, : latents.size(1)]
507
+
508
+ if do_true_cfg:
509
+ if negative_image_embeds is not None:
510
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
511
+ neg_noise_pred = self.transformer(
512
+ hidden_states=latent_model_input,
513
+ timestep=timestep / 1000,
514
+ guidance=guidance,
515
+ pooled_projections=negative_pooled_prompt_embeds,
516
+ encoder_hidden_states=negative_prompt_embeds,
517
+ txt_ids=negative_text_ids,
518
+ img_ids=latent_ids,
519
+ joint_attention_kwargs=self.joint_attention_kwargs,
520
+ return_dict=False,
521
+ )[0]
522
+ neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
523
+ noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
524
+
525
+ # compute the previous noisy sample x_t -> x_t-1
526
+ latents_dtype = latents.dtype
527
+ scheduler_output = self.scheduler.step(noise_pred, t, latents, return_dict=True)
528
+ latents = scheduler_output.latents
529
+ log_probs = scheduler_output.log_probs
530
+ all_latents.append(latents)
531
+ all_log_probs.append(log_probs)
532
+
533
+ all_timesteps.append(timestep)
534
+
535
+ if latents.dtype != latents_dtype:
536
+ latents = latents.to(latents_dtype)
537
+
538
+ if callback_on_step_end is not None:
539
+ callback_kwargs = {}
540
+ for k in callback_on_step_end_tensor_inputs:
541
+ callback_kwargs[k] = locals()[k]
542
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
543
+
544
+ latents = callback_outputs.pop("latents", latents)
545
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
546
+
547
+ # call the callback, if provided
548
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
549
+ progress_bar.update()
550
+
551
+ self._current_timestep = None
552
+ if output_type == "latent":
553
+ image = latents
554
+ else:
555
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
556
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
557
+ image = self.vae.decode(latents, return_dict=False)[0]
558
+ image = self.image_processor.postprocess(image, output_type=output_type)
559
+
560
+ # Offload all models
561
+ self.maybe_free_model_hooks()
562
+
563
+ if not return_dict:
564
+ return (image,)
565
+
566
+ return FluxPipelineOutput(image, all_latents, all_log_probs, latent_ids, all_timesteps, image_latents)
567
+
568
+ def pipeline_step_with_grad(
569
+ pipeline,
570
+ image: Optional[PipelineImageInput] = None,
571
+ prompt: Union[str, List[str]] = None,
572
+ prompt_2: Optional[Union[str, List[str]]] = None,
573
+ negative_prompt: Union[str, List[str]] = None,
574
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
575
+ true_cfg_scale: float = 1.0,
576
+ height: Optional[int] = None,
577
+ width: Optional[int] = None,
578
+ num_inference_steps: int = 28,
579
+ sigmas: Optional[List[float]] = None,
580
+ guidance_scale: float = 3.5,
581
+ truncated_backprop: bool = True,
582
+ truncated_backprop_rand: bool = True,
583
+ gradient_checkpoint: bool = True,
584
+ truncated_backprop_timestep: int = 49,
585
+ truncated_rand_backprop_minmax: tuple = (0, 50),
586
+ num_images_per_prompt: Optional[int] = 1,
587
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
588
+ latents: Optional[torch.FloatTensor] = None,
589
+ prompt_embeds: Optional[torch.FloatTensor] = None,
590
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
591
+ ip_adapter_image: Optional[PipelineImageInput] = None,
592
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
593
+ negative_ip_adapter_image: Optional[PipelineImageInput] = None,
594
+ negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
595
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
596
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
597
+ output_type: Optional[str] = "pil",
598
+ return_dict: bool = True,
599
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
600
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
601
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
602
+ max_sequence_length: int = 512,
603
+ max_area: int = 512**2,
604
+ _auto_resize: bool = True,
605
+ ):
606
+ height = height or pipeline.default_sample_size * pipeline.vae_scale_factor
607
+ width = width or pipeline.default_sample_size * pipeline.vae_scale_factor
608
+
609
+ original_height, original_width = height, width
610
+ aspect_ratio = width / height
611
+ width = round((max_area * aspect_ratio) ** 0.5)
612
+ height = round((max_area / aspect_ratio) ** 0.5)
613
+
614
+ multiple_of = pipeline.vae_scale_factor * 2
615
+ width = width // multiple_of * multiple_of
616
+ height = height // multiple_of * multiple_of
617
+
618
+ if height != original_height or width != original_width:
619
+ logger.warning(
620
+ f"Generation `height` and `width` have been adjusted to {height} and {width} to fit the model requirements."
621
+ )
622
+
623
+ # 1. Check inputs. Raise error if not correct
624
+ pipeline.check_inputs(
625
+ prompt,
626
+ prompt_2,
627
+ height,
628
+ width,
629
+ negative_prompt=negative_prompt,
630
+ negative_prompt_2=negative_prompt_2,
631
+ prompt_embeds=prompt_embeds,
632
+ negative_prompt_embeds=negative_prompt_embeds,
633
+ pooled_prompt_embeds=pooled_prompt_embeds,
634
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
635
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
636
+ max_sequence_length=max_sequence_length,
637
+ )
638
+
639
+ pipeline._guidance_scale = guidance_scale
640
+ pipeline._joint_attention_kwargs = joint_attention_kwargs
641
+ pipeline._current_timestep = None
642
+ pipeline._interrupt = False
643
+
644
+ # 2. Define call parameters
645
+ if prompt is not None and isinstance(prompt, str):
646
+ batch_size = 1
647
+ elif prompt is not None and isinstance(prompt, list):
648
+ batch_size = len(prompt)
649
+ else:
650
+ batch_size = prompt_embeds.shape[0]
651
+
652
+ device = pipeline._execution_device
653
+
654
+ lora_scale = (
655
+ pipeline.joint_attention_kwargs.get("scale", None) if pipeline.joint_attention_kwargs is not None else None
656
+ )
657
+ has_neg_prompt = negative_prompt is not None or (
658
+ negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
659
+ )
660
+ do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
661
+ (
662
+ prompt_embeds,
663
+ pooled_prompt_embeds,
664
+ text_ids,
665
+ ) = pipeline.encode_prompt(
666
+ prompt=prompt,
667
+ prompt_2=prompt_2,
668
+ prompt_embeds=prompt_embeds,
669
+ pooled_prompt_embeds=pooled_prompt_embeds,
670
+ device=device,
671
+ num_images_per_prompt=num_images_per_prompt,
672
+ max_sequence_length=max_sequence_length,
673
+ lora_scale=lora_scale,
674
+ )
675
+ if do_true_cfg:
676
+ (
677
+ negative_prompt_embeds,
678
+ negative_pooled_prompt_embeds,
679
+ negative_text_ids,
680
+ ) = pipeline.encode_prompt(
681
+ prompt=negative_prompt,
682
+ prompt_2=negative_prompt_2,
683
+ prompt_embeds=negative_prompt_embeds,
684
+ pooled_prompt_embeds=negative_pooled_prompt_embeds,
685
+ device=device,
686
+ num_images_per_prompt=num_images_per_prompt,
687
+ max_sequence_length=max_sequence_length,
688
+ lora_scale=lora_scale,
689
+ )
690
+
691
+ # 3. Preprocess image
692
+ # if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == pipeline.latent_channels):
693
+ # img = image[0] if isinstance(image, list) else image
694
+ # image_height, image_width = pipeline.image_processor.get_default_height_width(img)
695
+ # aspect_ratio = image_width / image_height
696
+ # if _auto_resize:
697
+ # # Kontext is trained on specific resolutions, using one of them is recommended
698
+ # _, image_width, image_height = min(
699
+ # (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS
700
+ # )
701
+ # image_width = image_width // multiple_of * multiple_of
702
+ # image_height = image_height // multiple_of * multiple_of
703
+ # image = pipeline.image_processor.resize(image, image_height, image_width)
704
+ # image = pipeline.image_processor.preprocess(image, image_height, image_width)
705
+
706
+ if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == pipeline.latent_channels):
707
+ imgs = image if isinstance(image, list) else [image]
708
+
709
+ images = []
710
+ for img in imgs:
711
+ img_0 = img[0] if isinstance(img, list) else img
712
+ image_height, image_width = pipeline.image_processor.get_default_height_width(img_0)
713
+ aspect_ratio = image_width / image_height
714
+
715
+ if _auto_resize:
716
+ _, image_width, image_height = min(
717
+ (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS
718
+ )
719
+
720
+ image_width = image_width // multiple_of * multiple_of
721
+ image_height = image_height // multiple_of * multiple_of
722
+
723
+ resized = pipeline.image_processor.resize(img, image_height, image_width)
724
+ processed = pipeline.image_processor.preprocess(resized, image_height, image_width)
725
+ images.append(processed)
726
+
727
+ # 4. Prepare latent variables
728
+ # num_channels_latents = pipeline.transformer.module.config.in_channels // 4
729
+ num_channels_latents = pipeline.transformer.config.in_channels // 4
730
+ latents, image_latents, latent_ids, image_ids = pipeline.prepare_latents(
731
+ images,
732
+ batch_size * num_images_per_prompt,
733
+ num_channels_latents,
734
+ height,
735
+ width,
736
+ prompt_embeds.dtype,
737
+ device,
738
+ generator,
739
+ latents,
740
+ )
741
+ if image_ids is not None:
742
+ latent_ids = torch.cat([latent_ids, image_ids], dim=0) # dim 0 is sequence dimension
743
+
744
+ # 5. Prepare timesteps
745
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
746
+ image_seq_len = latents.shape[1]
747
+ mu = calculate_shift(
748
+ image_seq_len,
749
+ pipeline.scheduler.config.get("base_image_seq_len", 256),
750
+ pipeline.scheduler.config.get("max_image_seq_len", 4096),
751
+ pipeline.scheduler.config.get("base_shift", 0.5),
752
+ pipeline.scheduler.config.get("max_shift", 1.15),
753
+ )
754
+ timesteps, num_inference_steps = retrieve_timesteps(
755
+ pipeline.scheduler,
756
+ num_inference_steps,
757
+ device,
758
+ sigmas=sigmas,
759
+ mu=mu,
760
+ )
761
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * pipeline.scheduler.order, 0)
762
+ pipeline._num_timesteps = len(timesteps)
763
+
764
+ # handle guidance
765
+ if pipeline.transformer.config.guidance_embeds:
766
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
767
+ guidance = guidance.expand(latents.shape[0])
768
+ else:
769
+ guidance = None
770
+
771
+ if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
772
+ negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
773
+ ):
774
+ negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
775
+ negative_ip_adapter_image = [negative_ip_adapter_image] * pipeline.transformer.encoder_hid_proj.num_ip_adapters
776
+
777
+ elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
778
+ negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
779
+ ):
780
+ ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
781
+ ip_adapter_image = [ip_adapter_image] * pipeline.transformer.encoder_hid_proj.num_ip_adapters
782
+
783
+ if pipeline.joint_attention_kwargs is None:
784
+ pipeline._joint_attention_kwargs = {}
785
+
786
+ image_embeds = None
787
+ negative_image_embeds = None
788
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
789
+ image_embeds = pipeline.prepare_ip_adapter_image_embeds(
790
+ ip_adapter_image,
791
+ ip_adapter_image_embeds,
792
+ device,
793
+ batch_size * num_images_per_prompt,
794
+ )
795
+ if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
796
+ negative_image_embeds = pipeline.prepare_ip_adapter_image_embeds(
797
+ negative_ip_adapter_image,
798
+ negative_ip_adapter_image_embeds,
799
+ device,
800
+ batch_size * num_images_per_prompt,
801
+ )
802
+ all_latents = [latents]
803
+ all_log_probs = []
804
+ all_timesteps = []
805
+ pipeline.scheduler.set_begin_index(0)
806
+ with pipeline.progress_bar(total=num_inference_steps) as progress_bar:
807
+ for i, t in enumerate(timesteps):
808
+ if pipeline.interrupt:
809
+ continue
810
+
811
+ pipeline._current_timestep = t
812
+ if image_embeds is not None:
813
+ pipeline._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
814
+
815
+ latent_model_input = latents
816
+ if image_latents is not None:
817
+ latent_model_input = torch.cat([latents, image_latents], dim=1)
818
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
819
+
820
+
821
+ if gradient_checkpoint:
822
+ noise_pred = checkpoint.checkpoint(
823
+ pipeline.transformer,
824
+ hidden_states=latent_model_input,
825
+ timestep=timestep / 1000,
826
+ guidance=guidance,
827
+ pooled_projections=pooled_prompt_embeds,
828
+ encoder_hidden_states=prompt_embeds,
829
+ txt_ids=text_ids,
830
+ img_ids=latent_ids,
831
+ joint_attention_kwargs=pipeline.joint_attention_kwargs,
832
+ return_dict=False,
833
+ )[0]
834
+ else:
835
+ noise_pred = pipeline.transformer(
836
+ hidden_states=latent_model_input,
837
+ timestep=timestep / 1000,
838
+ guidance=guidance,
839
+ pooled_projections=pooled_prompt_embeds,
840
+ encoder_hidden_states=prompt_embeds,
841
+ txt_ids=text_ids,
842
+ img_ids=latent_ids,
843
+ joint_attention_kwargs=pipeline.joint_attention_kwargs,
844
+ return_dict=False,
845
+ )[0]
846
+ noise_pred = noise_pred[:, : latents.size(1)]
847
+
848
+ if truncated_backprop:
849
+ # Randomized truncation randomizes the truncation process (https://huggingface.co/papers/2310.03739)
850
+ # the range of truncation is defined by truncated_rand_backprop_minmax
851
+ # Setting truncated_rand_backprop_minmax[0] to be low will allow the model to update earlier timesteps in the diffusion chain, while setitng it high will reduce the memory usage.
852
+ if truncated_backprop_rand:
853
+ rand_timestep = random.randint(
854
+ truncated_rand_backprop_minmax[0], truncated_rand_backprop_minmax[1]
855
+ )
856
+ if i < rand_timestep:
857
+ noise_pred = noise_pred.detach()
858
+ else:
859
+ # fixed truncation process
860
+ if i < truncated_backprop_timestep:
861
+ noise_pred = noise_pred.detach()
862
+ if do_true_cfg:
863
+ if negative_image_embeds is not None:
864
+ pipeline._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
865
+ neg_noise_pred = pipeline.transformer(
866
+ hidden_states=latent_model_input,
867
+ timestep=timestep / 1000,
868
+ guidance=guidance,
869
+ pooled_projections=negative_pooled_prompt_embeds,
870
+ encoder_hidden_states=negative_prompt_embeds,
871
+ txt_ids=negative_text_ids,
872
+ img_ids=latent_ids,
873
+ joint_attention_kwargs=pipeline.joint_attention_kwargs,
874
+ return_dict=False,
875
+ )[0]
876
+ neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
877
+ noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
878
+
879
+ # compute the previous noisy sample x_t -> x_t-1
880
+ latents_dtype = latents.dtype
881
+ scheduler_output = pipeline.scheduler.step(noise_pred, t, latents, return_dict=True)
882
+ latents = scheduler_output.latents
883
+ log_probs = scheduler_output.log_probs
884
+
885
+ all_latents.append(latents)
886
+ all_log_probs.append(log_probs)
887
+ all_timesteps.append(timestep)
888
+
889
+ if latents.dtype != latents_dtype:
890
+ if torch.backends.mps.is_available():
891
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
892
+ latents = latents.to(latents_dtype)
893
+
894
+ if callback_on_step_end is not None:
895
+ callback_kwargs = {}
896
+ for k in callback_on_step_end_tensor_inputs:
897
+ callback_kwargs[k] = locals()[k]
898
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
899
+
900
+ latents = callback_outputs.pop("latents", latents)
901
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
902
+
903
+ # call the callback, if provided
904
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0):
905
+ progress_bar.update()
906
+
907
+
908
+
909
+ pipeline._current_timestep = None
910
+
911
+ if output_type == "latent":
912
+ image = latents
913
+ else:
914
+ latents = pipeline._unpack_latents(latents, height, width, pipeline.vae_scale_factor)
915
+ latents = (latents / pipeline.vae.config.scaling_factor) + pipeline.vae.config.shift_factor
916
+ image = pipeline.vae.decode(latents, return_dict=False)[0]
917
+ image = pipeline.image_processor.postprocess(image, output_type=output_type)
918
+
919
+ # Offload all models
920
+ pipeline.maybe_free_model_hooks()
921
+
922
+ if not return_dict:
923
+ return (image,)
924
+
925
+ return FluxPipelineOutput(image, all_latents, all_log_probs, latent_ids, all_timesteps, image_latents)
926
+
927
+ class DefaultDDPOFluxPipeline(DDPOFluxPipeline):
928
+ def __init__(self, pretrained_model_name: str, finetuned_model_path: str=''):
929
+ self.flux_pipeline = FluxKontextPipeline.from_pretrained(
930
+ pretrained_model_name
931
+ )
932
+
933
+ self.pretrained_model = pretrained_model_name
934
+ self.flux_pipeline.scheduler = FlowMatchEulerDiscreteScheduler.from_config(self.flux_pipeline.scheduler.config)
935
+ self.flux_pipeline.scheduler.config.stochastic_sampling = True
936
+
937
+ # memory optimization
938
+ self.flux_pipeline.vae.requires_grad_(False)
939
+ self.flux_pipeline.text_encoder.requires_grad_(False)
940
+ self.flux_pipeline.text_encoder_2.requires_grad_(False)
941
+ self.flux_pipeline.transformer.requires_grad_(False)
942
+ if finetuned_model_path:
943
+ print(f"load finetuned model from {finetuned_model_path}")
944
+ self.flux_pipeline.transformer = FluxTransformer2DModel.from_single_file(finetuned_model_path, torch_dtype="bfloat16")
945
+
946
+ def __call__(self, *args, **kwargs) -> FluxPipelineOutput:
947
+ return pipeline_step(self.flux_pipeline, *args, **kwargs)
948
+
949
+ def rgb_with_grad(self, *args, **kwargs) -> FluxPipelineOutput:
950
+ return pipeline_step_with_grad(self.flux_pipeline, *args, **kwargs)
951
+
952
+ @property
953
+ def transformer(self):
954
+ return self.flux_pipeline.transformer
955
+
956
+ @property
957
+ def vae(self):
958
+ return self.flux_pipeline.vae
959
+
960
+ @property
961
+ def tokenizer(self):
962
+ return self.flux_pipeline.tokenizer
963
+
964
+ @property
965
+ def tokenizer_2(self):
966
+ return self.flux_pipeline.tokenizer_2
967
+
968
+ @property
969
+ def scheduler(self):
970
+ return self.flux_pipeline.scheduler
971
+
972
+ @property
973
+ def text_encoder(self):
974
+ return self.flux_pipeline.text_encoder
975
+
976
+ @property
977
+ def text_encoder_2(self):
978
+ return self.flux_pipeline.text_encoder_2
979
+
980
+ @property
981
+ def image_encoder(self):
982
+ return self.flux_pipeline.image_encoder
983
+
984
+ @property
985
+ def feature_extractor(self):
986
+ return self.flux_pipeline.feature_extractor
987
+
988
+ @property
989
+ def autocast(self):
990
+ return contextlib.nullcontext
991
+
992
+ def save_pretrained(self, output_dir):
993
+ state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(self.flux_pipeline.transformer))
994
+ self.flux_pipeline.transformer.save_pretrained(output_dir)
995
+ def set_progress_bar_config(self, *args, **kwargs):
996
+ self.flux_pipeline.set_progress_bar_config(*args, **kwargs)
997
+
kontext/pipeline_flux_kontext.py ADDED
@@ -0,0 +1,1189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Union
17
+ from dataclasses import dataclass
18
+ import numpy as np
19
+ import torch
20
+ from transformers import (
21
+ CLIPImageProcessor,
22
+ CLIPTextModel,
23
+ CLIPTokenizer,
24
+ CLIPVisionModelWithProjection,
25
+ T5EncoderModel,
26
+ T5TokenizerFast,
27
+ )
28
+
29
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
30
+ from diffusers.loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
31
+ from diffusers.models import AutoencoderKL, FluxTransformer2DModel
32
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
33
+ from diffusers.utils import (
34
+ USE_PEFT_BACKEND,
35
+ is_torch_xla_available,
36
+ logging,
37
+ replace_example_docstring,
38
+ scale_lora_layers,
39
+ unscale_lora_layers,
40
+ )
41
+ from diffusers.utils import BaseOutput
42
+ from diffusers.utils.torch_utils import randn_tensor
43
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
44
+
45
+
46
+ if is_torch_xla_available():
47
+ import torch_xla.core.xla_model as xm
48
+
49
+ XLA_AVAILABLE = True
50
+ else:
51
+ XLA_AVAILABLE = False
52
+
53
+ @dataclass
54
+ class FluxPipelineOutput:
55
+ """
56
+ Output class for the diffusers pipeline to be finetuned with the DDPO trainer
57
+
58
+ Args:
59
+ images (`torch.Tensor`):
60
+ The generated images.
61
+ latents (`list[torch.Tensor]`):
62
+ The latents used to generate the images.
63
+ log_probs (`list[torch.Tensor]`):
64
+ The log probabilities of the latents.
65
+
66
+ """
67
+
68
+ images: torch.Tensor
69
+ latents: torch.Tensor
70
+ log_probs: torch.Tensor
71
+ latent_ids: torch.Tensor
72
+ timesteps: torch.Tensor
73
+ image_latents: torch.Tensor
74
+
75
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
76
+
77
+ EXAMPLE_DOC_STRING = """
78
+ Examples:
79
+ ```py
80
+ >>> import torch
81
+ >>> from diffusers import FluxKontextPipeline
82
+ >>> from diffusers.utils import load_image
83
+
84
+ >>> pipe = FluxKontextPipeline.from_pretrained(
85
+ ... "black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16
86
+ ... )
87
+ >>> pipe.to("cuda")
88
+
89
+ >>> image = load_image(
90
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png"
91
+ ... ).convert("RGB")
92
+ >>> prompt = "Make Pikachu hold a sign that says 'Black Forest Labs is awesome', yarn art style, detailed, vibrant colors"
93
+ >>> image = pipe(
94
+ ... image=image,
95
+ ... prompt=prompt,
96
+ ... guidance_scale=2.5,
97
+ ... generator=torch.Generator().manual_seed(42),
98
+ ... ).images[0]
99
+ >>> image.save("output.png")
100
+ ```
101
+ """
102
+
103
+ PREFERRED_KONTEXT_RESOLUTIONS = [
104
+ (672, 1568),
105
+ (688, 1504),
106
+ (720, 1456),
107
+ (752, 1392),
108
+ (800, 1328),
109
+ (832, 1248),
110
+ (880, 1184),
111
+ (944, 1104),
112
+ (1024, 1024),
113
+ (1104, 944),
114
+ (1184, 880),
115
+ (1248, 832),
116
+ (1328, 800),
117
+ (1392, 752),
118
+ (1456, 720),
119
+ (1504, 688),
120
+ (1568, 672),
121
+ ]
122
+ PREFERRED_KONTEXT_RESOLUTIONS = [(x[0]//2,x[1]//2) for x in PREFERRED_KONTEXT_RESOLUTIONS]
123
+
124
+
125
+ def calculate_shift(
126
+ image_seq_len,
127
+ base_seq_len: int = 256,
128
+ max_seq_len: int = 4096,
129
+ base_shift: float = 0.5,
130
+ max_shift: float = 1.15,
131
+ ):
132
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
133
+ b = base_shift - m * base_seq_len
134
+ mu = image_seq_len * m + b
135
+ return mu
136
+
137
+
138
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
139
+ def retrieve_timesteps(
140
+ scheduler,
141
+ num_inference_steps: Optional[int] = None,
142
+ device: Optional[Union[str, torch.device]] = None,
143
+ timesteps: Optional[List[int]] = None,
144
+ sigmas: Optional[List[float]] = None,
145
+ **kwargs,
146
+ ):
147
+ r"""
148
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
149
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
150
+
151
+ Args:
152
+ scheduler (`SchedulerMixin`):
153
+ The scheduler to get timesteps from.
154
+ num_inference_steps (`int`):
155
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
156
+ must be `None`.
157
+ device (`str` or `torch.device`, *optional*):
158
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
159
+ timesteps (`List[int]`, *optional*):
160
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
161
+ `num_inference_steps` and `sigmas` must be `None`.
162
+ sigmas (`List[float]`, *optional*):
163
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
164
+ `num_inference_steps` and `timesteps` must be `None`.
165
+
166
+ Returns:
167
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
168
+ second element is the number of inference steps.
169
+ """
170
+ if timesteps is not None and sigmas is not None:
171
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
172
+ if timesteps is not None:
173
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
174
+ if not accepts_timesteps:
175
+ raise ValueError(
176
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
177
+ f" timestep schedules. Please check whether you are using the correct scheduler."
178
+ )
179
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
180
+ timesteps = scheduler.timesteps
181
+ num_inference_steps = len(timesteps)
182
+ elif sigmas is not None:
183
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
184
+ if not accept_sigmas:
185
+ raise ValueError(
186
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
187
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
188
+ )
189
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
190
+ timesteps = scheduler.timesteps
191
+ num_inference_steps = len(timesteps)
192
+ else:
193
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
194
+ timesteps = scheduler.timesteps
195
+ return timesteps, num_inference_steps
196
+
197
+
198
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
199
+ def retrieve_latents(
200
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
201
+ ):
202
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
203
+ return encoder_output.latent_dist.sample(generator)
204
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
205
+ return encoder_output.latent_dist.mode()
206
+ elif hasattr(encoder_output, "latents"):
207
+ return encoder_output.latents
208
+ else:
209
+ raise AttributeError("Could not access latents of provided encoder_output")
210
+
211
+
212
+ class FluxKontextPipeline(
213
+ DiffusionPipeline,
214
+ FluxLoraLoaderMixin,
215
+ FromSingleFileMixin,
216
+ TextualInversionLoaderMixin,
217
+ FluxIPAdapterMixin,
218
+ ):
219
+ r"""
220
+ The Flux Kontext pipeline for text-to-image generation.
221
+
222
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
223
+
224
+ Args:
225
+ transformer ([`FluxTransformer2DModel`]):
226
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
227
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
228
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
229
+ vae ([`AutoencoderKL`]):
230
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
231
+ text_encoder ([`CLIPTextModel`]):
232
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
233
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
234
+ text_encoder_2 ([`T5EncoderModel`]):
235
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
236
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
237
+ tokenizer (`CLIPTokenizer`):
238
+ Tokenizer of class
239
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
240
+ tokenizer_2 (`T5TokenizerFast`):
241
+ Second Tokenizer of class
242
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
243
+ """
244
+
245
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae"
246
+ _optional_components = ["image_encoder", "feature_extractor"]
247
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
248
+
249
+ def __init__(
250
+ self,
251
+ scheduler: FlowMatchEulerDiscreteScheduler,
252
+ vae: AutoencoderKL,
253
+ text_encoder: CLIPTextModel,
254
+ tokenizer: CLIPTokenizer,
255
+ text_encoder_2: T5EncoderModel,
256
+ tokenizer_2: T5TokenizerFast,
257
+ transformer: FluxTransformer2DModel,
258
+ image_encoder: CLIPVisionModelWithProjection = None,
259
+ feature_extractor: CLIPImageProcessor = None,
260
+ ):
261
+ super().__init__()
262
+
263
+ self.register_modules(
264
+ vae=vae,
265
+ text_encoder=text_encoder,
266
+ text_encoder_2=text_encoder_2,
267
+ tokenizer=tokenizer,
268
+ tokenizer_2=tokenizer_2,
269
+ transformer=transformer,
270
+ scheduler=scheduler,
271
+ image_encoder=image_encoder,
272
+ feature_extractor=feature_extractor,
273
+ )
274
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
275
+ # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
276
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
277
+ self.latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
278
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
279
+ self.tokenizer_max_length = (
280
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
281
+ )
282
+ self.default_sample_size = 128
283
+
284
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
285
+ def _get_t5_prompt_embeds(
286
+ self,
287
+ prompt: Union[str, List[str]] = None,
288
+ num_images_per_prompt: int = 1,
289
+ max_sequence_length: int = 512,
290
+ device: Optional[torch.device] = None,
291
+ dtype: Optional[torch.dtype] = None,
292
+ ):
293
+ device = device or self._execution_device
294
+ dtype = dtype or self.text_encoder.dtype
295
+
296
+ prompt = [prompt] if isinstance(prompt, str) else prompt
297
+ batch_size = len(prompt)
298
+
299
+ if isinstance(self, TextualInversionLoaderMixin):
300
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
301
+
302
+ text_inputs = self.tokenizer_2(
303
+ prompt,
304
+ padding="max_length",
305
+ max_length=max_sequence_length,
306
+ truncation=True,
307
+ return_length=False,
308
+ return_overflowing_tokens=False,
309
+ return_tensors="pt",
310
+ )
311
+ text_input_ids = text_inputs.input_ids
312
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
313
+
314
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
315
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
316
+ logger.warning(
317
+ "The following part of your input was truncated because `max_sequence_length` is set to "
318
+ f" {max_sequence_length} tokens: {removed_text}"
319
+ )
320
+
321
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
322
+
323
+ dtype = self.text_encoder_2.dtype
324
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
325
+
326
+ _, seq_len, _ = prompt_embeds.shape
327
+
328
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
329
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
330
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
331
+
332
+ return prompt_embeds
333
+
334
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds
335
+ def _get_clip_prompt_embeds(
336
+ self,
337
+ prompt: Union[str, List[str]],
338
+ num_images_per_prompt: int = 1,
339
+ device: Optional[torch.device] = None,
340
+ ):
341
+ device = device or self._execution_device
342
+
343
+ prompt = [prompt] if isinstance(prompt, str) else prompt
344
+ batch_size = len(prompt)
345
+
346
+ if isinstance(self, TextualInversionLoaderMixin):
347
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
348
+
349
+ text_inputs = self.tokenizer(
350
+ prompt,
351
+ padding="max_length",
352
+ max_length=self.tokenizer_max_length,
353
+ truncation=True,
354
+ return_overflowing_tokens=False,
355
+ return_length=False,
356
+ return_tensors="pt",
357
+ )
358
+
359
+ text_input_ids = text_inputs.input_ids
360
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
361
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
362
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
363
+ # logger.warning(
364
+ # "The following part of your input was truncated because CLIP can only handle sequences up to"
365
+ # f" {self.tokenizer_max_length} tokens: {removed_text}"
366
+ # )
367
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
368
+
369
+ # Use pooled output of CLIPTextModel
370
+ prompt_embeds = prompt_embeds.pooler_output
371
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
372
+
373
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
374
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
375
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
376
+
377
+ return prompt_embeds
378
+
379
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt
380
+ def encode_prompt(
381
+ self,
382
+ prompt: Union[str, List[str]],
383
+ prompt_2: Union[str, List[str]],
384
+ device: Optional[torch.device] = None,
385
+ num_images_per_prompt: int = 1,
386
+ prompt_embeds: Optional[torch.FloatTensor] = None,
387
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
388
+ max_sequence_length: int = 512,
389
+ lora_scale: Optional[float] = None,
390
+ ):
391
+ r"""
392
+
393
+ Args:
394
+ prompt (`str` or `List[str]`, *optional*):
395
+ prompt to be encoded
396
+ prompt_2 (`str` or `List[str]`, *optional*):
397
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
398
+ used in all text-encoders
399
+ device: (`torch.device`):
400
+ torch device
401
+ num_images_per_prompt (`int`):
402
+ number of images that should be generated per prompt
403
+ prompt_embeds (`torch.FloatTensor`, *optional*):
404
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
405
+ provided, text embeddings will be generated from `prompt` input argument.
406
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
407
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
408
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
409
+ lora_scale (`float`, *optional*):
410
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
411
+ """
412
+ device = device or self._execution_device
413
+
414
+ # set lora scale so that monkey patched LoRA
415
+ # function of text encoder can correctly access it
416
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
417
+ self._lora_scale = lora_scale
418
+
419
+ # dynamically adjust the LoRA scale
420
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
421
+ scale_lora_layers(self.text_encoder, lora_scale)
422
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
423
+ scale_lora_layers(self.text_encoder_2, lora_scale)
424
+
425
+ prompt = [prompt] if isinstance(prompt, str) else prompt
426
+
427
+ if prompt_embeds is None:
428
+ prompt_2 = prompt_2 or prompt
429
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
430
+
431
+ # We only use the pooled prompt output from the CLIPTextModel
432
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
433
+ prompt=prompt,
434
+ device=device,
435
+ num_images_per_prompt=num_images_per_prompt,
436
+ )
437
+ prompt_embeds = self._get_t5_prompt_embeds(
438
+ prompt=prompt_2,
439
+ num_images_per_prompt=num_images_per_prompt,
440
+ max_sequence_length=max_sequence_length,
441
+ device=device,
442
+ )
443
+
444
+ if self.text_encoder is not None:
445
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
446
+ # Retrieve the original scale by scaling back the LoRA layers
447
+ unscale_lora_layers(self.text_encoder, lora_scale)
448
+
449
+ if self.text_encoder_2 is not None:
450
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
451
+ # Retrieve the original scale by scaling back the LoRA layers
452
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
453
+
454
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
455
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
456
+
457
+ return prompt_embeds, pooled_prompt_embeds, text_ids
458
+
459
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image
460
+ def encode_image(self, image, device, num_images_per_prompt):
461
+ dtype = next(self.image_encoder.parameters()).dtype
462
+
463
+ if not isinstance(image, torch.Tensor):
464
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
465
+
466
+ image = image.to(device=device, dtype=dtype)
467
+ image_embeds = self.image_encoder(image).image_embeds
468
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
469
+ return image_embeds
470
+
471
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_ip_adapter_image_embeds
472
+ def prepare_ip_adapter_image_embeds(
473
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
474
+ ):
475
+ image_embeds = []
476
+ if ip_adapter_image_embeds is None:
477
+ if not isinstance(ip_adapter_image, list):
478
+ ip_adapter_image = [ip_adapter_image]
479
+
480
+ if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters:
481
+ raise ValueError(
482
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
483
+ )
484
+
485
+ for single_ip_adapter_image in ip_adapter_image:
486
+ single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1)
487
+ image_embeds.append(single_image_embeds[None, :])
488
+ else:
489
+ if not isinstance(ip_adapter_image_embeds, list):
490
+ ip_adapter_image_embeds = [ip_adapter_image_embeds]
491
+
492
+ if len(ip_adapter_image_embeds) != self.transformer.encoder_hid_proj.num_ip_adapters:
493
+ raise ValueError(
494
+ f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
495
+ )
496
+
497
+ for single_image_embeds in ip_adapter_image_embeds:
498
+ image_embeds.append(single_image_embeds)
499
+
500
+ ip_adapter_image_embeds = []
501
+ for single_image_embeds in image_embeds:
502
+ single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
503
+ single_image_embeds = single_image_embeds.to(device=device)
504
+ ip_adapter_image_embeds.append(single_image_embeds)
505
+
506
+ return ip_adapter_image_embeds
507
+
508
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.check_inputs
509
+ def check_inputs(
510
+ self,
511
+ prompt,
512
+ prompt_2,
513
+ height,
514
+ width,
515
+ negative_prompt=None,
516
+ negative_prompt_2=None,
517
+ prompt_embeds=None,
518
+ negative_prompt_embeds=None,
519
+ pooled_prompt_embeds=None,
520
+ negative_pooled_prompt_embeds=None,
521
+ callback_on_step_end_tensor_inputs=None,
522
+ max_sequence_length=None,
523
+ ):
524
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
525
+ logger.warning(
526
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
527
+ )
528
+
529
+ if callback_on_step_end_tensor_inputs is not None and not all(
530
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
531
+ ):
532
+ raise ValueError(
533
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
534
+ )
535
+
536
+ if prompt is not None and prompt_embeds is not None:
537
+ raise ValueError(
538
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
539
+ " only forward one of the two."
540
+ )
541
+ elif prompt_2 is not None and prompt_embeds is not None:
542
+ raise ValueError(
543
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
544
+ " only forward one of the two."
545
+ )
546
+ elif prompt is None and prompt_embeds is None:
547
+ raise ValueError(
548
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
549
+ )
550
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
551
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
552
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
553
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
554
+
555
+ if negative_prompt is not None and negative_prompt_embeds is not None:
556
+ raise ValueError(
557
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
558
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
559
+ )
560
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
561
+ raise ValueError(
562
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
563
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
564
+ )
565
+
566
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
567
+ raise ValueError(
568
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
569
+ )
570
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
571
+ raise ValueError(
572
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
573
+ )
574
+
575
+ if max_sequence_length is not None and max_sequence_length > 512:
576
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
577
+
578
+ @staticmethod
579
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
580
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype, h_offset=0, w_offset=0):
581
+ latent_image_ids = torch.zeros(height, width, 3)
582
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + h_offset
583
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] + w_offset
584
+
585
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
586
+
587
+ latent_image_ids = latent_image_ids.reshape(
588
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
589
+ )
590
+
591
+ return latent_image_ids.to(device=device, dtype=dtype)
592
+
593
+ @staticmethod
594
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
595
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
596
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
597
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
598
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
599
+
600
+ return latents
601
+
602
+ @staticmethod
603
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
604
+ def _unpack_latents(latents, height, width, vae_scale_factor):
605
+ batch_size, num_patches, channels = latents.shape
606
+
607
+ # VAE applies 8x compression on images but we must also account for packing which requires
608
+ # latent height and width to be divisible by 2.
609
+ height = 2 * (int(height) // (vae_scale_factor * 2))
610
+ width = 2 * (int(width) // (vae_scale_factor * 2))
611
+
612
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
613
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
614
+
615
+ latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
616
+
617
+ return latents
618
+
619
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
620
+ if isinstance(generator, list):
621
+ image_latents = [
622
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="argmax")
623
+ for i in range(image.shape[0])
624
+ ]
625
+ image_latents = torch.cat(image_latents, dim=0)
626
+ else:
627
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax")
628
+
629
+ image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
630
+
631
+ return image_latents
632
+
633
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_slicing
634
+ def enable_vae_slicing(self):
635
+ r"""
636
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
637
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
638
+ """
639
+ self.vae.enable_slicing()
640
+
641
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_slicing
642
+ def disable_vae_slicing(self):
643
+ r"""
644
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
645
+ computing decoding in one step.
646
+ """
647
+ self.vae.disable_slicing()
648
+
649
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_tiling
650
+ def enable_vae_tiling(self):
651
+ r"""
652
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
653
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
654
+ processing larger images.
655
+ """
656
+ self.vae.enable_tiling()
657
+
658
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_tiling
659
+ def disable_vae_tiling(self):
660
+ r"""
661
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
662
+ computing decoding in one step.
663
+ """
664
+ self.vae.disable_tiling()
665
+
666
+ def prepare_latents(
667
+ self,
668
+ images: list[torch.Tensor],
669
+ batch_size: int,
670
+ num_channels_latents: int,
671
+ height: int,
672
+ width: int,
673
+ dtype: torch.dtype,
674
+ device: torch.device,
675
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
676
+ latents: Optional[torch.Tensor] = None,
677
+ ):
678
+ if isinstance(generator, list) and len(generator) != batch_size:
679
+ raise ValueError(
680
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
681
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
682
+ )
683
+
684
+ # VAE applies 8x compression on images but we must also account for packing which requires
685
+ # latent height and width to be divisible by 2.
686
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
687
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
688
+ shape = (batch_size, num_channels_latents, height, width)
689
+
690
+ image_latents = image_ids = None
691
+ ref_img_ids = []
692
+ ref_img_latents = []
693
+ # pe_shift_w, pe_shift_h = 0 , 0
694
+ pe_shift_w, pe_shift_h = width//2,height//2
695
+ for image in images:
696
+ if image is not None:
697
+ image = image.to(device=device, dtype=dtype)
698
+ if image.shape[1] != self.latent_channels:
699
+ image_latents = self._encode_vae_image(image=image, generator=generator)
700
+ else:
701
+ image_latents = image
702
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
703
+ # expand init_latents for batch_size
704
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
705
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
706
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
707
+ raise ValueError(
708
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
709
+ )
710
+ else:
711
+ image_latents = torch.cat([image_latents], dim=0)
712
+
713
+ image_latent_height, image_latent_width = image_latents.shape[2:]
714
+ image_latents = self._pack_latents(
715
+ image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width
716
+ )
717
+ image_ids = self._prepare_latent_image_ids(
718
+ batch_size, image_latent_height // 2, image_latent_width // 2, device, dtype, h_offset=pe_shift_h, w_offset=pe_shift_w
719
+ )
720
+ # image ids are the same as latent ids with the first dimension set to 1 instead of 0
721
+ image_ids[..., 0] = 1
722
+
723
+ pe_shift_h += image_latent_height // 2
724
+ pe_shift_w += image_latent_width // 2
725
+ ref_img_latents.append(image_latents)
726
+ ref_img_ids.append(image_ids)
727
+
728
+ latent_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
729
+
730
+ if latents is None:
731
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
732
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
733
+ else:
734
+ latents = latents.to(device=device, dtype=dtype)
735
+
736
+ if len(ref_img_latents) == 1:
737
+ image_latents = ref_img_latents[0]
738
+ image_ids = ref_img_ids[0]
739
+ else:
740
+ image_latents = torch.cat(ref_img_latents, dim = 1)
741
+ image_ids = torch.cat(ref_img_ids, dim = 0)
742
+
743
+ return latents, image_latents, latent_ids, image_ids
744
+
745
+ @property
746
+ def guidance_scale(self):
747
+ return self._guidance_scale
748
+
749
+ @property
750
+ def joint_attention_kwargs(self):
751
+ return self._joint_attention_kwargs
752
+
753
+ @property
754
+ def num_timesteps(self):
755
+ return self._num_timesteps
756
+
757
+ @property
758
+ def current_timestep(self):
759
+ return self._current_timestep
760
+
761
+ @property
762
+ def interrupt(self):
763
+ return self._interrupt
764
+
765
+ @torch.no_grad()
766
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
767
+ def __call__(
768
+ self,
769
+ image: Optional[PipelineImageInput] = None,
770
+ prompt: Union[str, List[str]] = None,
771
+ prompt_2: Optional[Union[str, List[str]]] = None,
772
+ negative_prompt: Union[str, List[str]] = None,
773
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
774
+ true_cfg_scale: float = 1.0,
775
+ height: Optional[int] = None,
776
+ width: Optional[int] = None,
777
+ num_inference_steps: int = 28,
778
+ sigmas: Optional[List[float]] = None,
779
+ guidance_scale: float = 3.5,
780
+ num_images_per_prompt: Optional[int] = 1,
781
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
782
+ latents: Optional[torch.FloatTensor] = None,
783
+ prompt_embeds: Optional[torch.FloatTensor] = None,
784
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
785
+ ip_adapter_image: Optional[PipelineImageInput] = None,
786
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
787
+ negative_ip_adapter_image: Optional[PipelineImageInput] = None,
788
+ negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
789
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
790
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
791
+ output_type: Optional[str] = "pil",
792
+ return_dict: bool = True,
793
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
794
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
795
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
796
+ max_sequence_length: int = 512,
797
+ max_area: int = 1024**2,
798
+ _auto_resize: bool = True,
799
+ ):
800
+ r"""
801
+ Function invoked when calling the pipeline for generation.
802
+
803
+ Args:
804
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
805
+ `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
806
+ numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
807
+ or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
808
+ list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
809
+ latents as `image`, but if passing latents directly it is not encoded again.
810
+ prompt (`str` or `List[str]`, *optional*):
811
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
812
+ instead.
813
+ prompt_2 (`str` or `List[str]`, *optional*):
814
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
815
+ will be used instead.
816
+ negative_prompt (`str` or `List[str]`, *optional*):
817
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
818
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
819
+ not greater than `1`).
820
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
821
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
822
+ `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
823
+ true_cfg_scale (`float`, *optional*, defaults to 1.0):
824
+ When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
825
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
826
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
827
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
828
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
829
+ num_inference_steps (`int`, *optional*, defaults to 50):
830
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
831
+ expense of slower inference.
832
+ sigmas (`List[float]`, *optional*):
833
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
834
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
835
+ will be used.
836
+ guidance_scale (`float`, *optional*, defaults to 3.5):
837
+ Guidance scale as defined in [Classifier-Free Diffusion
838
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
839
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
840
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
841
+ the text `prompt`, usually at the expense of lower image quality.
842
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
843
+ The number of images to generate per prompt.
844
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
845
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
846
+ to make generation deterministic.
847
+ latents (`torch.FloatTensor`, *optional*):
848
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
849
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
850
+ tensor will ge generated by sampling using the supplied random `generator`.
851
+ prompt_embeds (`torch.FloatTensor`, *optional*):
852
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
853
+ provided, text embeddings will be generated from `prompt` input argument.
854
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
855
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
856
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
857
+ ip_adapter_image: (`PipelineImageInput`, *optional*):
858
+ Optional image input to work with IP Adapters.
859
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
860
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
861
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
862
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
863
+ negative_ip_adapter_image:
864
+ (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
865
+ negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
866
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
867
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
868
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
869
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
870
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
871
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
872
+ argument.
873
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
874
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
875
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
876
+ input argument.
877
+ output_type (`str`, *optional*, defaults to `"pil"`):
878
+ The output format of the generate image. Choose between
879
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
880
+ return_dict (`bool`, *optional*, defaults to `True`):
881
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
882
+ joint_attention_kwargs (`dict`, *optional*):
883
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
884
+ `self.processor` in
885
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
886
+ callback_on_step_end (`Callable`, *optional*):
887
+ A function that calls at the end of each denoising steps during the inference. The function is called
888
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
889
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
890
+ `callback_on_step_end_tensor_inputs`.
891
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
892
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
893
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
894
+ `._callback_tensor_inputs` attribute of your pipeline class.
895
+ max_sequence_length (`int` defaults to 512):
896
+ Maximum sequence length to use with the `prompt`.
897
+ max_area (`int`, defaults to `1024 ** 2`):
898
+ The maximum area of the generated image in pixels. The height and width will be adjusted to fit this
899
+ area while maintaining the aspect ratio.
900
+
901
+ Examples:
902
+
903
+ Returns:
904
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
905
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
906
+ images.
907
+ """
908
+
909
+ height = height or self.default_sample_size * self.vae_scale_factor
910
+ width = width or self.default_sample_size * self.vae_scale_factor
911
+
912
+ original_height, original_width = height, width
913
+ aspect_ratio = width / height
914
+ width = round((max_area * aspect_ratio) ** 0.5)
915
+ height = round((max_area / aspect_ratio) ** 0.5)
916
+
917
+ multiple_of = self.vae_scale_factor * 2
918
+ width = width // multiple_of * multiple_of
919
+ height = height // multiple_of * multiple_of
920
+
921
+ if height != original_height or width != original_width:
922
+ logger.warning(
923
+ f"Generation `height` and `width` have been adjusted to {height} and {width} to fit the model requirements."
924
+ )
925
+
926
+ # 1. Check inputs. Raise error if not correct
927
+ self.check_inputs(
928
+ prompt,
929
+ prompt_2,
930
+ height,
931
+ width,
932
+ negative_prompt=negative_prompt,
933
+ negative_prompt_2=negative_prompt_2,
934
+ prompt_embeds=prompt_embeds,
935
+ negative_prompt_embeds=negative_prompt_embeds,
936
+ pooled_prompt_embeds=pooled_prompt_embeds,
937
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
938
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
939
+ max_sequence_length=max_sequence_length,
940
+ )
941
+
942
+ self._guidance_scale = guidance_scale
943
+ self._joint_attention_kwargs = joint_attention_kwargs
944
+ self._current_timestep = None
945
+ self._interrupt = False
946
+
947
+ # 2. Define call parameters
948
+ if prompt is not None and isinstance(prompt, str):
949
+ batch_size = 1
950
+ elif prompt is not None and isinstance(prompt, list):
951
+ batch_size = len(prompt)
952
+ else:
953
+ batch_size = prompt_embeds.shape[0]
954
+
955
+ device = self._execution_device
956
+
957
+ lora_scale = (
958
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
959
+ )
960
+ has_neg_prompt = negative_prompt is not None or (
961
+ negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
962
+ )
963
+ do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
964
+ (
965
+ prompt_embeds,
966
+ pooled_prompt_embeds,
967
+ text_ids,
968
+ ) = self.encode_prompt(
969
+ prompt=prompt,
970
+ prompt_2=prompt_2,
971
+ prompt_embeds=prompt_embeds,
972
+ pooled_prompt_embeds=pooled_prompt_embeds,
973
+ device=device,
974
+ num_images_per_prompt=num_images_per_prompt,
975
+ max_sequence_length=max_sequence_length,
976
+ lora_scale=lora_scale,
977
+ )
978
+ if do_true_cfg:
979
+ (
980
+ negative_prompt_embeds,
981
+ negative_pooled_prompt_embeds,
982
+ negative_text_ids,
983
+ ) = self.encode_prompt(
984
+ prompt=negative_prompt,
985
+ prompt_2=negative_prompt_2,
986
+ prompt_embeds=negative_prompt_embeds,
987
+ pooled_prompt_embeds=negative_pooled_prompt_embeds,
988
+ device=device,
989
+ num_images_per_prompt=num_images_per_prompt,
990
+ max_sequence_length=max_sequence_length,
991
+ lora_scale=lora_scale,
992
+ )
993
+
994
+ # 3. Preprocess image
995
+ if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
996
+ imgs = image if isinstance(image, list) else [image]
997
+
998
+ images = []
999
+ for img in imgs:
1000
+ img_0 = img[0] if isinstance(img, list) else img
1001
+ image_height, image_width = self.image_processor.get_default_height_width(img_0)
1002
+ aspect_ratio = image_width / image_height
1003
+
1004
+ if _auto_resize:
1005
+ _, image_width, image_height = min(
1006
+ (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS
1007
+ )
1008
+
1009
+ image_width = image_width // multiple_of * multiple_of
1010
+ image_height = image_height // multiple_of * multiple_of
1011
+
1012
+ resized = self.image_processor.resize(img, image_height, image_width)
1013
+ processed = self.image_processor.preprocess(resized, image_height, image_width)
1014
+ images.append(processed)
1015
+
1016
+ # 4. Prepare latent variables
1017
+ num_channels_latents = self.transformer.config.in_channels // 4
1018
+ latents, image_latents, latent_ids, image_ids = self.prepare_latents(
1019
+ images,
1020
+ batch_size * num_images_per_prompt,
1021
+ num_channels_latents,
1022
+ height,
1023
+ width,
1024
+ prompt_embeds.dtype,
1025
+ device,
1026
+ generator,
1027
+ latents,
1028
+ )
1029
+ if image_ids is not None:
1030
+ latent_ids = torch.cat([latent_ids, image_ids], dim=0) # dim 0 is sequence dimension
1031
+
1032
+ # 5. Prepare timesteps
1033
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
1034
+ image_seq_len = latents.shape[1]
1035
+ mu = calculate_shift(
1036
+ image_seq_len,
1037
+ self.scheduler.config.get("base_image_seq_len", 256),
1038
+ self.scheduler.config.get("max_image_seq_len", 4096),
1039
+ self.scheduler.config.get("base_shift", 0.5),
1040
+ self.scheduler.config.get("max_shift", 1.15),
1041
+ )
1042
+ timesteps, num_inference_steps = retrieve_timesteps(
1043
+ self.scheduler,
1044
+ num_inference_steps,
1045
+ device,
1046
+ sigmas=sigmas,
1047
+ mu=mu,
1048
+ )
1049
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1050
+ self._num_timesteps = len(timesteps)
1051
+
1052
+ # handle guidance
1053
+ if self.transformer.config.guidance_embeds:
1054
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
1055
+ guidance = guidance.expand(latents.shape[0])
1056
+ else:
1057
+ guidance = None
1058
+
1059
+ if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
1060
+ negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
1061
+ ):
1062
+ negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
1063
+ negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
1064
+
1065
+ elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
1066
+ negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
1067
+ ):
1068
+ ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
1069
+ ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
1070
+
1071
+ if self.joint_attention_kwargs is None:
1072
+ self._joint_attention_kwargs = {}
1073
+
1074
+ image_embeds = None
1075
+ negative_image_embeds = None
1076
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
1077
+ image_embeds = self.prepare_ip_adapter_image_embeds(
1078
+ ip_adapter_image,
1079
+ ip_adapter_image_embeds,
1080
+ device,
1081
+ batch_size * num_images_per_prompt,
1082
+ )
1083
+ if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
1084
+ negative_image_embeds = self.prepare_ip_adapter_image_embeds(
1085
+ negative_ip_adapter_image,
1086
+ negative_ip_adapter_image_embeds,
1087
+ device,
1088
+ batch_size * num_images_per_prompt,
1089
+ )
1090
+
1091
+ # 6. Denoising loop
1092
+ # We set the index here to remove DtoH sync, helpful especially during compilation.
1093
+ # Check out more details here: https://github.com/huggingface/diffusers/pull/11696
1094
+ all_latents = [latents]
1095
+ all_log_probs = []
1096
+ all_timesteps = []
1097
+ self.scheduler.set_begin_index(0)
1098
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1099
+ for i, t in enumerate(timesteps):
1100
+ if self.interrupt:
1101
+ continue
1102
+
1103
+ self._current_timestep = t
1104
+ if image_embeds is not None:
1105
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
1106
+
1107
+ latent_model_input = latents
1108
+ if image_latents is not None:
1109
+ latent_model_input = torch.cat([latents, image_latents], dim=1)
1110
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
1111
+
1112
+ noise_pred = self.transformer(
1113
+ hidden_states=latent_model_input,
1114
+ timestep=timestep / 1000,
1115
+ guidance=guidance,
1116
+ pooled_projections=pooled_prompt_embeds,
1117
+ encoder_hidden_states=prompt_embeds,
1118
+ txt_ids=text_ids,
1119
+ img_ids=latent_ids,
1120
+ joint_attention_kwargs=self.joint_attention_kwargs,
1121
+ return_dict=False,
1122
+ )[0]
1123
+ noise_pred = noise_pred[:, : latents.size(1)]
1124
+
1125
+ if do_true_cfg:
1126
+ if negative_image_embeds is not None:
1127
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
1128
+ neg_noise_pred = self.transformer(
1129
+ hidden_states=latent_model_input,
1130
+ timestep=timestep / 1000,
1131
+ guidance=guidance,
1132
+ pooled_projections=negative_pooled_prompt_embeds,
1133
+ encoder_hidden_states=negative_prompt_embeds,
1134
+ txt_ids=negative_text_ids,
1135
+ img_ids=latent_ids,
1136
+ joint_attention_kwargs=self.joint_attention_kwargs,
1137
+ return_dict=False,
1138
+ )[0]
1139
+ neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
1140
+ noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
1141
+
1142
+ # compute the previous noisy sample x_t -> x_t-1
1143
+ latents_dtype = latents.dtype
1144
+ scheduler_output = self.scheduler.step(noise_pred, t, latents, return_dict=True)
1145
+ latents = scheduler_output.latents
1146
+ log_probs = scheduler_output.log_probs
1147
+
1148
+ all_latents.append(latents)
1149
+ all_log_probs.append(log_probs)
1150
+ all_timesteps.append(timesteps)
1151
+
1152
+ if latents.dtype != latents_dtype:
1153
+ if torch.backends.mps.is_available():
1154
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1155
+ latents = latents.to(latents_dtype)
1156
+
1157
+ if callback_on_step_end is not None:
1158
+ callback_kwargs = {}
1159
+ for k in callback_on_step_end_tensor_inputs:
1160
+ callback_kwargs[k] = locals()[k]
1161
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1162
+
1163
+ latents = callback_outputs.pop("latents", latents)
1164
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1165
+
1166
+ # call the callback, if provided
1167
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1168
+ progress_bar.update()
1169
+
1170
+ if XLA_AVAILABLE:
1171
+ xm.mark_step()
1172
+
1173
+ self._current_timestep = None
1174
+
1175
+ if output_type == "latent":
1176
+ image = latents
1177
+ else:
1178
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
1179
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
1180
+ image = self.vae.decode(latents, return_dict=False)[0]
1181
+ image = self.image_processor.postprocess(image, output_type=output_type)
1182
+
1183
+ # Offload all models
1184
+ self.maybe_free_model_hooks()
1185
+
1186
+ if not return_dict:
1187
+ return (image,)
1188
+
1189
+ return FluxPipelineOutput(image, all_latents, all_log_probs, latent_ids, all_timesteps, image_latents)
kontext/scheduling_flow_match_euler_discrete.py ADDED
@@ -0,0 +1,604 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from dataclasses import dataclass
17
+ from typing import List, Optional, Tuple, Union
18
+
19
+ import numpy as np
20
+ import torch
21
+
22
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
23
+ from diffusers.utils import BaseOutput, is_scipy_available, logging
24
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
25
+
26
+
27
+ if is_scipy_available():
28
+ import scipy.stats
29
+
30
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
31
+
32
+
33
+ @dataclass
34
+ class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput):
35
+ """
36
+ Output class for the scheduler's `step` function output.
37
+
38
+ Args:
39
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
40
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
41
+ denoising loop.
42
+ """
43
+
44
+ latents: torch.FloatTensor
45
+ log_probs: torch.FloatTensor
46
+
47
+
48
+ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
49
+ """
50
+ Euler scheduler.
51
+
52
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
53
+ methods the library implements for all schedulers such as loading and saving.
54
+
55
+ Args:
56
+ num_train_timesteps (`int`, defaults to 1000):
57
+ The number of diffusion steps to train the model.
58
+ shift (`float`, defaults to 1.0):
59
+ The shift value for the timestep schedule.
60
+ use_dynamic_shifting (`bool`, defaults to False):
61
+ Whether to apply timestep shifting on-the-fly based on the image resolution.
62
+ base_shift (`float`, defaults to 0.5):
63
+ Value to stabilize image generation. Increasing `base_shift` reduces variation and image is more consistent
64
+ with desired output.
65
+ max_shift (`float`, defaults to 1.15):
66
+ Value change allowed to latent vectors. Increasing `max_shift` encourages more variation and image may be
67
+ more exaggerated or stylized.
68
+ base_image_seq_len (`int`, defaults to 256):
69
+ The base image sequence length.
70
+ max_image_seq_len (`int`, defaults to 4096):
71
+ The maximum image sequence length.
72
+ invert_sigmas (`bool`, defaults to False):
73
+ Whether to invert the sigmas.
74
+ shift_terminal (`float`, defaults to None):
75
+ The end value of the shifted timestep schedule.
76
+ use_karras_sigmas (`bool`, defaults to False):
77
+ Whether to use Karras sigmas for step sizes in the noise schedule during sampling.
78
+ use_exponential_sigmas (`bool`, defaults to False):
79
+ Whether to use exponential sigmas for step sizes in the noise schedule during sampling.
80
+ use_beta_sigmas (`bool`, defaults to False):
81
+ Whether to use beta sigmas for step sizes in the noise schedule during sampling.
82
+ time_shift_type (`str`, defaults to "exponential"):
83
+ The type of dynamic resolution-dependent timestep shifting to apply. Either "exponential" or "linear".
84
+ stochastic_sampling (`bool`, defaults to False):
85
+ Whether to use stochastic sampling.
86
+ """
87
+
88
+ _compatibles = []
89
+ order = 1
90
+
91
+ @register_to_config
92
+ def __init__(
93
+ self,
94
+ num_train_timesteps: int = 1000,
95
+ shift: float = 1.0,
96
+ use_dynamic_shifting: bool = False,
97
+ base_shift: Optional[float] = 0.5,
98
+ max_shift: Optional[float] = 1.15,
99
+ base_image_seq_len: Optional[int] = 256,
100
+ max_image_seq_len: Optional[int] = 4096,
101
+ invert_sigmas: bool = False,
102
+ shift_terminal: Optional[float] = None,
103
+ use_karras_sigmas: Optional[bool] = False,
104
+ use_exponential_sigmas: Optional[bool] = False,
105
+ use_beta_sigmas: Optional[bool] = False,
106
+ time_shift_type: str = "exponential",
107
+ stochastic_sampling: bool = True,
108
+ ):
109
+ if self.config.use_beta_sigmas and not is_scipy_available():
110
+ raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
111
+ if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
112
+ raise ValueError(
113
+ "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
114
+ )
115
+ if time_shift_type not in {"exponential", "linear"}:
116
+ raise ValueError("`time_shift_type` must either be 'exponential' or 'linear'.")
117
+
118
+ timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
119
+ timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
120
+
121
+ sigmas = timesteps / num_train_timesteps
122
+ if not use_dynamic_shifting:
123
+ # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
124
+ sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
125
+
126
+ self.timesteps = sigmas * num_train_timesteps
127
+
128
+ self._step_index = None
129
+ self._begin_index = None
130
+
131
+ self._shift = shift
132
+
133
+ self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
134
+ self.sigma_min = self.sigmas[-1].item()
135
+ self.sigma_max = self.sigmas[0].item()
136
+
137
+ @property
138
+ def shift(self):
139
+ """
140
+ The value used for shifting.
141
+ """
142
+ return self._shift
143
+
144
+ @property
145
+ def step_index(self):
146
+ """
147
+ The index counter for current timestep. It will increase 1 after each scheduler step.
148
+ """
149
+ return self._step_index
150
+
151
+ @property
152
+ def begin_index(self):
153
+ """
154
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
155
+ """
156
+ return self._begin_index
157
+
158
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
159
+ def set_begin_index(self, begin_index: int = 0):
160
+ """
161
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
162
+
163
+ Args:
164
+ begin_index (`int`):
165
+ The begin index for the scheduler.
166
+ """
167
+ self._begin_index = begin_index
168
+
169
+ def set_shift(self, shift: float):
170
+ self._shift = shift
171
+
172
+ def scale_noise(
173
+ self,
174
+ sample: torch.FloatTensor,
175
+ timestep: Union[float, torch.FloatTensor],
176
+ noise: Optional[torch.FloatTensor] = None,
177
+ ) -> torch.FloatTensor:
178
+ """
179
+ Forward process in flow-matching
180
+
181
+ Args:
182
+ sample (`torch.FloatTensor`):
183
+ The input sample.
184
+ timestep (`int`, *optional*):
185
+ The current timestep in the diffusion chain.
186
+
187
+ Returns:
188
+ `torch.FloatTensor`:
189
+ A scaled input sample.
190
+ """
191
+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
192
+ sigmas = self.sigmas.to(device=sample.device, dtype=sample.dtype)
193
+
194
+ if sample.device.type == "mps" and torch.is_floating_point(timestep):
195
+ # mps does not support float64
196
+ schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32)
197
+ timestep = timestep.to(sample.device, dtype=torch.float32)
198
+ else:
199
+ schedule_timesteps = self.timesteps.to(sample.device)
200
+ timestep = timestep.to(sample.device)
201
+
202
+ # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
203
+ if self.begin_index is None:
204
+ step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timestep]
205
+ elif self.step_index is not None:
206
+ # add_noise is called after first denoising step (for inpainting)
207
+ step_indices = [self.step_index] * timestep.shape[0]
208
+ else:
209
+ # add noise is called before first denoising step to create initial latent(img2img)
210
+ step_indices = [self.begin_index] * timestep.shape[0]
211
+
212
+ sigma = sigmas[step_indices].flatten()
213
+ while len(sigma.shape) < len(sample.shape):
214
+ sigma = sigma.unsqueeze(-1)
215
+
216
+ sample = sigma * noise + (1.0 - sigma) * sample
217
+
218
+ return sample
219
+
220
+ def _sigma_to_t(self, sigma):
221
+ return sigma * self.config.num_train_timesteps
222
+
223
+ def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
224
+ if self.config.time_shift_type == "exponential":
225
+ return self._time_shift_exponential(mu, sigma, t)
226
+ elif self.config.time_shift_type == "linear":
227
+ return self._time_shift_linear(mu, sigma, t)
228
+
229
+ def stretch_shift_to_terminal(self, t: torch.Tensor) -> torch.Tensor:
230
+ r"""
231
+ Stretches and shifts the timestep schedule to ensure it terminates at the configured `shift_terminal` config
232
+ value.
233
+
234
+ Reference:
235
+ https://github.com/Lightricks/LTX-Video/blob/a01a171f8fe3d99dce2728d60a73fecf4d4238ae/ltx_video/schedulers/rf.py#L51
236
+
237
+ Args:
238
+ t (`torch.Tensor`):
239
+ A tensor of timesteps to be stretched and shifted.
240
+
241
+ Returns:
242
+ `torch.Tensor`:
243
+ A tensor of adjusted timesteps such that the final value equals `self.config.shift_terminal`.
244
+ """
245
+ one_minus_z = 1 - t
246
+ scale_factor = one_minus_z[-1] / (1 - self.config.shift_terminal)
247
+ stretched_t = 1 - (one_minus_z / scale_factor)
248
+ return stretched_t
249
+
250
+ def set_timesteps(
251
+ self,
252
+ num_inference_steps: Optional[int] = None,
253
+ device: Union[str, torch.device] = None,
254
+ sigmas: Optional[List[float]] = None,
255
+ mu: Optional[float] = None,
256
+ timesteps: Optional[List[float]] = None,
257
+ ):
258
+ """
259
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
260
+
261
+ Args:
262
+ num_inference_steps (`int`, *optional*):
263
+ The number of diffusion steps used when generating samples with a pre-trained model.
264
+ device (`str` or `torch.device`, *optional*):
265
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
266
+ sigmas (`List[float]`, *optional*):
267
+ Custom values for sigmas to be used for each diffusion step. If `None`, the sigmas are computed
268
+ automatically.
269
+ mu (`float`, *optional*):
270
+ Determines the amount of shifting applied to sigmas when performing resolution-dependent timestep
271
+ shifting.
272
+ timesteps (`List[float]`, *optional*):
273
+ Custom values for timesteps to be used for each diffusion step. If `None`, the timesteps are computed
274
+ automatically.
275
+ """
276
+ if self.config.use_dynamic_shifting and mu is None:
277
+ raise ValueError("`mu` must be passed when `use_dynamic_shifting` is set to be `True`")
278
+
279
+ if sigmas is not None and timesteps is not None:
280
+ if len(sigmas) != len(timesteps):
281
+ raise ValueError("`sigmas` and `timesteps` should have the same length")
282
+
283
+ if num_inference_steps is not None:
284
+ if (sigmas is not None and len(sigmas) != num_inference_steps) or (
285
+ timesteps is not None and len(timesteps) != num_inference_steps
286
+ ):
287
+ raise ValueError(
288
+ "`sigmas` and `timesteps` should have the same length as num_inference_steps, if `num_inference_steps` is provided"
289
+ )
290
+ else:
291
+ num_inference_steps = len(sigmas) if sigmas is not None else len(timesteps)
292
+
293
+ self.num_inference_steps = num_inference_steps
294
+
295
+ # 1. Prepare default sigmas
296
+ is_timesteps_provided = timesteps is not None
297
+
298
+ if is_timesteps_provided:
299
+ timesteps = np.array(timesteps).astype(np.float32)
300
+
301
+ if sigmas is None:
302
+ if timesteps is None:
303
+ timesteps = np.linspace(
304
+ self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps
305
+ )
306
+ sigmas = timesteps / self.config.num_train_timesteps
307
+ else:
308
+ sigmas = np.array(sigmas).astype(np.float32)
309
+ num_inference_steps = len(sigmas)
310
+
311
+ # 2. Perform timestep shifting. Either no shifting is applied, or resolution-dependent shifting of
312
+ # "exponential" or "linear" type is applied
313
+ if self.config.use_dynamic_shifting:
314
+ sigmas = self.time_shift(mu, 1.0, sigmas)
315
+ else:
316
+ sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas)
317
+
318
+ # 3. If required, stretch the sigmas schedule to terminate at the configured `shift_terminal` value
319
+ if self.config.shift_terminal:
320
+ sigmas = self.stretch_shift_to_terminal(sigmas)
321
+
322
+ # 4. If required, convert sigmas to one of karras, exponential, or beta sigma schedules
323
+ if self.config.use_karras_sigmas:
324
+ sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
325
+ elif self.config.use_exponential_sigmas:
326
+ sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
327
+ elif self.config.use_beta_sigmas:
328
+ sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
329
+
330
+ # 5. Convert sigmas and timesteps to tensors and move to specified device
331
+ sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
332
+ if not is_timesteps_provided:
333
+ timesteps = sigmas * self.config.num_train_timesteps
334
+ else:
335
+ timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32, device=device)
336
+
337
+ # 6. Append the terminal sigma value.
338
+ # If a model requires inverted sigma schedule for denoising but timesteps without inversion, the
339
+ # `invert_sigmas` flag can be set to `True`. This case is only required in Mochi
340
+ if self.config.invert_sigmas:
341
+ sigmas = 1.0 - sigmas
342
+ timesteps = sigmas * self.config.num_train_timesteps
343
+ sigmas = torch.cat([sigmas, torch.ones(1, device=sigmas.device)])
344
+ else:
345
+ sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
346
+
347
+ self.timesteps = timesteps
348
+ self.sigmas = sigmas
349
+ self._step_index = None
350
+ self._begin_index = None
351
+
352
+ # def index_for_timestep(self, timestep, schedule_timesteps=None):
353
+ # if schedule_timesteps is None:
354
+ # schedule_timesteps = self.timesteps
355
+ # indices = (schedule_timesteps == timestep).nonzero()
356
+
357
+ # # The sigma index that is taken for the **very** first `step`
358
+ # # is always the second index (or the last index if there is only 1)
359
+ # # This way we can ensure we don't accidentally skip a sigma in
360
+ # # case we start in the middle of the denoising schedule (e.g. for image-to-image)
361
+ # pos = 1 if len(indices) > 1 else 0
362
+
363
+ # return indices[pos].item()
364
+
365
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
366
+ if schedule_timesteps is None:
367
+ schedule_timesteps = self.timesteps
368
+ match = (schedule_timesteps[None, :] == timestep[:, None])
369
+
370
+ cols = torch.arange(schedule_timesteps.numel())
371
+ cols = cols.expand(timestep.numel(), -1)
372
+ match=match.to(cols.device)
373
+
374
+ idx_last = torch.where(match, cols, torch.full_like(cols, -1)).max(dim=1).values
375
+
376
+ return idx_last
377
+
378
+ def _init_step_index(self, timestep):
379
+ if self.begin_index is None:
380
+ if isinstance(timestep, torch.Tensor):
381
+ timestep = timestep.to(self.timesteps.device)
382
+ self._step_index = self.index_for_timestep(timestep)
383
+ else:
384
+ self._step_index = self._begin_index
385
+
386
+ def step(
387
+ self,
388
+ model_output: torch.FloatTensor,
389
+ timestep: Union[float, torch.FloatTensor],
390
+ sample: torch.FloatTensor,
391
+ s_churn: float = 0.0,
392
+ s_tmin: float = 0.0,
393
+ s_tmax: float = float("inf"),
394
+ s_noise: float = 1.0,
395
+ generator: Optional[torch.Generator] = None,
396
+ prev_sample: Optional[torch.FloatTensor] = None,
397
+ per_token_timesteps: Optional[torch.Tensor] = None,
398
+ return_dict: bool = True,
399
+ init_step: bool = False,
400
+ ) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
401
+ """
402
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
403
+ process from the learned model outputs (most often the predicted noise).
404
+
405
+ Args:
406
+ model_output (`torch.FloatTensor`):
407
+ The direct output from learned diffusion model.
408
+ timestep (`float`):
409
+ The current discrete timestep in the diffusion chain.
410
+ sample (`torch.FloatTensor`):
411
+ A current instance of a sample created by the diffusion process.
412
+ s_churn (`float`):
413
+ s_tmin (`float`):
414
+ s_tmax (`float`):
415
+ s_noise (`float`, defaults to 1.0):
416
+ Scaling factor for noise added to the sample.
417
+ generator (`torch.Generator`, *optional*):
418
+ A random number generator.
419
+ per_token_timesteps (`torch.Tensor`, *optional*):
420
+ The timesteps for each token in the sample.
421
+ return_dict (`bool`):
422
+ Whether or not to return a
423
+ [`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] or tuple.
424
+
425
+ Returns:
426
+ [`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] or `tuple`:
427
+ If return_dict is `True`,
428
+ [`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] is returned,
429
+ otherwise a tuple is returned where the first element is the sample tensor.
430
+ """
431
+
432
+ if (
433
+ isinstance(timestep, int)
434
+ or isinstance(timestep, torch.IntTensor)
435
+ or isinstance(timestep, torch.LongTensor)
436
+ ):
437
+ raise ValueError(
438
+ (
439
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
440
+ " `FlowMatchEulerDiscreteScheduler.step()` is not supported. Make sure to pass"
441
+ " one of the `scheduler.timesteps` as a timestep."
442
+ ),
443
+ )
444
+
445
+ if self.step_index is None:
446
+ self._init_step_index(timestep)
447
+ # if init_step:
448
+ # self._init_step_index(timestep)
449
+ # Upcast to avoid precision issues when computing prev_sample
450
+ # sample = sample.to(torch.float32)
451
+
452
+ if per_token_timesteps is not None:
453
+ per_token_sigmas = per_token_timesteps / self.config.num_train_timesteps
454
+
455
+ sigmas = self.sigmas[:, None, None]
456
+ lower_mask = sigmas < per_token_sigmas[None] - 1e-6
457
+ lower_sigmas = lower_mask * sigmas
458
+ lower_sigmas, _ = lower_sigmas.max(dim=0)
459
+
460
+ current_sigma = per_token_sigmas[..., None]
461
+ next_sigma = lower_sigmas[..., None]
462
+ dt = current_sigma - next_sigma
463
+ else:
464
+ sigma_idx = self.step_index
465
+ if init_step:
466
+ sigma_idx = self.index_for_timestep(timestep)
467
+ sigma = self.sigmas[sigma_idx]
468
+ sigma_next = self.sigmas[sigma_idx + 1]
469
+
470
+ current_sigma = sigma
471
+ next_sigma = sigma_next
472
+ if len(current_sigma.shape)> 0:
473
+ current_sigma = current_sigma[:, None, None]
474
+ next_sigma = next_sigma[:, None, None]
475
+ dt = sigma_next - sigma
476
+ log_prob = None
477
+ if self.config.stochastic_sampling:
478
+ # if len(current_sigma.shape)> 0:
479
+ # current_sigma = current_sigma[:, None, None]
480
+ # print(f"model_output {model_output.shape}")
481
+ # print(f"sigma {current_sigma.shape}")
482
+ # print(f"sample {sample.shape}")
483
+ x0 = sample - current_sigma * model_output
484
+
485
+ if prev_sample is None:
486
+ if generator is None:
487
+ generator = torch.Generator(device=sample.device)
488
+ generator.seed()
489
+ noise = torch.randn(sample.size(), generator=generator, device=sample.device, dtype=sample.dtype)
490
+ prev_sample = (1.0 - next_sigma) * x0 + next_sigma * noise
491
+
492
+
493
+ prev_sample_mean = (1-next_sigma)*x0
494
+ prev_sample_ =prev_sample.clone()
495
+ diff = prev_sample_.detach() - prev_sample_mean
496
+
497
+ log_prob = - (diff**2) / (2 * (next_sigma+1e-7)**2) \
498
+ - torch.log(next_sigma) \
499
+ - 0.5 * torch.log(torch.tensor(2 * np.pi))
500
+ log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
501
+ else:
502
+ prev_sample = sample + dt * model_output
503
+
504
+ # upon completion increase step index by one
505
+ self._step_index += 1
506
+ if per_token_timesteps is None:
507
+ # Cast sample back to model compatible dtype
508
+ prev_sample = prev_sample.to(model_output.dtype)
509
+
510
+ # if log_prob == None:
511
+ # raise ValueError("log_prob is None, stochastic_sampling is off")
512
+ if not return_dict:
513
+ return (prev_sample,log_prob)
514
+
515
+ return FlowMatchEulerDiscreteSchedulerOutput(latents=prev_sample, log_probs=log_prob)
516
+
517
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
518
+ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
519
+ """Constructs the noise schedule of Karras et al. (2022)."""
520
+
521
+ # Hack to make sure that other schedulers which copy this function don't break
522
+ # TODO: Add this logic to the other schedulers
523
+ if hasattr(self.config, "sigma_min"):
524
+ sigma_min = self.config.sigma_min
525
+ else:
526
+ sigma_min = None
527
+
528
+ if hasattr(self.config, "sigma_max"):
529
+ sigma_max = self.config.sigma_max
530
+ else:
531
+ sigma_max = None
532
+
533
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
534
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
535
+
536
+ rho = 7.0 # 7.0 is the value used in the paper
537
+ ramp = np.linspace(0, 1, num_inference_steps)
538
+ min_inv_rho = sigma_min ** (1 / rho)
539
+ max_inv_rho = sigma_max ** (1 / rho)
540
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
541
+ return sigmas
542
+
543
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
544
+ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
545
+ """Constructs an exponential noise schedule."""
546
+
547
+ # Hack to make sure that other schedulers which copy this function don't break
548
+ # TODO: Add this logic to the other schedulers
549
+ if hasattr(self.config, "sigma_min"):
550
+ sigma_min = self.config.sigma_min
551
+ else:
552
+ sigma_min = None
553
+
554
+ if hasattr(self.config, "sigma_max"):
555
+ sigma_max = self.config.sigma_max
556
+ else:
557
+ sigma_max = None
558
+
559
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
560
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
561
+
562
+ sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
563
+ return sigmas
564
+
565
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
566
+ def _convert_to_beta(
567
+ self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
568
+ ) -> torch.Tensor:
569
+ """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
570
+
571
+ # Hack to make sure that other schedulers which copy this function don't break
572
+ # TODO: Add this logic to the other schedulers
573
+ if hasattr(self.config, "sigma_min"):
574
+ sigma_min = self.config.sigma_min
575
+ else:
576
+ sigma_min = None
577
+
578
+ if hasattr(self.config, "sigma_max"):
579
+ sigma_max = self.config.sigma_max
580
+ else:
581
+ sigma_max = None
582
+
583
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
584
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
585
+
586
+ sigmas = np.array(
587
+ [
588
+ sigma_min + (ppf * (sigma_max - sigma_min))
589
+ for ppf in [
590
+ scipy.stats.beta.ppf(timestep, alpha, beta)
591
+ for timestep in 1 - np.linspace(0, 1, num_inference_steps)
592
+ ]
593
+ ]
594
+ )
595
+ return sigmas
596
+
597
+ def _time_shift_exponential(self, mu, sigma, t):
598
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
599
+
600
+ def _time_shift_linear(self, mu, sigma, t):
601
+ return mu / (mu + (1 / t - 1) ** sigma)
602
+
603
+ def __len__(self):
604
+ return self.config.num_train_timesteps