import gradio as gr import numpy as np import random import torch import spaces from PIL import Image from diffusers import QwenImageEditPipeline, FlowMatchEulerDiscreteScheduler from diffusers.utils import is_xformers_available from presets import PRESETS, get_preset_choices, get_preset_info, update_preset_prompt import os import sys import re import gc import math import json # Added json import from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig import logging from copy import deepcopy ############################# os.environ.setdefault('GRADIO_ANALYTICS_ENABLED', 'False') os.environ.setdefault('HF_HUB_DISABLE_TELEMETRY', '1') # Set up logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[logging.StreamHandler(sys.stdout)] ) logger = logging.getLogger(__name__) # Model configuration REWRITER_MODEL = "Qwen/Qwen1.5-4B-Chat" # Upgraded to 4B for better JSON handling dtype = torch.bfloat16 device = "cuda" if torch.cuda.is_available() else "cpu" MAX_SEED = np.iinfo(np.int32).max LOC = os.getenv("QIE") # Quantization configuration bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True ) rewriter_model = AutoModelForCausalLM.from_pretrained( REWRITER_MODEL, torch_dtype=dtype, device_map="auto", quantization_config=bnb_config, ) def get_fresh_presets(): """Return a fresh copy of presets to avoid persistence across users""" return deepcopy(PRESETS) # Store original presets for reference ORIGINAL_PRESETS = deepcopy(PRESETS) # Preload enhancement model at startup print("๐Ÿ”„ Loading prompt enhancement model...") rewriter_tokenizer = AutoTokenizer.from_pretrained(REWRITER_MODEL) print("โœ… Enhancement model loaded and ready!") SYSTEM_PROMPT_EDIT = ''' # Edit Instruction Rewriter You are a professional edit instruction rewriter. Your task is to generate a precise, concise, and visually achievable instruction based on the user's intent and the input image. ## 1. General Principles - Keep the rewritten instruction **concise** and clear. - Avoid contradictions, vagueness, or unachievable instructions. - Maintain the core logic of the original instruction; only enhance clarity and feasibility. - Ensure new added elements or modifications align with the image's original context and art style. ## 2. Task Types ### Add, Delete, Replace: - When the input is detailed, only refine grammar and clarity. - For vague instructions, infer minimal but sufficient details. - For replacement, use the format: `"Replace X with Y"`. ### Text Editing (e.g., text replacement): - Enclose text content in quotes, e.g., `Replace "abc" with "xyz"`. - Preserving the original structure and languageโ€”**do not translate** or alter style. ### Human Editing (e.g., change a personโ€™s face/hair): - Preserve core visual identity (gender, ethnic features). - Describe expressions in subtle and natural terms. - Maintain key clothing or styling details unless explicitly replaced. ### Style Transformation: - If a style is specified, e.g., `Disco style`, rewrite it to encapsulate the essential visual traits. - Use a fixed template for **coloring/restoration**: `"Restore old photograph, remove scratches, reduce noise, enhance details, high resolution, realistic, natural skin tones, clear facial features, no distortion, vintage photo restoration"` if applicable. ## 4. Output Format Please provide the rewritten instruction in a clean `json` format as: { "Rewritten": "..." } ''' def extract_json_response(model_output: str) -> str: """Extract rewritten instruction from potentially messy JSON output""" # Remove code block markers first model_output = re.sub(r'```(?:json)?\s*', '', model_output) try: # Find the JSON portion in the output start_idx = model_output.find('{') end_idx = model_output.rfind('}') # Fix the condition - check if brackets were found if start_idx == -1 or end_idx == -1 or start_idx >= end_idx: print(f"No valid JSON structure found in output. Start: {start_idx}, End: {end_idx}") return None # Expand to the full object including outer braces end_idx += 1 # Include the closing brace json_str = model_output[start_idx:end_idx] # Handle potential markdown or other formatting json_str = json_str.strip() # Try to parse JSON directly first try: data = json.loads(json_str) except json.JSONDecodeError as e: print(f"Direct JSON parsing failed: {e}") # If direct parsing fails, try cleanup # Quote keys properly json_str = re.sub(r'([^{}[\],\s"]+)(?=\s*:)', r'"\1"', json_str) # Remove any trailing commas that might cause issues json_str = re.sub(r',(\s*[}\]])', r'\1', json_str) # Try parsing again data = json.loads(json_str) # Extract rewritten prompt from possible key variations possible_keys = [ "Rewritten", "rewritten", "Rewrited", "rewrited", "Rewrittent", "Output", "output", "Enhanced", "enhanced" ] for key in possible_keys: if key in data: return data[key].strip() # Try nested path if "Response" in data and "Rewritten" in data["Response"]: return data["Response"]["Rewritten"].strip() # Handle nested JSON objects (additional protection) if isinstance(data, dict): for value in data.values(): if isinstance(value, dict) and "Rewritten" in value: return value["Rewritten"].strip() # Try to find any string value that looks like an instruction str_values = [v for v in data.values() if isinstance(v, str) and 10 < len(v) < 500] if str_values: return str_values[0].strip() except Exception as e: print(f"JSON parse error: {str(e)}") print(f"Model output was: {model_output}") return None def polish_prompt(original_prompt: str) -> str: """Enhanced prompt rewriting using original system prompt with JSON handling""" # Format as Qwen chat messages = [ {"role": "system", "content": SYSTEM_PROMPT_EDIT}, {"role": "user", "content": original_prompt} ] text = rewriter_tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) model_inputs = rewriter_tokenizer(text, return_tensors="pt").to(device) with torch.no_grad(): generated_ids = rewriter_model.generate( **model_inputs, max_new_tokens=256, do_sample=True, temperature=0.7, top_p=0.8, repetition_penalty=1.1, no_repeat_ngram_size=3, pad_token_id=rewriter_tokenizer.eos_token_id ) # Extract and clean response enhanced = rewriter_tokenizer.decode( generated_ids[0][model_inputs.input_ids.shape[1]:], skip_special_tokens=True ).strip() print(f"Original Prompt: {original_prompt}") print(f"Model raw output: {enhanced}") # Debug logging # Try to extract JSON content rewritten_prompt = extract_json_response(enhanced) if rewritten_prompt: # Clean up remaining artifacts rewritten_prompt = re.sub(r'(Replace|Change|Add) "(.*?)"', r'\1 \2', rewritten_prompt) rewritten_prompt = rewritten_prompt.replace('\\"', '"').replace('\\n', ' ') return rewritten_prompt else: # Fallback: try to extract from code blocks or just return cleaned content if '```' in enhanced: parts = enhanced.split('```') if len(parts) >= 2: rewritten_prompt = parts[1].strip() else: rewritten_prompt = enhanced else: rewritten_prompt = enhanced # Basic cleanup rewritten_prompt = re.sub(r'\s\s+', ' ', rewritten_prompt).strip() if ': ' in rewritten_prompt: rewritten_prompt = rewritten_prompt.split(': ', 1)[-1].strip() return rewritten_prompt[:200] if rewritten_prompt else original_prompt # Scheduler configuration for Lightning scheduler_config = { "base_image_seq_len": 256, "base_shift": math.log(3), "invert_sigmas": False, "max_image_seq_len": 8192, "max_shift": math.log(3), "num_train_timesteps": 1000, "shift": 1.0, "shift_terminal": None, "stochastic_sampling": False, "time_shift_type": "exponential", "use_beta_sigmas": False, "use_dynamic_shifting": True, "use_exponential_sigmas": False, "use_karras_sigmas": False, } # Initialize scheduler with Lightning config scheduler = FlowMatchEulerDiscreteScheduler.from_config(scheduler_config) # Load main image editing pipeline pipe = QwenImageEditPipeline.from_pretrained( LOC, scheduler=scheduler, torch_dtype=dtype ).to(device) # Load LoRA weights for acceleration pipe.load_lora_weights( "lightx2v/Qwen-Image-Lightning", # weight_name="Qwen-Image-Lightning-8steps-V1.1.safetensors" weight_name="Qwen-Image-Lightning-4steps-V1.0.safetensors" ) pipe.fuse_lora() # if is_xformers_available(): # pipe.enable_xformers_memory_efficient_attention() # else: # print("xformers not available") try: pipe.enable_vae_slicing() except Exception as e: print(f"VAE Slicing Failed: {e}") def toggle_output_count(preset_type): """Control output count slider interactivity and show/hide preset editor""" if preset_type and preset_type in ORIGINAL_PRESETS: # When preset is selected, disable manual output count and show editor preset = ORIGINAL_PRESETS[preset_type] prompts = preset["prompts"][:4] # Get up to 4 prompts # Pad prompts to 4 items if needed while len(prompts) < 4: prompts.append("") return ( gr.Group(visible=True), gr.Slider(interactive=False, value=len([p for p in prompts if p.strip()])), # Count non-empty prompts prompts[0], prompts[1], prompts[2], prompts[3] # Populate preset prompts ) else: # When no preset is selected, enable manual output count and hide editor return ( gr.Group(visible=False), gr.Slider(interactive=True), # Enable slider "", "", "", "" # Clear preset prompts ) def update_prompt_preview(preset_type, base_prompt): """Update the prompt preview display based on selected preset and base prompt""" if preset_type and preset_type in ORIGINAL_PRESETS: preset = ORIGINAL_PRESETS[preset_type] non_empty_prompts = [p for p in preset["prompts"] if p.strip()] if not non_empty_prompts: return "No prompts defined. Please enter at least one prompt in the editor." preview_text = f"**Preset: {preset_type}**\n\n" preview_text += f"*{preset['description']}*\n\n" preview_text += f"**Generating {len(non_empty_prompts)} image{'s' if len(non_empty_prompts) > 1 else ''}:**\n" for i, preset_prompt in enumerate(non_empty_prompts, 1): full_prompt = f"{base_prompt}, {preset_prompt}" preview_text += f"{i}. {full_prompt}\n" return preview_text else: return "Select a preset above to see how your base prompt will be modified for batch generation." def update_preset_prompt_textbox(preset_type, prompt_1, prompt_2, prompt_3, prompt_4): """Update preset prompts based on user input - now works with session copy""" if preset_type and preset_type in ORIGINAL_PRESETS: # Update each prompt in the preset copy (this won't persist globally) new_prompts = [prompt_1, prompt_2, prompt_3, prompt_4] # Create a working copy for preview purposes working_presets = get_fresh_presets() for i, new_prompt in enumerate(new_prompts): if i < len(working_presets[preset_type]["prompts"]): working_presets[preset_type]["prompts"][i] = new_prompt.strip() else: working_presets[preset_type]["prompts"].append(new_prompt.strip()) # Return updated preset info for preview return update_prompt_preview_with_presets(preset_type, "your subject", working_presets) return "Select a preset first to edit its prompts." def update_prompt_preview_with_presets(preset_type, base_prompt, custom_presets): """Update the prompt preview display with custom presets""" if preset_type and preset_type in custom_presets: preset = custom_presets[preset_type] non_empty_prompts = [p for p in preset["prompts"] if p.strip()] if not non_empty_prompts: return "No prompts defined. Please enter at least one prompt in the editor." preview_text = f"**Preset: {preset_type}**\n\n" preview_text += f"*{preset['description']}*\n\n" preview_text += f"**Generating {len(non_empty_prompts)} image{'s' if len(non_empty_prompts) > 1 else ''}:**\n" for i, preset_prompt in enumerate(non_empty_prompts, 1): full_prompt = f"{base_prompt}, {preset_prompt}" preview_text += f"{i}. {full_prompt}\n" return preview_text else: return "Select a preset above to see how your base prompt will be modified for batch generation." @spaces.GPU() def infer( image, prompt, seed=42, randomize_seed=False, true_guidance_scale=4.0, num_inference_steps=4, rewrite_prompt=True, num_images_per_prompt=1, preset_type=None, progress=gr.Progress(track_tqdm=True), ): """Image editing endpoint with optimized prompt handling - now uses fresh presets""" # Resize image to max 1024px on longest side def resize_image(pil_image, max_size=1024): """Resize image to maximum dimension of 1024px while maintaining aspect ratio""" try: if pil_image is None: return pil_image width, height = pil_image.size max_dimension = max(width, height) if max_dimension <= max_size: return pil_image # No resize needed # Calculate new dimensions maintaining aspect ratio scale = max_size / max_dimension new_width = int(width * scale) new_height = int(height * scale) # Resize image resized_image = pil_image.resize((new_width, new_height), Image.LANCZOS) print(f"๐Ÿ“ Image resized from {width}x{height} to {new_width}x{new_height}") return resized_image except Exception as e: print(f"โš ๏ธ Image resize failed: {e}") return pil_image # Return original if resize fails # Add noise function for batch variation def add_noise_to_image(pil_image, noise_level=0.001): """Add slight noise to image to create variation in outputs""" try: if pil_image is None: return pil_image img_array = np.array(pil_image).astype(np.float32) / 255.0 noise = np.random.normal(0, noise_level, img_array.shape) noisy_array = img_array + noise # Clip values to valid range noisy_array = np.clip(noisy_array, 0, 1) # Convert back to PIL noisy_array = (noisy_array * 255).astype(np.uint8) return Image.fromarray(noisy_array) except Exception as e: print(f"Warning: Could not add noise to image: {e}") return pil_image # Return original if noise addition fails # Get fresh presets for this session session_presets = get_fresh_presets() # Resize input image first image = resize_image(image, max_size=1024) original_prompt = prompt prompt_info = "" # Handle preset batch generation if preset_type and preset_type in session_presets: preset = session_presets[preset_type] # Filter out empty prompts non_empty_preset_prompts = [p for p in preset["prompts"] if p.strip()] if non_empty_preset_prompts: batch_prompts = [f"{original_prompt}, {preset_prompt}" for preset_prompt in non_empty_preset_prompts] num_images_per_prompt = len(non_empty_preset_prompts) # Use actual count of non-empty prompts prompt_info = ( f"
" f"

