import gradio as gr import random import os import spaces import torch import time import json import numpy as np from diffusers import BriaFiboPipeline from diffusers.modular_pipelines import ModularPipeline import requests import io import base64 MAX_SEED = np.iinfo(np.int32).max dtype = torch.bfloat16 device = "cuda" if torch.cuda.is_available() else "cpu" vlm_pipe = ModularPipeline.from_pretrained("briaai/FIBO-VLM-prompt-to-JSON", trust_remote_code=True) # vlm_pipe = ModularPipeline.from_pretrained("briaai/FIBO-gemini-prompt-to-JSON", trust_remote_code=True) pipe = BriaFiboPipeline.from_pretrained("briaai/FIBO", trust_remote_code=True, torch_dtype=dtype).to(device) @spaces.GPU() def get_default_negative_prompt(existing_json: dict) -> str: negative_prompt = "" style_medium = existing_json.get("style_medium", "").lower() if style_medium in ["photograph", "photography", "photo"]: negative_prompt = """{'style_medium':'digital illustration','artistic_style':'non-realistic'}""" return negative_prompt # @spaces.GPU(duration=300) # def generate_json_prompt( # prompt, # prompt_inspire_image, # ): # with torch.inference_mode(): # if prompt_inspire_image is not None: # output = vlm_pipe(image=prompt_inspire_image, prompt="") # else: # output = vlm_pipe(prompt=prompt) # json_prompt = output.values["json_prompt"] # return json_prompt def generate_json_prompt( prompt, json_prompt=None, prompt_inspire_image=None, seed=42 ): api_key = os.environ.get("BRIA_API_TOKEN") url = "https://engine.prod.bria-api.com/v2/structured_prompt/generate/pro" payload = {"seed": seed, "sync": True} if json_prompt: payload["structured_prompt"] = json_prompt if prompt: payload["prompt"] = prompt if prompt_inspire_image: buffered = io.BytesIO() prompt_inspire_image.save(buffered, format="PNG") image_bytes = base64.b64encode(buffered.getvalue()).decode('utf-8') payload["images"] = [image_bytes] headers = {"Content-Type": "application/json", "api_token": api_key} response = requests.post(url, json=payload, headers=headers) data = response.json() return data["result"]["structured_prompt"] @spaces.GPU(duration=300) def generate_image( prompt, prompt_inspire_image, prompt_in_json, negative_prompt="", seed=42, randomize_seed=False, width=1024, height=1024, guidance_scale=5, num_inference_steps=50, progress=gr.Progress(track_tqdm=True), ): if randomize_seed: seed = random.randint(0, MAX_SEED) with torch.inference_mode(): # If JSON prompt is empty or None, generate it first if not prompt_in_json or prompt_in_json == "": if prompt_inspire_image is not None: #output = vlm_pipe(image=prompt_inspire_image, prompt="") json_prompt = generate_json_prompt(prompt="", prompt_inspire_image=prompt_inspire_image, seed=seed) else: # output = vlm_pipe(prompt=prompt) json_prompt = generate_json_prompt(prompt=prompt, seed=seed) #json_prompt = output.values["json_prompt"] else: # Use the provided JSON prompt json_prompt = ( json.dumps(prompt_in_json, indent=4) if isinstance(prompt_in_json, (dict, list)) else str(prompt_in_json) ) if negative_prompt: # neg_output = vlm_pipe(prompt=negative_prompt) # neg_json_prompt = neg_output.values["json_prompt"] neg_json_prompt = generate_json_prompt(prompt=negative_prompt, seed=seed) else: neg_json_prompt = get_default_negative_prompt(json.loads(json_prompt)) image = pipe( prompt=json_prompt, num_inference_steps=num_inference_steps, negative_prompt=neg_json_prompt, width=width, height=height, guidance_scale=guidance_scale, ).images[0] print(neg_json_prompt) return image, seed, json.dumps(json.loads(json_prompt), indent=4), json.dumps(neg_json_prompt, indent=4), gr.update(visible=True), image, json.dumps(json.loads(json_prompt), indent=4), seed, seed @spaces.GPU(duration=300) def refine_prompt( refine_instruction, refine_json, negative_prompt="", seed=42, randomize_seed=False, width=1024, height=1024, guidance_scale=5, num_inference_steps=50, progress=gr.Progress(track_tqdm=True), ): if randomize_seed: seed = random.randint(0, MAX_SEED) with torch.inference_mode(): # Parse the JSON string if it's a string if isinstance(refine_json, str): try: refine_json = json.loads(refine_json) if refine_json else {} except: refine_json = {} json_prompt_str = ( json.dumps(refine_json, indent=4) if isinstance(refine_json, (dict, list)) else str(refine_json) ) # output = vlm_pipe(json_prompt=json_prompt_str, prompt=refine_instruction) # json_prompt = output.values["json_prompt"] json_prompt = generate_json_prompt(prompt=refine_instruction, json_prompt=json_prompt_str, seed=seed) if negative_prompt: # neg_output = vlm_pipe(prompt=negative_prompt) # neg_json_prompt = neg_output.values["json_prompt"] neg_json_prompt = generate_json_prompt(prompt=negative_prompt, seed=seed) else: neg_json_prompt = get_default_negative_prompt(json.loads(json_prompt)) image = pipe( prompt=json_prompt, num_inference_steps=num_inference_steps, negative_prompt=neg_json_prompt, width=width, height=height, guidance_scale=guidance_scale, ).images[0] print(neg_json_prompt) return image, seed, json.dumps(json.loads(json_prompt), indent=4), json.dumps(neg_json_prompt, indent=4), image, json.dumps(json.loads(json_prompt), indent=4), seed css = """ #col-container { margin: 0 auto; max-width: 1000px; } #app-title { text-align: center; margin: 0 auto; width: 100%; font-weight: bold; } #json_prompt, #json_prompt_refine{max-height: 800px} """ with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="violet")) as demo: gr.Markdown(f"""# FIBO 8B param text-to-image model trained on structured JSON captions up to 1,000+ words. [[non-commercial license]()] [[arxiv](https://arxiv.org/)] [[model](https://huggingface.co/briaai/FIBO)] [[code](https://github.com/Bria-AI/FIBO)] """ , elem_id="app-title") # State to store the last used seed last_seed_state = gr.State(value=0) with gr.Row(elem_id="col-container"): with gr.Column(scale=1): with gr.Tabs() as tabs: with gr.Tab("Generate", id="generate_tab") as tab_generate: with gr.Accordion("Inspire from Image", open=False) as inspire_accordion: prompt_inspire_image = gr.Image( label="Inspiration Image", type="pil", ) prompt_generate = gr.Textbox( label="Prompt", placeholder="a man holding a goose screaming", lines=3 ) prompt_in_json = gr.Code( label="Editable JSON Prompt", language="json", wrap_lines=True, lines=10, elem_id="json_prompt" ) with gr.Row(): generate_json_btn = gr.Button("Generate JSON Prompt", variant="secondary") generate_image_btn = gr.Button("Generate Image", variant="primary") with gr.Tab("Refine existing image", id="refine_tab", visible=False) as tab_refine: previous_result = gr.Image(label="Previous Generation", interactive=False, visible=False) refine_instruction = gr.Textbox( label="Refinement instruction", placeholder="make the cat white", lines=2, info="Describe the changes you want to make" ) refine_json = gr.Code( label="Editable JSON Prompt", language="json", wrap_lines=True, elem_id="json_prompt_refine" ) with gr.Row(): refine_seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0) refine_randomize_seed = gr.Checkbox(label="Randomize seed", value=False, info="Keep off for consistent refinements") refine_btn = gr.Button("Refine Image", variant="primary") with gr.Accordion("Advanced Settings", open=False): with gr.Row(): seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0) randomize_seed = gr.Checkbox(label="Randomize seed", value=True) with gr.Row(): guidance_scale = gr.Slider(label="guidance scale", minimum=1.0, maximum=10.0, step=0.1, value=5.0) num_inference_steps = gr.Slider( label="number of inference steps", minimum=1, maximum=60, step=1, value=50 ) height = gr.Slider(label="Height", minimum=768, maximum=1248, step=32, value=1024) width = gr.Slider(label="Width", minimum=832, maximum=1344, step=64, value=1024) with gr.Row(): negative_prompt = gr.Textbox(label="negative prompt") negative_prompt_json = gr.Code(label="json negative prompt", language="json", wrap_lines=True) with gr.Column(scale=1): result = gr.Image(label="output") # Generate JSON prompt only generate_json_btn.click( fn=generate_json_prompt, inputs=[ prompt_generate, prompt_inspire_image, seed ], outputs=[prompt_in_json], ) # Generate image (generates JSON first if needed) generate_image_btn.click( fn=generate_image, inputs=[ prompt_generate, prompt_inspire_image, prompt_in_json, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, ], outputs=[result, seed, prompt_in_json, negative_prompt_json, tab_refine, previous_result, refine_json, last_seed_state], ).then( fn=lambda s: s, # Copy seed to refine tab inputs=[last_seed_state], outputs=[refine_seed] ) # Refine image with dedicated seed controls refine_btn.click( fn=refine_prompt, inputs=[ refine_instruction, refine_json, negative_prompt, refine_seed, refine_randomize_seed, width, height, guidance_scale, num_inference_steps, ], outputs=[result, refine_seed, refine_json, negative_prompt_json, previous_result, refine_json, last_seed_state], ) demo.queue().launch()