atnikos commited on
Commit
10ff2d6
·
1 Parent(s): cfbdd27

attempts to fix

Browse files
Files changed (5) hide show
  1. app.py +16 -12
  2. body_renderer.py +7 -5
  3. dataset_utils.py +7 -3
  4. download_deps.py +3 -4
  5. tmed_denoiser.py +21 -14
app.py CHANGED
@@ -76,8 +76,9 @@ class MotionEditor:
76
  self.MFIX_p = download_motionfix() + '/motionfix'
77
  # self.SOURCE_MOTS_p = download_embeddings() + '/embeddings'
78
  self.MFIX_DATASET_DICT = download_motionfix_dataset()
79
- self.model_ckpt_path = download_models("last_zipped")
80
- self.model_config_feats = download_model_config()
 
81
 
82
  @spaces.GPU
83
  def initialize_if_needed(self):
@@ -113,18 +114,20 @@ class MotionEditor:
113
  self.infeats = self.model_config_feats
114
  checkpoint = torch.load(model_ckpt, map_location=self.device)
115
  checkpoint = {k.replace('denoiser.', ''): v for k, v in checkpoint.items()}
116
-
117
  # Setup denoiser
118
- self.tmed_denoiser = TMED_denoiser().to(self.device)
 
 
 
 
119
  self.tmed_denoiser.load_state_dict(checkpoint, strict=False)
120
  self.tmed_denoiser.eval()
121
-
122
  # Setup diffusion
123
  self.diffusion = create_diffusion(
124
  timestep_respacing=None,
125
  learn_sigma=False,
126
  sigma_small=True,
127
- diffusion_steps=300,
128
  noise_schedule='squaredcos_cap_v2',
129
  predict_xstart=True
130
  )
@@ -144,6 +147,7 @@ class MotionEditor:
144
  def process_motion(self, input_text, key_to_use):
145
  """Main processing function, GPU-decorated"""
146
  self.initialize_if_needed()
 
147
  # Load dataset sample
148
  ds_sample = self.MFIX_DATASET_DICT[key_to_use]
149
 
@@ -192,9 +196,8 @@ class MotionEditor:
192
  seqlen_tgt = target_motion.shape[0]
193
  cond_motion_mask = torch.ones((bsz, seqlen_src), dtype=bool, device=self.device)
194
  mask_target = torch.ones((bsz, seqlen_tgt), dtype=bool, device=self.device)
195
-
196
  # Generate diffusion output
197
- diff_out = self.tmed_denoiser._diffusion_reverse(
198
  text_emb.to(self.device),
199
  text_mask.to(self.device),
200
  source_motion,
@@ -204,8 +207,8 @@ class MotionEditor:
204
  init_vec=None,
205
  init_from='noise',
206
  gd_text=2.0,
207
- gd_motion=2.0,
208
- steps_num=300
209
  )
210
 
211
  return self.denormalize_motion(diff_out)
@@ -213,13 +216,14 @@ class MotionEditor:
213
  def denormalize_motion(self, diff_out):
214
  """Denormalize motion - called from within GPU-decorated function"""
215
  from geometry_utils import diffout2motion
 
 
216
  return diffout2motion(diff_out.permute(1, 0, 2), self.normalizer).squeeze()
217
 
218
  def render_result(self, edited_motion, source_motion):
219
  """Render result - called from within GPU-decorated function"""
220
  from body_renderer import get_render
221
  from transform3d import transform_body_pose, rotate_body_degrees
222
- # import ipdb; ipdb.set_trace()
223
  # Transform motions
224
  edited_motion_transformed = self.transform_motion(edited_motion)
225
  source_motion_transformed = self.transform_motion(source_motion)
@@ -227,7 +231,7 @@ class MotionEditor:
227
  # Render video
228
  if os.path.exists('./output_movie.mp4'):
229
  os.remove('./output_movie.mp4')
