Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import spaces | |
| import os | |
| import numpy as np | |
| from PIL import Image | |
| from huggingface_hub import hf_hub_download | |
| from safetensors.torch import load_file | |
| from omegaconf import OmegaConf | |
| from image_datasets.dataset import image_resize | |
| from src.flux.util import load_ae, load_clip, load_flow_model2, load_t5, tensor_to_pil_image | |
| from src.flux.xflux_pipeline import XFluxSampler | |
| args = OmegaConf.load("inference_configs/inference.yaml") | |
| is_schnell = args.model_name == "flux-schnell" | |
| '/home/user/app/assets/0_camera_zoom/20486354.png' | |
| '/home/user/app/assets/0_camera_zoom/20486354.png' | |
| # sampler = None | |
| device = torch.device("cuda") | |
| dtype = torch.bfloat16 | |
| dit = load_flow_model2(args.model_name, device="cpu").to(device, dtype=dtype) | |
| vae = load_ae(args.model_name, device="cpu").to(device, dtype=dtype) | |
| t5 = load_t5(device="cpu", max_length=256 if is_schnell else 512).to(device, dtype=dtype) | |
| clip = load_clip("cpu").to(device, dtype=dtype) | |
| #test push | |
| def generate(image: Image.Image, edit_prompt: str): | |
| vae.requires_grad_(False) | |
| t5.requires_grad_(False) | |
| clip.requires_grad_(False) | |
| model_path = hf_hub_download( | |
| repo_id="Boese0601/ByteMorpher", | |
| filename="dit.safetensors", | |
| use_auth_token=os.getenv("HF_TOKEN") | |
| ) | |
| state_dict = load_file(model_path) | |
| dit.load_state_dict(state_dict) | |
| dit.eval() | |
| dit.to(device, dtype=dtype) | |
| sampler = XFluxSampler( | |
| clip=clip, | |
| t5=t5, | |
| ae=vae, | |
| model=dit, | |
| device=device, | |
| ip_loaded=False, | |
| spatial_condition=False, | |
| clip_image_processor=None, | |
| image_encoder=None, | |
| improj=None | |
| ) | |
| # global sampler | |
| # device = torch.device("cuda") | |
| # dtype = torch.bfloat16 | |
| # if sampler is None: | |
| # dit = load_flow_model2(args.model_name, device="cpu").to(device, dtype=dtype) | |
| # vae = load_ae(args.model_name, device="cpu").to(device, dtype=dtype) | |
| # t5 = load_t5(device="cpu", max_length=256 if is_schnell else 512).to(device, dtype=dtype) | |
| # clip = load_clip("cpu").to(device, dtype=dtype) | |
| # vae.requires_grad_(False) | |
| # t5.requires_grad_(False) | |
| # clip.requires_grad_(False) | |
| # model_path = hf_hub_download( | |
| # repo_id="Boese0601/ByteMorpher", | |
| # filename="dit.safetensors", | |
| # use_auth_token=os.getenv("HF_TOKEN") | |
| # ) | |
| # state_dict = load_file(model_path) | |
| # dit.load_state_dict(state_dict) | |
| # dit.eval() | |
| # sampler = XFluxSampler( | |
| # clip=clip, | |
| # t5=t5, | |
| # ae=vae, | |
| # model=dit, | |
| # device=device, | |
| # ip_loaded=False, | |
| # spatial_condition=False, | |
| # clip_image_processor=None, | |
| # image_encoder=None, | |
| # improj=None | |
| # ) | |
| img = image_resize(image, 512) | |
| w, h = img.size | |
| img = img.resize(((w // 32) * 32, (h // 32) * 32)) | |
| img = torch.from_numpy((np.array(img) / 127.5) - 1) | |
| img = img.permute(2, 0, 1).unsqueeze(0).to(device, dtype=dtype) | |
| result = sampler( | |
| prompt=edit_prompt, | |
| width=args.sample_width, | |
| height=args.sample_height, | |
| num_steps=args.sample_steps, | |
| image_prompt=None, | |
| true_gs=args.cfg_scale, | |
| seed=args.seed, | |
| ip_scale=args.ip_scale if args.use_ip else 1.0, | |
| source_image=img if args.use_spatial_condition else None, | |
| ) | |
| return tensor_to_pil_image(result) | |
| def get_samples(): | |
| sample_list = [ | |
| { | |
| "image": "assets/0_camera_zoom/20486354.png", | |
| "edit_prompt": "Zoom in on the coral and add a small blue fish in the background.", | |
| }, | |
| ] | |
| return [ | |
| [ | |
| Image.open(sample["image"]).resize((512, 512)), | |
| sample["edit_prompt"], | |
| ] | |
| for sample in sample_list | |
| ] | |
| header = """ | |
| # ByteMorph | |
| <div style="text-align: center; display: flex; justify-content: left; gap: 5px;"> | |
| <a href=""><img src="https://img.shields.io/badge/ariXv-Paper-A42C25.svg" alt="arXiv"></a> | |
| <a href="https://huggingface.co/datasets/Boese0601/ByteMorph-Bench"><img src="https://img.shields.io/badge/🤗-Model-ffbd45.svg" alt="HuggingFace"></a> | |
| <a href="https://github.com/Boese0601/ByteMorph"><img src="https://img.shields.io/badge/GitHub-Code-blue.svg?logo=github&" alt="GitHub"></a> | |
| </div> | |
| """ | |
| def create_app(): | |
| with gr.Blocks() as app: | |
| gr.Markdown(header, elem_id="header") | |
| with gr.Row(equal_height=False): | |
| with gr.Column(variant="panel", elem_classes="inputPanel"): | |
| original_image = gr.Image( | |
| type="pil", label="Condition Image", width=300, elem_id="input" | |
| ) | |
| edit_prompt = gr.Textbox(lines=2, label="Edit Prompt", elem_id="edit_prompt") | |
| submit_btn = gr.Button("Run", elem_id="submit_btn") | |
| with gr.Column(variant="panel", elem_classes="outputPanel"): | |
| output_image = gr.Image(type="pil", elem_id="output") | |
| with gr.Row(): | |
| examples = gr.Examples( | |
| examples=get_samples(), | |
| inputs=[original_image, edit_prompt], | |
| label="Examples", | |
| ) | |
| submit_btn.click( | |
| fn=generate, | |
| inputs=[original_image, edit_prompt], | |
| outputs=output_image, | |
| ) | |
| gr.HTML( | |
| """ | |
| <div style="text-align: center;"> | |
| * This demo's template was modified from <a href="https://arxiv.org/abs/2411.15098" target="_blank">OminiControl</a>. | |
| </div> | |
| """ | |
| ) | |
| return app | |
| if __name__ == "__main__": | |
| create_app().launch(debug=False, share=False, ssr_mode=False) | |