Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,8 +1,10 @@
|
|
| 1 |
|
|
|
|
| 2 |
import os
|
| 3 |
import json
|
| 4 |
import torch
|
| 5 |
import random
|
|
|
|
| 6 |
|
| 7 |
import gradio as gr
|
| 8 |
from glob import glob
|
|
@@ -147,8 +149,8 @@ class AnimateController:
|
|
| 147 |
raise gr.Error(f"Please select a pretrained model path.")
|
| 148 |
if motion_module_dropdown == "":
|
| 149 |
raise gr.Error(f"Please select a motion module.")
|
| 150 |
-
if base_model_dropdown == "":
|
| 151 |
-
|
| 152 |
|
| 153 |
if is_xformers_available(): self.unet.enable_xformers_memory_efficient_attention()
|
| 154 |
|
|
@@ -158,11 +160,13 @@ class AnimateController:
|
|
| 158 |
).to("cuda")
|
| 159 |
|
| 160 |
if self.lora_model_state_dict != {}:
|
| 161 |
-
|
|
|
|
| 162 |
|
| 163 |
pipeline.to("cuda")
|
| 164 |
|
| 165 |
-
|
|
|
|
| 166 |
else: torch.seed()
|
| 167 |
seed = torch.initial_seed()
|
| 168 |
|
|
@@ -259,7 +263,7 @@ def ui():
|
|
| 259 |
)
|
| 260 |
lora_model_dropdown.change(fn=controller.update_lora_model, inputs=[lora_model_dropdown], outputs=[lora_model_dropdown])
|
| 261 |
|
| 262 |
-
lora_alpha_slider = gr.Slider(label="LoRA alpha", value=0.
|
| 263 |
|
| 264 |
personalized_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
|
| 265 |
def update_personalized_model():
|
|
|
|
| 1 |
|
| 2 |
+
|
| 3 |
import os
|
| 4 |
import json
|
| 5 |
import torch
|
| 6 |
import random
|
| 7 |
+
import copy
|
| 8 |
|
| 9 |
import gradio as gr
|
| 10 |
from glob import glob
|
|
|
|
| 149 |
raise gr.Error(f"Please select a pretrained model path.")
|
| 150 |
if motion_module_dropdown == "":
|
| 151 |
raise gr.Error(f"Please select a motion module.")
|
| 152 |
+
# if base_model_dropdown == "":
|
| 153 |
+
# raise gr.Error(f"Please select a base DreamBooth model.")
|
| 154 |
|
| 155 |
if is_xformers_available(): self.unet.enable_xformers_memory_efficient_attention()
|
| 156 |
|
|
|
|
| 160 |
).to("cuda")
|
| 161 |
|
| 162 |
if self.lora_model_state_dict != {}:
|
| 163 |
+
print(f"Lora alpha: {lora_alpha_slider}")
|
| 164 |
+
pipeline = convert_lora(copy.deepcopy(pipeline), self.lora_model_state_dict, alpha=lora_alpha_slider)
|
| 165 |
|
| 166 |
pipeline.to("cuda")
|
| 167 |
|
| 168 |
+
seed_textbox = int(seed_textbox)
|
| 169 |
+
if seed_textbox != -1 and seed_textbox != "": torch.manual_seed(seed_textbox)
|
| 170 |
else: torch.seed()
|
| 171 |
seed = torch.initial_seed()
|
| 172 |
|
|
|
|
| 263 |
)
|
| 264 |
lora_model_dropdown.change(fn=controller.update_lora_model, inputs=[lora_model_dropdown], outputs=[lora_model_dropdown])
|
| 265 |
|
| 266 |
+
lora_alpha_slider = gr.Slider(label="LoRA alpha", value=0.7, minimum=0, maximum=2, interactive=True)
|
| 267 |
|
| 268 |
personalized_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
|
| 269 |
def update_personalized_model():
|