import spaces import torch import gradio as gr from diffusers import StableDiffusionPipeline from torchao.quantization import quantize_, Float8DynamicActivationFloat8WeightConfig import os # --- 1. Model Loading and Optimization (AoT Compilation) --- # Choose a stable diffusion model MODEL_ID = "runwayml/stable-diffusion-v1-5" # Initialize pipeline, disable safety checker for faster compilation and inference # Use torch.float16 for efficiency on CUDA hardware pipe = StableDiffusionPipeline.from_pretrained( MODEL_ID, torch_dtype=torch.float16, safety_checker=None, requires_safety_checker=False ) pipe.to('cuda') pipe.scheduler.set_timesteps(50) # Set max steps for consistent performance testing print("Starting AoT Compilation...") @spaces.GPU(duration=1500) # Reserve maximum time for startup compilation def compile_optimized_unet(): # 1. Apply FP8 quantization (optional, requires H200/H100 for maximum benefit) try: quantize_(pipe.unet, Float8DynamicActivationFloat8WeightConfig()) print("✅ Applied FP8 quantization to UNet.") except Exception as e: print(f"⚠️ FP8 Quantization failed (may require specific hardware/libraries): {e}") # 2. Define and capture example inputs for the UNet (the core engine) # Standard Stable Diffusion UNet inputs (batch_size=2 for classifier-free guidance) bsz = 2 latent_model_input = torch.randn(bsz, 4, 64, 64, device="cuda", dtype=torch.float16) t = torch.randint(0, 1000, (bsz,), device="cuda') encoder_hidden_states = torch.randn(bsz, 77, 768, device="cuda", dtype=torch.float16) with spaces.aoti_capture(pipe.unet) as call: pipe.unet(latent_model_input, t, encoder_hidden_states) # 3. Export the model exported = torch.export.export( pipe.unet, args=call.args, kwargs=call.kwargs, ) # 4. Compile the exported model using AoT return spaces.aoti_compile(exported) # Execute compilation during startup compiled_unet = compile_optimized_unet() # 5. Apply compiled model to the pipeline's UNet component spaces.aoti_apply(compiled_unet, pipe.unet) print("✅ AoT Compilation completed successfully.") # --- 2. Inference Function (Running on GPU) --- @spaces.GPU(duration=60) # Standard duration for image generation def generate_image( prompt: str, negative_prompt: str, steps: int, seed: int ): if not prompt: raise gr.Error("Prompt cannot be empty.") generator = torch.Generator(device="cuda").manual_seed(seed) if seed != -1 else None steps = int(steps) # Run inference using the optimized pipeline result = pipe( prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=steps, guidance_scale=7.5, generator=generator ).images return result # --- 3. Gradio Interface --- with gr.Blocks(title="Optimized Vision Model (AoT Powered)") as demo: gr.HTML( """

Built with anycoder

High-Performance Creative VLM Simulator (AoT Optimized)

This demo simulates a creative Vision Language Model using AoT-compiled Stable Diffusion for lightning-fast image generation.

""" ) with gr.Row(): with gr.Column(scale=1): prompt = gr.Textbox( label="Prompt (Input to VLM)", placeholder="A futuristic city painted by Van Gogh, highly detailed.", lines=3 ) negative_prompt = gr.Textbox( label="Negative Prompt (What to avoid)", placeholder="Blurry, bad quality, low resolution", lines=2 ) with gr.Accordion("Generation Settings", open=True): steps = gr.Slider( minimum=10, maximum=50, step=1, value=30, label="Inference Steps (Higher = Slower/Better)" ) seed = gr.Number( value=-1, label="Seed (-1 for random)" ) generate_btn = gr.Button("Generate Image (AoT Fast!)", variant="primary") with gr.Column(scale=2): output_gallery = gr.Gallery( label="Creative VLM Output", show_label=True, height=512, columns=2, object_fit="contain" ) generate_btn.click( fn=generate_image, inputs=[prompt, negative_prompt, steps, seed], outputs=output_gallery ) gr.Examples( examples=[ ["A majestic wolf standing on a snowy mountain peak, cinematic lighting", "ugly, deformed, low detail", 30], ["Cyberpunk cat sitting in a neon-lit alley, 8k, digital art", "human, blurry, messy background", 40], ["A vintage photograph of a space shuttle launching from a tropical island", "modern, cartoon, painting", 25] ], inputs=[prompt, negative_prompt, steps], outputs=output_gallery, fn=generate_image, cache_examples=False, ) demo.queue() demo.launch()