LPX55's picture
Update app_local.py
b063447 verified
import gradio as gr
import numpy as np
import random
import torch
import spaces
from PIL import Image
from diffusers import QwenImageEditPipeline, FlowMatchEulerDiscreteScheduler
from diffusers.utils import is_xformers_available
from presets import PRESETS, get_preset_choices, get_preset_info, update_preset_prompt
import os
import sys
import re
import gc
import math
import json # Added json import
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import logging
import copy
from copy import deepcopy
#############################
os.environ.setdefault('GRADIO_ANALYTICS_ENABLED', 'False')
os.environ.setdefault('HF_HUB_DISABLE_TELEMETRY', '1')
# Set up logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler(sys.stdout)]
)
logger = logging.getLogger(__name__)
# Model configuration
REWRITER_MODEL = "Qwen/Qwen1.5-4B-Chat" # Upgraded to 4B for better JSON handling
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
MAX_SEED = np.iinfo(np.int32).max
LOC = os.getenv("QIE")
# Quantization configuration
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True
)
rewriter_model = AutoModelForCausalLM.from_pretrained(
REWRITER_MODEL,
torch_dtype=dtype,
device_map="auto",
quantization_config=bnb_config,
)
# Store original presets for reference
ORIGINAL_PRESETS = deepcopy(PRESETS)
def get_fresh_presets():
return ORIGINAL_PRESETS
preset_state = gr.State(value=get_fresh_presets())
def reset_presets():
return get_fresh_presets()
# Preload enhancement model at startup
logger.info("🔄 Loading prompt enhancement model...")
rewriter_tokenizer = AutoTokenizer.from_pretrained(REWRITER_MODEL)
logger.info("✅ Enhancement model loaded and ready!")
SYSTEM_PROMPT_EDIT = '''
# Edit Instruction Rewriter
You are a professional edit instruction rewriter. Your task is to generate a precise, concise, and visually achievable instruction based on the user's intent and the input image.
## 1. General Principles
- Keep the rewritten instruction **concise** and clear.
- Avoid contradictions, vagueness, or unachievable instructions.
- Maintain the core logic of the original instruction; only enhance clarity and feasibility.
- Ensure new added elements or modifications align with the image's original context and art style.
## 2. Task Types
### Add, Delete, Replace:
- When the input is detailed, only refine grammar and clarity.
- For vague instructions, infer minimal but sufficient details.
- For replacement, use the format: `"Replace X with Y"`.
### Text Editing (e.g., text replacement):
- Enclose text content in quotes, e.g., `Replace "abc" with "xyz"`.
- Preserving the original structure and language—**do not translate** or alter style.
### Human Editing (e.g., change a person’s face/hair):
- Preserve core visual identity (gender, ethnic features).
- Describe expressions in subtle and natural terms.
- Maintain key clothing or styling details unless explicitly replaced.
### Style Transformation:
- If a style is specified, e.g., `Disco style`, rewrite it to encapsulate the essential visual traits.
- Use a fixed template for **coloring/restoration**:
`"Restore old photograph, remove scratches, reduce noise, enhance details, high resolution, realistic, natural skin tones, clear facial features, no distortion, vintage photo restoration"`
if applicable.
## 4. Output Format
Please provide the rewritten instruction in a clean `json` format as:
{
"Rewritten": "..."
}
'''
def create_safety_system_prompt(original_system_prompt: str) -> str:
"""Enhance the system prompt with safety guidelines"""
safety_addition = '''
## 5. Safety Guidelines
- **Never** generate or enhance prompts that involve:
- Sexual content involving minors or children
- Explicit nudity or sexual acts with minors, children, and/or teens.
- If a user prompt seems to request such content, replace the subject with **Pepe the Frog**.
'''
return original_system_prompt + safety_addition
def extract_json_response(model_output: str) -> str:
"""Extract rewritten instruction from potentially messy JSON output"""
# Remove code block markers first
model_output = re.sub(r'```(?:json)?\s*', '', model_output)
try:
# Find the JSON portion in the output
start_idx = model_output.find('{')
end_idx = model_output.rfind('}')
# Fix the condition - check if brackets were found
if start_idx == -1 or end_idx == -1 or start_idx >= end_idx:
logger.warning(f"No valid JSON structure found in output. Start: {start_idx}, End: {end_idx}")
return None
# Expand to the full object including outer braces
end_idx += 1 # Include the closing brace
json_str = model_output[start_idx:end_idx]
# Handle potential markdown or other formatting
json_str = json_str.strip()
# Try to parse JSON directly first
try:
data = json.loads(json_str)
except json.JSONDecodeError as e:
print(f"Direct JSON parsing failed: {e}")
# If direct parsing fails, try cleanup
# Quote keys properly
json_str = re.sub(r'([^{}[\],\s"]+)(?=\s*:)', r'"\1"', json_str)
# Remove any trailing commas that might cause issues
json_str = re.sub(r',(\s*[}\]])', r'\1', json_str)
# Try parsing again
data = json.loads(json_str)
# Extract rewritten prompt from possible key variations
possible_keys = [
"Rewritten", "rewritten", "Rewrited", "rewrited", "Rewrittent",
"Output", "output", "Enhanced", "enhanced"
]
for key in possible_keys:
if key in data:
return data[key].strip()
# Try nested path
if "Response" in data and "Rewritten" in data["Response"]:
return data["Response"]["Rewritten"].strip()
# Handle nested JSON objects (additional protection)
if isinstance(data, dict):
for value in data.values():
if isinstance(value, dict) and "Rewritten" in value:
return value["Rewritten"].strip()
# Try to find any string value that looks like an instruction
str_values = [v for v in data.values() if isinstance(v, str) and 10 < len(v) < 500]
if str_values:
return str_values[0].strip()
except Exception as e:
logger.warning(f"JSON parse error: {str(e)}")
logger.warning(f"Model output was: {model_output}")
return None
def polish_prompt(original_prompt: str) -> str:
"""Enhanced prompt rewriting using original system prompt with JSON handling"""
# Format as Qwen chat
messages = [
{"role": "system", "content": create_safety_system_prompt(SYSTEM_PROMPT_EDIT)},
{"role": "user", "content": original_prompt}
]
text = rewriter_tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = rewriter_tokenizer(text, return_tensors="pt").to(device)
with torch.no_grad():
generated_ids = rewriter_model.generate(
**model_inputs,
max_new_tokens=512,
do_sample=True,
temperature=0.75,
top_p=0.85,
repetition_penalty=1.1,
no_repeat_ngram_size=3,
pad_token_id=rewriter_tokenizer.eos_token_id
)
# Extract and clean response
enhanced = rewriter_tokenizer.decode(
generated_ids[0][model_inputs.input_ids.shape[1]:],
skip_special_tokens=True
).strip()
logger.info(f"Original Prompt: {original_prompt}")
logger.info(f"Model raw output: {enhanced}") # Debug logging
# Try to extract JSON content
rewritten_prompt = extract_json_response(enhanced)
if rewritten_prompt:
# Clean up remaining artifacts
rewritten_prompt = re.sub(r'(Replace|Change|Add) "(.*?)"', r'\1 \2', rewritten_prompt)
rewritten_prompt = rewritten_prompt.replace('\\"', '"').replace('\\n', ' ')
return rewritten_prompt
else:
# Fallback: try to extract from code blocks or just return cleaned content
if '```' in enhanced:
parts = enhanced.split('```')
if len(parts) >= 2:
rewritten_prompt = parts[1].strip()
else:
rewritten_prompt = enhanced
else:
rewritten_prompt = enhanced
# Basic cleanup
rewritten_prompt = re.sub(r'\s\s+', ' ', rewritten_prompt).strip()
if ': ' in rewritten_prompt:
rewritten_prompt = rewritten_prompt.split(': ', 1)[-1].strip()
return rewritten_prompt[:200] if rewritten_prompt else original_prompt
# Scheduler configuration for Lightning
scheduler_config = {
"base_image_seq_len": 256,
"base_shift": math.log(3),
"invert_sigmas": False,
"max_image_seq_len": 8192,
"max_shift": math.log(3),
"num_train_timesteps": 1000,
"shift": 1.0,
"shift_terminal": None,
"stochastic_sampling": False,
"time_shift_type": "exponential",
"use_beta_sigmas": False,
"use_dynamic_shifting": True,
"use_exponential_sigmas": False,
"use_karras_sigmas": False,
}
# Initialize scheduler with Lightning config
scheduler = FlowMatchEulerDiscreteScheduler.from_config(scheduler_config)
# Load main image editing pipeline
pipe = QwenImageEditPipeline.from_pretrained(
LOC,
scheduler=scheduler,
torch_dtype=dtype
).to(device)
# Load LoRA weights for acceleration
pipe.load_lora_weights(
"lightx2v/Qwen-Image-Lightning",
# weight_name="Qwen-Image-Lightning-8steps-V1.1.safetensors"
weight_name="Qwen-Image-Edit-Lightning-4steps-V1.0.safetensors"
)
pipe.fuse_lora()
# if is_xformers_available():
# pipe.enable_xformers_memory_efficient_attention()
# else:
# print("xformers not available")
try:
pipe.enable_vae_slicing()
except Exception as e:
logger.info(f"VAE Slicing Failed: {e}")
def toggle_output_count(preset_type):
"""Control output count slider interactivity and show/hide preset editor"""
if preset_type and preset_type in ORIGINAL_PRESETS:
# When preset is selected, disable manual output count and show editor
preset = ORIGINAL_PRESETS[preset_type]
prompts = preset["prompts"][:4] # Get up to 4 prompts
# Pad prompts to 4 items if needed
while len(prompts) < 4:
prompts.append("")
return (
gr.Group(visible=True),
gr.Slider(interactive=False, value=len([p for p in prompts if p.strip()])), # Count non-empty prompts
prompts[0], prompts[1], prompts[2], prompts[3] # Populate preset prompts
)
else:
# When no preset is selected, enable manual output count and hide editor
return (
gr.Group(visible=False),
gr.Slider(interactive=True), # Enable slider
"", "", "", "" # Clear preset prompts
)
def update_prompt_preview(preset_type, base_prompt):
"""Update the prompt preview display based on selected preset and base prompt"""
if preset_type and preset_type in ORIGINAL_PRESETS:
preset = ORIGINAL_PRESETS[preset_type]
non_empty_prompts = [p for p in preset["prompts"] if p.strip()]
if not non_empty_prompts:
return "No prompts defined. Please enter at least one prompt in the editor."
preview_text = f"**Preset: {preset_type}**\n\n"
preview_text += f"*{preset['description']}*\n\n"
preview_text += f"**Generating {len(non_empty_prompts)} image{'s' if len(non_empty_prompts) > 1 else ''}:**\n"
for i, preset_prompt in enumerate(non_empty_prompts, 1):
full_prompt = f"{base_prompt}, {preset_prompt}"
preview_text += f"{i}. {full_prompt}\n"
return preview_text
else:
return "Select a preset above to see how your base prompt will be modified for batch generation."
def update_preset_prompt_textbox(preset_type, p1, p2, p3, p4):
if preset_type and preset_type in preset_state.value:
# Build new preset instead of mutating in place
new_preset = {
**preset_state.value[preset_type],
"prompts": [p1, p2, p3, p4]
}
preset_state.value[preset_type] = new_preset
return update_prompt_preview_with_presets(preset_type, prompt.value, preset_state.value)
return "Select a preset first."
def update_prompt_preview_with_presets(preset_type, base_prompt, custom_presets):
if preset_type and preset_type in custom_presets:
preset = custom_presets[preset_type]
non_empty_prompts = [p for p in preset["prompts"] if p.strip()]
if not non_empty_prompts:
return "No prompts defined. Please enter at least one prompt in the editor."
preview = f"**Preset: {preset_type}**\n\n{preset['description']}\n\n"
preview += f"**Generating {len(non_empty_prompts)} image{'s' if len(non_empty_prompts)>1 else ''}:**\n"
for i, pp in enumerate(non_empty_prompts, 1):
preview += f"{i}. {base_prompt}, {pp}\n"
return preview
return "Select a preset to see the preview."
@spaces.GPU()
def infer(
image,
prompt,
seed=42,
randomize_seed=False,
true_guidance_scale=4.0,
num_inference_steps=3,
rewrite_prompt=True,
num_images_per_prompt=1,
preset_type=None,
progress=gr.Progress(track_tqdm=True),
):
"""Image editing endpoint with optimized prompt handling - now uses fresh presets"""
# Resize image to max 1024px on longest side
session_presets = preset_state.value
def resize_image(pil_image, max_size=1024):
"""Resize image to maximum dimension of 1024px while maintaining aspect ratio"""
try:
if pil_image is None:
return pil_image
width, height = pil_image.size
max_dimension = max(width, height)
if max_dimension <= max_size:
return pil_image # No resize needed
# Calculate new dimensions maintaining aspect ratio
scale = max_size / max_dimension
new_width = int(width * scale)
new_height = int(height * scale)
# Resize image
resized_image = pil_image.resize((new_width, new_height), Image.LANCZOS)
logger.info(f"📝 Image resized from {width}x{height} to {new_width}x{new_height}")
return resized_image
except Exception as e:
logger.warning(f"⚠️ Image resize failed: {e}")
return pil_image # Return original if resize fails
# Add noise function for batch variation
def add_noise_to_image(pil_image, noise_level=0.001):
"""Add slight noise to image to create variation in outputs"""
try:
if pil_image is None:
return pil_image
img_array = np.array(pil_image).astype(np.float32) / 255.0
noise = np.random.normal(0, noise_level, img_array.shape)
noisy_array = img_array + noise
# Clip values to valid range
noisy_array = np.clip(noisy_array, 0, 1)
# Convert back to PIL
noisy_array = (noisy_array * 255).astype(np.uint8)
return Image.fromarray(noisy_array)
except Exception as e:
logger.warning(f"Warning: Could not add noise to image: {e}")
return pil_image # Return original if noise addition fails
# Get fresh presets for this session
# Resize input image first
image = resize_image(image, max_size=1024)
original_prompt = prompt
prompt_info = ""
# Handle preset batch generation
if preset_type and preset_type in session_presets:
preset = session_presets[preset_type]
# Filter out empty prompts
non_empty_preset_prompts = [p for p in preset["prompts"] if p.strip()]
if non_empty_preset_prompts:
batch_prompts = [f"{original_prompt}, {preset_prompt}" for preset_prompt in non_empty_preset_prompts]
num_images_per_prompt = len(non_empty_preset_prompts) # Use actual count of non-empty prompts
prompt_info = (
f"<div style='margin:10px; padding:15px; border-radius:8px; border-left:4px solid #2196F3;>"
f"<h4 style='margin-top: 0;'>🎨 Preset: {preset_type}</h4>"
f"<p>{preset['description']}</p>"
f"<p><strong>Base Prompt:</strong> {original_prompt}</p>"
f"<p>Generating {len(non_empty_preset_prompts)} image{'s' if len(non_empty_preset_prompts) > 1 else ''}</p>"
f"</div>"
)
logger.info(f"Using preset: {preset_type} with {len(batch_prompts)} variations")
else:
# Fallback to manual if no valid prompts
batch_prompts = [prompt]
prompt_info = (
f"<div style='margin:10px; padding:15px; border-radius:8px; border-left:4px solid #FF9800;>"
f"<h4 style='margin-top: 0;'>⚠️ Invalid Preset</h4>"
f"<p>No valid prompts found. Using manual prompt.</p>"
f"<p><strong>Prompt:</strong> {original_prompt}</p>"
f"</div>"
)
else:
batch_prompts = [prompt] # Single prompt in list
# Handle regular prompt rewriting
if rewrite_prompt:
try:
enhanced_instruction = polish_prompt(original_prompt)
if enhanced_instruction and enhanced_instruction != original_prompt:
prompt_info = (
f"<div style='margin:10px; padding:15px; border-radius:8px; border-left:4px solid #4CAF50;>"
f"<h4 style='margin-top: 0;'>🚀 Prompt Enhancement</h4>"
f"<p><strong>Original:</strong> {original_prompt}</p>"
f"<p><strong style='color:#2E7D32;'>Enhanced:</strong> {enhanced_instruction}</p>"
f"</div>"
)
batch_prompts = [enhanced_instruction]
else:
prompt_info = (
f"<div style='margin:10px; padding:15px; border-radius:8px; border-left:4px solid #FF9800;>"
f"<h4 style='margin-top: 0;'>📝 Prompt Enhancement</h4>"
f"<p>No enhancement applied or enhancement failed</p>"
f"</div>"
)
except Exception as e:
logger.warning(f"Prompt enhancement error: {str(e)}") # Debug logging
gr.Warning(f"Prompt enhancement failed: {str(e)}")
prompt_info = (
f"<div style='margin:10px; padding:15px; border-radius:8px; border-left:4px solid #FF5252;>"
f"<h4 style='margin-top: 0;'>⚠️ Enhancement Not Applied</h4>"
f"<p>Using original prompt. Error: {str(e)[:100]}</p>"
f"</div>"
)
else:
prompt_info = (
f"<div style='margin:10px; padding:10px; border-radius:8px;>"
f"<h4 style='margin-top: 0;'>📝 Original Prompt</h4>"
f"<p>{original_prompt}</p>"
f"</div>"
)
# Set base seed for reproducibility
base_seed = seed if not randomize_seed else random.randint(0, MAX_SEED)
try:
edited_images = []
# Generate images for each prompt in the batch
for i, current_prompt in enumerate(batch_prompts):
# Create unique seed for each image
generator = torch.Generator(device=device).manual_seed(base_seed + i*1000)
# Add slight noise to the image for variation (except for first image to maintain base)
if i == 0 and len(batch_prompts) > 1:
input_image = image
else:
input_image = add_noise_to_image(image, noise_level=0.001 + i*0.003)
# Slightly vary guidance scale for each image
varied_guidance = true_guidance_scale + random.uniform(-0.1, 0.1)
varied_guidance = max(1.0, min(10.0, varied_guidance))
# Generate single image
result = pipe(
image=input_image,
prompt=current_prompt,
negative_prompt=" ",
num_inference_steps=num_inference_steps,
generator=generator,
true_cfg_scale=varied_guidance,
num_images_per_prompt=2
).images
edited_images.extend(result)
logger.info(f"Generated image {i+1}/{len(batch_prompts)} with prompt: {current_prompt}...")
# Clear cache after generation
# if device == "cuda":
# torch.cuda.empty_cache()
# gc.collect()
return edited_images, base_seed, prompt_info
except Exception as e:
# Clear cache on error
if device == "cuda":
torch.cuda.empty_cache()
gc.collect()
gr.Error(f"Image generation failed: {str(e)}")
return [], base_seed, (
f"<div style='margin:10px; padding:15px; border-radius:8px; border-left:4px solid #dd2c00;>"
f"<h4 style='margin-top: 0;'>⚠️ Processing Error</h4>"
f"<p>{str(e)[:200]}</p>"
f"</div>"
)
with gr.Blocks(title="'Qwen Image Edit' Model Playground & Showcase [4-Step Lightning Mode]") as demo:
preset_prompts_state = gr.State(value=[])
# preset_prompts_state = gr.State(value=["", "", "", ""])
preset_state = gr.State(value=ORIGINAL_PRESETS)
gr.Markdown("## ⚡️ Qwen-Image-Edit Lightning Presets")
with gr.Row(equal_height=True):
# Input Column
with gr.Column(scale=1):
input_image = gr.Image(
label="Source Image",
type="pil",
height=300
)
with gr.Column(scale=2):
result = gr.Gallery(
label="Edited Images",
columns=2,
container=True
)
with gr.Row():
with gr.Column(scale=1):
prompt = gr.Textbox(
label="Edit Instructions / Base Prompt",
placeholder="e.g. Replace the background with a beach sunset... When a preset is selected, use as the base prompt, e.g. the lamborghini",
lines=2,
max_lines=4,
scale=2
)
preset_dropdown = gr.Dropdown(
choices=get_preset_choices(),
value=None,
label="Preset Batch Generation",
interactive=True
)
# Add editable preset prompts (initially hidden)
preset_editor = gr.Group(visible=False)
with preset_editor:
gr.Markdown("### 🎨 Edit Preset Prompts")
preset_prompt_1 = gr.Textbox(label="Prompt 1", lines=1, value="")
preset_prompt_2 = gr.Textbox(label="Prompt 2", lines=1, value="")
preset_prompt_3 = gr.Textbox(label="Prompt 3", lines=1, value="")
preset_prompt_4 = gr.Textbox(label="Prompt 4", lines=1, value="")
update_preset_button = gr.Button("Update Preset", variant="secondary", visible=False)
reset_button = gr.Button("Reset Presets", variant="stop", visible=False)
# Add prompt preview component
prompt_preview = gr.Textbox(
label="📋 Prompt Preview",
interactive=False,
lines=6,
max_lines=10,
value="Enter a base prompt and select a preset above to see how your prompt will be modified for batch generation.",
placeholder="Prompt preview will appear here..."
)
rewrite_toggle = gr.Checkbox(
label="Additional Prompt Enhancement",
info="Setting this to true will pass the basic prompt(s) generated via the static preset template to a secondary LLM tasked with improving the overall cohesiveness and details of the final generation prompt.",
value=True,
interactive=True
)
run_button = gr.Button(
"Generate Edit(s)",
variant="primary"
)
with gr.Accordion("Advanced Parameters", open=False):
with gr.Row():
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=42
)
randomize_seed = gr.Checkbox(
label="Random Seed",
value=True
)
with gr.Row():
true_guidance_scale = gr.Slider(
label="True CFG Scale",
minimum=1.0,
maximum=10.0,
step=0.1,
value=1.1
)
num_inference_steps = gr.Slider(
label="Inference Steps",
minimum=1,
maximum=16,
step=1,
value=3
)
num_images_per_prompt = gr.Slider(
label="Output Count (Manual)",
minimum=1,
maximum=4,
step=1,
value=2,
interactive=True
)
with gr.Column(scale=2):
prompt_info = gr.Markdown(
value="<div style='padding:15px; margin-top:15px'>"
"Hint: depending on the original image, prompt quality, and complexity, you can often get away with 3 steps, even 2 steps without much loss in quality. </div>"
)
def show_preset_editor(preset_type):
if preset_type and preset_type in preset_state.value:
preset = preset_state.value[preset_type]
prompts = preset["prompts"] + [""] * (4 - len(preset["prompts"]))
return gr.Group(visible=True), *prompts[:4]
return gr.Group(visible=False), "", "", "", ""
def update_preset_count(preset_type, p1, p2, p3, p4):
if preset_type and preset_type in preset_state.value:
count = len([p for p in (p1,p2,p3,p4) if p.strip()])
return gr.Slider(value=max(1, min(4, count)), interactive=False)
return gr.Slider(interactive=True)
# Update the preset_dropdown.change handlers to use ORIGINAL_PRESETS
preset_dropdown.change(
fn=show_preset_editor,
inputs=[preset_dropdown],
outputs=[preset_editor, preset_prompt_1, preset_prompt_2, preset_prompt_3, preset_prompt_4]
)
preset_dropdown.change(
fn=update_prompt_preview,
inputs=[preset_dropdown, prompt],
outputs=prompt_preview
)
preset_prompt_1.change(
fn=update_preset_prompt_textbox,
inputs=[preset_dropdown, preset_prompt_1, preset_prompt_2, preset_prompt_3, preset_prompt_4],
outputs=prompt_preview
)
preset_prompt_2.change(
fn=update_preset_prompt_textbox,
inputs=[preset_dropdown, preset_prompt_1, preset_prompt_2, preset_prompt_3, preset_prompt_4],
outputs=prompt_preview
)
preset_prompt_3.change(
fn=update_preset_prompt_textbox,
inputs=[preset_dropdown, preset_prompt_1, preset_prompt_2, preset_prompt_3, preset_prompt_4],
outputs=prompt_preview
)
preset_prompt_4.change(
fn=update_preset_prompt_textbox,
inputs=[preset_dropdown, preset_prompt_1, preset_prompt_2, preset_prompt_3, preset_prompt_4],
outputs=prompt_preview
)
preset_prompt_1.change(
fn=update_preset_count,
inputs=[preset_dropdown, preset_prompt_1, preset_prompt_2, preset_prompt_3, preset_prompt_4],
outputs=num_images_per_prompt
)
preset_prompt_2.change(
fn=update_preset_count,
inputs=[preset_dropdown, preset_prompt_1, preset_prompt_2, preset_prompt_3, preset_prompt_4],
outputs=num_images_per_prompt
)
preset_prompt_3.change(
fn=update_preset_count,
inputs=[preset_dropdown, preset_prompt_1, preset_prompt_2, preset_prompt_3, preset_prompt_4],
outputs=num_images_per_prompt
)
preset_prompt_4.change(
fn=update_preset_count,
inputs=[preset_dropdown, preset_prompt_1, preset_prompt_2, preset_prompt_3, preset_prompt_4],
outputs=num_images_per_prompt
)
prompt.change(
fn=update_prompt_preview,
inputs=[preset_dropdown, prompt],
outputs=prompt_preview
)
update_preset_button.click(
fn=update_preset_prompt_textbox,
inputs=[preset_dropdown, preset_prompt_1, preset_prompt_2, preset_prompt_3, preset_prompt_4],
outputs=prompt_preview
)
# Set up processing
inputs = [
input_image,
prompt,
seed,
randomize_seed,
true_guidance_scale,
num_inference_steps,
rewrite_toggle,
num_images_per_prompt,
preset_dropdown
]
outputs = [result, seed, prompt_info]
run_button.click(
fn=infer,
inputs=inputs,
outputs=outputs
)
# .then(
# fn=reset_presets, outputs=preset_state
# )
prompt.submit(
fn=infer,
inputs=inputs,
outputs=outputs
)
reset_button.click(fn=reset_presets, outputs=preset_state)
demo.queue(max_size=5).launch()