Spaces:
Running
Running
first tmr retrieval efffort
Browse files- app.py +102 -23
- gen_utils.py +10 -0
app.py
CHANGED
|
@@ -17,6 +17,7 @@ zero = torch.Tensor([0]).cuda()
|
|
| 17 |
print(zero.device) # <-- 'cuda:0' 🤗
|
| 18 |
|
| 19 |
DEFAULT_TEXT = "A person is "
|
|
|
|
| 20 |
from aitviewer.models.smpl import SMPLLayer
|
| 21 |
def get_smpl_models():
|
| 22 |
REPO_ID = 'athn-nik/smpl_models'
|
|
@@ -35,7 +36,6 @@ def get_renderer():
|
|
| 35 |
return HeadlessRenderer()
|
| 36 |
|
| 37 |
|
| 38 |
-
|
| 39 |
WEBSITE = ("""<div class="embed_hidden" style="text-align: center;">
|
| 40 |
<h1>MotionFix: Text-Driven 3D Human Motion Editing</h1>
|
| 41 |
<h3>
|
|
@@ -65,8 +65,8 @@ WEB_source = ("""<div class="embed_hidden" style="text-align: center;">
|
|
| 65 |
<h1>Pick a motion to edit!</h1>
|
| 66 |
<h3>
|
| 67 |
Here you should pick a source motion
|
|
|
|
| 68 |
</h3>
|
| 69 |
-
<hr class="double">
|
| 70 |
</div>
|
| 71 |
""")
|
| 72 |
|
|
@@ -74,6 +74,7 @@ WEB_target = ("""<div class="embed_hidden" style="text-align: center;">
|
|
| 74 |
<h1>Now type the text to edit that motion!</h1>
|
| 75 |
<h3>
|
| 76 |
Here you should get the generated motion!
|
|
|
|
| 77 |
</h3>
|
| 78 |
</div>
|
| 79 |
""")
|
|
@@ -152,13 +153,6 @@ def show_video(input_text):
|
|
| 152 |
smpl_layer=SMPL_LAYER)
|
| 153 |
return fname
|
| 154 |
|
| 155 |
-
def retrieve_video(retrieve_text):
|
| 156 |
-
tmr_text_encoder = get_tmr_model(download_tmr())
|
| 157 |
-
text_encoded = tmr_text_encoder([retrieve_text])
|
| 158 |
-
motion_embeds = None
|
| 159 |
-
retrieved_motion = tmr_text_encoder.compute_scores(text_encoded, motion_embeds)
|
| 160 |
-
return
|
| 161 |
-
|
| 162 |
from huggingface_hub import hf_hub_download, hf_hub_url, cached_download
|
| 163 |
|
| 164 |
def download_models():
|
|
@@ -171,15 +165,70 @@ def download_tmr():
|
|
| 171 |
from huggingface_hub import snapshot_download
|
| 172 |
return snapshot_download(repo_id=REPO_ID, allow_patterns="tmr*",
|
| 173 |
token=access_token_smpl)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
import gradio as gr
|
| 175 |
|
| 176 |
def clear():
|
| 177 |
return ""
|
| 178 |
|
| 179 |
-
def
|
| 180 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
|
| 182 |
-
with gr.Blocks() as demo:
|
| 183 |
gr.Markdown(WEBSITE)
|
| 184 |
gr.Markdown(WEB_source)
|
| 185 |
# TODO load TMR text-encoder
|
|
@@ -188,18 +237,44 @@ with gr.Blocks() as demo:
|
|
| 188 |
# edit that motion!
|
| 189 |
with gr.Row():
|
| 190 |
with gr.Column(scale=10):
|
| 191 |
-
|
|
|
|
|
|
|
| 192 |
show_label=True, label="Retrieval Text", value=DEFAULT_TEXT)
|
| 193 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
|
| 195 |
with gr.Column(scale=8):
|
| 196 |
retrieved_video_output = gr.Video(label="Retrieved Motion",
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
with gr.Row():
|
| 200 |
-
clear_button_retrieval = gr.Button("Clear Retrieval Text")
|
| 201 |
-
retrieve_button = gr.Button("TMRetrieve")
|
| 202 |
-
random_button = gr.Button("Random")
|
| 203 |
|
| 204 |
gr.Markdown(WEB_target)
|
| 205 |
with gr.Row():
|
|
@@ -223,14 +298,18 @@ with gr.Blocks() as demo:
|
|
| 223 |
return fname
|
| 224 |
|
| 225 |
from retrieval_loader import get_tmr_model
|
|
|
|
| 226 |
# load the dataset and splits
|
| 227 |
-
|
|
|
|
| 228 |
edit_button.click(process_and_show_video, inputs=input_text, outputs=video_output)
|
| 229 |
-
retrieve_button.click(process_and_retrieve_video, inputs=retrieve_text, outputs=retrieved_video_output)
|
|
|
|
|
|
|
|
|
|
| 230 |
# import ipdb;ipdb.set_trace()
|
| 231 |
|
| 232 |
clear_button_edit.click(clear, outputs=input_text)
|
| 233 |
clear_button_retrieval.click(clear, outputs=retrieve_text)
|
| 234 |
-
random_button.click(random_number, outputs=input_text)
|
| 235 |
|
| 236 |
demo.launch()
|
|
|
|
| 17 |
print(zero.device) # <-- 'cuda:0' 🤗
|
| 18 |
|
| 19 |
DEFAULT_TEXT = "A person is "
|
| 20 |
+
|
| 21 |
from aitviewer.models.smpl import SMPLLayer
|
| 22 |
def get_smpl_models():
|
| 23 |
REPO_ID = 'athn-nik/smpl_models'
|
|
|
|
| 36 |
return HeadlessRenderer()
|
| 37 |
|
| 38 |
|
|
|
|
| 39 |
WEBSITE = ("""<div class="embed_hidden" style="text-align: center;">
|
| 40 |
<h1>MotionFix: Text-Driven 3D Human Motion Editing</h1>
|
| 41 |
<h3>
|
|
|
|
| 65 |
<h1>Pick a motion to edit!</h1>
|
| 66 |
<h3>
|
| 67 |
Here you should pick a source motion
|
| 68 |
+
<hr class="double">
|
| 69 |
</h3>
|
|
|
|
| 70 |
</div>
|
| 71 |
""")
|
| 72 |
|
|
|
|
| 74 |
<h1>Now type the text to edit that motion!</h1>
|
| 75 |
<h3>
|
| 76 |
Here you should get the generated motion!
|
| 77 |
+
<hr class="double">
|
| 78 |
</h3>
|
| 79 |
</div>
|
| 80 |
""")
|
|
|
|
| 153 |
smpl_layer=SMPL_LAYER)
|
| 154 |
return fname
|
| 155 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
from huggingface_hub import hf_hub_download, hf_hub_url, cached_download
|
| 157 |
|
| 158 |
def download_models():
|
|
|
|
| 165 |
from huggingface_hub import snapshot_download
|
| 166 |
return snapshot_download(repo_id=REPO_ID, allow_patterns="tmr*",
|
| 167 |
token=access_token_smpl)
|
| 168 |
+
|
| 169 |
+
def download_motionfix():
|
| 170 |
+
REPO_ID = 'athn-nik/example-model'
|
| 171 |
+
# return hf_hub_download(REPO_ID, filename="min_checkpoint.ckpt")
|
| 172 |
+
from huggingface_hub import snapshot_download
|
| 173 |
+
return snapshot_download(repo_id=REPO_ID, allow_patterns="motionfix*",
|
| 174 |
+
token=access_token_smpl)
|
| 175 |
+
|
| 176 |
+
def download_embeddings():
|
| 177 |
+
REPO_ID = 'athn-nik/example-model'
|
| 178 |
+
# return hf_hub_download(REPO_ID, filename="min_checkpoint.ckpt")
|
| 179 |
+
from huggingface_hub import snapshot_download
|
| 180 |
+
return snapshot_download(repo_id=REPO_ID, allow_patterns="embeddings*",
|
| 181 |
+
token=access_token_smpl)
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
MFIX_p = download_motionfix() + '/motionfix'
|
| 185 |
+
SOURCE_MOTS_p = download_embeddings() + '/embeddings'
|
| 186 |
+
|
| 187 |
import gradio as gr
|
| 188 |
|
| 189 |
def clear():
|
| 190 |
return ""
|
| 191 |
|
| 192 |
+
def random_source_motion(set_to_pick):
|
| 193 |
+
# import ipdb;ipdb.set_trace()
|
| 194 |
+
mfix_train, mfix_test = load_motionfix(MFIX_p)
|
| 195 |
+
if set_to_pick == 'all':
|
| 196 |
+
current_set = mfix_test | mfix_train
|
| 197 |
+
elif set_to_pick == 'train':
|
| 198 |
+
current_set = mfix_train
|
| 199 |
+
elif set_to_pick == 'test':
|
| 200 |
+
current_set = mfix_test
|
| 201 |
+
import random
|
| 202 |
+
random_key = random.choice(list(current_set.keys()))
|
| 203 |
+
curvid = current_set[random_key]['motion_a']
|
| 204 |
+
text_annot = current_set[random_key]['annotation']
|
| 205 |
+
return curvid, text_annot
|
| 206 |
+
|
| 207 |
+
def retrieve_video(retrieve_text):
|
| 208 |
+
tmr_text_encoder = get_tmr_model(download_tmr())
|
| 209 |
+
# import ipdb;ipdb.set_trace()
|
| 210 |
+
# text_encoded = tmr_text_encoder([retrieve_text])
|
| 211 |
+
motion_embeds = None
|
| 212 |
+
from gen_utils import read_json
|
| 213 |
+
import numpy as np
|
| 214 |
+
|
| 215 |
+
motion_embeds = torch.load(SOURCE_MOTS_p+'/source_motions_embeddings.pt')
|
| 216 |
+
motion_keyids =np.array(read_json(SOURCE_MOTS_p+'/keyids_embeddings.json'))
|
| 217 |
+
|
| 218 |
+
mfix_train, mfix_test = load_motionfix(MFIX_p)
|
| 219 |
+
all_mots = mfix_test | mfix_train
|
| 220 |
+
scores = tmr_text_encoder.compute_scores(retrieve_text, embs=motion_embeds)
|
| 221 |
+
sorted_idxs = np.argsort(-scores)
|
| 222 |
+
best_keyids = motion_keyids[sorted_idxs]
|
| 223 |
+
# best_scores = scores[sorted_idxs]
|
| 224 |
+
|
| 225 |
+
top_mot = best_keyids[0]
|
| 226 |
+
curvid = all_mots[top_mot]['motion_b']
|
| 227 |
+
text_annot = all_mots[top_mot]['annotation']
|
| 228 |
+
return curvid, text_annot
|
| 229 |
+
|
| 230 |
|
| 231 |
+
with gr.Blocks(css="style.css") as demo:
|
| 232 |
gr.Markdown(WEBSITE)
|
| 233 |
gr.Markdown(WEB_source)
|
| 234 |
# TODO load TMR text-encoder
|
|
|
|
| 237 |
# edit that motion!
|
| 238 |
with gr.Row():
|
| 239 |
with gr.Column(scale=10):
|
| 240 |
+
with gr.Column(scale=5):
|
| 241 |
+
|
| 242 |
+
retrieve_text = gr.Textbox(placeholder="Type the text for the motion you want to Retrieve:",
|
| 243 |
show_label=True, label="Retrieval Text", value=DEFAULT_TEXT)
|
| 244 |
+
suggested_edit_text = gr.Textbox(placeholder="Texts likely to edit the motion:",
|
| 245 |
+
show_label=True, label="Suggested Edit Text",
|
| 246 |
+
value='')
|
| 247 |
+
xxx = 'https://motion-editing.s3.eu-central-1.amazonaws.com/collection_wo_walks_runs/rendered_pairs/011327_120_240-002682_120_240.mp4'
|
| 248 |
+
|
| 249 |
+
with gr.Column(scale=5):
|
| 250 |
+
set_to_pick = gr.Radio(['all', 'train', 'test'],
|
| 251 |
+
value='all',
|
| 252 |
+
label="Set to pick from",
|
| 253 |
+
info="Motion will be picked from whole dataset or test or train data.")
|
| 254 |
+
|
| 255 |
+
with gr.Row():
|
| 256 |
+
with gr.Column(scale=10):
|
| 257 |
+
retrieve_button = gr.Button("TMRetrieve")
|
| 258 |
+
random_button = gr.Button("Random")
|
| 259 |
+
with gr.Column(scale=10):
|
| 260 |
+
how_many_videos = gr.Radio([1, 3, 5, 7],
|
| 261 |
+
value=3,
|
| 262 |
+
label="# Videos",
|
| 263 |
+
info="# Videos to be retrieved in each case."),
|
| 264 |
+
|
| 265 |
+
# temp_slider = gr.Slider(minimum=1,
|
| 266 |
+
# maximum=5,
|
| 267 |
+
# value=1,
|
| 268 |
+
# step=2,
|
| 269 |
+
# interactive=True,
|
| 270 |
+
# label="Slide me")
|
| 271 |
+
with gr.Column(scale=10,elem_id="center-column"):
|
| 272 |
+
clear_button_retrieval = gr.Button("Clear Retrieval Text")
|
| 273 |
|
| 274 |
with gr.Column(scale=8):
|
| 275 |
retrieved_video_output = gr.Video(label="Retrieved Motion",
|
| 276 |
+
value=xxx,
|
| 277 |
+
height=360, width=480)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 278 |
|
| 279 |
gr.Markdown(WEB_target)
|
| 280 |
with gr.Row():
|
|
|
|
| 298 |
return fname
|
| 299 |
|
| 300 |
from retrieval_loader import get_tmr_model
|
| 301 |
+
from dataset_utils import load_motionfix
|
| 302 |
# load the dataset and splits
|
| 303 |
+
# import ipdb;ipdb.set_trace()
|
| 304 |
+
|
| 305 |
edit_button.click(process_and_show_video, inputs=input_text, outputs=video_output)
|
| 306 |
+
retrieve_button.click(process_and_retrieve_video, inputs=retrieve_text, outputs=[retrieved_video_output, suggested_edit_text])
|
| 307 |
+
random_button.click(random_source_motion, inputs=set_to_pick, outputs=[retrieved_video_output, suggested_edit_text])
|
| 308 |
+
# import ipdb;ipdb.set_trace()
|
| 309 |
+
|
| 310 |
# import ipdb;ipdb.set_trace()
|
| 311 |
|
| 312 |
clear_button_edit.click(clear, outputs=input_text)
|
| 313 |
clear_button_retrieval.click(clear, outputs=retrieve_text)
|
|
|
|
| 314 |
|
| 315 |
demo.launch()
|
gen_utils.py
CHANGED
|
@@ -61,6 +61,16 @@ def save_config(cfg: DictConfig) -> str:
|
|
| 61 |
f.write(string)
|
| 62 |
return path
|
| 63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
|
| 65 |
def read_config(run_dir: str, return_json=False) -> DictConfig:
|
| 66 |
path = os.path.join(run_dir, "config.json")
|
|
|
|
| 61 |
f.write(string)
|
| 62 |
return path
|
| 63 |
|
| 64 |
+
def write_json(data, p):
|
| 65 |
+
import json
|
| 66 |
+
with open(p, 'w') as fp:
|
| 67 |
+
json.dump(data, fp, indent=2)
|
| 68 |
+
|
| 69 |
+
def read_json(p):
|
| 70 |
+
import json
|
| 71 |
+
with open(p, 'r') as fp:
|
| 72 |
+
json_contents = json.load(fp)
|
| 73 |
+
return json_contents
|
| 74 |
|
| 75 |
def read_config(run_dir: str, return_json=False) -> DictConfig:
|
| 76 |
path = os.path.join(run_dir, "config.json")
|