Spaces:
Runtime error
Runtime error
Update opensora/serve/gradio_web_server.py
Browse files
opensora/serve/gradio_web_server.py
CHANGED
|
@@ -22,32 +22,38 @@ from opensora.models.ae import ae_stride_config, getae, getae_wrapper
|
|
| 22 |
from opensora.models.ae.videobase import CausalVQVAEModelWrapper, CausalVAEModelWrapper
|
| 23 |
from opensora.models.diffusion.latte.modeling_latte import LatteT2V
|
| 24 |
from opensora.sample.pipeline_videogen import VideoGenPipeline
|
| 25 |
-
from opensora.serve.gradio_utils import block_css, title_markdown, randomize_seed_fn, set_env,
|
| 26 |
-
|
| 27 |
|
| 28 |
@spaces.GPU(duration=300)
|
|
|
|
|
|
|
| 29 |
@torch.inference_mode()
|
| 30 |
def generate_img(prompt, sample_steps, scale, seed=0, randomize_seed=False, force_images=False):
|
| 31 |
-
seed = int(randomize_seed_fn(seed, randomize_seed))
|
| 32 |
-
set_env(seed)
|
| 33 |
video_length = transformer_model.config.video_length if not force_images else 1
|
| 34 |
height, width = int(args.version.split('x')[1]), int(args.version.split('x')[2])
|
| 35 |
num_frames = 1 if video_length == 1 else int(args.version.split('x')[0])
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
display_model_info = f"Video size: {num_frames}×{height}×{width}, \nSampling Step: {sample_steps}, \nGuidance Scale: {scale}"
|
| 52 |
return tmp_save_path, prompt, display_model_info, seed
|
| 53 |
|
|
|
|
| 22 |
from opensora.models.ae.videobase import CausalVQVAEModelWrapper, CausalVAEModelWrapper
|
| 23 |
from opensora.models.diffusion.latte.modeling_latte import LatteT2V
|
| 24 |
from opensora.sample.pipeline_videogen import VideoGenPipeline
|
| 25 |
+
from opensora.serve.gradio_utils import block_css, title_markdown, randomize_seed_fn, set_env, DESCRIPTION
|
| 26 |
+
from opensora.serve.gradio_utils import examples_txt, examples
|
| 27 |
|
| 28 |
@spaces.GPU(duration=300)
|
| 29 |
+
@torch.inference_mode()
|
| 30 |
+
|
| 31 |
@torch.inference_mode()
|
| 32 |
def generate_img(prompt, sample_steps, scale, seed=0, randomize_seed=False, force_images=False):
|
|
|
|
|
|
|
| 33 |
video_length = transformer_model.config.video_length if not force_images else 1
|
| 34 |
height, width = int(args.version.split('x')[1]), int(args.version.split('x')[2])
|
| 35 |
num_frames = 1 if video_length == 1 else int(args.version.split('x')[0])
|
| 36 |
+
if not force_images and prompt in examples_txt:
|
| 37 |
+
idx = examples_txt.index(prompt)
|
| 38 |
+
tmp_save_path = f'demo65-221/f65/{idx+1}.mp4'
|
| 39 |
+
else:
|
| 40 |
+
seed = int(randomize_seed_fn(seed, randomize_seed))
|
| 41 |
+
set_env(seed)
|
| 42 |
+
videos = videogen_pipeline(prompt,
|
| 43 |
+
num_frames=num_frames,
|
| 44 |
+
height=height,
|
| 45 |
+
width=width,
|
| 46 |
+
num_inference_steps=sample_steps,
|
| 47 |
+
guidance_scale=scale,
|
| 48 |
+
enable_temporal_attentions=not force_images,
|
| 49 |
+
num_images_per_prompt=1,
|
| 50 |
+
mask_feature=True,
|
| 51 |
+
).video
|
| 52 |
|
| 53 |
+
torch.cuda.empty_cache()
|
| 54 |
+
videos = videos[0]
|
| 55 |
+
tmp_save_path = 'tmp.mp4'
|
| 56 |
+
imageio.mimwrite(tmp_save_path, videos, fps=24, quality=9) # highest quality is 10, lowest is 0
|
| 57 |
display_model_info = f"Video size: {num_frames}×{height}×{width}, \nSampling Step: {sample_steps}, \nGuidance Scale: {scale}"
|
| 58 |
return tmp_save_path, prompt, display_model_info, seed
|
| 59 |
|