Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python | |
| from __future__ import annotations | |
| import os | |
| import gradio as gr | |
| from constants import MODEL_LIBRARY_ORG_NAME, SAMPLE_MODEL_REPO, UploadTarget | |
| from inference import InferencePipeline | |
| from trainer import Trainer | |
| def create_training_demo(trainer: Trainer, | |
| pipe: InferencePipeline | None = None) -> gr.Blocks: | |
| hf_token = os.getenv('HF_TOKEN') | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Box(): | |
| gr.Markdown('Training Data') | |
| training_video = gr.File(label='Training video') | |
| training_prompt = gr.Textbox( | |
| label='Training prompt', | |
| max_lines=1, | |
| placeholder='A man is surfing') | |
| gr.Markdown(''' | |
| - Upload a video and write a `Training Prompt` that describes the video. | |
| ''') | |
| with gr.Column(): | |
| with gr.Box(): | |
| gr.Markdown('Training Parameters') | |
| with gr.Row(): | |
| base_model = gr.Text(label='Base Model', | |
| value='CompVis/stable-diffusion-v1-4', | |
| max_lines=1) | |
| resolution = gr.Dropdown(choices=['512', '768'], | |
| value='512', | |
| label='Resolution', | |
| visible=False) | |
| token = gr.Text(label="Hugging Face Write Token", placeholder="", visible=False if hf_token else True) | |
| with gr.Accordion("Advanced settings", open=False): | |
| num_training_steps = gr.Number( | |
| label='Number of Training Steps', value=300, precision=0) | |
| learning_rate = gr.Number(label='Learning Rate', | |
| value=0.000035) | |
| gradient_accumulation = gr.Number( | |
| label='Number of Gradient Accumulation', | |
| value=1, | |
| precision=0) | |
| seed = gr.Slider(label='Seed', | |
| minimum=0, | |
| maximum=100000, | |
| step=1, | |
| randomize=True, | |
| value=0) | |
| fp16 = gr.Checkbox(label='FP16', value=True) | |
| use_8bit_adam = gr.Checkbox(label='Use 8bit Adam', value=False) | |
| checkpointing_steps = gr.Number(label='Checkpointing Steps', | |
| value=1000, | |
| precision=0) | |
| validation_epochs = gr.Number(label='Validation Epochs', | |
| value=100, | |
| precision=0) | |
| gr.Markdown(''' | |
| - The base model must be a Stable Diffusion model compatible with [diffusers](https://github.com/huggingface/diffusers) library. | |
| - Expected time to train a model for 300 steps: ~20 minutes with T4 | |
| - You can check the training status by pressing the "Open logs" button if you are running this on your Space. | |
| ''') | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown('Output Model') | |
| output_model_name = gr.Text(label='Name of your model', | |
| placeholder='The surfer man', | |
| max_lines=1) | |
| validation_prompt = gr.Text(label='Validation Prompt', placeholder='prompt to test the model, e.g: a dog is surfing') | |
| with gr.Column(): | |
| gr.Markdown('Upload Settings') | |
| with gr.Row(): | |
| upload_to_hub = gr.Checkbox( | |
| label='Upload model to Hub', value=True) | |
| use_private_repo = gr.Checkbox(label='Private', | |
| value=True) | |
| delete_existing_repo = gr.Checkbox( | |
| label='Delete existing repo of the same name', | |
| value=False) | |
| upload_to = gr.Radio( | |
| label='Upload to', | |
| choices=[_.value for _ in UploadTarget], | |
| value=UploadTarget.MODEL_LIBRARY.value) | |
| remove_gpu_after_training = gr.Checkbox( | |
| label='Remove GPU after training', | |
| value=False, | |
| interactive=bool(os.getenv('SPACE_ID')), | |
| visible=False) | |
| run_button = gr.Button('Start Training') | |
| with gr.Box(): | |
| gr.Markdown('Output message') | |
| output_message = gr.Markdown() | |
| if pipe is not None: | |
| run_button.click(fn=pipe.clear) | |
| run_button.click(fn=trainer.run, | |
| inputs=[ | |
| training_video, | |
| training_prompt, | |
| output_model_name, | |
| delete_existing_repo, | |
| validation_prompt, | |
| base_model, | |
| resolution, | |
| num_training_steps, | |
| learning_rate, | |
| gradient_accumulation, | |
| seed, | |
| fp16, | |
| use_8bit_adam, | |
| checkpointing_steps, | |
| validation_epochs, | |
| upload_to_hub, | |
| use_private_repo, | |
| delete_existing_repo, | |
| upload_to, | |
| remove_gpu_after_training, | |
| token | |
| ], | |
| outputs=output_message) | |
| return demo | |
| if __name__ == '__main__': | |
| hf_token = os.getenv('HF_TOKEN') | |
| trainer = Trainer(hf_token) | |
| demo = create_training_demo(trainer) | |
| demo.queue(max_size=1).launch(share=False) | |