atnikos commited on
Commit
6837c8b
·
1 Parent(s): 777e3d5

fix retrieval placeholders

Browse files
Files changed (7) hide show
  1. app.py +170 -94
  2. gen_utils.py +62 -2
  3. model_utils.py +35 -1
  4. renderer.py +138 -0
  5. requirements.txt +1 -0
  6. retrieval_loader.py +67 -0
  7. tmr_model.py +128 -0
app.py CHANGED
@@ -3,40 +3,62 @@ import gradio as gr
3
  import spaces
4
  import torch
5
  import random
 
 
 
 
 
 
 
 
 
6
 
7
  zero = torch.Tensor([0]).cuda()
8
- print(zero.device) # <-- 'cpu' 🤔
9
- # G&uumll Varol
10
  DEFAULT_TEXT = "A person is "
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- WEBSITE = """
13
- <div class="embed_hidden">
14
- <h1 style='text-align: center'> ACRONYM: The actual title </h1>
15
-
16
- <h2 style='text-align: center'>
17
- <a href="https://google.com" target="_blank"><nobr>fname m. lname</nobr></a> &emsp;
18
- <a href="https://google.com" target="_blank"><nobr>fname m. lname</nobr></a> &emsp;
19
- <a href="https://google.com" target="_blank"><nobr>fname m. lname</nobr></a>
20
- </h2>
21
-
22
- <h2 style='text-align: center'>
23
- <nobr>XXX 2024</nobr>
24
- </h2>
25
-
26
- <h3 style="text-align:center;">
27
- <a target="_blank" href="https://arxiv.org/"> <button type="button" class="btn btn-primary btn-lg"> Paper </button></a>
28
- <a target="_blank" href="https://github.com/"> <button type="button" class="btn btn-primary btn-lg"> Code </button></a>
29
- <a target="_blank" href="google.com"> <button type="button" class="btn btn-primary btn-lg"> Webpage </button></a>
30
- <a target="_blank" href="bibfile.com"> <button type="button" class="btn btn-primary btn-lg"> BibTex </button></a>
31
- </h3>
32
-
33
- <h3> Description </h3>
34
- <p>
35
- This space illustrates <a href='project.com' target='_blank'><b>XXX</b></a>, a method for XXX.
36
- What does it do?
37
- </p>
38
  </div>
39
- """
40
 
41
  @spaces.GPU
42
  def greet(n):
@@ -50,82 +72,136 @@ def greet(n):
50
  def clear():
51
  return ""
52
 
53
- def random_number():
54
- return str(random.uniform(0, 100))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  from huggingface_hub import hf_hub_download, hf_hub_url, cached_download
56
 
57
  def download_models():
58
  REPO_ID = 'athn-nik/example-model'
59
-
60
  return hf_hub_download(REPO_ID, filename="min_checkpoint.ckpt")
61
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  with gr.Blocks() as demo:
63
  gr.Markdown(WEBSITE)
64
-
65
- input_text = gr.Textbox(placeholder="Type the edit text you want:",
66
- show_label=True,label="Input Text", value=DEFAULT_TEXT)
67
- # output_text = gr.Textbox(label="Output Text")
68
 
69
  with gr.Row():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  retrieve_button = gr.Button("Retrieve")
71
- clear_button = gr.Button("Clear")
72
  random_button = gr.Button("Random")
