bo.l commited on
Commit
46657b2
·
1 Parent(s): ee3f2c0

update input1

Browse files
Files changed (1) hide show
  1. app.py +129 -84
app.py CHANGED
@@ -1,7 +1,9 @@
1
  import gradio as gr
2
  import numpy as np
3
  import random
4
- import spaces #[uncomment to use ZeroGPU]
 
 
5
  from kontext.pipeline_flux_kontext import FluxKontextPipeline
6
  from kontext.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
7
  from diffusers import FluxTransformer2DModel
@@ -9,6 +11,9 @@ import torch
9
  from huggingface_hub import hf_hub_download
10
  from safetensors.torch import load_file
11
 
 
 
 
12
  def resize_by_bucket(images_pil, resolution=512):
13
  assert len(images_pil) > 0, "images_pil 不能为空"
14
  bucket_override = [
@@ -18,14 +23,9 @@ def resize_by_bucket(images_pil, resolution=512):
18
  (552, 472), (592, 440), (624, 416), (664, 400),
19
  (696, 376), (728, 360), (752, 344), (784, 336),
20
  ]
21
- bucket_override = [
22
- (int(h / 512 * resolution), int(w / 512 * resolution))
23
- for h, w in bucket_override
24
- ]
25
- bucket_override = [
26
- (h // 16 * 16, w // 16 * 16)
27
- for h, w in bucket_override
28
- ]
29
 
30
  aspect_ratios = [img.height / img.width for img in images_pil]
31
  mean_aspect_ratio = float(np.mean(aspect_ratios))
@@ -38,60 +38,88 @@ def resize_by_bucket(images_pil, resolution=512):
38
  min_aspect_diff = aspect_diff
39
  new_h, new_w = h, w
40
 
41
- resized_images = [
42
- img.resize((new_w, new_h), resample=Image.BICUBIC) for img in images_pil
43
- ]
44
  return resized_images
45
 
 
 
 
46
  device = "cuda" if torch.cuda.is_available() else "cpu"
47
 
48
  flux_pipeline = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev")
49
  flux_pipeline.scheduler = FlowMatchEulerDiscreteScheduler.from_config(flux_pipeline.scheduler.config)
 
 
 
50
  flux_pipeline.vae.to(device).to(torch.bfloat16)
51
  flux_pipeline.text_encoder.to(device).to(torch.bfloat16)
52
  flux_pipeline.text_encoder_2.to(device).to(torch.bfloat16)
53
- flux_pipeline.scheduler.config.stochastic_sampling = False
 
54
  ckpt_path = hf_hub_download("NoobDoge/Multi_Ref_Model", "full_ema_model.safetensors")
55
- # new_weight = load_file(ckpt_path)
56
  flux_pipeline.transformer.from_single_file(ckpt_path, torch_dtype=torch.bfloat16)
57
  flux_pipeline.transformer.to(device).to(torch.bfloat16)
58
 
 
 
 
59
  MAX_SEED = np.iinfo(np.int32).max
60
- MAX_IMAGE_SIZE = 512
61
 
62
-
63
- @spaces.GPU #[uncomment to use ZeroGPU]
 
 
64
  def infer(
65
  prompt,
66
- ref1,
67
- ref2,
68
  seed,
69
  randomize_seed,
70
  width,
71
  height,
72
- guidance_scale,
73
  num_inference_steps,
74
  progress=gr.Progress(track_tqdm=True),
75
  ):
 
 
 
 
 
 
 
 
 
 
76
  if randomize_seed:
77
  seed = random.randint(0, MAX_SEED)
78
- raw_images = [resize_by_bucket(x) for x in raw_images]
79
- generator = torch.Generator().manual_seed(seed)
80
 
 
 
 
 
 
81
  with torch.no_grad():
82
- output_img = flux_pipeline(
83
- image = raw_images,
84
- prompt = prompts,
85
- height = height,
86
- width = width,
87
- num_inference_steps = num_inference_steps,
88
- max_area=MAX_IMAGE_SIZE**2,
89
  generator=generator,
90
- ).images[0]
91
-
92
- return image, seed
 
93
 
 
94
 
 
 
 
95
  examples = [
96
  "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
97
  "An astronaut riding a green horse",
@@ -110,66 +138,83 @@ with gr.Blocks(css=css) as demo:
110
  gr.Markdown("# Text-to-Image Gradio Template")
111
 
112
  with gr.Row():
113
- prompt = gr.Text(label="Prompt", show_label=False, max_lines=1,
114
- placeholder="Enter your prompt", container=False)
 
 
 
 
 
115
  run_button = gr.Button("Run", scale=0, variant="primary")
116
 
117
- # 两张输入图片,其中 ref2 可为空
118
  with gr.Row():
119
- ref1 = gr.Image(label="Input Image 1", type="pil")
120
- ref2 = gr.Image(label="Input Image 2", type="pil")
121
 
122
  result = gr.Image(label="Result", show_label=False)
123
 
124
  with gr.Accordion("Advanced Settings", open=False):
125
- seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
126
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
 
 
 
 
 
 
 
127
  with gr.Row():
128
- width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
129
- height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  with gr.Row():
131
- guidance_scale = gr.Slider(label="Guidance scale", minimum=0.0, maximum=10.0, step=0.1, value=0.0)
132
- num_inference_steps = gr.Slider(label="Number of inference steps", minimum=1, maximum=50, step=1, value=2)
133
-
134
- examples = [
135
- ["a cute corgi in a wizard hat"],
136
- ["a watercolor painting of yosemite valley at sunrise"],
137
- ]
138
- gr.Examples(examples=examples, inputs=[prompt])
139
-
140
- # 用于装“可变长度”的参考图列表
141
- refs_state = gr.State([])
142
-
143
- # 先把两张图打包到 state,自动过滤 None,这样 ref2 就是可选的
144
- def pack_refs(a, b):
145
- return [x for x in (a, b) if x is not None]
146
-
147
- # 你的推理函数接受“列表”refs
148
- def infer(prompt, refs, seed, randomize_seed, width, height, guidance_scale, num_steps):
149
- # 如需长度为2,可补齐到 [ref1, None]
150
- if len(refs) == 0:
151
- refs = [None, None]
152
- elif len(refs) == 1:
153
- refs = [refs[0], None]
154
-
155
- # TODO: 在这里调用你的模型,使用 refs[0], refs[1](第二张可能是 None)
156
- # out_img = ...
157
- # used_seed = ...
158
- return out_img, used_seed
159
-
160
- # 第一步:把 ref1/ref2 打包进 refs_state
161
- dep = gr.on(
162
- triggers=[run_button.click, prompt.submit],
163
- fn=pack_refs,
164
- inputs=[ref1, ref2],
165
- outputs=refs_state,
166
- )
167
- # 第二步:再把打包好的列表传给 infer
168
- dep.then(
169
- fn=infer,
170
- inputs=[prompt, refs_state, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
171
- outputs=[result, seed],
172
- )
173
 
174
  if __name__ == "__main__":
175
- demo.launch()
 
1
  import gradio as gr
2
  import numpy as np
3
  import random
4
+ import spaces # [uncomment to use ZeroGPU]
5
+ from PIL import Image
6
+
7
  from kontext.pipeline_flux_kontext import FluxKontextPipeline
8
  from kontext.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
9
  from diffusers import FluxTransformer2DModel
 
11
  from huggingface_hub import hf_hub_download
12
  from safetensors.torch import load_file
13
 
14
+ # ---------------------------
15
+ # utils
16
+ # ---------------------------
17
  def resize_by_bucket(images_pil, resolution=512):
18
  assert len(images_pil) > 0, "images_pil 不能为空"
19
  bucket_override = [
 
23
  (552, 472), (592, 440), (624, 416), (664, 400),
24
  (696, 376), (728, 360), (752, 344), (784, 336),
25
  ]
26
+ # 按目标分辨率缩放,并对齐到 16
27
+ bucket_override = [(int(h / 512 * resolution), int(w / 512 * resolution)) for h, w in bucket_override]
28
+ bucket_override = [(h // 16 * 16, w // 16 * 16) for h, w in bucket_override]
 
 
 
 
 
29
 
30
  aspect_ratios = [img.height / img.width for img in images_pil]
31
  mean_aspect_ratio = float(np.mean(aspect_ratios))
 
38
  min_aspect_diff = aspect_diff
39
  new_h, new_w = h, w
40
 
41
+ resized_images = [img.resize((new_w, new_h), resample=Image.BICUBIC) for img in images_pil]
 
 
42
  return resized_images
43
 
44
+ # ---------------------------
45
+ # pipeline init
46
+ # ---------------------------
47
  device = "cuda" if torch.cuda.is_available() else "cpu"
48
 
49
  flux_pipeline = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev")
50
  flux_pipeline.scheduler = FlowMatchEulerDiscreteScheduler.from_config(flux_pipeline.scheduler.config)
51
+ flux_pipeline.scheduler.config.stochastic_sampling = False
52
+
53
+ # precision & device
54
  flux_pipeline.vae.to(device).to(torch.bfloat16)
55
  flux_pipeline.text_encoder.to(device).to(torch.bfloat16)
56
  flux_pipeline.text_encoder_2.to(device).to(torch.bfloat16)
57
+
58
+ # 替换 transformer 权重
59
  ckpt_path = hf_hub_download("NoobDoge/Multi_Ref_Model", "full_ema_model.safetensors")
 
60
  flux_pipeline.transformer.from_single_file(ckpt_path, torch_dtype=torch.bfloat16)
61
  flux_pipeline.transformer.to(device).to(torch.bfloat16)
62
 
63
+ # ---------------------------
64
+ # constants
65
+ # ---------------------------
66
  MAX_SEED = np.iinfo(np.int32).max
67
+ MAX_IMAGE_SIZE = 512 # 与下方滑块默认值 1024 保持一致
68
 
69
+ # ---------------------------
70
+ # inference
71
+ # ---------------------------
72
+ @spaces.GPU # [uncomment to use ZeroGPU]
73
  def infer(
74
  prompt,
75
+ ref1, # PIL.Image 或 None
76
+ ref2, # PIL.Image 或 None(可选)
77
  seed,
78
  randomize_seed,
79
  width,
80
  height,
81
+ guidance_scale, # 目前没传入 pipeline,如需要可在下面调用里加上
82
  num_inference_steps,
83
  progress=gr.Progress(track_tqdm=True),
84
  ):
85
+ # 组装可选参考图列表
86
+ refs = [x for x in (ref1, ref2) if x is not None]
87
+ if len(refs) == 0:
88
+ raise gr.Error("请至少上传一张参考图(ref1 或 ref2)。")
89
+
90
+ # 规范宽高:不超过 MAX_IMAGE_SIZE 且对齐到 16
91
+ width = max(16, min(width, MAX_IMAGE_SIZE)) // 16 * 16
92
+ height = max(16, min(height, MAX_IMAGE_SIZE)) // 16 * 16
93
+
94
+ # 随机种子
95
  if randomize_seed:
96
  seed = random.randint(0, MAX_SEED)
97
+ generator = torch.Generator(device=device).manual_seed(int(seed))
 
98
 
99
+ # 参考图按桶缩放
100
+ base_res = min(width, height, MAX_IMAGE_SIZE)
101
+ raw_images = resize_by_bucket(refs, resolution=base_res)
102
+
103
+ # 推理
104
  with torch.no_grad():
105
+ out = flux_pipeline(
106
+ image=raw_images,
107
+ prompt=prompt,
108
+ height=height,
109
+ width=width,
110
+ num_inference_steps=int(num_inference_steps),
111
+ max_area=MAX_IMAGE_SIZE ** 2,
112
  generator=generator,
113
+ # 如需 guidance_scale,确保 pipeline 支持这个参数后再打开:
114
+ # guidance_scale=float(guidance_scale),
115
+ )
116
+ output_img = out.images[0]
117
 
118
+ return output_img, int(seed)
119
 
120
+ # ---------------------------
121
+ # UI
122
+ # ---------------------------
123
  examples = [
124
  "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
125
  "An astronaut riding a green horse",
 
138
  gr.Markdown("# Text-to-Image Gradio Template")
139
 
140
  with gr.Row():
141
+ prompt = gr.Text(
142
+ label="Prompt",
143
+ show_label=False,
144
+ max_lines=1,
145
+ placeholder="Enter your prompt",
146
+ container=False,
147
+ )
148
  run_button = gr.Button("Run", scale=0, variant="primary")
149
 
150
+ # 两张输入图片(ref2 可空)
151
  with gr.Row():
152
+ ref1_comp = gr.Image(label="Input Image 1", type="pil")
153
+ ref2_comp = gr.Image(label="Input Image 2 (optional)", type="pil")
154
 
155
  result = gr.Image(label="Result", show_label=False)
156
 
157
  with gr.Accordion("Advanced Settings", open=False):
158
+ seed_comp = gr.Slider(
159
+ label="Seed",
160
+ minimum=0,
161
+ maximum=MAX_SEED,
162
+ step=1,
163
+ value=0,
164
+ )
165
+ randomize_seed_comp = gr.Checkbox(label="Randomize seed", value=True)
166
+
167
  with gr.Row():
168
+ width_comp = gr.Slider(
169
+ label="Width",
170
+ minimum=256,
171
+ maximum=MAX_IMAGE_SIZE,
172
+ step=32,
173
+ value=512,
174
+ )
175
+ height_comp = gr.Slider(
176
+ label="Height",
177
+ minimum=256,
178
+ maximum=MAX_IMAGE_SIZE,
179
+ step=32,
180
+ value=512,
181
+ )
182
+
183
  with gr.Row():
184
+ guidance_scale_comp = gr.Slider(
185
+ label="Guidance scale",
186
+ minimum=0.0,
187
+ maximum=10.0,
188
+ step=0.1,
189
+ value=2.5,
190
+ )
191
+ num_inference_steps_comp = gr.Slider(
192
+ label="Number of inference steps",
193
+ minimum=1,
194
+ maximum=50,
195
+ step=1,
196
+ value=28,
197
+ )
198
+
199
+ gr.Examples(examples=[[e] for e in examples], inputs=[prompt])
200
+
201
+ # 注意:不要把 [ref1, ref2] 当作列表传给 inputs!
202
+ gr.on(
203
+ triggers=[run_button.click, prompt.submit],
204
+ fn=infer,
205
+ inputs=[
206
+ prompt,
207
+ ref1_comp,
208
+ ref2_comp, # ref2 可为空
209
+ seed_comp,
210
+ randomize_seed_comp,
211
+ width_comp,
212
+ height_comp,
213
+ guidance_scale_comp,
214
+ num_inference_steps_comp,
215
+ ],
216
+ outputs=[result, seed_comp],
217
+ )
 
 
 
 
 
 
 
 
218
 
219
  if __name__ == "__main__":
220
+ demo.launch()