230
-
231
  return get_render(
232
  self.body_model,
233
  [edited_motion_transformed['trans'].detach().cpu(),
 
76
  self.MFIX_p = download_motionfix() + '/motionfix'
77
  # self.SOURCE_MOTS_p = download_embeddings() + '/embeddings'
78
  self.MFIX_DATASET_DICT = download_motionfix_dataset()
79
+ self.model_ckpt_path = download_models("899_bs128_zipped") # small_model_zipped_last/last_zipped
80
+ self.model_cfg = download_model_config('bs_128_conf') # small_model_config / big_model_config
81
+ self.model_config_feats = self.model_cfg.model.input_feats
82
 
83
  @spaces.GPU
84
  def initialize_if_needed(self):
 
114
  self.infeats = self.model_config_feats
115
  checkpoint = torch.load(model_ckpt, map_location=self.device)
116
  checkpoint = {k.replace('denoiser.', ''): v for k, v in checkpoint.items()}
 
117
  # Setup denoiser
118
+
119
+ self.tmed_denoiser = TMED_denoiser(latent_dim=self.model_cfg.model.latent_dim,
120
+ num_layers=8,
121
+ ff_size=1024,
122
+ num_heads=4).to(self.device)
123
  self.tmed_denoiser.load_state_dict(checkpoint, strict=False)
124
  self.tmed_denoiser.eval()
 
125
  # Setup diffusion
126
  self.diffusion = create_diffusion(
127
  timestep_respacing=None,
128
  learn_sigma=False,
129
  sigma_small=True,
130
+ diffusion_steps=self.model_cfg.model.diff_params.num_train_timesteps,
131
  noise_schedule='squaredcos_cap_v2',
132
  predict_xstart=True
133
  )
 
147
  def process_motion(self, input_text, key_to_use):
148
  """Main processing function, GPU-decorated"""
149
  self.initialize_if_needed()
150
+ # import ipdb; ipdb.set_trace()
151
  # Load dataset sample
152
  ds_sample = self.MFIX_DATASET_DICT[key_to_use]
153
 
 
196
  seqlen_tgt = target_motion.shape[0]
197
  cond_motion_mask = torch.ones((bsz, seqlen_src), dtype=bool, device=self.device)
198
  mask_target = torch.ones((bsz, seqlen_tgt), dtype=bool, device=self.device)
 
199
  # Generate diffusion output
200
+ diff_out = self.tmed_cenoiser._diffusion_reverse(
201
  text_emb.to(self.device),
202
  text_mask.to(self.device),
203
  source_motion,
 
207
  init_vec=None,
208
  init_from='noise',
209
  gd_text=2.0,
210
+ gd_motion=3.0,
211
+ steps_num=self.model_cfg.model.diff_params.num_train_timesteps
212
  )
213
 
214
  return self.denormalize_motion(diff_out)
 
216
  def denormalize_motion(self, diff_out):
217
  """Denormalize motion - called from within GPU-decorated function"""
218
  from geometry_utils import diffout2motion
219
+ # import ipdb; ipdb.set_trace()
220
+
221
  return diffout2motion(diff_out.permute(1, 0, 2), self.normalizer).squeeze()
222
 
223
  def render_result(self, edited_motion, source_motion):
224
  """Render result - called from within GPU-decorated function"""
225
  from body_renderer import get_render
226
  from transform3d import transform_body_pose, rotate_body_degrees
 
227
  # Transform motions
228
  edited_motion_transformed = self.transform_motion(edited_motion)
229
  source_motion_transformed = self.transform_motion(source_motion)
 
231
  # Render video
232
  if os.path.exists('./output_movie.mp4'):
233
  os.remove('./output_movie.mp4')
234
+ # import ipdb; ipdb.set_trace()
235
  return get_render(
236
  self.body_model,
237
  [edited_motion_transformed['trans'].detach().cpu(),
body_renderer.py CHANGED
@@ -30,7 +30,9 @@ def get_render(body_model_loaded,
30
  if not isinstance(body_pose, list):
31
  body_pose = [body_pose]
32
 
33
- for trans, orient,pose in zip(body_trans,body_orient,body_pose):
 
 
34
 
35
  vertices= run_smpl_fwd_vertices(body_model_loaded,
36
  trans,
@@ -38,15 +40,15 @@ def get_render(body_model_loaded,
38
  pose)
39
 
40
  vertices=vertices.vertices
41
- vertices = subsample_tensor(vertices, original_fps=30, target_fps=25)
42
  vertices = vertices.detach().cpu().numpy()
43
  vertices_list.append(vertices)
44
 
45
  #Initialising the renderer
46
  from renderer.humor import HumorRenderer
47
- fps = 25.0
48
- imw = 480 # 480
49
- imh = 360 # 360
50
  renderer = HumorRenderer(fps=fps, imw=imw, imh=imh)
51
 
52
  if len(vertices_list)==2:
 
30
  if not isinstance(body_pose, list):
31
  body_pose = [body_pose]
32
 
33
+ for trans, orient, pose in zip(body_trans,
34
+ body_orient,
35
+ body_pose):
36
 
37
  vertices= run_smpl_fwd_vertices(body_model_loaded,
38
  trans,
 
40
  pose)
41
 
42
  vertices=vertices.vertices
43
+ # vertices = subsample_tensor(vertices, original_fps=30, target_fps=25)
44
  vertices = vertices.detach().cpu().numpy()
45
  vertices_list.append(vertices)
46
 
47
  #Initialising the renderer
48
  from renderer.humor import HumorRenderer
49
+ fps = 30.0
50
+ imw = 720 # 480
51
+ imh = 540 # 360
52
  renderer = HumorRenderer(fps=fps, imw=imw, imh=imh)
53
 
54
  if len(vertices_list)==2:
dataset_utils.py CHANGED
@@ -10,13 +10,17 @@ def load_motionfix(path_to_data):
10
 
11
  # Fill each dictionary with the corresponding data
12
  for key in splits['train']:
13
- train_data[key] = dataset[key]
 
14
 
15
  for key in splits['val']:
16
- val_data[key] = dataset[key]
 
 
17
 
18
  for key in splits['test']:
19
- test_data[key] = dataset[key]
 
20
  validation_test_data = {**val_data, **test_data}
21
 
22
  return train_data, validation_test_data
 
10
 
11
  # Fill each dictionary with the corresponding data
12
  for key in splits['train']:
13
+ if key in dataset:
14
+ train_data[key] = dataset[key]
15
 
16
  for key in splits['val']:
17
+ if key in dataset:
18
+
19
+ val_data[key] = dataset[key]
20
 
21
  for key in splits['test']:
22
+ if key in dataset:
23
+ test_data[key] = dataset[key]
24
  validation_test_data = {**val_data, **test_data}
25
 
26
  return train_data, validation_test_data
download_deps.py CHANGED
@@ -16,13 +16,12 @@ def download_models(ckpt_to_dl):
16
  REPO_ID = 'athn-nik/example-model'
17
  return hf_hub_download(REPO_ID, filename=f"{ckpt_to_dl}.ckpt")
18
 
19
- def download_model_config():
20
  REPO_ID = 'athn-nik/example-model'
21
- path_to_config = hf_hub_download(REPO_ID, filename="tmed/.hydra/config.yaml")
22
  from omegaconf import OmegaConf
23
  model_cfg = OmegaConf.load(path_to_config)
24
- return model_cfg.model.input_feats
25
-
26
 
27
  def download_motion_from_dataset(key_to_dl):
28
  REPO_ID = 'athn-nik/example-model'
 
16
  REPO_ID = 'athn-nik/example-model'
17
  return hf_hub_download(REPO_ID, filename=f"{ckpt_to_dl}.ckpt")
18
 
19
+ def download_model_config(config_name):
20
  REPO_ID = 'athn-nik/example-model'
21
+ path_to_config = hf_hub_download(REPO_ID, filename=f"{config_name}/.hydra/config.yaml")
22
  from omegaconf import OmegaConf
23
  model_cfg = OmegaConf.load(path_to_config)
24
+ return model_cfg
 
25
 
26
  def download_motion_from_dataset(key_to_dl):
27
  REPO_ID = 'athn-nik/example-model'
tmed_denoiser.py CHANGED
@@ -17,6 +17,7 @@ class TMED_denoiser(nn.Module):
17
  text_encoded_dim: int = 768,
18
  pred_delta_motion: bool = False,
19
  use_sep: bool = True,
 
20
  **kwargs) -> None:
21
 
22
  super().__init__()
@@ -28,6 +29,8 @@ class TMED_denoiser(nn.Module):
28
  self.pose_proj_in_source = nn.Linear(nfeats, self.latent_dim)
29
  self.pose_proj_in_target = nn.Linear(nfeats, self.latent_dim)
30
  self.pose_proj_out = nn.Linear(self.latent_dim, nfeats)
 
 
31
 
32
  # emb proj
33
  if self.condition in ["text", "text_uncond"]:
@@ -47,8 +50,9 @@ class TMED_denoiser(nn.Module):
47
  self.use_sep = use_sep
48
  self.query_pos = PositionalEncoding(self.latent_dim, dropout)
49
  self.mem_pos = PositionalEncoding(self.latent_dim, dropout)
50
- if self.use_sep:
51
- self.sep_token = nn.Parameter(torch.randn(1, self.latent_dim))
 
52
 
53
  # use torch transformer
54
  encoder_layer = nn.TransformerEncoderLayer(
@@ -83,7 +87,7 @@ class TMED_denoiser(nn.Module):
83
 
84
  # 1. time_embeddingno
85
  # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
86
- timesteps = timestep.expand(noised_motion.shape[1]).clone().to(noised_motion.device)
87
  time_emb = self.embed_timestep(timesteps).to(dtype=noised_motion.dtype)
88
  # make it S first
89
  # time_emb = self.time_embedding(time_emb).unsqueeze(0)
@@ -119,16 +123,19 @@ class TMED_denoiser(nn.Module):
119
  # if self.diffusion_only:
120
  proj_noised_motion = self.pose_proj_in_target(noised_motion)
121
 
122
- if self.use_sep:
123
-
124
- sep_token_batch = torch.tile(self.sep_token, (bs,)).reshape(bs,
125
- -1)
126
- xseq = torch.cat((emb_latent, motion_embeds_proj,
127
- sep_token_batch[None],
128
- proj_noised_motion), axis=0)
129
  else:
130
- xseq = torch.cat((emb_latent, motion_embeds_proj,
131
- proj_noised_motion), axis=0)
 
 
 
 
 
 
 
 
132
  # if self.ablation_skip_connection:
133
  # xseq = self.query_pos(xseq)
134
  # tokens = self.encoder(xseq)
@@ -249,8 +256,8 @@ class TMED_denoiser(nn.Module):
249
  mask_src_parts = inpaint_dict['mask'].unsqueeze(1).repeat(1,
250
  mot_len,
251
  1)
252
- uncond_eps = uncond_eps*(~mask_src_parts) + source_mot*mask_src_parts
253
- cond_eps_text = cond_eps_text*(~mask_src_parts) + source_mot*mask_src_parts
254
  half_eps = uncond_eps + guidance_text_n_motion * (cond_eps_text - uncond_eps)
255
  eps = torch.cat([half_eps, half_eps], dim=0)
256
  else:
 
17
  text_encoded_dim: int = 768,
18
  pred_delta_motion: bool = False,
19
  use_sep: bool = True,
20
+ motion_condition: str = 'source',
21
  **kwargs) -> None:
22
 
23
  super().__init__()
 
29
  self.pose_proj_in_source = nn.Linear(nfeats, self.latent_dim)
30
  self.pose_proj_in_target = nn.Linear(nfeats, self.latent_dim)
31
  self.pose_proj_out = nn.Linear(self.latent_dim, nfeats)
32
+ self.first_pose_proj = nn.Linear(self.latent_dim, nfeats)
33
+ self.motion_condition = motion_condition
34
 
35
  # emb proj
36
  if self.condition in ["text", "text_uncond"]:
 
50
  self.use_sep = use_sep
51
  self.query_pos = PositionalEncoding(self.latent_dim, dropout)
52
  self.mem_pos = PositionalEncoding(self.latent_dim, dropout)
53
+ if self.motion_condition == "source":
54
+ if self.use_sep:
55
+ self.sep_token = nn.Parameter(torch.randn(1, self.latent_dim))
56
 
57
  # use torch transformer
58
  encoder_layer = nn.TransformerEncoderLayer(
 
87
 
88
  # 1. time_embeddingno
89
  # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
90
+ timesteps = timestep.expand(noised_motion.shape[1]).clone()
91
  time_emb = self.embed_timestep(timesteps).to(dtype=noised_motion.dtype)
92
  # make it S first
93
  # time_emb = self.time_embedding(time_emb).unsqueeze(0)
 
123
  # if self.diffusion_only:
124
  proj_noised_motion = self.pose_proj_in_target(noised_motion)
125
 
126
+ if motion_embeds is None:
127
+ xseq = torch.cat((emb_latent, proj_noised_motion), axis=0)
 
 
 
 
 
128
  else:
129
+ if self.use_sep:
130
+
131
+ sep_token_batch = torch.tile(self.sep_token, (bs,)).reshape(bs,
132
+ -1)
133
+ xseq = torch.cat((emb_latent, motion_embeds_proj,
134
+ sep_token_batch[None],
135
+ proj_noised_motion), axis=0)
136
+ else:
137
+ xseq = torch.cat((emb_latent, motion_embeds_proj,
138
+ proj_noised_motion), axis=0)
139
  # if self.ablation_skip_connection:
140
  # xseq = self.query_pos(xseq)
141
  # tokens = self.encoder(xseq)
 
256
  mask_src_parts = inpaint_dict['mask'].unsqueeze(1).repeat(1,
257
  mot_len,
258
  1)
259
+ uncond_eps = uncond_eps*(mask_src_parts) + source_mot*(~mask_src_parts)
260
+ cond_eps_text = cond_eps_text*(mask_src_parts) + source_mot*(~mask_src_parts)
261
  half_eps = uncond_eps + guidance_text_n_motion * (cond_eps_text - uncond_eps)
262
  eps = torch.cat([half_eps, half_eps], dim=0)
263
  else: