Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import diffusers | |
| import random | |
| import json | |
| diffusers.utils.logging.set_verbosity_error() | |
| import torch | |
| from PIL import Image | |
| import numpy as np | |
| from unet.unet_controller import UNetController | |
| from main import load_unet_controller | |
| from unet import utils | |
| # Global flag to control interruption | |
| interrupt_flag = False | |
| def main_gradio(model_path, id_prompt, frame_prompt_list, precision, seed, window_length, alpha_weaken, beta_weaken, alpha_enhance, beta_enhance, ipca_drop_out, use_freeu, use_same_init_noise): | |
| global interrupt_flag | |
| interrupt_flag = False # Reset the flag at the start of the function | |
| if seed == -1: | |
| seed = random.randint(0, 2**32 - 1) | |
| frame_prompt_list = frame_prompt_list.split(",") | |
| pipe, _ = utils.load_pipe_from_path(model_path, "cuda:1", torch.float16 if precision == "fp16" else torch.float32, precision) | |
| if interrupt_flag: | |
| print("Generation interrupted") | |
| del pipe | |
| torch.cuda.empty_cache() | |
| if 'story_image' not in locals(): | |
| empty_image = Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8)) | |
| yield empty_image | |
| return | |
| unet_controller = load_unet_controller(pipe, "cuda:1") | |
| unet_controller.Alpha_enhance = alpha_enhance | |
| unet_controller.Beta_enhance = beta_enhance | |
| unet_controller.Alpha_weaken = alpha_weaken | |
| unet_controller.Beta_weaken = beta_weaken | |
| unet_controller.Ipca_dropout = ipca_drop_out | |
| unet_controller.Is_freeu_enabled = use_freeu | |
| unet_controller.Use_same_init_noise = use_same_init_noise | |
| import os | |
| from datetime import datetime | |
| current_time = datetime.now().strftime("%Y%m%d%H") | |
| current_time_ = datetime.now().strftime("%M%S") | |
| save_dir = os.path.join(".", f'result/{current_time}/{current_time_}_gradio_seed{seed}') | |
| os.makedirs(save_dir, exist_ok=True) | |
| generate = torch.Generator().manual_seed(seed) | |
| if unet_controller.Use_ipca is True: | |
| unet_controller.Store_qkv = True | |
| original_prompt_embeds_mode = unet_controller.Prompt_embeds_mode | |
| unet_controller.Prompt_embeds_mode = "original" | |
| _ = pipe(id_prompt, generator=generate, unet_controller=unet_controller).images | |
| unet_controller.Prompt_embeds_mode = original_prompt_embeds_mode | |
| unet_controller.Store_qkv = False | |
| max_window_length = utils.get_max_window_length(unet_controller, id_prompt, frame_prompt_list) | |
| window_length = min(window_length, max_window_length) | |
| if window_length < len(frame_prompt_list): | |
| movement_lists = utils.circular_sliding_windows(frame_prompt_list, window_length) | |
| else: | |
| movement_lists = [movement for movement in frame_prompt_list] | |
| story_image_list = [] | |
| generate = torch.Generator().manual_seed(seed) | |
| unet_controller.id_prompt = id_prompt | |
| for index, movement in enumerate(frame_prompt_list): | |
| if interrupt_flag: | |
| print("Generation interrupted") | |
| del pipe | |
| torch.cuda.empty_cache() | |
| if 'story_image' not in locals(): | |
| empty_image = Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8)) | |
| yield empty_image | |
| return | |
| if unet_controller is not None: | |
| if window_length < len(frame_prompt_list): | |
| unet_controller.frame_prompt_suppress = movement_lists[index][1:] | |
| unet_controller.frame_prompt_express = movement_lists[index][0] | |
| gen_propmts = [f'{id_prompt} {" ".join(movement_lists[index])}'] | |
| else: | |
| unet_controller.frame_prompt_suppress = movement_lists[:index] + movement_lists[index+1:] | |
| unet_controller.frame_prompt_express = movement_lists[index] | |
| gen_propmts = [f'{id_prompt} {" ".join(movement_lists)}'] | |
| else: | |
| gen_propmts = f'{id_prompt} {movement}' | |
| print(f"suppress: {unet_controller.frame_prompt_suppress}") | |
| print(f"express: {unet_controller.frame_prompt_express}") | |
| print(f'id_prompt: {id_prompt}') | |
| print(f"gen_propmts: {gen_propmts}") | |
| if unet_controller is not None and unet_controller.Use_same_init_noise is True: | |
| generate = torch.Generator().manual_seed(seed) | |
| images = pipe(gen_propmts, generator=generate, unet_controller=unet_controller).images | |
| story_image_list.append(images[0]) | |
| story_image = np.concatenate(story_image_list, axis=1) | |
| story_image = Image.fromarray(story_image.astype(np.uint8)) | |
| yield story_image | |
| import os | |
| images[0].save(os.path.join(save_dir, f'{id_prompt} {unet_controller.frame_prompt_express}.jpg')) | |
| story_image.save(os.path.join(save_dir, 'story_image.jpg')) | |
| import gc | |
| del pipe | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| # Gradio interface | |
| def gradio_interface(): | |
| global interrupt_flag | |
| with gr.Blocks() as demo: | |
| gr.Markdown("### Consistent Image Generation with 1Prompt1Story") | |
| # Load JSON data | |
| with open('./resource/example.json', 'r') as f: | |
| data = json.load(f) | |
| # Extract id_prompts and frame_prompts | |
| id_prompts = [item['id_prompt'] for item in data['combinations']] | |
| frame_prompts = [", ".join(item['frame_prompt_list']) for item in data['combinations']] | |
| # Input fields | |
| id_prompt = gr.Dropdown( | |
| label="ID Prompt", | |
| choices=id_prompts, | |
| value=id_prompts[0], | |
| allow_custom_value=True | |
| ) | |
| frame_prompt_list = gr.Dropdown( | |
| label="Frame Prompts (comma-separated)", | |
| choices=frame_prompts, | |
| value=frame_prompts[0], | |
| allow_custom_value=True | |
| ) | |
| model_path = gr.Dropdown( | |
| label="Model Path", | |
| choices=["stabilityai/stable-diffusion-xl-base-1.0", "RunDiffusion/Juggernaut-X-v10", "playgroundai/playground-v2.5-1024px-aesthetic", "SG161222/RealVisXL_V4.0", "RunDiffusion/Juggernaut-XI-v11", "SG161222/RealVisXL_V5.0"], | |
| value="playgroundai/playground-v2.5-1024px-aesthetic", | |
| allow_custom_value=True | |
| ) | |
| with gr.Row(): | |
| seed = gr.Slider(label="Seed (set -1 for random seed)", minimum=-1, maximum=10000, value=-1, step=1) | |
| window_length = gr.Slider(label="Window Length", minimum=1, maximum=20, value=10, step=1) | |
| with gr.Row(): | |
| alpha_weaken = gr.Number(label="Alpha Weaken", value=UNetController.Alpha_weaken, interactive=True, step=0.01) | |
| beta_weaken = gr.Number(label="Beta Weaken", value=UNetController.Beta_weaken, interactive=True, step=0.01) | |
| alpha_enhance = gr.Number(label="Alpha Enhance", value=UNetController.Alpha_enhance, interactive=True, step=0.001) | |
| beta_enhance = gr.Number(label="Beta Enhance", value=UNetController.Beta_enhance, interactive=True, step=0.1) | |
| with gr.Row(): | |
| ipca_drop_out = gr.Number(label="Ipca Dropout", value=UNetController.Ipca_dropout, interactive=True, step=0.1, minimum=0, maximum=1) | |
| precision = gr.Dropdown(label="Precision", choices=["fp16", "fp32"], value="fp16") | |
| use_freeu = gr.Dropdown(label="Use FreeU", choices=[False, True], value=UNetController.Is_freeu_enabled) | |
| use_same_init_noise = gr.Dropdown(label="Use Same Init Noise", choices=[True, False], value=UNetController.Use_same_init_noise) | |
| reset_button = gr.Button("Reset to Default") | |
| def reset_values(): | |
| return UNetController.Alpha_weaken, UNetController.Beta_weaken, UNetController.Alpha_enhance, UNetController.Beta_enhance, UNetController.Ipca_dropout, "fp16", UNetController.Is_freeu_enabled, UNetController.Use_same_init_noise | |
| reset_button.click( | |
| fn=reset_values, | |
| inputs=[], | |
| outputs=[alpha_weaken, beta_weaken, alpha_enhance, beta_enhance, ipca_drop_out, precision, use_freeu, use_same_init_noise] | |
| ) | |
| # Output | |
| output_gallery = gr.Image() | |
| # Buttons | |
| generate_button = gr.Button("Generate Images") | |
| interrupt_button = gr.Button("Interrupt") | |
| def interrupt_generation(): | |
| global interrupt_flag | |
| interrupt_flag = True | |
| interrupt_button.click( | |
| fn=interrupt_generation, | |
| inputs=[], | |
| outputs=[] | |
| ) | |
| generate_button.click( | |
| fn=main_gradio, | |
| inputs=[ | |
| model_path, id_prompt, frame_prompt_list, precision, seed, window_length, alpha_weaken, beta_weaken, alpha_enhance, beta_enhance, ipca_drop_out, use_freeu, use_same_init_noise | |
| ], | |
| outputs=output_gallery | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = gradio_interface() | |
| demo.launch(share=True) | |