73
- from normalization import Normalizer
74
- normalizer = Normalizer()
75
- # tmed_den = load_model()
76
- from diffusion import create_diffusion
77
- from text_encoder import ClipTextEncoder
78
- from tmed_denoiser import TMED_denoiser
79
- model_ckpt = download_models()
80
- checkpoint = torch.load(model_ckpt)
81
-
82
- checkpoint = {k.replace('denoiser.', ''): v for k, v in checkpoint.items()}
83
- tmed_denoiser = TMED_denoiser().to('cuda')
84
- tmed_denoiser.load_state_dict(checkpoint, strict=False)
85
- tmed_denoiser.eval()
86
- text_encoder = ClipTextEncoder()
87
- texts_cond = [input_text.value]
88
-
89
- diffusion_process = create_diffusion(timestep_respacing=None,
90
- learn_sigma=False, sigma_small=True,
91
- diffusion_steps=300,
92
- noise_schedule='squaredcos_cap_v2',
93
- predict_xstart=True) # noise vs sample
94
- # uncond_tokens = [""] * len(texts_cond)
95
- # if self.condition == 'text':
96
- # uncond_tokens.extend(texts_cond)
97
- # elif self.condition == 'text_uncond':
98
- # uncond_tokens.extend(uncond_tokens)
99
- bsz = 1
100
- seqlen_tgt = 180
101
- no_of_texts = len(texts_cond)
102
- texts_cond = ['']*no_of_texts + texts_cond
103
- texts_cond = ['']*no_of_texts + texts_cond
104
- print(texts_cond)
105
- text_emb, text_mask = text_encoder(texts_cond)
106
-
107
- cond_emb_motion = torch.zeros(1, bsz,
108
- 512,
109
- device='cuda')
110
- cond_motion_mask = torch.ones((bsz, 1),
111
- dtype=bool, device='cuda')
112
- mask_target = torch.ones((1, bsz),
113
- dtype=bool, device='cuda')
114
- # complete noise
115
- # import ipdb;ipdb.set_trace()
116
- diff_out = tmed_denoiser._diffusion_reverse(text_emb.to(cond_emb_motion.device),
117
- text_mask.to(cond_emb_motion.device),
118
- cond_emb_motion,
119
- cond_motion_mask,
120
- mask_target,
121
- diffusion_process,
122
- init_vec=None,
123
- init_from='noise',
124
- gd_text=4.0,
125
- gd_motion=2.0,
126
- steps_num=300)
127
- edited_motion = diffout2motion(diff_out, normalizer)
128
- clear_button.click(clear, outputs=input_text)
129
  random_button.click(random_number, outputs=input_text)
130
 
131
  demo.launch()
 
3
  import spaces
4
  import torch
5
  import random
6
+ import os
7
+ from pathlib import Path
8
+ from aitviewer.headless import HeadlessRenderer
9
+ from aitviewer.configuration import CONFIG as AITVIEWER_CONFIG
10
+ # import cv2
11
+ # import moderngl
12
+ # ctx = moderngl.create_context(standalone=True)
13
+ # print(ctx)
14
+ access_token_smpl = os.environ.get('HF_SMPL_TOKEN')
15
 
16
  zero = torch.Tensor([0]).cuda()
17
+ print(zero.device) # <-- 'cuda:0' 🤗
18
+
19
  DEFAULT_TEXT = "A person is "
20
+ from aitviewer.models.smpl import SMPLLayer
21
+ def get_smpl_models():
22
+ REPO_ID = 'athn-nik/smpl_models'
23
+ from huggingface_hub import snapshot_download
24
+ return snapshot_download(repo_id=REPO_ID, allow_patterns="smplh*",
25
+ token=access_token_smpl)
26
+
27
+ def get_renderer():
28
+ from aitviewer.headless import HeadlessRenderer
29
+ from aitviewer.configuration import CONFIG as AITVIEWER_CONFIG
30
+ smpl_models_path = str(Path(get_smpl_models()))
31
+ AITVIEWER_CONFIG.update_conf({'playback_fps': 30,
32
+ 'auto_set_floor': True,
33
+ 'smplx_models': smpl_models_path,
34
+ 'z_up': True})
35
+ return HeadlessRenderer()
36
+
37
 
38
+
39
+ WEBSITE = ("""<div class="embed_hidden" style="text-align: center;">
40
+ <h1>MotionFix: Text-Driven 3D Human Motion Editing</h1>
41
+ <h3>
42
+ <a href="https://is.mpg.de/person/~nathanasiou" target="_blank" rel="noopener noreferrer">Nikos Athanasiou</a><sup>1</sup>,
43
+ <a href="https://is.mpg.de/person/acseke" target="_blank" rel="noopener noreferrer">Alpar Cseke</a><sup>1</sup>,
44
+ <br>
45
+ <a href="https://ps.is.mpg.de/person/mdiomataris" target="_blank" rel="noopener noreferrer">Markos Diomataris</a><sup>1, 3</sup>,
46
+ <a href="https://is.mpg.de/person/black" target="_blank" rel="noopener noreferrer">Michael J. Black</a><sup>1</sup>,
47
+ <a href="https://imagine.enpc.fr/~varolg/" target="_blank" rel="noopener noreferrer">G&uuml;l Varol</a><sup>2</sup>,
48
+ </h3>
49
+ <h3>
50
+ <sup>1</sup>Max Planck Institute for Intelligent Systems, T&uuml;bingen, Germany;
51
+ <sup>2</sup>LIGM, &Eacute;cole des Ponts, Univ Gustave Eiffel, CNRS, France,
52
+ <sup>3</sup>ETH Z&uuml;rich, Switzerland
53
+ </h3>
54
+ </div>
55
+ <div style="display:flex; gap: 0.3rem; justify-content: center; align-items: center;" align="center">
56
+ <a href='https://arxiv.org/abs/'><img src='https://img.shields.io/badge/Arxiv-2405.20340-A42C25?style=flat&logo=arXiv&logoColor=A42C25'></a>
57
+ <a href='https://arxiv.org/pdf/'><img src='https://img.shields.io/badge/Paper-PDF-yellow?style=flat&logo=arXiv&logoColor=yellow'></a>
58
+ <a href='https://motionfix.is.tue.mpg.de'><img src='https://img.shields.io/badge/Project-Page-%23df5b46?style=flat&logo=Google%20chrome&logoColor=%23df5b46'></a>
59
+ <a href='https://youtube.com/'><img src='https://img.shields.io/badge/YouTube-red?style=flat&logo=youtube&logoColor=white'></a>
 
 
 
 
60
  </div>
61
+ """)
62
 
