import spaces import torch import gradio as gr from diffusers import StableDiffusionPipeline from torchao.quantization import quantize_, Float8DynamicActivationFloat8WeightConfig import os # --- 1. Model Loading and Optimization (AoT Compilation) --- # Choose a stable diffusion model MODEL_ID = "runwayml/stable-diffusion-v1-5" # Initialize pipeline, disable safety checker for faster compilation and inference # Use torch.float16 for efficiency on CUDA hardware pipe = StableDiffusionPipeline.from_pretrained( MODEL_ID, torch_dtype=torch.float16, safety_checker=None, requires_safety_checker=False ) pipe.to('cuda') pipe.scheduler.set_timesteps(50) # Set max steps for consistent performance testing print("Starting AoT Compilation...") @spaces.GPU(duration=1500) # Reserve maximum time for startup compilation def compile_optimized_unet(): # 1. Apply FP8 quantization (optional, requires H200/H100 for maximum benefit) try: quantize_(pipe.unet, Float8DynamicActivationFloat8WeightConfig()) print("✅ Applied FP8 quantization to UNet.") except Exception as e: print(f"⚠️ FP8 Quantization failed (may require specific hardware/libraries): {e}") # 2. Define and capture example inputs for the UNet (the core engine) # Standard Stable Diffusion UNet inputs (batch_size=2 for classifier-free guidance) bsz = 2 latent_model_input = torch.randn(bsz, 4, 64, 64, device="cuda", dtype=torch.float16) t = torch.randint(0, 1000, (bsz,), device="cuda') encoder_hidden_states = torch.randn(bsz, 77, 768, device="cuda", dtype=torch.float16) with spaces.aoti_capture(pipe.unet) as call: pipe.unet(latent_model_input, t, encoder_hidden_states) # 3. Export the model exported = torch.export.export( pipe.unet, args=call.args, kwargs=call.kwargs, ) # 4. Compile the exported model using AoT return spaces.aoti_compile(exported) # Execute compilation during startup compiled_unet = compile_optimized_unet() # 5. Apply compiled model to the pipeline's UNet component spaces.aoti_apply(compiled_unet, pipe.unet) print("✅ AoT Compilation completed successfully.") # --- 2. Inference Function (Running on GPU) --- @spaces.GPU(duration=60) # Standard duration for image generation def generate_image( prompt: str, negative_prompt: str, steps: int, seed: int ): if not prompt: raise gr.Error("Prompt cannot be empty.") generator = torch.Generator(device="cuda").manual_seed(seed) if seed != -1 else None steps = int(steps) # Run inference using the optimized pipeline result = pipe( prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=steps, guidance_scale=7.5, generator=generator ).images return result # --- 3. Gradio Interface --- with gr.Blocks(title="Optimized Vision Model (AoT Powered)") as demo: gr.HTML( """
This demo simulates a creative Vision Language Model using AoT-compiled Stable Diffusion for lightning-fast image generation.