๐ŸŽจ Preset: {preset_type}

" f"

{preset['description']}

" f"

Base Prompt: {original_prompt}

" f"

Generating {len(non_empty_preset_prompts)} image{'s' if len(non_empty_preset_prompts) > 1 else ''}

" f"
" ) print(f"Using preset: {preset_type} with {len(batch_prompts)} variations") else: # Fallback to manual if no valid prompts batch_prompts = [prompt] prompt_info = ( f"
" f"

โš ๏ธ Invalid Preset

" f"

No valid prompts found. Using manual prompt.

" f"

Prompt: {original_prompt}

" f"
" ) else: batch_prompts = [prompt] # Single prompt in list # Handle regular prompt rewriting if rewrite_prompt: try: enhanced_instruction = polish_prompt(original_prompt) if enhanced_instruction and enhanced_instruction != original_prompt: prompt_info = ( f"
" f"

๐Ÿš€ Prompt Enhancement

" f"

Original: {original_prompt}

" f"

Enhanced: {enhanced_instruction}

" f"
" ) batch_prompts = [enhanced_instruction] else: prompt_info = ( f"
" f"

๐Ÿ“ Prompt Enhancement

" f"

No enhancement applied or enhancement failed

" f"
" ) except Exception as e: print(f"Prompt enhancement error: {str(e)}") # Debug logging gr.Warning(f"Prompt enhancement failed: {str(e)}") prompt_info = ( f"
" f"