63
  @spaces.GPU
64
  def greet(n):
 
72
  def clear():
73
  return ""
74
 
75
+ def show_video(input_text):
76
+ from normalization import Normalizer
77
+ normalizer = Normalizer()
78
+ from diffusion import create_diffusion
79
+ from text_encoder import ClipTextEncoder
80
+ from tmed_denoiser import TMED_denoiser
81
+ model_ckpt = download_models()
82
+ checkpoint = torch.load(model_ckpt)
83
+
84
+ checkpoint = {k.replace('denoiser.', ''): v for k, v in checkpoint.items()}
85
+ tmed_denoiser = TMED_denoiser().to('cuda')
86
+ tmed_denoiser.load_state_dict(checkpoint, strict=False)
87
+ tmed_denoiser.eval()
88
+ text_encoder = ClipTextEncoder()
89
+ texts_cond = [input_text]
90
+
91
+ diffusion_process = create_diffusion(timestep_respacing=None,
92
+ learn_sigma=False, sigma_small=True,
93
+ diffusion_steps=300,
94
+ noise_schedule='squaredcos_cap_v2',
95
+ predict_xstart=True)
96
+ bsz = 1
97
+ seqlen_tgt = 180
98
+ no_of_texts = len(texts_cond)
99
+ texts_cond = ['']*no_of_texts + texts_cond
100
+ texts_cond = ['']*no_of_texts + texts_cond
101
+ text_emb, text_mask = text_encoder(texts_cond)
102
+
103
+ cond_emb_motion = torch.zeros(seqlen_tgt, bsz,
104
+ 512,
105
+ device='cuda')
106
+ cond_motion_mask = torch.ones((bsz, seqlen_tgt),
107
+ dtype=bool, device='cuda')
108
+ mask_target = torch.ones((bsz, seqlen_tgt),
109
+ dtype=bool, device='cuda')
110
+
111
+ diff_out = tmed_denoiser._diffusion_reverse(text_emb.to(cond_emb_motion.device),
112
+ text_mask.to(cond_emb_motion.device),
113
+ cond_emb_motion,
114
+ cond_motion_mask,
115
+ mask_target,
116
+ diffusion_process,
117
+ init_vec=None,
118
+ init_from='noise',
119
+ gd_text=4.0,
120
+ gd_motion=2.0,
121
+ steps_num=300)
122
+ edited_motion = diffout2motion(diff_out, normalizer).squeeze()
123
+ from renderer import render_motion, color_map, pack_to_render
124
+ # aitrenderer = get_renderer()
125
+ AIT_RENDERER = get_renderer()
126
+ SMPL_LAYER = SMPLLayer(model_type='smplh', ext='npz', gender='neutral')
127
+ edited_mot_to_render = pack_to_render(rots=edited_motion[..., 3:],
128
+ trans=edited_motion[..., :3])
129
+ import random
130
+ xx = random.randint(1, 1000)
131
+ fname = render_motion(AIT_RENDERER, [edited_mot_to_render],
132
+ f"movie_example--{str(xx)}",
133
+ pose_repr='aa',
134
+ color=[color_map['generated']],
135
+ smpl_layer=SMPL_LAYER)
136
+ return fname
137
+ def retrieve_video(retrieve_text):
138
+ pass
139
+
140
  from huggingface_hub import hf_hub_download, hf_hub_url, cached_download
141
 
142
  def download_models():
143
  REPO_ID = 'athn-nik/example-model'
 
144
  return hf_hub_download(REPO_ID, filename="min_checkpoint.ckpt")
