Spaces:
Runtime error
Runtime error
Commit
·
4dff355
1
Parent(s):
8214cae
cache the ckpt; fix bugs when input new video
Browse files- .gitignore +2 -1
- FateZero/test_fatezero.py +24 -18
- FateZero/video_diffusion/common/util.py +8 -2
- FateZero/video_diffusion/data/dataset.py +11 -5
- app_fatezero.py +4 -4
- inference_fatezero.py +84 -51
.gitignore
CHANGED
|
@@ -1 +1,2 @@
|
|
| 1 |
-
trash/*
|
|
|
|
|
|
| 1 |
+
trash/*
|
| 2 |
+
tmp
|
FateZero/test_fatezero.py
CHANGED
|
@@ -48,6 +48,10 @@ def test(
|
|
| 48 |
config: str,
|
| 49 |
pretrained_model_path: str,
|
| 50 |
train_dataset: Dict,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
logdir: str = None,
|
| 52 |
validation_sample_logger_config: Optional[Dict] = None,
|
| 53 |
test_pipeline_config: Optional[Dict] = None,
|
|
@@ -79,26 +83,28 @@ def test(
|
|
| 79 |
set_seed(seed)
|
| 80 |
|
| 81 |
# Load the tokenizer
|
| 82 |
-
tokenizer
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
|
|
|
| 87 |
|
| 88 |
# Load models and create wrapper for stable diffusion
|
| 89 |
-
text_encoder
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
vae
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
unet
|
| 100 |
-
|
| 101 |
-
|
|
|
|
| 102 |
|
| 103 |
if 'target' not in test_pipeline_config:
|
| 104 |
test_pipeline_config['target'] = 'video_diffusion.pipelines.stable_diffusion.SpatioTemporalStableDiffusionPipeline'
|
|
|
|
| 48 |
config: str,
|
| 49 |
pretrained_model_path: str,
|
| 50 |
train_dataset: Dict,
|
| 51 |
+
tokenizer = None,
|
| 52 |
+
text_encoder = None,
|
| 53 |
+
vae = None,
|
| 54 |
+
unet = None,
|
| 55 |
logdir: str = None,
|
| 56 |
validation_sample_logger_config: Optional[Dict] = None,
|
| 57 |
test_pipeline_config: Optional[Dict] = None,
|
|
|
|
| 83 |
set_seed(seed)
|
| 84 |
|
| 85 |
# Load the tokenizer
|
| 86 |
+
if tokenizer is None:
|
| 87 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 88 |
+
pretrained_model_path,
|
| 89 |
+
subfolder="tokenizer",
|
| 90 |
+
use_fast=False,
|
| 91 |
+
)
|
| 92 |
|
| 93 |
# Load models and create wrapper for stable diffusion
|
| 94 |
+
if text_encoder is None:
|
| 95 |
+
text_encoder = CLIPTextModel.from_pretrained(
|
| 96 |
+
pretrained_model_path,
|
| 97 |
+
subfolder="text_encoder",
|
| 98 |
+
)
|
| 99 |
+
if vae is None:
|
| 100 |
+
vae = AutoencoderKL.from_pretrained(
|
| 101 |
+
pretrained_model_path,
|
| 102 |
+
subfolder="vae",
|
| 103 |
+
)
|
| 104 |
+
if unet is None:
|
| 105 |
+
unet = UNetPseudo3DConditionModel.from_2d_model(
|
| 106 |
+
os.path.join(pretrained_model_path, "unet"), model_config=model_config
|
| 107 |
+
)
|
| 108 |
|
| 109 |
if 'target' not in test_pipeline_config:
|
| 110 |
test_pipeline_config['target'] = 'video_diffusion.pipelines.stable_diffusion.SpatioTemporalStableDiffusionPipeline'
|
FateZero/video_diffusion/common/util.py
CHANGED
|
@@ -4,7 +4,7 @@ import copy
|
|
| 4 |
import inspect
|
| 5 |
import datetime
|
| 6 |
from typing import List, Tuple, Optional, Dict
|
| 7 |
-
|
| 8 |
|
| 9 |
def glob_files(
|
| 10 |
root_path: str,
|
|
@@ -68,6 +68,12 @@ def get_time_string() -> str:
|
|
| 68 |
def get_function_args() -> Dict:
|
| 69 |
frame = sys._getframe(1)
|
| 70 |
args, _, _, values = inspect.getargvalues(frame)
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
return args_dict
|
|
|
|
| 4 |
import inspect
|
| 5 |
import datetime
|
| 6 |
from typing import List, Tuple, Optional, Dict
|
| 7 |
+
import torch
|
| 8 |
|
| 9 |
def glob_files(
|
| 10 |
root_path: str,
|
|
|
|
| 68 |
def get_function_args() -> Dict:
|
| 69 |
frame = sys._getframe(1)
|
| 70 |
args, _, _, values = inspect.getargvalues(frame)
|
| 71 |
+
tmp_dict = {}
|
| 72 |
+
for arg in args:
|
| 73 |
+
v = values[arg]
|
| 74 |
+
if not isinstance(v, torch.nn.Module) and arg !='tokenizer' :
|
| 75 |
+
tmp_dict[arg] = v
|
| 76 |
+
|
| 77 |
+
args_dict = copy.deepcopy(tmp_dict)
|
| 78 |
|
| 79 |
return args_dict
|
FateZero/video_diffusion/data/dataset.py
CHANGED
|
@@ -6,6 +6,7 @@ from einops import rearrange
|
|
| 6 |
from pathlib import Path
|
| 7 |
import imageio
|
| 8 |
import cv2
|
|
|
|
| 9 |
|
| 10 |
import torch
|
| 11 |
from torch.utils.data import Dataset
|
|
@@ -156,7 +157,7 @@ class ImageSequenceDataset(Dataset):
|
|
| 156 |
images = []
|
| 157 |
if path[-4:] == '.mp4':
|
| 158 |
path = self.mp4_to_png(path)
|
| 159 |
-
|
| 160 |
|
| 161 |
for file in sorted(os.listdir(path)):
|
| 162 |
if file.endswith(IMAGE_EXTENSION):
|
|
@@ -164,14 +165,19 @@ class ImageSequenceDataset(Dataset):
|
|
| 164 |
return images
|
| 165 |
|
| 166 |
# @staticmethod
|
|
|
|
| 167 |
def mp4_to_png(self, video_source=None):
|
| 168 |
reader = imageio.get_reader(video_source)
|
| 169 |
-
|
| 170 |
-
|
|
|
|
|
|
|
|
|
|
| 171 |
for i, im in enumerate(reader):
|
| 172 |
# use :05d to add zero, no space before the 05d
|
| 173 |
# if (i+1)%10 == 0:
|
| 174 |
-
path = os.path.join(
|
| 175 |
# print(path)
|
| 176 |
cv2.imwrite(path, im[:, :, ::-1])
|
| 177 |
-
|
|
|
|
|
|
| 6 |
from pathlib import Path
|
| 7 |
import imageio
|
| 8 |
import cv2
|
| 9 |
+
import shutil
|
| 10 |
|
| 11 |
import torch
|
| 12 |
from torch.utils.data import Dataset
|
|
|
|
| 157 |
images = []
|
| 158 |
if path[-4:] == '.mp4':
|
| 159 |
path = self.mp4_to_png(path)
|
| 160 |
+
|
| 161 |
|
| 162 |
for file in sorted(os.listdir(path)):
|
| 163 |
if file.endswith(IMAGE_EXTENSION):
|
|
|
|
| 165 |
return images
|
| 166 |
|
| 167 |
# @staticmethod
|
| 168 |
+
|
| 169 |
def mp4_to_png(self, video_source=None):
|
| 170 |
reader = imageio.get_reader(video_source)
|
| 171 |
+
dir_path = './tmp/fatezero_user_video'
|
| 172 |
+
if os.path.exists(dir_path):
|
| 173 |
+
shutil.rmtree(dir_path)
|
| 174 |
+
os.makedirs(dir_path, exist_ok=True)
|
| 175 |
+
|
| 176 |
for i, im in enumerate(reader):
|
| 177 |
# use :05d to add zero, no space before the 05d
|
| 178 |
# if (i+1)%10 == 0:
|
| 179 |
+
path = os.path.join(dir_path, f"{i:05d}.png")
|
| 180 |
# print(path)
|
| 181 |
cv2.imwrite(path, im[:, :, ::-1])
|
| 182 |
+
self.path = dir_path
|
| 183 |
+
return self.path
|
app_fatezero.py
CHANGED
|
@@ -28,7 +28,7 @@ from inference_fatezero import merge_config_then_run
|
|
| 28 |
# TITLE = '# [FateZero](http://fate-zero-edit.github.io/)'
|
| 29 |
HF_TOKEN = os.getenv('HF_TOKEN')
|
| 30 |
# pipe = InferencePipeline(HF_TOKEN)
|
| 31 |
-
|
| 32 |
# app = InferenceUtil(HF_TOKEN)
|
| 33 |
|
| 34 |
with gr.Blocks(css='style.css') as demo:
|
|
@@ -288,7 +288,7 @@ with gr.Blocks(css='style.css') as demo:
|
|
| 288 |
*ImageSequenceDataset_list
|
| 289 |
],
|
| 290 |
outputs=result,
|
| 291 |
-
fn=
|
| 292 |
cache_examples=os.getenv('SYSTEM') == 'spaces')
|
| 293 |
|
| 294 |
# model_id.change(fn=app.load_model_info,
|
|
@@ -312,8 +312,8 @@ with gr.Blocks(css='style.css') as demo:
|
|
| 312 |
*ImageSequenceDataset_list
|
| 313 |
]
|
| 314 |
# prompt.submit(fn=pipe.run, inputs=inputs, outputs=result)
|
| 315 |
-
target_prompt.submit(fn=
|
| 316 |
# run_button.click(fn=pipe.run, inputs=inputs, outputs=result)
|
| 317 |
-
run_button.click(fn=
|
| 318 |
|
| 319 |
demo.queue().launch()
|
|
|
|
| 28 |
# TITLE = '# [FateZero](http://fate-zero-edit.github.io/)'
|
| 29 |
HF_TOKEN = os.getenv('HF_TOKEN')
|
| 30 |
# pipe = InferencePipeline(HF_TOKEN)
|
| 31 |
+
pipe = merge_config_then_run()
|
| 32 |
# app = InferenceUtil(HF_TOKEN)
|
| 33 |
|
| 34 |
with gr.Blocks(css='style.css') as demo:
|
|
|
|
| 288 |
*ImageSequenceDataset_list
|
| 289 |
],
|
| 290 |
outputs=result,
|
| 291 |
+
fn=pipe.run,
|
| 292 |
cache_examples=os.getenv('SYSTEM') == 'spaces')
|
| 293 |
|
| 294 |
# model_id.change(fn=app.load_model_info,
|
|
|
|
| 312 |
*ImageSequenceDataset_list
|
| 313 |
]
|
| 314 |
# prompt.submit(fn=pipe.run, inputs=inputs, outputs=result)
|
| 315 |
+
target_prompt.submit(fn=pipe.run, inputs=inputs, outputs=result)
|
| 316 |
# run_button.click(fn=pipe.run, inputs=inputs, outputs=result)
|
| 317 |
+
run_button.click(fn=pipe.run, inputs=inputs, outputs=result)
|
| 318 |
|
| 319 |
demo.queue().launch()
|
inference_fatezero.py
CHANGED
|
@@ -4,8 +4,40 @@ from FateZero.test_fatezero import *
|
|
| 4 |
import copy
|
| 5 |
import gradio as gr
|
| 6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
model_id,
|
| 10 |
data_path,
|
| 11 |
source_prompt,
|
|
@@ -27,58 +59,59 @@ def merge_config_then_run(
|
|
| 27 |
top_crop=0,
|
| 28 |
bottom_crop=0,
|
| 29 |
):
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
# fatezero config
|
| 62 |
-
p2p_config_now = copy.deepcopy(config_now['validation_sample_logger_config']['p2p_config'][0])
|
| 63 |
-
p2p_config_now['cross_replace_steps']['default_'] = cross_replace_steps
|
| 64 |
-
p2p_config_now['self_replace_steps'] = self_replace_steps
|
| 65 |
-
p2p_config_now['eq_params']['words'] = enhance_words.split(" ")
|
| 66 |
-
p2p_config_now['eq_params']['values'] = [enhance_words_value,]*len(p2p_config_now['eq_params']['words'])
|
| 67 |
-
config_now['validation_sample_logger_config']['p2p_config'][0] = copy.deepcopy(p2p_config_now)
|
| 68 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
-
# ddim config
|
| 71 |
-
config_now['validation_sample_logger_config']['guidance_scale'] = guidance_scale
|
| 72 |
-
config_now['validation_sample_logger_config']['num_inference_steps'] = num_steps
|
| 73 |
-
|
| 74 |
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
mp4_path = save_path.replace('_0.gif', '_0_0_0.mp4')
|
| 80 |
-
return mp4_path
|
| 81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
|
| 83 |
-
if __name__ == "__main__":
|
| 84 |
-
run()
|
|
|
|
| 4 |
import copy
|
| 5 |
import gradio as gr
|
| 6 |
|
| 7 |
+
class merge_config_then_run():
|
| 8 |
+
def __init__(self) -> None:
|
| 9 |
+
# Load the tokenizer
|
| 10 |
+
pretrained_model_path = 'FateZero/ckpt/stable-diffusion-v1-4'
|
| 11 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 12 |
+
pretrained_model_path,
|
| 13 |
+
# 'FateZero/ckpt/stable-diffusion-v1-4',
|
| 14 |
+
subfolder="tokenizer",
|
| 15 |
+
use_fast=False,
|
| 16 |
+
)
|
| 17 |
|
| 18 |
+
# Load models and create wrapper for stable diffusion
|
| 19 |
+
self.text_encoder = CLIPTextModel.from_pretrained(
|
| 20 |
+
pretrained_model_path,
|
| 21 |
+
subfolder="text_encoder",
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
self.vae = AutoencoderKL.from_pretrained(
|
| 25 |
+
pretrained_model_path,
|
| 26 |
+
subfolder="vae",
|
| 27 |
+
)
|
| 28 |
+
model_config = {
|
| 29 |
+
"lora": 160,
|
| 30 |
+
# temporal_downsample_time: 4
|
| 31 |
+
"SparseCausalAttention_index": ['mid'],
|
| 32 |
+
"least_sc_channel": 640
|
| 33 |
+
}
|
| 34 |
+
self.unet = UNetPseudo3DConditionModel.from_2d_model(
|
| 35 |
+
os.path.join(pretrained_model_path, "unet"), model_config=model_config
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
def run(
|
| 39 |
+
self,
|
| 40 |
+
# def merge_config_then_run(
|
| 41 |
model_id,
|
| 42 |
data_path,
|
| 43 |
source_prompt,
|
|
|
|
| 59 |
top_crop=0,
|
| 60 |
bottom_crop=0,
|
| 61 |
):
|
| 62 |
+
# , ] = inputs
|
| 63 |
+
default_edit_config='FateZero/config/low_resource_teaser/jeep_watercolor_ddim_10_steps.yaml'
|
| 64 |
+
Omegadict_default_edit_config = OmegaConf.load(default_edit_config)
|
| 65 |
+
|
| 66 |
+
dataset_time_string = get_time_string()
|
| 67 |
+
config_now = copy.deepcopy(Omegadict_default_edit_config)
|
| 68 |
+
print(f"config_now['pretrained_model_path'] = model_id {model_id}")
|
| 69 |
+
# config_now['pretrained_model_path'] = model_id
|
| 70 |
+
config_now['train_dataset']['prompt'] = source_prompt
|
| 71 |
+
config_now['train_dataset']['path'] = data_path
|
| 72 |
+
# ImageSequenceDataset_dict = { }
|
| 73 |
+
offset_dict = {
|
| 74 |
+
"left": left_crop,
|
| 75 |
+
"right": right_crop,
|
| 76 |
+
"top": top_crop,
|
| 77 |
+
"bottom": bottom_crop,
|
| 78 |
+
}
|
| 79 |
+
ImageSequenceDataset_dict = {
|
| 80 |
+
"start_sample_frame" : start_sample_frame,
|
| 81 |
+
"n_sample_frame" : n_sample_frame,
|
| 82 |
+
"stride" : stride,
|
| 83 |
+
"offset": offset_dict,
|
| 84 |
+
}
|
| 85 |
+
config_now['train_dataset'].update(ImageSequenceDataset_dict)
|
| 86 |
+
if user_input_video and data_path is None:
|
| 87 |
+
raise gr.Error('You need to upload a video or choose a provided video')
|
| 88 |
+
if user_input_video is not None and user_input_video.name is not None:
|
| 89 |
+
config_now['train_dataset']['path'] = user_input_video.name
|
| 90 |
+
config_now['validation_sample_logger_config']['prompts'] = [target_prompt]
|
| 91 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
|
| 93 |
+
# fatezero config
|
| 94 |
+
p2p_config_now = copy.deepcopy(config_now['validation_sample_logger_config']['p2p_config'][0])
|
| 95 |
+
p2p_config_now['cross_replace_steps']['default_'] = cross_replace_steps
|
| 96 |
+
p2p_config_now['self_replace_steps'] = self_replace_steps
|
| 97 |
+
p2p_config_now['eq_params']['words'] = enhance_words.split(" ")
|
| 98 |
+
p2p_config_now['eq_params']['values'] = [enhance_words_value,]*len(p2p_config_now['eq_params']['words'])
|
| 99 |
+
config_now['validation_sample_logger_config']['p2p_config'][0] = copy.deepcopy(p2p_config_now)
|
| 100 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
|
| 102 |
+
# ddim config
|
| 103 |
+
config_now['validation_sample_logger_config']['guidance_scale'] = guidance_scale
|
| 104 |
+
config_now['validation_sample_logger_config']['num_inference_steps'] = num_steps
|
| 105 |
+
|
|
|
|
|
|
|
| 106 |
|
| 107 |
+
logdir = default_edit_config.replace('config', 'result').replace('.yml', '').replace('.yaml', '')+f'_{dataset_time_string}'
|
| 108 |
+
config_now['logdir'] = logdir
|
| 109 |
+
print(f'Saving at {logdir}')
|
| 110 |
+
save_path = test(tokenizer = self.tokenizer,
|
| 111 |
+
text_encoder = self.text_encoder,
|
| 112 |
+
vae = self.vae,
|
| 113 |
+
unet = self.unet,
|
| 114 |
+
config=default_edit_config, **config_now)
|
| 115 |
+
mp4_path = save_path.replace('_0.gif', '_0_0_0.mp4')
|
| 116 |
+
return mp4_path
|
| 117 |
|
|
|
|
|
|