Spaces:
Running
Running
attempts to fix
Browse files- app.py +16 -12
- body_renderer.py +7 -5
- dataset_utils.py +7 -3
- download_deps.py +3 -4
- tmed_denoiser.py +21 -14
app.py
CHANGED
|
@@ -76,8 +76,9 @@ class MotionEditor:
|
|
| 76 |
self.MFIX_p = download_motionfix() + '/motionfix'
|
| 77 |
# self.SOURCE_MOTS_p = download_embeddings() + '/embeddings'
|
| 78 |
self.MFIX_DATASET_DICT = download_motionfix_dataset()
|
| 79 |
-
self.model_ckpt_path = download_models("
|
| 80 |
-
self.
|
|
|
|
| 81 |
|
| 82 |
@spaces.GPU
|
| 83 |
def initialize_if_needed(self):
|
|
@@ -113,18 +114,20 @@ class MotionEditor:
|
|
| 113 |
self.infeats = self.model_config_feats
|
| 114 |
checkpoint = torch.load(model_ckpt, map_location=self.device)
|
| 115 |
checkpoint = {k.replace('denoiser.', ''): v for k, v in checkpoint.items()}
|
| 116 |
-
|
| 117 |
# Setup denoiser
|
| 118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
self.tmed_denoiser.load_state_dict(checkpoint, strict=False)
|
| 120 |
self.tmed_denoiser.eval()
|
| 121 |
-
|
| 122 |
# Setup diffusion
|
| 123 |
self.diffusion = create_diffusion(
|
| 124 |
timestep_respacing=None,
|
| 125 |
learn_sigma=False,
|
| 126 |
sigma_small=True,
|
| 127 |
-
diffusion_steps=
|
| 128 |
noise_schedule='squaredcos_cap_v2',
|
| 129 |
predict_xstart=True
|
| 130 |
)
|
|
@@ -144,6 +147,7 @@ class MotionEditor:
|
|
| 144 |
def process_motion(self, input_text, key_to_use):
|
| 145 |
"""Main processing function, GPU-decorated"""
|
| 146 |
self.initialize_if_needed()
|
|
|
|
| 147 |
# Load dataset sample
|
| 148 |
ds_sample = self.MFIX_DATASET_DICT[key_to_use]
|
| 149 |
|
|
@@ -192,9 +196,8 @@ class MotionEditor:
|
|
| 192 |
seqlen_tgt = target_motion.shape[0]
|
| 193 |
cond_motion_mask = torch.ones((bsz, seqlen_src), dtype=bool, device=self.device)
|
| 194 |
mask_target = torch.ones((bsz, seqlen_tgt), dtype=bool, device=self.device)
|
| 195 |
-
|
| 196 |
# Generate diffusion output
|
| 197 |
-
diff_out = self.
|
| 198 |
text_emb.to(self.device),
|
| 199 |
text_mask.to(self.device),
|
| 200 |
source_motion,
|
|
@@ -204,8 +207,8 @@ class MotionEditor:
|
|
| 204 |
init_vec=None,
|
| 205 |
init_from='noise',
|
| 206 |
gd_text=2.0,
|
| 207 |
-
gd_motion=
|
| 208 |
-
steps_num=
|
| 209 |
)
|
| 210 |
|
| 211 |
return self.denormalize_motion(diff_out)
|
|
@@ -213,13 +216,14 @@ class MotionEditor:
|
|
| 213 |
def denormalize_motion(self, diff_out):
|
| 214 |
"""Denormalize motion - called from within GPU-decorated function"""
|
| 215 |
from geometry_utils import diffout2motion
|
|
|
|
|
|
|
| 216 |
return diffout2motion(diff_out.permute(1, 0, 2), self.normalizer).squeeze()
|
| 217 |
|
| 218 |
def render_result(self, edited_motion, source_motion):
|
| 219 |
"""Render result - called from within GPU-decorated function"""
|
| 220 |
from body_renderer import get_render
|
| 221 |
from transform3d import transform_body_pose, rotate_body_degrees
|
| 222 |
-
# import ipdb; ipdb.set_trace()
|
| 223 |
# Transform motions
|
| 224 |
edited_motion_transformed = self.transform_motion(edited_motion)
|
| 225 |
source_motion_transformed = self.transform_motion(source_motion)
|
|
@@ -227,7 +231,7 @@ class MotionEditor:
|
|
| 227 |
# Render video
|
| 228 |
if os.path.exists('./output_movie.mp4'):
|
| 229 |
os.remove('./output_movie.mp4')
|
| 230 |
-
|
| 231 |
return get_render(
|
| 232 |
self.body_model,
|
| 233 |
[edited_motion_transformed['trans'].detach().cpu(),
|
|
|
|
| 76 |
self.MFIX_p = download_motionfix() + '/motionfix'
|
| 77 |
# self.SOURCE_MOTS_p = download_embeddings() + '/embeddings'
|
| 78 |
self.MFIX_DATASET_DICT = download_motionfix_dataset()
|
| 79 |
+
self.model_ckpt_path = download_models("899_bs128_zipped") # small_model_zipped_last/last_zipped
|
| 80 |
+
self.model_cfg = download_model_config('bs_128_conf') # small_model_config / big_model_config
|
| 81 |
+
self.model_config_feats = self.model_cfg.model.input_feats
|
| 82 |
|
| 83 |
@spaces.GPU
|
| 84 |
def initialize_if_needed(self):
|
|
|
|
| 114 |
self.infeats = self.model_config_feats
|
| 115 |
checkpoint = torch.load(model_ckpt, map_location=self.device)
|
| 116 |
checkpoint = {k.replace('denoiser.', ''): v for k, v in checkpoint.items()}
|
|
|
|
| 117 |
# Setup denoiser
|
| 118 |
+
|
| 119 |
+
self.tmed_denoiser = TMED_denoiser(latent_dim=self.model_cfg.model.latent_dim,
|
| 120 |
+
num_layers=8,
|
| 121 |
+
ff_size=1024,
|
| 122 |
+
num_heads=4).to(self.device)
|
| 123 |
self.tmed_denoiser.load_state_dict(checkpoint, strict=False)
|
| 124 |
self.tmed_denoiser.eval()
|
|
|
|
| 125 |
# Setup diffusion
|
| 126 |
self.diffusion = create_diffusion(
|
| 127 |
timestep_respacing=None,
|
| 128 |
learn_sigma=False,
|
| 129 |
sigma_small=True,
|
| 130 |
+
diffusion_steps=self.model_cfg.model.diff_params.num_train_timesteps,
|
| 131 |
noise_schedule='squaredcos_cap_v2',
|
| 132 |
predict_xstart=True
|
| 133 |
)
|
|
|
|
| 147 |
def process_motion(self, input_text, key_to_use):
|
| 148 |
"""Main processing function, GPU-decorated"""
|
| 149 |
self.initialize_if_needed()
|
| 150 |
+
# import ipdb; ipdb.set_trace()
|
| 151 |
# Load dataset sample
|
| 152 |
ds_sample = self.MFIX_DATASET_DICT[key_to_use]
|
| 153 |
|
|
|
|
| 196 |
seqlen_tgt = target_motion.shape[0]
|
| 197 |
cond_motion_mask = torch.ones((bsz, seqlen_src), dtype=bool, device=self.device)
|
| 198 |
mask_target = torch.ones((bsz, seqlen_tgt), dtype=bool, device=self.device)
|
|
|
|
| 199 |
# Generate diffusion output
|
| 200 |
+
diff_out = self.tmed_cenoiser._diffusion_reverse(
|
| 201 |
text_emb.to(self.device),
|
| 202 |
text_mask.to(self.device),
|
| 203 |
source_motion,
|
|
|
|
| 207 |
init_vec=None,
|
| 208 |
init_from='noise',
|
| 209 |
gd_text=2.0,
|
| 210 |
+
gd_motion=3.0,
|
| 211 |
+
steps_num=self.model_cfg.model.diff_params.num_train_timesteps
|
| 212 |
)
|
| 213 |
|
| 214 |
return self.denormalize_motion(diff_out)
|
|
|
|
| 216 |
def denormalize_motion(self, diff_out):
|
| 217 |
"""Denormalize motion - called from within GPU-decorated function"""
|
| 218 |
from geometry_utils import diffout2motion
|
| 219 |
+
# import ipdb; ipdb.set_trace()
|
| 220 |
+
|
| 221 |
return diffout2motion(diff_out.permute(1, 0, 2), self.normalizer).squeeze()
|
| 222 |
|
| 223 |
def render_result(self, edited_motion, source_motion):
|
| 224 |
"""Render result - called from within GPU-decorated function"""
|
| 225 |
from body_renderer import get_render
|
| 226 |
from transform3d import transform_body_pose, rotate_body_degrees
|
|
|
|
| 227 |
# Transform motions
|
| 228 |
edited_motion_transformed = self.transform_motion(edited_motion)
|
| 229 |
source_motion_transformed = self.transform_motion(source_motion)
|
|
|
|
| 231 |
# Render video
|
| 232 |
if os.path.exists('./output_movie.mp4'):
|
| 233 |
os.remove('./output_movie.mp4')
|
| 234 |
+
# import ipdb; ipdb.set_trace()
|
| 235 |
return get_render(
|
| 236 |
self.body_model,
|
| 237 |
[edited_motion_transformed['trans'].detach().cpu(),
|
body_renderer.py
CHANGED
|
@@ -30,7 +30,9 @@ def get_render(body_model_loaded,
|
|
| 30 |
if not isinstance(body_pose, list):
|
| 31 |
body_pose = [body_pose]
|
| 32 |
|
| 33 |
-
for trans, orient,pose in zip(body_trans,
|
|
|
|
|
|
|
| 34 |
|
| 35 |
vertices= run_smpl_fwd_vertices(body_model_loaded,
|
| 36 |
trans,
|
|
@@ -38,15 +40,15 @@ def get_render(body_model_loaded,
|
|
| 38 |
pose)
|
| 39 |
|
| 40 |
vertices=vertices.vertices
|
| 41 |
-
vertices = subsample_tensor(vertices, original_fps=30, target_fps=25)
|
| 42 |
vertices = vertices.detach().cpu().numpy()
|
| 43 |
vertices_list.append(vertices)
|
| 44 |
|
| 45 |
#Initialising the renderer
|
| 46 |
from renderer.humor import HumorRenderer
|
| 47 |
-
fps =
|
| 48 |
-
imw =
|
| 49 |
-
imh =
|
| 50 |
renderer = HumorRenderer(fps=fps, imw=imw, imh=imh)
|
| 51 |
|
| 52 |
if len(vertices_list)==2:
|
|
|
|
| 30 |
if not isinstance(body_pose, list):
|
| 31 |
body_pose = [body_pose]
|
| 32 |
|
| 33 |
+
for trans, orient, pose in zip(body_trans,
|
| 34 |
+
body_orient,
|
| 35 |
+
body_pose):
|
| 36 |
|
| 37 |
vertices= run_smpl_fwd_vertices(body_model_loaded,
|
| 38 |
trans,
|
|
|
|
| 40 |
pose)
|
| 41 |
|
| 42 |
vertices=vertices.vertices
|
| 43 |
+
# vertices = subsample_tensor(vertices, original_fps=30, target_fps=25)
|
| 44 |
vertices = vertices.detach().cpu().numpy()
|
| 45 |
vertices_list.append(vertices)
|
| 46 |
|
| 47 |
#Initialising the renderer
|
| 48 |
from renderer.humor import HumorRenderer
|
| 49 |
+
fps = 30.0
|
| 50 |
+
imw = 720 # 480
|
| 51 |
+
imh = 540 # 360
|
| 52 |
renderer = HumorRenderer(fps=fps, imw=imw, imh=imh)
|
| 53 |
|
| 54 |
if len(vertices_list)==2:
|
dataset_utils.py
CHANGED
|
@@ -10,13 +10,17 @@ def load_motionfix(path_to_data):
|
|
| 10 |
|
| 11 |
# Fill each dictionary with the corresponding data
|
| 12 |
for key in splits['train']:
|
| 13 |
-
|
|
|
|
| 14 |
|
| 15 |
for key in splits['val']:
|
| 16 |
-
|
|
|
|
|
|
|
| 17 |
|
| 18 |
for key in splits['test']:
|
| 19 |
-
|
|
|
|
| 20 |
validation_test_data = {**val_data, **test_data}
|
| 21 |
|
| 22 |
return train_data, validation_test_data
|
|
|
|
| 10 |
|
| 11 |
# Fill each dictionary with the corresponding data
|
| 12 |
for key in splits['train']:
|
| 13 |
+
if key in dataset:
|
| 14 |
+
train_data[key] = dataset[key]
|
| 15 |
|
| 16 |
for key in splits['val']:
|
| 17 |
+
if key in dataset:
|
| 18 |
+
|
| 19 |
+
val_data[key] = dataset[key]
|
| 20 |
|
| 21 |
for key in splits['test']:
|
| 22 |
+
if key in dataset:
|
| 23 |
+
test_data[key] = dataset[key]
|
| 24 |
validation_test_data = {**val_data, **test_data}
|
| 25 |
|
| 26 |
return train_data, validation_test_data
|
download_deps.py
CHANGED
|
@@ -16,13 +16,12 @@ def download_models(ckpt_to_dl):
|
|
| 16 |
REPO_ID = 'athn-nik/example-model'
|
| 17 |
return hf_hub_download(REPO_ID, filename=f"{ckpt_to_dl}.ckpt")
|
| 18 |
|
| 19 |
-
def download_model_config():
|
| 20 |
REPO_ID = 'athn-nik/example-model'
|
| 21 |
-
path_to_config = hf_hub_download(REPO_ID, filename="
|
| 22 |
from omegaconf import OmegaConf
|
| 23 |
model_cfg = OmegaConf.load(path_to_config)
|
| 24 |
-
return model_cfg
|
| 25 |
-
|
| 26 |
|
| 27 |
def download_motion_from_dataset(key_to_dl):
|
| 28 |
REPO_ID = 'athn-nik/example-model'
|
|
|
|
| 16 |
REPO_ID = 'athn-nik/example-model'
|
| 17 |
return hf_hub_download(REPO_ID, filename=f"{ckpt_to_dl}.ckpt")
|
| 18 |
|
| 19 |
+
def download_model_config(config_name):
|
| 20 |
REPO_ID = 'athn-nik/example-model'
|
| 21 |
+
path_to_config = hf_hub_download(REPO_ID, filename=f"{config_name}/.hydra/config.yaml")
|
| 22 |
from omegaconf import OmegaConf
|
| 23 |
model_cfg = OmegaConf.load(path_to_config)
|
| 24 |
+
return model_cfg
|
|
|
|
| 25 |
|
| 26 |
def download_motion_from_dataset(key_to_dl):
|
| 27 |
REPO_ID = 'athn-nik/example-model'
|
tmed_denoiser.py
CHANGED
|
@@ -17,6 +17,7 @@ class TMED_denoiser(nn.Module):
|
|
| 17 |
text_encoded_dim: int = 768,
|
| 18 |
pred_delta_motion: bool = False,
|
| 19 |
use_sep: bool = True,
|
|
|
|
| 20 |
**kwargs) -> None:
|
| 21 |
|
| 22 |
super().__init__()
|
|
@@ -28,6 +29,8 @@ class TMED_denoiser(nn.Module):
|
|
| 28 |
self.pose_proj_in_source = nn.Linear(nfeats, self.latent_dim)
|
| 29 |
self.pose_proj_in_target = nn.Linear(nfeats, self.latent_dim)
|
| 30 |
self.pose_proj_out = nn.Linear(self.latent_dim, nfeats)
|
|
|
|
|
|
|
| 31 |
|
| 32 |
# emb proj
|
| 33 |
if self.condition in ["text", "text_uncond"]:
|
|
@@ -47,8 +50,9 @@ class TMED_denoiser(nn.Module):
|
|
| 47 |
self.use_sep = use_sep
|
| 48 |
self.query_pos = PositionalEncoding(self.latent_dim, dropout)
|
| 49 |
self.mem_pos = PositionalEncoding(self.latent_dim, dropout)
|
| 50 |
-
if self.
|
| 51 |
-
|
|
|
|
| 52 |
|
| 53 |
# use torch transformer
|
| 54 |
encoder_layer = nn.TransformerEncoderLayer(
|
|
@@ -83,7 +87,7 @@ class TMED_denoiser(nn.Module):
|
|
| 83 |
|
| 84 |
# 1. time_embeddingno
|
| 85 |
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 86 |
-
timesteps = timestep.expand(noised_motion.shape[1]).clone()
|
| 87 |
time_emb = self.embed_timestep(timesteps).to(dtype=noised_motion.dtype)
|
| 88 |
# make it S first
|
| 89 |
# time_emb = self.time_embedding(time_emb).unsqueeze(0)
|
|
@@ -119,16 +123,19 @@ class TMED_denoiser(nn.Module):
|
|
| 119 |
# if self.diffusion_only:
|
| 120 |
proj_noised_motion = self.pose_proj_in_target(noised_motion)
|
| 121 |
|
| 122 |
-
if
|
| 123 |
-
|
| 124 |
-
sep_token_batch = torch.tile(self.sep_token, (bs,)).reshape(bs,
|
| 125 |
-
-1)
|
| 126 |
-
xseq = torch.cat((emb_latent, motion_embeds_proj,
|
| 127 |
-
sep_token_batch[None],
|
| 128 |
-
proj_noised_motion), axis=0)
|
| 129 |
else:
|
| 130 |
-
|
| 131 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
# if self.ablation_skip_connection:
|
| 133 |
# xseq = self.query_pos(xseq)
|
| 134 |
# tokens = self.encoder(xseq)
|
|
@@ -249,8 +256,8 @@ class TMED_denoiser(nn.Module):
|
|
| 249 |
mask_src_parts = inpaint_dict['mask'].unsqueeze(1).repeat(1,
|
| 250 |
mot_len,
|
| 251 |
1)
|
| 252 |
-
uncond_eps = uncond_eps*(
|
| 253 |
-
cond_eps_text = cond_eps_text*(
|
| 254 |
half_eps = uncond_eps + guidance_text_n_motion * (cond_eps_text - uncond_eps)
|
| 255 |
eps = torch.cat([half_eps, half_eps], dim=0)
|
| 256 |
else:
|
|
|
|
| 17 |
text_encoded_dim: int = 768,
|
| 18 |
pred_delta_motion: bool = False,
|
| 19 |
use_sep: bool = True,
|
| 20 |
+
motion_condition: str = 'source',
|
| 21 |
**kwargs) -> None:
|
| 22 |
|
| 23 |
super().__init__()
|
|
|
|
| 29 |
self.pose_proj_in_source = nn.Linear(nfeats, self.latent_dim)
|
| 30 |
self.pose_proj_in_target = nn.Linear(nfeats, self.latent_dim)
|
| 31 |
self.pose_proj_out = nn.Linear(self.latent_dim, nfeats)
|
| 32 |
+
self.first_pose_proj = nn.Linear(self.latent_dim, nfeats)
|
| 33 |
+
self.motion_condition = motion_condition
|
| 34 |
|
| 35 |
# emb proj
|
| 36 |
if self.condition in ["text", "text_uncond"]:
|
|
|
|
| 50 |
self.use_sep = use_sep
|
| 51 |
self.query_pos = PositionalEncoding(self.latent_dim, dropout)
|
| 52 |
self.mem_pos = PositionalEncoding(self.latent_dim, dropout)
|
| 53 |
+
if self.motion_condition == "source":
|
| 54 |
+
if self.use_sep:
|
| 55 |
+
self.sep_token = nn.Parameter(torch.randn(1, self.latent_dim))
|
| 56 |
|
| 57 |
# use torch transformer
|
| 58 |
encoder_layer = nn.TransformerEncoderLayer(
|
|
|
|
| 87 |
|
| 88 |
# 1. time_embeddingno
|
| 89 |
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 90 |
+
timesteps = timestep.expand(noised_motion.shape[1]).clone()
|
| 91 |
time_emb = self.embed_timestep(timesteps).to(dtype=noised_motion.dtype)
|
| 92 |
# make it S first
|
| 93 |
# time_emb = self.time_embedding(time_emb).unsqueeze(0)
|
|
|
|
| 123 |
# if self.diffusion_only:
|
| 124 |
proj_noised_motion = self.pose_proj_in_target(noised_motion)
|
| 125 |
|
| 126 |
+
if motion_embeds is None:
|
| 127 |
+
xseq = torch.cat((emb_latent, proj_noised_motion), axis=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
else:
|
| 129 |
+
if self.use_sep:
|
| 130 |
+
|
| 131 |
+
sep_token_batch = torch.tile(self.sep_token, (bs,)).reshape(bs,
|
| 132 |
+
-1)
|
| 133 |
+
xseq = torch.cat((emb_latent, motion_embeds_proj,
|
| 134 |
+
sep_token_batch[None],
|
| 135 |
+
proj_noised_motion), axis=0)
|
| 136 |
+
else:
|
| 137 |
+
xseq = torch.cat((emb_latent, motion_embeds_proj,
|
| 138 |
+
proj_noised_motion), axis=0)
|
| 139 |
# if self.ablation_skip_connection:
|
| 140 |
# xseq = self.query_pos(xseq)
|
| 141 |
# tokens = self.encoder(xseq)
|
|
|
|
| 256 |
mask_src_parts = inpaint_dict['mask'].unsqueeze(1).repeat(1,
|
| 257 |
mot_len,
|
| 258 |
1)
|
| 259 |
+
uncond_eps = uncond_eps*(mask_src_parts) + source_mot*(~mask_src_parts)
|
| 260 |
+
cond_eps_text = cond_eps_text*(mask_src_parts) + source_mot*(~mask_src_parts)
|
| 261 |
half_eps = uncond_eps + guidance_text_n_motion * (cond_eps_text - uncond_eps)
|
| 262 |
eps = torch.cat([half_eps, half_eps], dim=0)
|
| 263 |
else:
|