145
+
146
+ def download_tmr():
147
+ REPO_ID = 'athn-nik/example-model'
148
+ # return hf_hub_download(REPO_ID, filename="min_checkpoint.ckpt")
149
+ from huggingface_hub import snapshot_download
150
+ return snapshot_download(repo_id=REPO_ID, allow_patterns="tmr*",
151
+ token=access_token_smpl)
152
+ import gradio as gr
153
+
154
+ def clear():
155
+ return ""
156
+
157
+ def random_number():
158
+ return "Random text"
159
+
160
  with gr.Blocks() as demo:
161
  gr.Markdown(WEBSITE)
 
 
 
 
162
 
163
  with gr.Row():
164
+ with gr.Column(scale=8):
165
+ retrieve_text = gr.Textbox(placeholder="Type the text for the motion you want to Retrieve:",
166
+ show_label=True, label="Retrieval Text", value=DEFAULT_TEXT)
167
+ with gr.Column(scale=1):
168
+ clear_button_retrieval = gr.Button("Clear Retrieval Text")
169
+
170
+ with gr.Row():
171
+ with gr.Column(scale=8):
172
+ input_text = gr.Textbox(placeholder="Type the edit text you want:",
173
+ show_label=True, label="Input Text", value=DEFAULT_TEXT)
174
+ with gr.Column(scale=1):
175
+ clear_button_edit = gr.Button("Clear Edit Text")
176
+
177
+ with gr.Row():
178
+ video_output = gr.Video(label="Generated Video", height=240, width=320)
179
+ retrieved_video_output = gr.Video(label="Retrieved Motion", height=240, width=320)
180
+
181
+ with gr.Row():
182
+ edit_button = gr.Button("Edit")
183
  retrieve_button = gr.Button("Retrieve")
184
+
185
  random_button = gr.Button("Random")
