Multi_Ref_Edit / app.py
bo.l
update model
3ba81e2
raw
history blame
6.04 kB
import gradio as gr
import numpy as np
import random
import spaces #[uncomment to use ZeroGPU]
from kontext.pipeline_flux_kontext import FluxKontextPipeline
from kontext.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
from diffusers import FluxTransformer2DModel
import torch
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
def resize_by_bucket(images_pil, resolution=512):
assert len(images_pil) > 0, "images_pil 不能为空"
bucket_override = [
(336, 784), (344, 752), (360, 728), (376, 696),
(400, 664), (416, 624), (440, 592), (472, 552),
(512, 512),
(552, 472), (592, 440), (624, 416), (664, 400),
(696, 376), (728, 360), (752, 344), (784, 336),
]
bucket_override = [
(int(h / 512 * resolution), int(w / 512 * resolution))
for h, w in bucket_override
]
bucket_override = [
(h // 16 * 16, w // 16 * 16)
for h, w in bucket_override
]
aspect_ratios = [img.height / img.width for img in images_pil]
mean_aspect_ratio = float(np.mean(aspect_ratios))
new_h, new_w = bucket_override[0]
min_aspect_diff = abs(new_h / new_w - mean_aspect_ratio)
for h, w in bucket_override:
aspect_diff = abs(h / w - mean_aspect_ratio)
if aspect_diff < min_aspect_diff:
min_aspect_diff = aspect_diff
new_h, new_w = h, w
resized_images = [
img.resize((new_w, new_h), resample=Image.BICUBIC) for img in images_pil
]
return resized_images
device = "cuda" if torch.cuda.is_available() else "cpu"
flux_pipeline = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev")
flux_pipeline.scheduler = FlowMatchEulerDiscreteScheduler.from_config(flux_pipeline.scheduler.config)
flux_pipeline.vae.to(device).to(torch.bfloat16)
flux_pipeline.text_encoder.to(device).to(torch.bfloat16)
flux_pipeline.text_encoder_2.to(device).to(torch.bfloat16)
flux_pipeline.scheduler.config.stochastic_sampling = False
ckpt_path = hf_hub_download("NoobDoge/Multi_Ref_Model", "full_ema_model.safetensors")
new_weight = load_file(ckpt_path)
flux_pipeline.transformer.load_state_dict(new_weight)
flux_pipeline.transformer.to(device).to(torch.bfloat16)
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 512
@spaces.GPU #[uncomment to use ZeroGPU]
def infer(
prompt,
raw_images,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
progress=gr.Progress(track_tqdm=True),
):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
raw_images = [resize_by_bucket(x) for x in raw_images]
generator = torch.Generator().manual_seed(seed)
with torch.no_grad():
output_img = flux_pipeline(
image = raw_images,
prompt = prompts,
height = height,
width = width,
num_inference_steps = num_inference_steps,
max_area=MAX_IMAGE_SIZE**2,
generator=generator,
).images[0]
return image, seed
examples = [
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
"An astronaut riding a green horse",
"A delicious ceviche cheesecake slice",
]
css = """
#col-container {
margin: 0 auto;
max-width: 640px;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown("# Text-to-Image Gradio Template")
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
container=False,
)
run_button = gr.Button("Run", scale=0, variant="primary")
# 新增:两张输入图片
with gr.Row():
ref1 = gr.Image(label="Input Image 1", type="pil")
ref2 = gr.Image(label="Input Image 2", type="pil")
result = gr.Image(label="Result", show_label=False)
with gr.Accordion("Advanced Settings", open=False):
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=0.0,
maximum=10.0,
step=0.1,
value=0.0,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=50,
step=1,
value=2,
)
# 如果 examples 只包含文本 prompt,保持如下即可
examples = [
["a cute corgi in a wizard hat"],
["a watercolor painting of yosemite valley at sunrise"],
]
gr.Examples(examples=examples, inputs=[prompt])
gr.on(
triggers=[run_button.click, prompt.submit],
fn=infer,
inputs=[
prompt,
[ref1, ref2], # 新增:两张图
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
],
outputs=[result, seed],
)
if __name__ == "__main__":
demo.launch()