Spaces:
Running
Running
fix retrieval placeholders
Browse files- app.py +170 -94
- gen_utils.py +62 -2
- model_utils.py +35 -1
- renderer.py +138 -0
- requirements.txt +1 -0
- retrieval_loader.py +67 -0
- tmr_model.py +128 -0
app.py
CHANGED
|
@@ -3,40 +3,62 @@ import gradio as gr
|
|
| 3 |
import spaces
|
| 4 |
import torch
|
| 5 |
import random
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
zero = torch.Tensor([0]).cuda()
|
| 8 |
-
print(zero.device) # <-- '
|
| 9 |
-
|
| 10 |
DEFAULT_TEXT = "A person is "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
-
|
| 13 |
-
<div class="embed_hidden">
|
| 14 |
-
<h1
|
| 15 |
-
|
| 16 |
-
<
|
| 17 |
-
<a href="https://
|
| 18 |
-
<
|
| 19 |
-
<a href="https://
|
| 20 |
-
</
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
<
|
| 24 |
-
</
|
| 25 |
-
|
| 26 |
-
<
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
<
|
| 30 |
-
<a
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
<
|
| 34 |
-
<p>
|
| 35 |
-
This space illustrates <a href='project.com' target='_blank'><b>XXX</b></a>, a method for XXX.
|
| 36 |
-
What does it do?
|
| 37 |
-
</p>
|
| 38 |
</div>
|
| 39 |
-
"""
|
| 40 |
|
| 41 |
@spaces.GPU
|
| 42 |
def greet(n):
|
|
@@ -50,82 +72,136 @@ def greet(n):
|
|
| 50 |
def clear():
|
| 51 |
return ""
|
| 52 |
|
| 53 |
-
def
|
| 54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
from huggingface_hub import hf_hub_download, hf_hub_url, cached_download
|
| 56 |
|
| 57 |
def download_models():
|
| 58 |
REPO_ID = 'athn-nik/example-model'
|
| 59 |
-
|
| 60 |
return hf_hub_download(REPO_ID, filename="min_checkpoint.ckpt")
|
| 61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
with gr.Blocks() as demo:
|
| 63 |
gr.Markdown(WEBSITE)
|
| 64 |
-
|
| 65 |
-
input_text = gr.Textbox(placeholder="Type the edit text you want:",
|
| 66 |
-
show_label=True,label="Input Text", value=DEFAULT_TEXT)
|
| 67 |
-
# output_text = gr.Textbox(label="Output Text")
|
| 68 |
|
| 69 |
with gr.Row():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
retrieve_button = gr.Button("Retrieve")
|
| 71 |
-
|
| 72 |
random_button = gr.Button("Random")
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
noise_schedule='squaredcos_cap_v2',
|
| 93 |
-
predict_xstart=True) # noise vs sample
|
| 94 |
-
# uncond_tokens = [""] * len(texts_cond)
|
| 95 |
-
# if self.condition == 'text':
|
| 96 |
-
# uncond_tokens.extend(texts_cond)
|
| 97 |
-
# elif self.condition == 'text_uncond':
|
| 98 |
-
# uncond_tokens.extend(uncond_tokens)
|
| 99 |
-
bsz = 1
|
| 100 |
-
seqlen_tgt = 180
|
| 101 |
-
no_of_texts = len(texts_cond)
|
| 102 |
-
texts_cond = ['']*no_of_texts + texts_cond
|
| 103 |
-
texts_cond = ['']*no_of_texts + texts_cond
|
| 104 |
-
print(texts_cond)
|
| 105 |
-
text_emb, text_mask = text_encoder(texts_cond)
|
| 106 |
-
|
| 107 |
-
cond_emb_motion = torch.zeros(1, bsz,
|
| 108 |
-
512,
|
| 109 |
-
device='cuda')
|
| 110 |
-
cond_motion_mask = torch.ones((bsz, 1),
|
| 111 |
-
dtype=bool, device='cuda')
|
| 112 |
-
mask_target = torch.ones((1, bsz),
|
| 113 |
-
dtype=bool, device='cuda')
|
| 114 |
-
# complete noise
|
| 115 |
-
# import ipdb;ipdb.set_trace()
|
| 116 |
-
diff_out = tmed_denoiser._diffusion_reverse(text_emb.to(cond_emb_motion.device),
|
| 117 |
-
text_mask.to(cond_emb_motion.device),
|
| 118 |
-
cond_emb_motion,
|
| 119 |
-
cond_motion_mask,
|
| 120 |
-
mask_target,
|
| 121 |
-
diffusion_process,
|
| 122 |
-
init_vec=None,
|
| 123 |
-
init_from='noise',
|
| 124 |
-
gd_text=4.0,
|
| 125 |
-
gd_motion=2.0,
|
| 126 |
-
steps_num=300)
|
| 127 |
-
edited_motion = diffout2motion(diff_out, normalizer)
|
| 128 |
-
clear_button.click(clear, outputs=input_text)
|
| 129 |
random_button.click(random_number, outputs=input_text)
|
| 130 |
|
| 131 |
demo.launch()
|
|
|
|
| 3 |
import spaces
|
| 4 |
import torch
|
| 5 |
import random
|
| 6 |
+
import os
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from aitviewer.headless import HeadlessRenderer
|
| 9 |
+
from aitviewer.configuration import CONFIG as AITVIEWER_CONFIG
|
| 10 |
+
# import cv2
|
| 11 |
+
# import moderngl
|
| 12 |
+
# ctx = moderngl.create_context(standalone=True)
|
| 13 |
+
# print(ctx)
|
| 14 |
+
access_token_smpl = os.environ.get('HF_SMPL_TOKEN')
|
| 15 |
|
| 16 |
zero = torch.Tensor([0]).cuda()
|
| 17 |
+
print(zero.device) # <-- 'cuda:0' 🤗
|
| 18 |
+
|
| 19 |
DEFAULT_TEXT = "A person is "
|
| 20 |
+
from aitviewer.models.smpl import SMPLLayer
|
| 21 |
+
def get_smpl_models():
|
| 22 |
+
REPO_ID = 'athn-nik/smpl_models'
|
| 23 |
+
from huggingface_hub import snapshot_download
|
| 24 |
+
return snapshot_download(repo_id=REPO_ID, allow_patterns="smplh*",
|
| 25 |
+
token=access_token_smpl)
|
| 26 |
+
|
| 27 |
+
def get_renderer():
|
| 28 |
+
from aitviewer.headless import HeadlessRenderer
|
| 29 |
+
from aitviewer.configuration import CONFIG as AITVIEWER_CONFIG
|
| 30 |
+
smpl_models_path = str(Path(get_smpl_models()))
|
| 31 |
+
AITVIEWER_CONFIG.update_conf({'playback_fps': 30,
|
| 32 |
+
'auto_set_floor': True,
|
| 33 |
+
'smplx_models': smpl_models_path,
|
| 34 |
+
'z_up': True})
|
| 35 |
+
return HeadlessRenderer()
|
| 36 |
+
|
| 37 |
|
| 38 |
+
|
| 39 |
+
WEBSITE = ("""<div class="embed_hidden" style="text-align: center;">
|
| 40 |
+
<h1>MotionFix: Text-Driven 3D Human Motion Editing</h1>
|
| 41 |
+
<h3>
|
| 42 |
+
<a href="https://is.mpg.de/person/~nathanasiou" target="_blank" rel="noopener noreferrer">Nikos Athanasiou</a><sup>1</sup>,
|
| 43 |
+
<a href="https://is.mpg.de/person/acseke" target="_blank" rel="noopener noreferrer">Alpar Cseke</a><sup>1</sup>,
|
| 44 |
+
<br>
|
| 45 |
+
<a href="https://ps.is.mpg.de/person/mdiomataris" target="_blank" rel="noopener noreferrer">Markos Diomataris</a><sup>1, 3</sup>,
|
| 46 |
+
<a href="https://is.mpg.de/person/black" target="_blank" rel="noopener noreferrer">Michael J. Black</a><sup>1</sup>,
|
| 47 |
+
<a href="https://imagine.enpc.fr/~varolg/" target="_blank" rel="noopener noreferrer">Gül Varol</a><sup>2</sup>,
|
| 48 |
+
</h3>
|
| 49 |
+
<h3>
|
| 50 |
+
<sup>1</sup>Max Planck Institute for Intelligent Systems, Tübingen, Germany;
|
| 51 |
+
<sup>2</sup>LIGM, École des Ponts, Univ Gustave Eiffel, CNRS, France,
|
| 52 |
+
<sup>3</sup>ETH Zürich, Switzerland
|
| 53 |
+
</h3>
|
| 54 |
+
</div>
|
| 55 |
+
<div style="display:flex; gap: 0.3rem; justify-content: center; align-items: center;" align="center">
|
| 56 |
+
<a href='https://arxiv.org/abs/'><img src='https://img.shields.io/badge/Arxiv-2405.20340-A42C25?style=flat&logo=arXiv&logoColor=A42C25'></a>
|
| 57 |
+
<a href='https://arxiv.org/pdf/'><img src='https://img.shields.io/badge/Paper-PDF-yellow?style=flat&logo=arXiv&logoColor=yellow'></a>
|
| 58 |
+
<a href='https://motionfix.is.tue.mpg.de'><img src='https://img.shields.io/badge/Project-Page-%23df5b46?style=flat&logo=Google%20chrome&logoColor=%23df5b46'></a>
|
| 59 |
+
<a href='https://youtube.com/'><img src='https://img.shields.io/badge/YouTube-red?style=flat&logo=youtube&logoColor=white'></a>
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
</div>
|
| 61 |
+
""")
|
| 62 |
|
| 63 |
@spaces.GPU
|
| 64 |
def greet(n):
|
|
|
|
| 72 |
def clear():
|
| 73 |
return ""
|
| 74 |
|
| 75 |
+
def show_video(input_text):
|
| 76 |
+
from normalization import Normalizer
|
| 77 |
+
normalizer = Normalizer()
|
| 78 |
+
from diffusion import create_diffusion
|
| 79 |
+
from text_encoder import ClipTextEncoder
|
| 80 |
+
from tmed_denoiser import TMED_denoiser
|
| 81 |
+
model_ckpt = download_models()
|
| 82 |
+
checkpoint = torch.load(model_ckpt)
|
| 83 |
+
|
| 84 |
+
checkpoint = {k.replace('denoiser.', ''): v for k, v in checkpoint.items()}
|
| 85 |
+
tmed_denoiser = TMED_denoiser().to('cuda')
|
| 86 |
+
tmed_denoiser.load_state_dict(checkpoint, strict=False)
|
| 87 |
+
tmed_denoiser.eval()
|
| 88 |
+
text_encoder = ClipTextEncoder()
|
| 89 |
+
texts_cond = [input_text]
|
| 90 |
+
|
| 91 |
+
diffusion_process = create_diffusion(timestep_respacing=None,
|
| 92 |
+
learn_sigma=False, sigma_small=True,
|
| 93 |
+
diffusion_steps=300,
|
| 94 |
+
noise_schedule='squaredcos_cap_v2',
|
| 95 |
+
predict_xstart=True)
|
| 96 |
+
bsz = 1
|
| 97 |
+
seqlen_tgt = 180
|
| 98 |
+
no_of_texts = len(texts_cond)
|
| 99 |
+
texts_cond = ['']*no_of_texts + texts_cond
|
| 100 |
+
texts_cond = ['']*no_of_texts + texts_cond
|
| 101 |
+
text_emb, text_mask = text_encoder(texts_cond)
|
| 102 |
+
|
| 103 |
+
cond_emb_motion = torch.zeros(seqlen_tgt, bsz,
|
| 104 |
+
512,
|
| 105 |
+
device='cuda')
|
| 106 |
+
cond_motion_mask = torch.ones((bsz, seqlen_tgt),
|
| 107 |
+
dtype=bool, device='cuda')
|
| 108 |
+
mask_target = torch.ones((bsz, seqlen_tgt),
|
| 109 |
+
dtype=bool, device='cuda')
|
| 110 |
+
|
| 111 |
+
diff_out = tmed_denoiser._diffusion_reverse(text_emb.to(cond_emb_motion.device),
|
| 112 |
+
text_mask.to(cond_emb_motion.device),
|
| 113 |
+
cond_emb_motion,
|
| 114 |
+
cond_motion_mask,
|
| 115 |
+
mask_target,
|
| 116 |
+
diffusion_process,
|
| 117 |
+
init_vec=None,
|
| 118 |
+
init_from='noise',
|
| 119 |
+
gd_text=4.0,
|
| 120 |
+
gd_motion=2.0,
|
| 121 |
+
steps_num=300)
|
| 122 |
+
edited_motion = diffout2motion(diff_out, normalizer).squeeze()
|
| 123 |
+
from renderer import render_motion, color_map, pack_to_render
|
| 124 |
+
# aitrenderer = get_renderer()
|
| 125 |
+
AIT_RENDERER = get_renderer()
|
| 126 |
+
SMPL_LAYER = SMPLLayer(model_type='smplh', ext='npz', gender='neutral')
|
| 127 |
+
edited_mot_to_render = pack_to_render(rots=edited_motion[..., 3:],
|
| 128 |
+
trans=edited_motion[..., :3])
|
| 129 |
+
import random
|
| 130 |
+
xx = random.randint(1, 1000)
|
| 131 |
+
fname = render_motion(AIT_RENDERER, [edited_mot_to_render],
|
| 132 |
+
f"movie_example--{str(xx)}",
|
| 133 |
+
pose_repr='aa',
|
| 134 |
+
color=[color_map['generated']],
|
| 135 |
+
smpl_layer=SMPL_LAYER)
|
| 136 |
+
return fname
|
| 137 |
+
def retrieve_video(retrieve_text):
|
| 138 |
+
pass
|
| 139 |
+
|
| 140 |
from huggingface_hub import hf_hub_download, hf_hub_url, cached_download
|
| 141 |
|
| 142 |
def download_models():
|
| 143 |
REPO_ID = 'athn-nik/example-model'
|
|
|
|
| 144 |
return hf_hub_download(REPO_ID, filename="min_checkpoint.ckpt")
|
| 145 |
+
|
| 146 |
+
def download_tmr():
|
| 147 |
+
REPO_ID = 'athn-nik/example-model'
|
| 148 |
+
# return hf_hub_download(REPO_ID, filename="min_checkpoint.ckpt")
|
| 149 |
+
from huggingface_hub import snapshot_download
|
| 150 |
+
return snapshot_download(repo_id=REPO_ID, allow_patterns="tmr*",
|
| 151 |
+
token=access_token_smpl)
|
| 152 |
+
import gradio as gr
|
| 153 |
+
|
| 154 |
+
def clear():
|
| 155 |
+
return ""
|
| 156 |
+
|
| 157 |
+
def random_number():
|
| 158 |
+
return "Random text"
|
| 159 |
+
|
| 160 |
with gr.Blocks() as demo:
|
| 161 |
gr.Markdown(WEBSITE)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
|
| 163 |
with gr.Row():
|
| 164 |
+
with gr.Column(scale=8):
|
| 165 |
+
retrieve_text = gr.Textbox(placeholder="Type the text for the motion you want to Retrieve:",
|
| 166 |
+
show_label=True, label="Retrieval Text", value=DEFAULT_TEXT)
|
| 167 |
+
with gr.Column(scale=1):
|
| 168 |
+
clear_button_retrieval = gr.Button("Clear Retrieval Text")
|
| 169 |
+
|
| 170 |
+
with gr.Row():
|
| 171 |
+
with gr.Column(scale=8):
|
| 172 |
+
input_text = gr.Textbox(placeholder="Type the edit text you want:",
|
| 173 |
+
show_label=True, label="Input Text", value=DEFAULT_TEXT)
|
| 174 |
+
with gr.Column(scale=1):
|
| 175 |
+
clear_button_edit = gr.Button("Clear Edit Text")
|
| 176 |
+
|
| 177 |
+
with gr.Row():
|
| 178 |
+
video_output = gr.Video(label="Generated Video", height=240, width=320)
|
| 179 |
+
retrieved_video_output = gr.Video(label="Retrieved Motion", height=240, width=320)
|
| 180 |
+
|
| 181 |
+
with gr.Row():
|
| 182 |
+
edit_button = gr.Button("Edit")
|
| 183 |
retrieve_button = gr.Button("Retrieve")
|
| 184 |
+
|
| 185 |
random_button = gr.Button("Random")
|
| 186 |
+
|
| 187 |
+
def process_and_show_video(input_text):
|
| 188 |
+
fname = show_video(input_text)
|
| 189 |
+
return fname
|
| 190 |
+
|
| 191 |
+
def process_and_retrieve_video(input_text):
|
| 192 |
+
fname = retrieve_video(input_text)
|
| 193 |
+
return fname
|
| 194 |
+
|
| 195 |
+
from gen_utils import read_config
|
| 196 |
+
from retrieval_loader import load_model_from_cfg
|
| 197 |
+
from retrieval_loader import get_tmr_model
|
| 198 |
+
tmr = get_tmr_model(download_tmr())
|
| 199 |
+
edit_button.click(process_and_show_video, inputs=input_text, outputs=video_output)
|
| 200 |
+
retrieve_button.click(process_and_retrieve_video, inputs=retrieve_text, outputs=retrieved_video_output)
|
| 201 |
+
# import ipdb;ipdb.set_trace()
|
| 202 |
+
|
| 203 |
+
clear_button_edit.click(clear, outputs=input_text)
|
| 204 |
+
clear_button_retrieval.click(clear, outputs=retrieve_text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
random_button.click(random_number, outputs=input_text)
|
| 206 |
|
| 207 |
demo.launch()
|
gen_utils.py
CHANGED
|
@@ -1,5 +1,10 @@
|
|
| 1 |
import torch
|
| 2 |
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
def cast_dict_to_tensors(d, device="cpu"):
|
| 4 |
if isinstance(d, dict):
|
| 5 |
return {k: cast_dict_to_tensors(v, device) for k, v in d.items()}
|
|
@@ -9,5 +14,60 @@ def cast_dict_to_tensors(d, device="cpu"):
|
|
| 9 |
return d.to(device)
|
| 10 |
else:
|
| 11 |
return d
|
| 12 |
-
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
import numpy as np
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
logger = logging.getLogger(__name__)
|
| 7 |
+
|
| 8 |
def cast_dict_to_tensors(d, device="cpu"):
|
| 9 |
if isinstance(d, dict):
|
| 10 |
return {k: cast_dict_to_tensors(v, device) for k, v in d.items()}
|
|
|
|
| 14 |
return d.to(device)
|
| 15 |
else:
|
| 16 |
return d
|
| 17 |
+
|
| 18 |
+
def rgba(c: str):
|
| 19 |
+
from matplotlib import colors as mcolors
|
| 20 |
+
return mcolors.to_rgba(c)
|
| 21 |
+
|
| 22 |
+
def rgb(c: str):
|
| 23 |
+
from matplotlib import colors as mcolors
|
| 24 |
+
return mcolors.to_rgb(c)
|
| 25 |
+
|
| 26 |
+
# split the lightning checkpoint into
|
| 27 |
+
# seperate state_dict modules for faster loading
|
| 28 |
+
def extract_ckpt(run_dir, ckpt_name="last"):
|
| 29 |
+
import torch
|
| 30 |
+
|
| 31 |
+
ckpt_path = os.path.join(run_dir, f"logs/checkpoints/{ckpt_name}.ckpt")
|
| 32 |
+
|
| 33 |
+
extracted_path = os.path.join(run_dir, f"{ckpt_name}_weights")
|
| 34 |
+
os.makedirs(extracted_path, exist_ok=True)
|
| 35 |
+
|
| 36 |
+
new_path_template = os.path.join(extracted_path, "{}.pt")
|
| 37 |
+
ckpt_dict = torch.load(ckpt_path)
|
| 38 |
+
state_dict = ckpt_dict["state_dict"]
|
| 39 |
+
module_names = list(set([x.split(".")[0] for x in state_dict.keys()]))
|
| 40 |
+
|
| 41 |
+
# should be ['motion_encoder', 'text_encoder', 'motion_decoder'] for example
|
| 42 |
+
for module_name in module_names:
|
| 43 |
+
path = new_path_template.format(module_name)
|
| 44 |
+
sub_state_dict = {
|
| 45 |
+
".".join(x.split(".")[1:]): y.cpu()
|
| 46 |
+
for x, y in state_dict.items()
|
| 47 |
+
if x.split(".")[0] == module_name
|
| 48 |
+
}
|
| 49 |
+
torch.save(sub_state_dict, path)
|
| 50 |
+
|
| 51 |
+
import os
|
| 52 |
+
import json
|
| 53 |
+
from omegaconf import DictConfig, OmegaConf
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def save_config(cfg: DictConfig) -> str:
|
| 57 |
+
path = os.path.join(cfg.run_dir, "config.json")
|
| 58 |
+
config = OmegaConf.to_container(cfg, resolve=True)
|
| 59 |
+
with open(path, "w") as f:
|
| 60 |
+
string = json.dumps(config, indent=4)
|
| 61 |
+
f.write(string)
|
| 62 |
+
return path
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def read_config(run_dir: str, return_json=False) -> DictConfig:
|
| 66 |
+
path = os.path.join(run_dir, "config.json")
|
| 67 |
+
with open(path, "r") as f:
|
| 68 |
+
config = json.load(f)
|
| 69 |
+
if return_json:
|
| 70 |
+
return config
|
| 71 |
+
cfg = OmegaConf.create(config)
|
| 72 |
+
cfg.run_dir = run_dir
|
| 73 |
+
return cfg
|
model_utils.py
CHANGED
|
@@ -3,6 +3,10 @@ import torch.nn as nn
|
|
| 3 |
import numpy as np
|
| 4 |
import torch
|
| 5 |
from torch import nn
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
class TimestepEmbedderMDM(nn.Module):
|
| 8 |
def __init__(self, latent_dim):
|
|
@@ -61,4 +65,34 @@ class PositionalEncoding(nn.Module):
|
|
| 61 |
else:
|
| 62 |
last = first + x.shape[0]
|
| 63 |
x = x + self.pe[first:last, :]
|
| 64 |
-
return self.dropout(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
import numpy as np
|
| 4 |
import torch
|
| 5 |
from torch import nn
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from typing import List, Dict, Optional
|
| 9 |
+
from torch import Tensor
|
| 10 |
|
| 11 |
class TimestepEmbedderMDM(nn.Module):
|
| 12 |
def __init__(self, latent_dim):
|
|
|
|
| 65 |
else:
|
| 66 |
last = first + x.shape[0]
|
| 67 |
x = x + self.pe[first:last, :]
|
| 68 |
+
return self.dropout(x)
|
| 69 |
+
|
| 70 |
+
def collate_tensor_with_padding(batch: List[Tensor]) -> Tensor:
|
| 71 |
+
dims = batch[0].dim()
|
| 72 |
+
max_size = [max([b.size(i) for b in batch]) for i in range(dims)]
|
| 73 |
+
size = (len(batch),) + tuple(max_size)
|
| 74 |
+
canvas = batch[0].new_zeros(size=size)
|
| 75 |
+
for i, b in enumerate(batch):
|
| 76 |
+
sub_tensor = canvas[i]
|
| 77 |
+
for d in range(dims):
|
| 78 |
+
sub_tensor = sub_tensor.narrow(d, 0, b.size(d))
|
| 79 |
+
sub_tensor.add_(b)
|
| 80 |
+
return canvas
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def collate_x_dict(lst_x_dict: List, *, device: Optional[str] = 'cuda') -> Dict:
|
| 84 |
+
x = collate_tensor_with_padding([x_dict["x"] for x_dict in lst_x_dict])
|
| 85 |
+
if device is not None:
|
| 86 |
+
x = x.to(device)
|
| 87 |
+
length = [x_dict["length"] for x_dict in lst_x_dict]
|
| 88 |
+
|
| 89 |
+
if isinstance(length, list):
|
| 90 |
+
length = torch.tensor(length, device=device)
|
| 91 |
+
|
| 92 |
+
max_len = max(length)
|
| 93 |
+
mask = torch.arange(max_len, device=device).expand(
|
| 94 |
+
len(length), max_len
|
| 95 |
+
) < length.unsqueeze(1)
|
| 96 |
+
|
| 97 |
+
batch = {"x": x, "length": length, "mask": mask}
|
| 98 |
+
return batch
|
renderer.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
from transform3d import transform_body_pose
|
| 4 |
+
from aitviewer.headless import HeadlessRenderer
|
| 5 |
+
from gen_utils import rgb, rgba
|
| 6 |
+
|
| 7 |
+
color_map = {
|
| 8 |
+
'source_motion': rgba('darkred'),
|
| 9 |
+
'source': rgba('darkred'),
|
| 10 |
+
'target_motion': rgba('olivedrab'),
|
| 11 |
+
'input': rgba('olivedrab'),
|
| 12 |
+
'target': rgba('olivedrab'),
|
| 13 |
+
'generation': rgba('purple'),
|
| 14 |
+
'generated': rgba('steelblue'),
|
| 15 |
+
'denoised': rgba('purple'),
|
| 16 |
+
'noised': rgba('darkgrey'),
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def pack_to_render(rots, trans, pose_repr='6d'):
|
| 21 |
+
# make axis-angle
|
| 22 |
+
# global_orient = transform_body_pose(rots, f"{pose_repr}->aa")
|
| 23 |
+
|
| 24 |
+
if rots.is_cuda:
|
| 25 |
+
rots = rots.detach().cpu()
|
| 26 |
+
if trans.is_cuda:
|
| 27 |
+
trans = trans.detach().cpu()
|
| 28 |
+
|
| 29 |
+
if pose_repr != 'aa':
|
| 30 |
+
body_pose = transform_body_pose(rots, f"{pose_repr}->aa")
|
| 31 |
+
else:
|
| 32 |
+
body_pose = rots
|
| 33 |
+
if trans is None:
|
| 34 |
+
trans = torch.zeros((rots.shape[0], rots.shape[1], 3),
|
| 35 |
+
device=rots.device)
|
| 36 |
+
render_d = {'body_transl': trans,
|
| 37 |
+
'body_orient': body_pose[..., :3],
|
| 38 |
+
'body_pose': body_pose[..., 3:]}
|
| 39 |
+
return render_d
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def render_motion(renderer: HeadlessRenderer, datum: dict,
|
| 43 |
+
filename: str, pose_repr='6d',
|
| 44 |
+
color=(160 / 255, 160 / 255, 160 / 255, 1.0),
|
| 45 |
+
return_verts=False, smpl_layer=None) -> None:
|
| 46 |
+
"""
|
| 47 |
+
Function to render a video of a motion sequence
|
| 48 |
+
renderer: aitviewer renderer
|
| 49 |
+
datum: dictionary containing sequence of poses, body translations and body orientations
|
| 50 |
+
data could be numpy or pytorch tensors
|
| 51 |
+
filename: the absolute path you want the video to be saved at
|
| 52 |
+
|
| 53 |
+
"""
|
| 54 |
+
from aitviewer.headless import HeadlessRenderer
|
| 55 |
+
from aitviewer.renderables.smpl import SMPLSequence
|
| 56 |
+
|
| 57 |
+
if isinstance(datum, dict): datum = [datum]
|
| 58 |
+
if not isinstance(color, list):
|
| 59 |
+
colors = [color]
|
| 60 |
+
else:
|
| 61 |
+
colors = color
|
| 62 |
+
# assert {'body_transl', 'body_orient', 'body_pose'}.issubset(set(datum[0].keys()))
|
| 63 |
+
# os.environ['DISPLAY'] = ":11"
|
| 64 |
+
gender = 'neutral'
|
| 65 |
+
only_skel = False
|
| 66 |
+
import sys
|
| 67 |
+
seqs_of_human_motions = []
|
| 68 |
+
if smpl_layer is None:
|
| 69 |
+
from aitviewer.models.smpl import SMPLLayer
|
| 70 |
+
smpl_layer = SMPLLayer(model_type='smplh',
|
| 71 |
+
ext='npz',
|
| 72 |
+
gender=gender)
|
| 73 |
+
|
| 74 |
+
for iid, mesh_seq in enumerate(datum):
|
| 75 |
+
|
| 76 |
+
if pose_repr != 'aa':
|
| 77 |
+
global_orient = transform_body_pose(mesh_seq['body_orient'],
|
| 78 |
+
f"{pose_repr}->aa")
|
| 79 |
+
body_pose = transform_body_pose(mesh_seq['body_pose'],
|
| 80 |
+
f"{pose_repr}->aa")
|
| 81 |
+
else:
|
| 82 |
+
global_orient = mesh_seq['body_orient']
|
| 83 |
+
body_pose = mesh_seq['body_pose']
|
| 84 |
+
|
| 85 |
+
body_transl = mesh_seq['body_transl']
|
| 86 |
+
sys.stdout.flush()
|
| 87 |
+
|
| 88 |
+
old = os.dup(1)
|
| 89 |
+
os.close(1)
|
| 90 |
+
os.open(os.devnull, os.O_WRONLY)
|
| 91 |
+
print(body_pose.shape)
|
| 92 |
+
print('\n')
|
| 93 |
+
smpl_template = SMPLSequence(body_pose,
|
| 94 |
+
smpl_layer,
|
| 95 |
+
poses_root=global_orient,
|
| 96 |
+
trans=body_transl,
|
| 97 |
+
color=colors[iid],
|
| 98 |
+
z_up=True)
|
| 99 |
+
if only_skel:
|
| 100 |
+
smpl_template.remove(smpl_template.mesh_seq)
|
| 101 |
+
|
| 102 |
+
seqs_of_human_motions.append(smpl_template)
|
| 103 |
+
renderer.scene.add(smpl_template)
|
| 104 |
+
# camera follows smpl sequence
|
| 105 |
+
# FIX CAMERA
|
| 106 |
+
from transform3d import get_z_rot
|
| 107 |
+
R_z = get_z_rot(global_orient[0], in_format='aa')
|
| 108 |
+
heading = -R_z[:, 1]
|
| 109 |
+
xy_facing = body_transl[0] + heading*2.5
|
| 110 |
+
camera = renderer.lock_to_node(seqs_of_human_motions[0],
|
| 111 |
+
(xy_facing[0], xy_facing[1], 1.5), smooth_sigma=5.0)
|
| 112 |
+
|
| 113 |
+
# /FIX CAMERA
|
| 114 |
+
if len(mesh_seq['body_pose']) == 1:
|
| 115 |
+
renderer.save_frame(file_path=str(filename) + '.png')
|
| 116 |
+
sfx = 'png'
|
| 117 |
+
else:
|
| 118 |
+
renderer.save_video(video_dir=str(filename), output_fps=30)
|
| 119 |
+
sfx = 'mp4'
|
| 120 |
+
|
| 121 |
+
# aitviewer adds a counter to the filename, we remove it
|
| 122 |
+
# filename.split('_')[-1].replace('.mp4', '')
|
| 123 |
+
# os.rename(filename + '_0.mp4', filename[:-4] + '.mp4')
|
| 124 |
+
if sfx == 'mp4':
|
| 125 |
+
os.rename(str(filename) + f'_0.{sfx}', str(filename) + f'.{sfx}')
|
| 126 |
+
|
| 127 |
+
# empty scene for the next rendering
|
| 128 |
+
for mesh in seqs_of_human_motions:
|
| 129 |
+
renderer.scene.remove(mesh)
|
| 130 |
+
renderer.scene.remove(camera)
|
| 131 |
+
|
| 132 |
+
sys.stdout.flush()
|
| 133 |
+
os.close(1)
|
| 134 |
+
os.dup(old)
|
| 135 |
+
os.close(old)
|
| 136 |
+
renderer.reset()
|
| 137 |
+
fname = f'{filename}.{sfx}'
|
| 138 |
+
return fname
|
requirements.txt
CHANGED
|
@@ -2,3 +2,4 @@ spaces
|
|
| 2 |
gradio==4.36.1
|
| 3 |
torch
|
| 4 |
transformers==4.41.2
|
|
|
|
|
|
| 2 |
gradio==4.36.1
|
| 3 |
torch
|
| 4 |
transformers==4.41.2
|
| 5 |
+
hydra-core
|
retrieval_loader.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from gen_utils import extract_ckpt
|
| 2 |
+
import hydra
|
| 3 |
+
import os
|
| 4 |
+
from hydra.utils import instantiate
|
| 5 |
+
from gen_utils import read_config
|
| 6 |
+
from model_utils import collate_x_dict
|
| 7 |
+
import torch
|
| 8 |
+
from tmr_model import TMR_textencoder
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def load_model_from_cfg(cfg, ckpt_name="last", device="cuda", eval_mode=True):
|
| 12 |
+
import src.prepare # noqa
|
| 13 |
+
import torch
|
| 14 |
+
|
| 15 |
+
run_dir = cfg.run_dir
|
| 16 |
+
model = hydra.utils.instantiate(cfg.model)
|
| 17 |
+
|
| 18 |
+
# Loading modules one by one
|
| 19 |
+
# motion_encoder / text_encoder / text_decoder
|
| 20 |
+
pt_path = os.path.join(run_dir, f"{ckpt_name}_weights")
|
| 21 |
+
|
| 22 |
+
if not os.path.exists(pt_path):
|
| 23 |
+
extract_ckpt(run_dir, ckpt_name)
|
| 24 |
+
|
| 25 |
+
for fname in os.listdir(pt_path):
|
| 26 |
+
module_name, ext = os.path.splitext(fname)
|
| 27 |
+
if ext != ".pt":
|
| 28 |
+
continue
|
| 29 |
+
|
| 30 |
+
module = getattr(model, module_name, None)
|
| 31 |
+
if module is None:
|
| 32 |
+
continue
|
| 33 |
+
|
| 34 |
+
module_path = os.path.join(pt_path, fname)
|
| 35 |
+
state_dict = torch.load(module_path)
|
| 36 |
+
module.load_state_dict(state_dict)
|
| 37 |
+
model = model.to(device)
|
| 38 |
+
if eval_mode:
|
| 39 |
+
model = model.eval()
|
| 40 |
+
return model
|
| 41 |
+
|
| 42 |
+
# def get_tmr_model(run_dir):
|
| 43 |
+
# from gen_utils import read_config
|
| 44 |
+
# cfg = read_config(run_dir+'/tmr')
|
| 45 |
+
# import ipdb;ipdb.set_trace()
|
| 46 |
+
# text_model = instantiate(cfg.data.text_to_token_emb, device='cuda')
|
| 47 |
+
# model = load_model_from_cfg(cfg, 'last', eval_mode=True, device='cuda')
|
| 48 |
+
# return text_model, model
|
| 49 |
+
|
| 50 |
+
def get_tmr_model(run_dir):
|
| 51 |
+
text_params = {
|
| 52 |
+
"latent_dim": 256,
|
| 53 |
+
"ff_size": 1024,
|
| 54 |
+
"num_layers": 6,
|
| 55 |
+
"num_heads": 4,
|
| 56 |
+
"activation": "gelu",
|
| 57 |
+
"modelpath": "distilbert-base-uncased",
|
| 58 |
+
}
|
| 59 |
+
"unit_motion_embs"
|
| 60 |
+
model = TMR_textencoder(**text_params)
|
| 61 |
+
state_dict = torch.load(f"{run_dir}/tmr/last_weights/text_encoder.pt",
|
| 62 |
+
map_location='cuda')
|
| 63 |
+
# load values for the transformer only
|
| 64 |
+
model.load_state_dict(state_dict, strict=False)
|
| 65 |
+
model = model.eval()
|
| 66 |
+
return model.to('cuda')
|
| 67 |
+
|
tmr_model.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import numpy as np
|
| 7 |
+
from torch import Tensor
|
| 8 |
+
from transformers import AutoTokenizer, AutoModel
|
| 9 |
+
from transformers import logging
|
| 10 |
+
from torch.nn.functional import normalize
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class PositionalEncoding(nn.Module):
|
| 14 |
+
def __init__(self, d_model, max_len=5000):
|
| 15 |
+
super().__init__()
|
| 16 |
+
|
| 17 |
+
pe = torch.zeros(max_len, d_model)
|
| 18 |
+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
| 19 |
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
|
| 20 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
| 21 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
| 22 |
+
pe = pe.unsqueeze(0).transpose(0, 1)
|
| 23 |
+
|
| 24 |
+
self.register_buffer('pe', pe, persistent=False)
|
| 25 |
+
|
| 26 |
+
def forward(self, x):
|
| 27 |
+
return x + self.pe[:x.shape[0], :]
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class TMR_textencoder(nn.Module):
|
| 31 |
+
def __init__(self, modelpath: str, latent_dim: int, ff_size: int,
|
| 32 |
+
num_layers: int, num_heads: int, activation: str, **kwargs) -> None:
|
| 33 |
+
super().__init__()
|
| 34 |
+
|
| 35 |
+
logging.set_verbosity_error()
|
| 36 |
+
|
| 37 |
+
# Tokenizer
|
| 38 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 39 |
+
self.tokenizer = AutoTokenizer.from_pretrained(modelpath)
|
| 40 |
+
|
| 41 |
+
# Text model
|
| 42 |
+
self.text_model = AutoModel.from_pretrained(modelpath)
|
| 43 |
+
# Then configure the model
|
| 44 |
+
self.text_encoded_dim = self.text_model.config.hidden_size
|
| 45 |
+
|
| 46 |
+
# Projection of the text-outputs into the latent space
|
| 47 |
+
self.projection = nn.Sequential(
|
| 48 |
+
nn.ReLU(),
|
| 49 |
+
nn.Linear(self.text_encoded_dim, latent_dim)
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
self.mu_token = nn.Parameter(torch.randn(latent_dim))
|
| 53 |
+
self.logvar_token = nn.Parameter(torch.randn(latent_dim))
|
| 54 |
+
self.sequence_pos_encoding = PositionalEncoding(latent_dim)
|
| 55 |
+
|
| 56 |
+
seq_trans_encoder_layer = nn.TransformerEncoderLayer(d_model=latent_dim,
|
| 57 |
+
nhead=num_heads,
|
| 58 |
+
dim_feedforward=ff_size,
|
| 59 |
+
dropout=0.0,
|
| 60 |
+
activation=activation)
|
| 61 |
+
self.seqTransEncoder = nn.TransformerEncoder(
|
| 62 |
+
seq_trans_encoder_layer,
|
| 63 |
+
num_layers=num_layers
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
def get_last_hidden_state(self, texts: List[str],
|
| 67 |
+
return_mask: bool = False):
|
| 68 |
+
encoded_inputs = self.tokenizer(texts, return_tensors="pt", padding=True)
|
| 69 |
+
output = self.text_model(**encoded_inputs.to(self.text_model.device))
|
| 70 |
+
if not return_mask:
|
| 71 |
+
return output.last_hidden_state
|
| 72 |
+
return output.last_hidden_state, encoded_inputs.attention_mask.to(dtype=bool)
|
| 73 |
+
|
| 74 |
+
def forward(self, texts: List[str]) -> Tensor:
|
| 75 |
+
text_encoded, mask = self.get_last_hidden_state(texts, return_mask=True)
|
| 76 |
+
|
| 77 |
+
x = self.projection(text_encoded)
|
| 78 |
+
bs, nframes, _ = x.shape
|
| 79 |
+
# bs, nframes, totjoints, nfeats = x.shape
|
| 80 |
+
# Switch sequence and batch_size because the input of
|
| 81 |
+
# Pytorch Transformer is [Sequence, Batch size, ...]
|
| 82 |
+
x = x.permute(1, 0, 2) # now it is [nframes, bs, latent_dim]
|
| 83 |
+
|
| 84 |
+
mu_token = torch.tile(self.mu_token, (bs,)).reshape(bs, -1)
|
| 85 |
+
logvar_token = torch.tile(self.logvar_token, (bs,)).reshape(bs, -1)
|
| 86 |
+
|
| 87 |
+
# adding the distribution tokens for all sequences
|
| 88 |
+
xseq = torch.cat((mu_token[None], logvar_token[None], x), 0)
|
| 89 |
+
|
| 90 |
+
# create a bigger mask, to allow attend to mu and logvar
|
| 91 |
+
token_mask = torch.ones((bs, 2), dtype=bool, device=x.device)
|
| 92 |
+
aug_mask = torch.cat((token_mask, mask), 1)
|
| 93 |
+
|
| 94 |
+
# add positional encoding
|
| 95 |
+
xseq = self.sequence_pos_encoding(xseq)
|
| 96 |
+
final = self.seqTransEncoder(xseq, src_key_padding_mask=~aug_mask)
|
| 97 |
+
|
| 98 |
+
# only mu for inference
|
| 99 |
+
mu = final[0]
|
| 100 |
+
return mu
|
| 101 |
+
|
| 102 |
+
# compute score for retrieval
|
| 103 |
+
def compute_scores(self, texts, unit_embs=None, embs=None):
|
| 104 |
+
# not both empty
|
| 105 |
+
assert not (unit_embs is None and embs is None)
|
| 106 |
+
# not both filled
|
| 107 |
+
assert not (unit_embs is not None and embs is not None)
|
| 108 |
+
|
| 109 |
+
output_str = False
|
| 110 |
+
# if one input, squeeze the output
|
| 111 |
+
if isinstance(texts, str):
|
| 112 |
+
texts = [texts]
|
| 113 |
+
output_str = True
|
| 114 |
+
|
| 115 |
+
# compute unit_embs from embs if not given
|
| 116 |
+
if embs is not None:
|
| 117 |
+
unit_embs = normalize(embs)
|
| 118 |
+
|
| 119 |
+
with torch.no_grad():
|
| 120 |
+
latent_unit_texts = normalize(self(texts))
|
| 121 |
+
# compute cosine similarity between 0 and 1
|
| 122 |
+
scores = (unit_embs @ latent_unit_texts.T).T/2 + 0.5
|
| 123 |
+
scores = scores.cpu().numpy()
|
| 124 |
+
|
| 125 |
+
if output_str:
|
| 126 |
+
scores = scores[0]
|
| 127 |
+
|
| 128 |
+
return scores
|