186
+
187
+ def process_and_show_video(input_text):
188
+ fname = show_video(input_text)
189
+ return fname
190
+
191
+ def process_and_retrieve_video(input_text):
192
+ fname = retrieve_video(input_text)
193
+ return fname
194
+
195
+ from gen_utils import read_config
196
+ from retrieval_loader import load_model_from_cfg
197
+ from retrieval_loader import get_tmr_model
198
+ tmr = get_tmr_model(download_tmr())
199
+ edit_button.click(process_and_show_video, inputs=input_text, outputs=video_output)
200
+ retrieve_button.click(process_and_retrieve_video, inputs=retrieve_text, outputs=retrieved_video_output)
201
+ # import ipdb;ipdb.set_trace()
202
+
203
+ clear_button_edit.click(clear, outputs=input_text)
204
+ clear_button_retrieval.click(clear, outputs=retrieve_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  random_button.click(random_number, outputs=input_text)
206
 
207
  demo.launch()
gen_utils.py CHANGED
@@ -1,5 +1,10 @@
1
  import torch
2
  import numpy as np
 
 
 
 
 
3
  def cast_dict_to_tensors(d, device="cpu"):
4
  if isinstance(d, dict):
5
  return {k: cast_dict_to_tensors(v, device) for k, v in d.items()}
@@ -9,5 +14,60 @@ def cast_dict_to_tensors(d, device="cpu"):
9
  return d.to(device)
10
  else:
11
  return d
12
-
13
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  import numpy as np
3
+ import logging
4
+ import os
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
  def cast_dict_to_tensors(d, device="cpu"):
9
  if isinstance(d, dict):
10
  return {k: cast_dict_to_tensors(v, device) for k, v in d.items()}
 
14
  return d.to(device)
15
  else:
16
  return d
17
+
18
+ def rgba(c: str):
19
+ from matplotlib import colors as mcolors
20
+ return mcolors.to_rgba(c)
21
+
22
+ def rgb(c: str):
23
+ from matplotlib import colors as mcolors
24
+ return mcolors.to_rgb(c)
25
+
26
+ # split the lightning checkpoint into
27
+ # seperate state_dict modules for faster loading
28
+ def extract_ckpt(run_dir, ckpt_name="last"):
29
+ import torch
30
+
31
+ ckpt_path = os.path.join(run_dir, f"logs/checkpoints/{ckpt_name}.ckpt")
32
+
33
+ extracted_path = os.path.join(run_dir, f"{ckpt_name}_weights")
34
+ os.makedirs(extracted_path, exist_ok=True)
35
+
36
+ new_path_template = os.path.join(extracted_path, "{}.pt")
37
+ ckpt_dict = torch.load(ckpt_path)
38
+ state_dict = ckpt_dict["state_dict"]
39
+ module_names = list(set([x.split(".")[0] for x in state_dict.keys()]))
40
+
41
+ # should be ['motion_encoder', 'text_encoder', 'motion_decoder'] for example
42
+ for module_name in module_names:
43
+ path = new_path_template.format(module_name)
44
+ sub_state_dict = {
45
+ ".".join(x.split(".")[1:]): y.cpu()
46
+ for x, y in state_dict.items()
47
+ if x.split(".")[0] == module_name
48
+ }
49
+ torch.save(sub_state_dict, path)
50
+
51
+ import os
52
+ import json
53
+ from omegaconf import DictConfig, OmegaConf
54
+
55
+
56
+ def save_config(cfg: DictConfig) -> str:
57
+ path = os.path.join(cfg.run_dir, "config.json")
58
+ config = OmegaConf.to_container(cfg, resolve=True)
59
+ with open(path, "w") as f:
60
+ string = json.dumps(config, indent=4)
61
+ f.write(string)
62
+ return path
63
+
64
+
65
+ def read_config(run_dir: str, return_json=False) -> DictConfig:
66
+ path = os.path.join(run_dir, "config.json")
67
+ with open(path, "r") as f:
68
+ config = json.load(f)
69
+ if return_json:
70
+ return config
71
+ cfg = OmegaConf.create(config)
72
+ cfg.run_dir = run_dir
73
+ return cfg
model_utils.py CHANGED
@@ -3,6 +3,10 @@ import torch.nn as nn
3
  import numpy as np
4
  import torch
5
  from torch import nn
 
 
 
 
6
 
7
  class TimestepEmbedderMDM(nn.Module):
8
  def __init__(self, latent_dim):
@@ -61,4 +65,34 @@ class PositionalEncoding(nn.Module):
61
  else:
62
  last = first + x.shape[0]
63
  x = x + self.pe[first:last, :]
64
- return self.dropout(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import numpy as np
4
  import torch
5
  from torch import nn
6
+ import torch
7
+
8
+ from typing import List, Dict, Optional
9
+ from torch import Tensor
10
 
11
  class TimestepEmbedderMDM(nn.Module):
12
  def __init__(self, latent_dim):
 
65
  else:
66
  last = first + x.shape[0]
67
  x = x + self.pe[first:last, :]
68
+ return self.dropout(x)
69
+
70
+ def collate_tensor_with_padding(batch: List[Tensor]) -> Tensor:
71
+ dims = batch[0].dim()
72
+ max_size = [max([b.size(i) for b in batch]) for i in range(dims)]
73
+ size = (len(batch),) + tuple(max_size)
74
+ canvas = batch[0].new_zeros(size=size)
75
+ for i, b in enumerate(batch):
76
+ sub_tensor = canvas[i]
77
+ for d in range(dims):
78
+ sub_tensor = sub_tensor.narrow(d, 0, b.size(d))
79
+ sub_tensor.add_(b)
80
+ return canvas
81
+
82
+
83
+ def collate_x_dict(lst_x_dict: List, *, device: Optional[str] = 'cuda') -> Dict:
84
+ x = collate_tensor_with_padding([x_dict["x"] for x_dict in lst_x_dict])
85
+ if device is not None:
86
+ x = x.to(device)
87
+ length = [x_dict["length"] for x_dict in lst_x_dict]
88
+
89
+ if isinstance(length, list):
90
+ length = torch.tensor(length, device=device)
91
+
92
+ max_len = max(length)
93
+ mask = torch.arange(max_len, device=device).expand(
94
+ len(length), max_len
95
+ ) < length.unsqueeze(1)
96
+
97
+ batch = {"x": x, "length": length, "mask": mask}
98
+ return batch
renderer.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from transform3d import transform_body_pose
4
+ from aitviewer.headless import HeadlessRenderer
5
+ from gen_utils import rgb, rgba
6
+
7
+ color_map = {
8
+ 'source_motion': rgba('darkred'),
9
+ 'source': rgba('darkred'),
10
+ 'target_motion': rgba('olivedrab'),
11
+ 'input': rgba('olivedrab'),
12
+ 'target': rgba('olivedrab'),
13
+ 'generation': rgba('purple'),
14
+ 'generated': rgba('steelblue'),
15
+ 'denoised': rgba('purple'),
16
+ 'noised': rgba('darkgrey'),
17
+ }
18
+
19
+
20
+ def pack_to_render(rots, trans, pose_repr='6d'):
21
+ # make axis-angle
22
+ # global_orient = transform_body_pose(rots, f"{pose_repr}->aa")
23
+
24
+ if rots.is_cuda:
25
+ rots = rots.detach().cpu()
26
+ if trans.is_cuda:
27
+ trans = trans.detach().cpu()
28
+
29
+ if pose_repr != 'aa':
30
+ body_pose = transform_body_pose(rots, f"{pose_repr}->aa")
31
+ else:
32
+ body_pose = rots
33
+ if trans is None:
34
+ trans = torch.zeros((rots.shape[0], rots.shape[1], 3),
35
+ device=rots.device)
36
+ render_d = {'body_transl': trans,
37
+ 'body_orient': body_pose[..., :3],
38
+ 'body_pose': body_pose[..., 3:]}
39
+ return render_d
40
+
41
+
42
+ def render_motion(renderer: HeadlessRenderer, datum: dict,
43
+ filename: str, pose_repr='6d',
44
+ color=(160 / 255, 160 / 255, 160 / 255, 1.0),
45
+ return_verts=False, smpl_layer=None) -> None:
46
+ """
47
+ Function to render a video of a motion sequence
48
+ renderer: aitviewer renderer
49
+ datum: dictionary containing sequence of poses, body translations and body orientations
50
+ data could be numpy or pytorch tensors
51
+ filename: the absolute path you want the video to be saved at
52
+
53
+ """
54
+ from aitviewer.headless import HeadlessRenderer
55
+ from aitviewer.renderables.smpl import SMPLSequence
56
+
57
+ if isinstance(datum, dict): datum = [datum]
58
+ if not isinstance(color, list):
59
+ colors = [color]
60
+ else:
61
+ colors = color
62
+ # assert {'body_transl', 'body_orient', 'body_pose'}.issubset(set(datum[0].keys()))
63
+ # os.environ['DISPLAY'] = ":11"
64
+ gender = 'neutral'
65
+ only_skel = False
66
+ import sys
67
+ seqs_of_human_motions = []
68
+ if smpl_layer is None:
69
+ from aitviewer.models.smpl import SMPLLayer
70
+ smpl_layer = SMPLLayer(model_type='smplh',
71
+ ext='npz',
72
+ gender=gender)
73
+
74
+ for iid, mesh_seq in enumerate(datum):
75
+
76
+ if pose_repr != 'aa':
77
+ global_orient = transform_body_pose(mesh_seq['body_orient'],
78
+ f"{pose_repr}->aa")
79
+ body_pose = transform_body_pose(mesh_seq['body_pose'],
80
+ f"{pose_repr}->aa")
81
+ else:
82
+ global_orient = mesh_seq['body_orient']
83
+ body_pose = mesh_seq['body_pose']
84
+
85
+ body_transl = mesh_seq['body_transl']
86
+ sys.stdout.flush()
87
+
88
+ old = os.dup(1)
89
+ os.close(1)
90
+ os.open(os.devnull, os.O_WRONLY)
91
+ print(body_pose.shape)
92
+ print('\n')
93
+ smpl_template = SMPLSequence(body_pose,
94
+ smpl_layer,
95
+ poses_root=global_orient,
96
+ trans=body_transl,
97
+ color=colors[iid],
98
+ z_up=True)
99
+ if only_skel:
100
+ smpl_template.remove(smpl_template.mesh_seq)
101
+
102
+ seqs_of_human_motions.append(smpl_template)
103
+ renderer.scene.add(smpl_template)
104
+ # camera follows smpl sequence
105
+ # FIX CAMERA
106
+ from transform3d import get_z_rot
107
+ R_z = get_z_rot(global_orient[0], in_format='aa')
108
+ heading = -R_z[:, 1]
109
+ xy_facing = body_transl[0] + heading*2.5
110
+ camera = renderer.lock_to_node(seqs_of_human_motions[0],
111
+ (xy_facing[0], xy_facing[1], 1.5), smooth_sigma=5.0)
112
+
113
+ # /FIX CAMERA
114
+ if len(mesh_seq['body_pose']) == 1:
115
+ renderer.save_frame(file_path=str(filename) + '.png')
116
+ sfx = 'png'
117
+ else:
118
+ renderer.save_video(video_dir=str(filename), output_fps=30)
119
+ sfx = 'mp4'
120
+
121
+ # aitviewer adds a counter to the filename, we remove it
122
+ # filename.split('_')[-1].replace('.mp4', '')
123
+ # os.rename(filename + '_0.mp4', filename[:-4] + '.mp4')
124
+ if sfx == 'mp4':
125
+ os.rename(str(filename) + f'_0.{sfx}', str(filename) + f'.{sfx}')
126
+
127
+ # empty scene for the next rendering
128
+ for mesh in seqs_of_human_motions:
129
+ renderer.scene.remove(mesh)
130
+ renderer.scene.remove(camera)
131
+
132
+ sys.stdout.flush()
133
+ os.close(1)
134
+ os.dup(old)
135
+ os.close(old)
136
+ renderer.reset()
137
+ fname = f'{filename}.{sfx}'
138
+ return fname
requirements.txt CHANGED
@@ -2,3 +2,4 @@ spaces
2
  gradio==4.36.1
3
  torch
4
  transformers==4.41.2
 
 
2
  gradio==4.36.1
3
  torch
4
  transformers==4.41.2
5
+ hydra-core
retrieval_loader.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from gen_utils import extract_ckpt
2
+ import hydra
3
+ import os
4
+ from hydra.utils import instantiate
5
+ from gen_utils import read_config
6
+ from model_utils import collate_x_dict
7
+ import torch
8
+ from tmr_model import TMR_textencoder
9
+
10
+
11
+ def load_model_from_cfg(cfg, ckpt_name="last", device="cuda", eval_mode=True):
12
+ import src.prepare # noqa
13
+ import torch
14
+
15
+ run_dir = cfg.run_dir
16
+ model = hydra.utils.instantiate(cfg.model)
17
+
18
+ # Loading modules one by one
19
+ # motion_encoder / text_encoder / text_decoder
20
+ pt_path = os.path.join(run_dir, f"{ckpt_name}_weights")
21
+
22
+ if not os.path.exists(pt_path):
23
+ extract_ckpt(run_dir, ckpt_name)
24
+
25
+ for fname in os.listdir(pt_path):
26
+ module_name, ext = os.path.splitext(fname)
27
+ if ext != ".pt":
28
+ continue
29
+
30
+ module = getattr(model, module_name, None)
31
+ if module is None:
32
+ continue
33
+
34
+ module_path = os.path.join(pt_path, fname)
35
+ state_dict = torch.load(module_path)
36
+ module.load_state_dict(state_dict)
37
+ model = model.to(device)
38
+ if eval_mode:
39
+ model = model.eval()
40
+ return model
41
+
42
+ # def get_tmr_model(run_dir):
43
+ # from gen_utils import read_config
44
+ # cfg = read_config(run_dir+'/tmr')
45
+ # import ipdb;ipdb.set_trace()
46
+ # text_model = instantiate(cfg.data.text_to_token_emb, device='cuda')
47
+ # model = load_model_from_cfg(cfg, 'last', eval_mode=True, device='cuda')
48
+ # return text_model, model
49
+
50
+ def get_tmr_model(run_dir):
51
+ text_params = {
52
+ "latent_dim": 256,
53
+ "ff_size": 1024,
54
+ "num_layers": 6,
55
+ "num_heads": 4,
56
+ "activation": "gelu",
57
+ "modelpath": "distilbert-base-uncased",
58
+ }
59
+ "unit_motion_embs"
60
+ model = TMR_textencoder(**text_params)
61
+ state_dict = torch.load(f"{run_dir}/tmr/last_weights/text_encoder.pt",
62
+ map_location='cuda')
63
+ # load values for the transformer only
64
+ model.load_state_dict(state_dict, strict=False)
65
+ model = model.eval()
66
+ return model.to('cuda')
67
+
tmr_model.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ import torch.nn as nn
3
+ import os
4
+
5
+ import torch
6
+ import numpy as np
7
+ from torch import Tensor
8
+ from transformers import AutoTokenizer, AutoModel
9
+ from transformers import logging
10
+ from torch.nn.functional import normalize
11
+
12
+
13
+ class PositionalEncoding(nn.Module):
14
+ def __init__(self, d_model, max_len=5000):
15
+ super().__init__()
16
+
17
+ pe = torch.zeros(max_len, d_model)
18
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
19
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
20
+ pe[:, 0::2] = torch.sin(position * div_term)
21
+ pe[:, 1::2] = torch.cos(position * div_term)
22
+ pe = pe.unsqueeze(0).transpose(0, 1)
23
+
24
+ self.register_buffer('pe', pe, persistent=False)
25
+
26
+ def forward(self, x):
27
+ return x + self.pe[:x.shape[0], :]
28
+
29
+
30
+ class TMR_textencoder(nn.Module):
31
+ def __init__(self, modelpath: str, latent_dim: int, ff_size: int,
32
+ num_layers: int, num_heads: int, activation: str, **kwargs) -> None:
33
+ super().__init__()
34
+
35
+ logging.set_verbosity_error()
36
+
37
+ # Tokenizer
38
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
39
+ self.tokenizer = AutoTokenizer.from_pretrained(modelpath)
40
+
41
+ # Text model
42
+ self.text_model = AutoModel.from_pretrained(modelpath)
43
+ # Then configure the model
44
+ self.text_encoded_dim = self.text_model.config.hidden_size
45
+
46
+ # Projection of the text-outputs into the latent space
47
+ self.projection = nn.Sequential(
48
+ nn.ReLU(),
49
+ nn.Linear(self.text_encoded_dim, latent_dim)
50
+ )
51
+
52
+ self.mu_token = nn.Parameter(torch.randn(latent_dim))
53
+ self.logvar_token = nn.Parameter(torch.randn(latent_dim))
54
+ self.sequence_pos_encoding = PositionalEncoding(latent_dim)
55
+
56
+ seq_trans_encoder_layer = nn.TransformerEncoderLayer(d_model=latent_dim,
57
+ nhead=num_heads,
58
+ dim_feedforward=ff_size,
59
+ dropout=0.0,
60
+ activation=activation)
61
+ self.seqTransEncoder = nn.TransformerEncoder(
62
+ seq_trans_encoder_layer,
63
+ num_layers=num_layers
64
+ )
65
+
66
+ def get_last_hidden_state(self, texts: List[str],
67
+ return_mask: bool = False):
68
+ encoded_inputs = self.tokenizer(texts, return_tensors="pt", padding=True)
69
+ output = self.text_model(**encoded_inputs.to(self.text_model.device))
70
+ if not return_mask:
71
+ return output.last_hidden_state
72
+ return output.last_hidden_state, encoded_inputs.attention_mask.to(dtype=bool)
73
+
74
+ def forward(self, texts: List[str]) -> Tensor:
75
+ text_encoded, mask = self.get_last_hidden_state(texts, return_mask=True)
76
+
77
+ x = self.projection(text_encoded)
78
+ bs, nframes, _ = x.shape
79
+ # bs, nframes, totjoints, nfeats = x.shape
80
+ # Switch sequence and batch_size because the input of
81
+ # Pytorch Transformer is [Sequence, Batch size, ...]
82
+ x = x.permute(1, 0, 2) # now it is [nframes, bs, latent_dim]
83
+
84
+ mu_token = torch.tile(self.mu_token, (bs,)).reshape(bs, -1)
85
+ logvar_token = torch.tile(self.logvar_token, (bs,)).reshape(bs, -1)
86
+
87
+ # adding the distribution tokens for all sequences
88
+ xseq = torch.cat((mu_token[None], logvar_token[None], x), 0)
89
+
90
+ # create a bigger mask, to allow attend to mu and logvar
91
+ token_mask = torch.ones((bs, 2), dtype=bool, device=x.device)
92
+ aug_mask = torch.cat((token_mask, mask), 1)
93
+
94
+ # add positional encoding
95
+ xseq = self.sequence_pos_encoding(xseq)
96
+ final = self.seqTransEncoder(xseq, src_key_padding_mask=~aug_mask)
97
+
98
+ # only mu for inference
99
+ mu = final[0]
100
+ return mu
101
+
102
+ # compute score for retrieval
103
+ def compute_scores(self, texts, unit_embs=None, embs=None):
104
+ # not both empty
105
+ assert not (unit_embs is None and embs is None)
106
+ # not both filled
107
+ assert not (unit_embs is not None and embs is not None)
108
+
109
+ output_str = False
110
+ # if one input, squeeze the output
111
+ if isinstance(texts, str):
112
+ texts = [texts]
113
+ output_str = True
114
+
115
+ # compute unit_embs from embs if not given
116
+ if embs is not None:
117
+ unit_embs = normalize(embs)
118
+
119
+ with torch.no_grad():
120
+ latent_unit_texts = normalize(self(texts))
121
+ # compute cosine similarity between 0 and 1
122
+ scores = (unit_embs @ latent_unit_texts.T).T/2 + 0.5
123
+ scores = scores.cpu().numpy()
124
+
125
+ if output_str:
126
+ scores = scores[0]
127
+
128
+ return scores