Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import numpy as np | |
| import random | |
| import torch | |
| import spaces | |
| from PIL import Image | |
| from diffusers import QwenImageEditPipeline, FlowMatchEulerDiscreteScheduler | |
| from diffusers.utils import is_xformers_available | |
| from presets import PRESETS, get_preset_choices, get_preset_info, update_preset_prompt | |
| import os | |
| import sys | |
| import re | |
| import gc | |
| import math | |
| import json # Added json import | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig | |
| import logging | |
| import copy | |
| from copy import deepcopy | |
| ############################# | |
| os.environ.setdefault('GRADIO_ANALYTICS_ENABLED', 'False') | |
| os.environ.setdefault('HF_HUB_DISABLE_TELEMETRY', '1') | |
| # Set up logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(levelname)s - %(message)s', | |
| handlers=[logging.StreamHandler(sys.stdout)] | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Model configuration | |
| REWRITER_MODEL = "Qwen/Qwen1.5-4B-Chat" # Upgraded to 4B for better JSON handling | |
| dtype = torch.bfloat16 | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| MAX_SEED = np.iinfo(np.int32).max | |
| LOC = os.getenv("QIE") | |
| # Quantization configuration | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_use_double_quant=True | |
| ) | |
| rewriter_model = AutoModelForCausalLM.from_pretrained( | |
| REWRITER_MODEL, | |
| torch_dtype=dtype, | |
| device_map="auto", | |
| quantization_config=bnb_config, | |
| ) | |
| # Store original presets for reference | |
| ORIGINAL_PRESETS = deepcopy(PRESETS) | |
| def get_fresh_presets(): | |
| return ORIGINAL_PRESETS | |
| preset_state = gr.State(value=get_fresh_presets()) | |
| def reset_presets(): | |
| return get_fresh_presets() | |
| # Preload enhancement model at startup | |
| logger.info("🔄 Loading prompt enhancement model...") | |
| rewriter_tokenizer = AutoTokenizer.from_pretrained(REWRITER_MODEL) | |
| logger.info("✅ Enhancement model loaded and ready!") | |
| SYSTEM_PROMPT_EDIT = ''' | |
| # Edit Instruction Rewriter | |
| You are a professional edit instruction rewriter. Your task is to generate a precise, concise, and visually achievable instruction based on the user's intent and the input image. | |
| ## 1. General Principles | |
| - Keep the rewritten instruction **concise** and clear. | |
| - Avoid contradictions, vagueness, or unachievable instructions. | |
| - Maintain the core logic of the original instruction; only enhance clarity and feasibility. | |
| - Ensure new added elements or modifications align with the image's original context and art style. | |
| ## 2. Task Types | |
| ### Add, Delete, Replace: | |
| - When the input is detailed, only refine grammar and clarity. | |
| - For vague instructions, infer minimal but sufficient details. | |
| - For replacement, use the format: `"Replace X with Y"`. | |
| ### Text Editing (e.g., text replacement): | |
| - Enclose text content in quotes, e.g., `Replace "abc" with "xyz"`. | |
| - Preserving the original structure and language—**do not translate** or alter style. | |
| ### Human Editing (e.g., change a person’s face/hair): | |
| - Preserve core visual identity (gender, ethnic features). | |
| - Describe expressions in subtle and natural terms. | |
| - Maintain key clothing or styling details unless explicitly replaced. | |
| ### Style Transformation: | |
| - If a style is specified, e.g., `Disco style`, rewrite it to encapsulate the essential visual traits. | |
| - Use a fixed template for **coloring/restoration**: | |
| `"Restore old photograph, remove scratches, reduce noise, enhance details, high resolution, realistic, natural skin tones, clear facial features, no distortion, vintage photo restoration"` | |
| if applicable. | |
| ## 4. Output Format | |
| Please provide the rewritten instruction in a clean `json` format as: | |
| { | |
| "Rewritten": "..." | |
| } | |
| ''' | |
| def create_safety_system_prompt(original_system_prompt: str) -> str: | |
| """Enhance the system prompt with safety guidelines""" | |
| safety_addition = ''' | |
| ## 5. Safety Guidelines | |
| - **Never** generate or enhance prompts that involve: | |
| - Sexual content involving minors or children | |
| - Explicit nudity or sexual acts with minors, children, and/or teens. | |
| - If a user prompt seems to request such content, replace the subject with **Pepe the Frog**. | |
| ''' | |
| return original_system_prompt + safety_addition | |
| def extract_json_response(model_output: str) -> str: | |
| """Extract rewritten instruction from potentially messy JSON output""" | |
| # Remove code block markers first | |
| model_output = re.sub(r'```(?:json)?\s*', '', model_output) | |
| try: | |
| # Find the JSON portion in the output | |
| start_idx = model_output.find('{') | |
| end_idx = model_output.rfind('}') | |
| # Fix the condition - check if brackets were found | |
| if start_idx == -1 or end_idx == -1 or start_idx >= end_idx: | |
| logger.warning(f"No valid JSON structure found in output. Start: {start_idx}, End: {end_idx}") | |
| return None | |
| # Expand to the full object including outer braces | |
| end_idx += 1 # Include the closing brace | |
| json_str = model_output[start_idx:end_idx] | |
| # Handle potential markdown or other formatting | |
| json_str = json_str.strip() | |
| # Try to parse JSON directly first | |
| try: | |
| data = json.loads(json_str) | |
| except json.JSONDecodeError as e: | |
| print(f"Direct JSON parsing failed: {e}") | |
| # If direct parsing fails, try cleanup | |
| # Quote keys properly | |
| json_str = re.sub(r'([^{}[\],\s"]+)(?=\s*:)', r'"\1"', json_str) | |
| # Remove any trailing commas that might cause issues | |
| json_str = re.sub(r',(\s*[}\]])', r'\1', json_str) | |
| # Try parsing again | |
| data = json.loads(json_str) | |
| # Extract rewritten prompt from possible key variations | |
| possible_keys = [ | |
| "Rewritten", "rewritten", "Rewrited", "rewrited", "Rewrittent", | |
| "Output", "output", "Enhanced", "enhanced" | |
| ] | |
| for key in possible_keys: | |
| if key in data: | |
| return data[key].strip() | |
| # Try nested path | |
| if "Response" in data and "Rewritten" in data["Response"]: | |
| return data["Response"]["Rewritten"].strip() | |
| # Handle nested JSON objects (additional protection) | |
| if isinstance(data, dict): | |
| for value in data.values(): | |
| if isinstance(value, dict) and "Rewritten" in value: | |
| return value["Rewritten"].strip() | |
| # Try to find any string value that looks like an instruction | |
| str_values = [v for v in data.values() if isinstance(v, str) and 10 < len(v) < 500] | |
| if str_values: | |
| return str_values[0].strip() | |
| except Exception as e: | |
| logger.warning(f"JSON parse error: {str(e)}") | |
| logger.warning(f"Model output was: {model_output}") | |
| return None | |
| def polish_prompt(original_prompt: str) -> str: | |
| """Enhanced prompt rewriting using original system prompt with JSON handling""" | |
| # Format as Qwen chat | |
| messages = [ | |
| {"role": "system", "content": create_safety_system_prompt(SYSTEM_PROMPT_EDIT)}, | |
| {"role": "user", "content": original_prompt} | |
| ] | |
| text = rewriter_tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| model_inputs = rewriter_tokenizer(text, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| generated_ids = rewriter_model.generate( | |
| **model_inputs, | |
| max_new_tokens=512, | |
| do_sample=True, | |
| temperature=0.75, | |
| top_p=0.85, | |
| repetition_penalty=1.1, | |
| no_repeat_ngram_size=3, | |
| pad_token_id=rewriter_tokenizer.eos_token_id | |
| ) | |
| # Extract and clean response | |
| enhanced = rewriter_tokenizer.decode( | |
| generated_ids[0][model_inputs.input_ids.shape[1]:], | |
| skip_special_tokens=True | |
| ).strip() | |
| logger.info(f"Original Prompt: {original_prompt}") | |
| logger.info(f"Model raw output: {enhanced}") # Debug logging | |
| # Try to extract JSON content | |
| rewritten_prompt = extract_json_response(enhanced) | |
| if rewritten_prompt: | |
| # Clean up remaining artifacts | |
| rewritten_prompt = re.sub(r'(Replace|Change|Add) "(.*?)"', r'\1 \2', rewritten_prompt) | |
| rewritten_prompt = rewritten_prompt.replace('\\"', '"').replace('\\n', ' ') | |
| return rewritten_prompt | |
| else: | |
| # Fallback: try to extract from code blocks or just return cleaned content | |
| if '```' in enhanced: | |
| parts = enhanced.split('```') | |
| if len(parts) >= 2: | |
| rewritten_prompt = parts[1].strip() | |
| else: | |
| rewritten_prompt = enhanced | |
| else: | |
| rewritten_prompt = enhanced | |
| # Basic cleanup | |
| rewritten_prompt = re.sub(r'\s\s+', ' ', rewritten_prompt).strip() | |
| if ': ' in rewritten_prompt: | |
| rewritten_prompt = rewritten_prompt.split(': ', 1)[-1].strip() | |
| return rewritten_prompt[:200] if rewritten_prompt else original_prompt | |
| # Scheduler configuration for Lightning | |
| scheduler_config = { | |
| "base_image_seq_len": 256, | |
| "base_shift": math.log(3), | |
| "invert_sigmas": False, | |
| "max_image_seq_len": 8192, | |
| "max_shift": math.log(3), | |
| "num_train_timesteps": 1000, | |
| "shift": 1.0, | |
| "shift_terminal": None, | |
| "stochastic_sampling": False, | |
| "time_shift_type": "exponential", | |
| "use_beta_sigmas": False, | |
| "use_dynamic_shifting": True, | |
| "use_exponential_sigmas": False, | |
| "use_karras_sigmas": False, | |
| } | |
| # Initialize scheduler with Lightning config | |
| scheduler = FlowMatchEulerDiscreteScheduler.from_config(scheduler_config) | |
| # Load main image editing pipeline | |
| pipe = QwenImageEditPipeline.from_pretrained( | |
| LOC, | |
| scheduler=scheduler, | |
| torch_dtype=dtype | |
| ).to(device) | |
| # Load LoRA weights for acceleration | |
| pipe.load_lora_weights( | |
| "lightx2v/Qwen-Image-Lightning", | |
| # weight_name="Qwen-Image-Lightning-8steps-V1.1.safetensors" | |
| weight_name="Qwen-Image-Edit-Lightning-4steps-V1.0.safetensors" | |
| ) | |
| pipe.fuse_lora() | |
| # if is_xformers_available(): | |
| # pipe.enable_xformers_memory_efficient_attention() | |
| # else: | |
| # print("xformers not available") | |
| try: | |
| pipe.enable_vae_slicing() | |
| except Exception as e: | |
| logger.info(f"VAE Slicing Failed: {e}") | |
| def toggle_output_count(preset_type): | |
| """Control output count slider interactivity and show/hide preset editor""" | |
| if preset_type and preset_type in ORIGINAL_PRESETS: | |
| # When preset is selected, disable manual output count and show editor | |
| preset = ORIGINAL_PRESETS[preset_type] | |
| prompts = preset["prompts"][:4] # Get up to 4 prompts | |
| # Pad prompts to 4 items if needed | |
| while len(prompts) < 4: | |
| prompts.append("") | |
| return ( | |
| gr.Group(visible=True), | |
| gr.Slider(interactive=False, value=len([p for p in prompts if p.strip()])), # Count non-empty prompts | |
| prompts[0], prompts[1], prompts[2], prompts[3] # Populate preset prompts | |
| ) | |
| else: | |
| # When no preset is selected, enable manual output count and hide editor | |
| return ( | |
| gr.Group(visible=False), | |
| gr.Slider(interactive=True), # Enable slider | |
| "", "", "", "" # Clear preset prompts | |
| ) | |
| def update_prompt_preview(preset_type, base_prompt): | |
| """Update the prompt preview display based on selected preset and base prompt""" | |
| if preset_type and preset_type in ORIGINAL_PRESETS: | |
| preset = ORIGINAL_PRESETS[preset_type] | |
| non_empty_prompts = [p for p in preset["prompts"] if p.strip()] | |
| if not non_empty_prompts: | |
| return "No prompts defined. Please enter at least one prompt in the editor." | |
| preview_text = f"**Preset: {preset_type}**\n\n" | |
| preview_text += f"*{preset['description']}*\n\n" | |
| preview_text += f"**Generating {len(non_empty_prompts)} image{'s' if len(non_empty_prompts) > 1 else ''}:**\n" | |
| for i, preset_prompt in enumerate(non_empty_prompts, 1): | |
| full_prompt = f"{base_prompt}, {preset_prompt}" | |
| preview_text += f"{i}. {full_prompt}\n" | |
| return preview_text | |
| else: | |
| return "Select a preset above to see how your base prompt will be modified for batch generation." | |
| def update_preset_prompt_textbox(preset_type, p1, p2, p3, p4): | |
| if preset_type and preset_type in preset_state.value: | |
| # Build new preset instead of mutating in place | |
| new_preset = { | |
| **preset_state.value[preset_type], | |
| "prompts": [p1, p2, p3, p4] | |
| } | |
| preset_state.value[preset_type] = new_preset | |
| return update_prompt_preview_with_presets(preset_type, prompt.value, preset_state.value) | |
| return "Select a preset first." | |
| def update_prompt_preview_with_presets(preset_type, base_prompt, custom_presets): | |
| if preset_type and preset_type in custom_presets: | |
| preset = custom_presets[preset_type] | |
| non_empty_prompts = [p for p in preset["prompts"] if p.strip()] | |
| if not non_empty_prompts: | |
| return "No prompts defined. Please enter at least one prompt in the editor." | |
| preview = f"**Preset: {preset_type}**\n\n{preset['description']}\n\n" | |
| preview += f"**Generating {len(non_empty_prompts)} image{'s' if len(non_empty_prompts)>1 else ''}:**\n" | |
| for i, pp in enumerate(non_empty_prompts, 1): | |
| preview += f"{i}. {base_prompt}, {pp}\n" | |
| return preview | |
| return "Select a preset to see the preview." | |
| def infer( | |
| image, | |
| prompt, | |
| seed=42, | |
| randomize_seed=False, | |
| true_guidance_scale=4.0, | |
| num_inference_steps=3, | |
| rewrite_prompt=True, | |
| num_images_per_prompt=1, | |
| preset_type=None, | |
| progress=gr.Progress(track_tqdm=True), | |
| ): | |
| """Image editing endpoint with optimized prompt handling - now uses fresh presets""" | |
| # Resize image to max 1024px on longest side | |
| session_presets = preset_state.value | |
| def resize_image(pil_image, max_size=1024): | |
| """Resize image to maximum dimension of 1024px while maintaining aspect ratio""" | |
| try: | |
| if pil_image is None: | |
| return pil_image | |
| width, height = pil_image.size | |
| max_dimension = max(width, height) | |
| if max_dimension <= max_size: | |
| return pil_image # No resize needed | |
| # Calculate new dimensions maintaining aspect ratio | |
| scale = max_size / max_dimension | |
| new_width = int(width * scale) | |
| new_height = int(height * scale) | |
| # Resize image | |
| resized_image = pil_image.resize((new_width, new_height), Image.LANCZOS) | |
| logger.info(f"📝 Image resized from {width}x{height} to {new_width}x{new_height}") | |
| return resized_image | |
| except Exception as e: | |
| logger.warning(f"⚠️ Image resize failed: {e}") | |
| return pil_image # Return original if resize fails | |
| # Add noise function for batch variation | |
| def add_noise_to_image(pil_image, noise_level=0.001): | |
| """Add slight noise to image to create variation in outputs""" | |
| try: | |
| if pil_image is None: | |
| return pil_image | |
| img_array = np.array(pil_image).astype(np.float32) / 255.0 | |
| noise = np.random.normal(0, noise_level, img_array.shape) | |
| noisy_array = img_array + noise | |
| # Clip values to valid range | |
| noisy_array = np.clip(noisy_array, 0, 1) | |
| # Convert back to PIL | |
| noisy_array = (noisy_array * 255).astype(np.uint8) | |
| return Image.fromarray(noisy_array) | |
| except Exception as e: | |
| logger.warning(f"Warning: Could not add noise to image: {e}") | |
| return pil_image # Return original if noise addition fails | |
| # Get fresh presets for this session | |
| # Resize input image first | |
| image = resize_image(image, max_size=1024) | |
| original_prompt = prompt | |
| prompt_info = "" | |
| # Handle preset batch generation | |
| if preset_type and preset_type in session_presets: | |
| preset = session_presets[preset_type] | |
| # Filter out empty prompts | |
| non_empty_preset_prompts = [p for p in preset["prompts"] if p.strip()] | |
| if non_empty_preset_prompts: | |
| batch_prompts = [f"{original_prompt}, {preset_prompt}" for preset_prompt in non_empty_preset_prompts] | |
| num_images_per_prompt = len(non_empty_preset_prompts) # Use actual count of non-empty prompts | |
| prompt_info = ( | |
| f"<div style='margin:10px; padding:15px; border-radius:8px; border-left:4px solid #2196F3;>" | |
| f"<h4 style='margin-top: 0;'>🎨 Preset: {preset_type}</h4>" | |
| f"<p>{preset['description']}</p>" | |
| f"<p><strong>Base Prompt:</strong> {original_prompt}</p>" | |
| f"<p>Generating {len(non_empty_preset_prompts)} image{'s' if len(non_empty_preset_prompts) > 1 else ''}</p>" | |
| f"</div>" | |
| ) | |
| logger.info(f"Using preset: {preset_type} with {len(batch_prompts)} variations") | |
| else: | |
| # Fallback to manual if no valid prompts | |
| batch_prompts = [prompt] | |
| prompt_info = ( | |
| f"<div style='margin:10px; padding:15px; border-radius:8px; border-left:4px solid #FF9800;>" | |
| f"<h4 style='margin-top: 0;'>⚠️ Invalid Preset</h4>" | |
| f"<p>No valid prompts found. Using manual prompt.</p>" | |
| f"<p><strong>Prompt:</strong> {original_prompt}</p>" | |
| f"</div>" | |
| ) | |
| else: | |
| batch_prompts = [prompt] # Single prompt in list | |
| # Handle regular prompt rewriting | |
| if rewrite_prompt: | |
| try: | |
| enhanced_instruction = polish_prompt(original_prompt) | |
| if enhanced_instruction and enhanced_instruction != original_prompt: | |
| prompt_info = ( | |
| f"<div style='margin:10px; padding:15px; border-radius:8px; border-left:4px solid #4CAF50;>" | |
| f"<h4 style='margin-top: 0;'>🚀 Prompt Enhancement</h4>" | |
| f"<p><strong>Original:</strong> {original_prompt}</p>" | |
| f"<p><strong style='color:#2E7D32;'>Enhanced:</strong> {enhanced_instruction}</p>" | |
| f"</div>" | |
| ) | |
| batch_prompts = [enhanced_instruction] | |
| else: | |
| prompt_info = ( | |
| f"<div style='margin:10px; padding:15px; border-radius:8px; border-left:4px solid #FF9800;>" | |
| f"<h4 style='margin-top: 0;'>📝 Prompt Enhancement</h4>" | |
| f"<p>No enhancement applied or enhancement failed</p>" | |
| f"</div>" | |
| ) | |
| except Exception as e: | |
| logger.warning(f"Prompt enhancement error: {str(e)}") # Debug logging | |
| gr.Warning(f"Prompt enhancement failed: {str(e)}") | |
| prompt_info = ( | |
| f"<div style='margin:10px; padding:15px; border-radius:8px; border-left:4px solid #FF5252;>" | |
| f"<h4 style='margin-top: 0;'>⚠️ Enhancement Not Applied</h4>" | |
| f"<p>Using original prompt. Error: {str(e)[:100]}</p>" | |
| f"</div>" | |
| ) | |
| else: | |
| prompt_info = ( | |
| f"<div style='margin:10px; padding:10px; border-radius:8px;>" | |
| f"<h4 style='margin-top: 0;'>📝 Original Prompt</h4>" | |
| f"<p>{original_prompt}</p>" | |
| f"</div>" | |
| ) | |
| # Set base seed for reproducibility | |
| base_seed = seed if not randomize_seed else random.randint(0, MAX_SEED) | |
| try: | |
| edited_images = [] | |
| # Generate images for each prompt in the batch | |
| for i, current_prompt in enumerate(batch_prompts): | |
| # Create unique seed for each image | |
| generator = torch.Generator(device=device).manual_seed(base_seed + i*1000) | |
| # Add slight noise to the image for variation (except for first image to maintain base) | |
| if i == 0 and len(batch_prompts) > 1: | |
| input_image = image | |
| else: | |
| input_image = add_noise_to_image(image, noise_level=0.001 + i*0.003) | |
| # Slightly vary guidance scale for each image | |
| varied_guidance = true_guidance_scale + random.uniform(-0.1, 0.1) | |
| varied_guidance = max(1.0, min(10.0, varied_guidance)) | |
| # Generate single image | |
| result = pipe( | |
| image=input_image, | |
| prompt=current_prompt, | |
| negative_prompt=" ", | |
| num_inference_steps=num_inference_steps, | |
| generator=generator, | |
| true_cfg_scale=varied_guidance, | |
| num_images_per_prompt=2 | |
| ).images | |
| edited_images.extend(result) | |
| logger.info(f"Generated image {i+1}/{len(batch_prompts)} with prompt: {current_prompt}...") | |
| # Clear cache after generation | |
| # if device == "cuda": | |
| # torch.cuda.empty_cache() | |
| # gc.collect() | |
| return edited_images, base_seed, prompt_info | |
| except Exception as e: | |
| # Clear cache on error | |
| if device == "cuda": | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| gr.Error(f"Image generation failed: {str(e)}") | |
| return [], base_seed, ( | |
| f"<div style='margin:10px; padding:15px; border-radius:8px; border-left:4px solid #dd2c00;>" | |
| f"<h4 style='margin-top: 0;'>⚠️ Processing Error</h4>" | |
| f"<p>{str(e)[:200]}</p>" | |
| f"</div>" | |
| ) | |
| with gr.Blocks(title="'Qwen Image Edit' Model Playground & Showcase [4-Step Lightning Mode]") as demo: | |
| preset_prompts_state = gr.State(value=[]) | |
| # preset_prompts_state = gr.State(value=["", "", "", ""]) | |
| preset_state = gr.State(value=ORIGINAL_PRESETS) | |
| gr.Markdown("## ⚡️ Qwen-Image-Edit Lightning Presets") | |
| with gr.Row(equal_height=True): | |
| # Input Column | |
| with gr.Column(scale=1): | |
| input_image = gr.Image( | |
| label="Source Image", | |
| type="pil", | |
| height=300 | |
| ) | |
| with gr.Column(scale=2): | |
| result = gr.Gallery( | |
| label="Edited Images", | |
| columns=2, | |
| container=True | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| prompt = gr.Textbox( | |
| label="Edit Instructions / Base Prompt", | |
| placeholder="e.g. Replace the background with a beach sunset... When a preset is selected, use as the base prompt, e.g. the lamborghini", | |
| lines=2, | |
| max_lines=4, | |
| scale=2 | |
| ) | |
| preset_dropdown = gr.Dropdown( | |
| choices=get_preset_choices(), | |
| value=None, | |
| label="Preset Batch Generation", | |
| interactive=True | |
| ) | |
| # Add editable preset prompts (initially hidden) | |
| preset_editor = gr.Group(visible=False) | |
| with preset_editor: | |
| gr.Markdown("### 🎨 Edit Preset Prompts") | |
| preset_prompt_1 = gr.Textbox(label="Prompt 1", lines=1, value="") | |
| preset_prompt_2 = gr.Textbox(label="Prompt 2", lines=1, value="") | |
| preset_prompt_3 = gr.Textbox(label="Prompt 3", lines=1, value="") | |
| preset_prompt_4 = gr.Textbox(label="Prompt 4", lines=1, value="") | |
| update_preset_button = gr.Button("Update Preset", variant="secondary", visible=False) | |
| reset_button = gr.Button("Reset Presets", variant="stop", visible=False) | |
| # Add prompt preview component | |
| prompt_preview = gr.Textbox( | |
| label="📋 Prompt Preview", | |
| interactive=False, | |
| lines=6, | |
| max_lines=10, | |
| value="Enter a base prompt and select a preset above to see how your prompt will be modified for batch generation.", | |
| placeholder="Prompt preview will appear here..." | |
| ) | |
| rewrite_toggle = gr.Checkbox( | |
| label="Additional Prompt Enhancement", | |
| info="Setting this to true will pass the basic prompt(s) generated via the static preset template to a secondary LLM tasked with improving the overall cohesiveness and details of the final generation prompt.", | |
| value=True, | |
| interactive=True | |
| ) | |
| run_button = gr.Button( | |
| "Generate Edit(s)", | |
| variant="primary" | |
| ) | |
| with gr.Accordion("Advanced Parameters", open=False): | |
| with gr.Row(): | |
| seed = gr.Slider( | |
| label="Seed", | |
| minimum=0, | |
| maximum=MAX_SEED, | |
| step=1, | |
| value=42 | |
| ) | |
| randomize_seed = gr.Checkbox( | |
| label="Random Seed", | |
| value=True | |
| ) | |
| with gr.Row(): | |
| true_guidance_scale = gr.Slider( | |
| label="True CFG Scale", | |
| minimum=1.0, | |
| maximum=10.0, | |
| step=0.1, | |
| value=1.1 | |
| ) | |
| num_inference_steps = gr.Slider( | |
| label="Inference Steps", | |
| minimum=1, | |
| maximum=16, | |
| step=1, | |
| value=3 | |
| ) | |
| num_images_per_prompt = gr.Slider( | |
| label="Output Count (Manual)", | |
| minimum=1, | |
| maximum=4, | |
| step=1, | |
| value=2, | |
| interactive=True | |
| ) | |
| with gr.Column(scale=2): | |
| prompt_info = gr.Markdown( | |
| value="<div style='padding:15px; margin-top:15px'>" | |
| "Hint: depending on the original image, prompt quality, and complexity, you can often get away with 3 steps, even 2 steps without much loss in quality. </div>" | |
| ) | |
| def show_preset_editor(preset_type): | |
| if preset_type and preset_type in preset_state.value: | |
| preset = preset_state.value[preset_type] | |
| prompts = preset["prompts"] + [""] * (4 - len(preset["prompts"])) | |
| return gr.Group(visible=True), *prompts[:4] | |
| return gr.Group(visible=False), "", "", "", "" | |
| def update_preset_count(preset_type, p1, p2, p3, p4): | |
| if preset_type and preset_type in preset_state.value: | |
| count = len([p for p in (p1,p2,p3,p4) if p.strip()]) | |
| return gr.Slider(value=max(1, min(4, count)), interactive=False) | |
| return gr.Slider(interactive=True) | |
| # Update the preset_dropdown.change handlers to use ORIGINAL_PRESETS | |
| preset_dropdown.change( | |
| fn=show_preset_editor, | |
| inputs=[preset_dropdown], | |
| outputs=[preset_editor, preset_prompt_1, preset_prompt_2, preset_prompt_3, preset_prompt_4] | |
| ) | |
| preset_dropdown.change( | |
| fn=update_prompt_preview, | |
| inputs=[preset_dropdown, prompt], | |
| outputs=prompt_preview | |
| ) | |
| preset_prompt_1.change( | |
| fn=update_preset_prompt_textbox, | |
| inputs=[preset_dropdown, preset_prompt_1, preset_prompt_2, preset_prompt_3, preset_prompt_4], | |
| outputs=prompt_preview | |
| ) | |
| preset_prompt_2.change( | |
| fn=update_preset_prompt_textbox, | |
| inputs=[preset_dropdown, preset_prompt_1, preset_prompt_2, preset_prompt_3, preset_prompt_4], | |
| outputs=prompt_preview | |
| ) | |
| preset_prompt_3.change( | |
| fn=update_preset_prompt_textbox, | |
| inputs=[preset_dropdown, preset_prompt_1, preset_prompt_2, preset_prompt_3, preset_prompt_4], | |
| outputs=prompt_preview | |
| ) | |
| preset_prompt_4.change( | |
| fn=update_preset_prompt_textbox, | |
| inputs=[preset_dropdown, preset_prompt_1, preset_prompt_2, preset_prompt_3, preset_prompt_4], | |
| outputs=prompt_preview | |
| ) | |
| preset_prompt_1.change( | |
| fn=update_preset_count, | |
| inputs=[preset_dropdown, preset_prompt_1, preset_prompt_2, preset_prompt_3, preset_prompt_4], | |
| outputs=num_images_per_prompt | |
| ) | |
| preset_prompt_2.change( | |
| fn=update_preset_count, | |
| inputs=[preset_dropdown, preset_prompt_1, preset_prompt_2, preset_prompt_3, preset_prompt_4], | |
| outputs=num_images_per_prompt | |
| ) | |
| preset_prompt_3.change( | |
| fn=update_preset_count, | |
| inputs=[preset_dropdown, preset_prompt_1, preset_prompt_2, preset_prompt_3, preset_prompt_4], | |
| outputs=num_images_per_prompt | |
| ) | |
| preset_prompt_4.change( | |
| fn=update_preset_count, | |
| inputs=[preset_dropdown, preset_prompt_1, preset_prompt_2, preset_prompt_3, preset_prompt_4], | |
| outputs=num_images_per_prompt | |
| ) | |
| prompt.change( | |
| fn=update_prompt_preview, | |
| inputs=[preset_dropdown, prompt], | |
| outputs=prompt_preview | |
| ) | |
| update_preset_button.click( | |
| fn=update_preset_prompt_textbox, | |
| inputs=[preset_dropdown, preset_prompt_1, preset_prompt_2, preset_prompt_3, preset_prompt_4], | |
| outputs=prompt_preview | |
| ) | |
| # Set up processing | |
| inputs = [ | |
| input_image, | |
| prompt, | |
| seed, | |
| randomize_seed, | |
| true_guidance_scale, | |
| num_inference_steps, | |
| rewrite_toggle, | |
| num_images_per_prompt, | |
| preset_dropdown | |
| ] | |
| outputs = [result, seed, prompt_info] | |
| run_button.click( | |
| fn=infer, | |
| inputs=inputs, | |
| outputs=outputs | |
| ) | |
| # .then( | |
| # fn=reset_presets, outputs=preset_state | |
| # ) | |
| prompt.submit( | |
| fn=infer, | |
| inputs=inputs, | |
| outputs=outputs | |
| ) | |
| reset_button.click(fn=reset_presets, outputs=preset_state) | |
| demo.queue(max_size=5).launch() |