โš ๏ธ Enhancement Not Applied

" f"

Using original prompt. Error: {str(e)[:100]}

" f"
" ) else: prompt_info = ( f"
" f"

๐Ÿ“ Original Prompt

" f"

{original_prompt}

" f"
" ) # Set base seed for reproducibility base_seed = seed if not randomize_seed else random.randint(0, MAX_SEED) try: edited_images = [] # Generate images for each prompt in the batch for i, current_prompt in enumerate(batch_prompts): # Create unique seed for each image generator = torch.Generator(device=device).manual_seed(base_seed + i*1000) # Add slight noise to the image for variation (except for first image to maintain base) if i == 0 and len(batch_prompts) > 1: input_image = image else: input_image = add_noise_to_image(image, noise_level=0.001 + i*0.003) # Slightly vary guidance scale for each image varied_guidance = true_guidance_scale + random.uniform(-0.1, 0.1) varied_guidance = max(1.0, min(10.0, varied_guidance)) # Generate single image result = pipe( image=input_image, prompt=current_prompt, negative_prompt=" ", num_inference_steps=num_inference_steps, generator=generator, true_cfg_scale=varied_guidance, num_images_per_prompt=1 ).images edited_images.extend(result) print(f"Generated image {i+1}/{len(batch_prompts)} with prompt: {current_prompt[:75]}...") # Clear cache after generation # if device == "cuda": # torch.cuda.empty_cache() # gc.collect() return edited_images, base_seed, prompt_info except Exception as e: # Clear cache on error if device == "cuda": torch.cuda.empty_cache() gc.collect() gr.Error(f"Image generation failed: {str(e)}") return [], base_seed, ( f"
" f"

