import gradio as gr import torch import spaces import os from diffusers import DiffusionPipeline # --- Model Configuration and Loading --- MODEL_ID = "Manojb/stable-diffusion-2-1-base" DTYPE = torch.bfloat16 try: # Load pipeline pipe = DiffusionPipeline.from_pretrained( MODEL_ID, torch_dtype=DTYPE, use_safetensors=True ) pipe.to('cuda') # --- Mandatory ZeroGPU AoT Compilation for Optimization --- @spaces.GPU(duration=1500) # Extended duration for startup compilation def compile_unet(): print("Starting AoT compilation for UNet...") # Dummy inputs for 512x512 generation (B=1, latents=64x64 for UNet) B, C, H, W = 1, 4, 64, 64 sample = torch.randn(B, C, H, W, dtype=DTYPE, device='cuda') timestep = torch.tensor([999], dtype=torch.long, device='cuda') # Encoder Hidden States (text embeddings): (B, 77, 1024) for SD2.1 EHS_DIM = 77 EHS_HIDDEN = 1024 encoder_hidden_states = torch.randn(B, EHS_DIM, EHS_HIDDEN, dtype=DTYPE, device='cuda') inputs = (sample, timestep, encoder_hidden_states) with spaces.aoti_capture(pipe.unet) as call: call(*inputs) exported = torch.export.export(pipe.unet, args=call.args, kwargs=call.kwargs) compiled_model = spaces.aoti_compile(exported) print("AoT compilation successful.") return compiled_model # Execute compilation during startup compiled_unet = compile_unet() spaces.aoti_apply(compiled_unet, pipe.unet) except Exception as e: print(f"⚠️ Warning: Model initialization or AoT compilation failed ({e}). Running without optimization or skipping initialization if severe.") # Fallback to loading the model without AoT if compilation fails if 'pipe' not in locals(): pipe = DiffusionPipeline.from_pretrained(MODEL_ID, torch_dtype=DTYPE, use_safetensors=True) pipe.to('cuda') print("Model loaded successfully without AoT.") @spaces.GPU(duration=60) # Standard GPU allocation for inference def generate(prompt: str, num_images: int): """Generates images using the Stable Diffusion pipeline.""" if not prompt: raise gr.Error("Prompt cannot be empty.") # Prepare batch input prompt_list = [prompt] * num_images # Generate images output = pipe( prompt_list, num_inference_steps=25, guidance_scale=9.0, ) return output.images # --- Gradio Interface --- with gr.Blocks(theme=gr.themes.Soft(), title="SD 2.1 Base Generator") as demo: gr.HTML( """

Stable Diffusion 2.1 Base (512x512)

Model: Manojb/stable-diffusion-2-1-base | Optimized with ZeroGPU AoT

Built with anycoder

""" ) with gr.Row(): with gr.Column(scale=1): prompt = gr.Textbox( label="Prompt", placeholder="A detailed digital painting of a majestic dragon flying over a medieval castle, fantasy art", lines=3 ) num_images = gr.Slider( minimum=1, maximum=4, step=1, value=2, label="Number of Images to Generate (Max 4)", info="Generates multiple images in a single batch call." ) generate_btn = gr.Button("Generate Images", variant="primary") with gr.Column(scale=2): output_gallery = gr.Gallery( label="Generated Images (512x512)", height=512, columns=2, rows=2, object_fit="contain" ) generate_btn.click( fn=generate, inputs=[prompt, num_images], outputs=output_gallery ) gr.Examples( examples=[ ["A photorealistic portrait of a golden retriever wearing sunglasses on a beach, cinematic lighting", 2], ["Steampunk owl on a bookshelf, detailed brass gears, oil painting", 4], ["High contrast black and white photograph of an old lighthouse during a storm", 1] ], inputs=[prompt, num_images], outputs=output_gallery, fn=generate, cache_examples=True, cache_mode="eager" ) demo.queue() if __name__ == "__main__": demo.launch()