โš ๏ธ Processing Error

" f"

{str(e)[:200]}

" f"
" ) with gr.Blocks(title="Qwen Image Edit - Fast Lightning Mode w/ Batch") as demo: preset_prompts_state = gr.State(value=[]) # preset_prompts_state = gr.State(value=["", "", "", ""]) gr.Markdown("""

โšก๏ธ Qwen-Image-Edit Lightning

โœจ 4-step inferencing with lightx2v's LoRA.

๐Ÿ“ Local Prompt Enhancement, Batched Multi-image Generation, ๐ŸŽจ Preset Batches

""") with gr.Row(equal_height=True): # Input Column with gr.Column(scale=1): input_image = gr.Image( label="Source Image", type="pil", height=300 ) prompt = gr.Textbox( label="Edit Instructions / Base Prompt", placeholder="e.g. Replace the background with a beach sunset... When a preset is selected, use as the base prompt, e.g. the lamborghini", lines=2, max_lines=4, scale=2 ) preset_dropdown = gr.Dropdown( choices=get_preset_choices(), value=None, label="Preset Batch Generation", interactive=True ) # Add editable preset prompts (initially hidden) preset_editor = gr.Group(visible=False) with preset_editor: gr.Markdown("### ๐ŸŽจ Edit Preset Prompts") preset_prompt_1 = gr.Textbox(label="Prompt 1", lines=1, value="") preset_prompt_2 = gr.Textbox(label="Prompt 2", lines=1, value="") preset_prompt_3 = gr.Textbox(label="Prompt 3", lines=1, value="") preset_prompt_4 = gr.Textbox(label="Prompt 4", lines=1, value="") update_preset_button = gr.Button("Update Preset", variant="secondary") rewrite_toggle = gr.Checkbox( label="Enable Prompt Enhancement", value=True, interactive=True ) # Add prompt preview component prompt_preview = gr.Textbox( label="๐Ÿ“‹ Prompt Preview", interactive=False, lines=6, max_lines=10, value="Enter a base prompt and select a preset above to see how your prompt will be modified for batch generation.", placeholder="Prompt preview will appear here..." ) run_button = gr.Button( "Generate Edit(s)", variant="primary" ) with gr.Accordion("Advanced Parameters", open=False): with gr.Row(): seed = gr.Slider( label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42 ) randomize_seed = gr.Checkbox( label="Random Seed", value=True ) with gr.Row(): true_guidance_scale = gr.Slider( label="True CFG Scale", minimum=1.0, maximum=10.0, step=0.1, value=1.0 ) num_inference_steps = gr.Slider( label="Inference Steps", minimum=2, maximum=16, step=1, value=4 ) num_images_per_prompt = gr.Slider( label="Output Count (Manual)", minimum=1, maximum=4, step=1, value=2, interactive=True ) # Output Column with gr.Column(scale=2): result = gr.Gallery( label="Edited Images", columns=lambda x: min(x, 2), height=500, object_fit="cover", preview=True ) prompt_info = gr.HTML( value="
" "Prompt details will appear after generation. Ability to edit Preset Prompts on the fly will be implemented shortly.
" ) # Fix the show_preset_editor function to use ORIGINAL_PRESETS: def show_preset_editor(preset_type): if preset_type and preset_type in ORIGINAL_PRESETS: # Changed from PRESETS to ORIGINAL_PRESETS preset = ORIGINAL_PRESETS[preset_type] prompts = preset["prompts"] # Pad prompts to 4 items if needed while len(prompts) < 4: prompts.append("") return gr.Group(visible=True), prompts[0], prompts[1], prompts[2], prompts[3] return gr.Group(visible=False), "", "", "", "" # Fix the update_preset_count function to use ORIGINAL_PRESETS: def update_preset_count(preset_type, prompt_1, prompt_2, prompt_3, prompt_4): """Update the output count slider based on non-empty preset prompts""" if preset_type and preset_type in ORIGINAL_PRESETS: # Changed from PRESETS to ORIGINAL_PRESETS non_empty_count = len([p for p in [prompt_1, prompt_2, prompt_3, prompt_4] if p.strip()]) return gr.Slider(value=max(1, min(4, non_empty_count)), interactive=False) return gr.Slider(interactive=True) # Update the preset_dropdown.change handlers to use ORIGINAL_PRESETS preset_dropdown.change( fn=toggle_output_count, inputs=preset_dropdown, outputs=[preset_editor, num_images_per_prompt, preset_prompt_1, preset_prompt_2, preset_prompt_3, preset_prompt_4] ) preset_dropdown.change( fn=update_prompt_preview, inputs=[preset_dropdown, prompt], outputs=prompt_preview ) preset_prompt_1.change( fn=update_preset_count, inputs=[preset_dropdown, preset_prompt_1, preset_prompt_2, preset_prompt_3, preset_prompt_4], outputs=num_images_per_prompt ) preset_prompt_2.change( fn=update_preset_count, inputs=[preset_dropdown, preset_prompt_1, preset_prompt_2, preset_prompt_3, preset_prompt_4], outputs=num_images_per_prompt ) preset_prompt_3.change( fn=update_preset_count, inputs=[preset_dropdown, preset_prompt_1, preset_prompt_2, preset_prompt_3, preset_prompt_4], outputs=num_images_per_prompt ) preset_prompt_4.change( fn=update_preset_count, inputs=[preset_dropdown, preset_prompt_1, preset_prompt_2, preset_prompt_3, preset_prompt_4], outputs=num_images_per_prompt ) prompt.change( fn=update_prompt_preview, inputs=[preset_dropdown, prompt], outputs=prompt_preview ) update_preset_button.click( fn=update_preset_prompt_textbox, inputs=[preset_dropdown, preset_prompt_1, preset_prompt_2, preset_prompt_3, preset_prompt_4], outputs=prompt_preview ) # Set up processing inputs = [ input_image, prompt, seed, randomize_seed, true_guidance_scale, num_inference_steps, rewrite_toggle, num_images_per_prompt, preset_dropdown ] outputs = [result, seed, prompt_info] run_button.click( fn=infer, inputs=inputs, outputs=outputs ) prompt.submit( fn=infer, inputs=inputs, outputs=outputs ) demo.queue(max_size=5).launch()