atnikos commited on
Commit
38c4910
·
1 Parent(s): 7495dca

first tmr retrieval efffort

Browse files
Files changed (2) hide show
  1. app.py +102 -23
  2. gen_utils.py +10 -0
app.py CHANGED
@@ -17,6 +17,7 @@ zero = torch.Tensor([0]).cuda()
17
  print(zero.device) # <-- 'cuda:0' 🤗
18
 
19
  DEFAULT_TEXT = "A person is "
 
20
  from aitviewer.models.smpl import SMPLLayer
21
  def get_smpl_models():
22
  REPO_ID = 'athn-nik/smpl_models'
@@ -35,7 +36,6 @@ def get_renderer():
35
  return HeadlessRenderer()
36
 
37
 
38
-
39
  WEBSITE = ("""<div class="embed_hidden" style="text-align: center;">
40
  <h1>MotionFix: Text-Driven 3D Human Motion Editing</h1>
41
  <h3>
@@ -65,8 +65,8 @@ WEB_source = ("""<div class="embed_hidden" style="text-align: center;">
65
  <h1>Pick a motion to edit!</h1>
66
  <h3>
67
  Here you should pick a source motion
 
68
  </h3>
69
- <hr class="double">
70
  </div>
71
  """)
72
 
@@ -74,6 +74,7 @@ WEB_target = ("""<div class="embed_hidden" style="text-align: center;">
74
  <h1>Now type the text to edit that motion!</h1>
75
  <h3>
76
  Here you should get the generated motion!
 
77
  </h3>
78
  </div>
79
  """)
@@ -152,13 +153,6 @@ def show_video(input_text):
152
  smpl_layer=SMPL_LAYER)
153
  return fname
154
 
155
- def retrieve_video(retrieve_text):
156
- tmr_text_encoder = get_tmr_model(download_tmr())
157
- text_encoded = tmr_text_encoder([retrieve_text])
158
- motion_embeds = None
159
- retrieved_motion = tmr_text_encoder.compute_scores(text_encoded, motion_embeds)
160
- return
161
-
162
  from huggingface_hub import hf_hub_download, hf_hub_url, cached_download
163
 
164
  def download_models():
@@ -171,15 +165,70 @@ def download_tmr():
171
  from huggingface_hub import snapshot_download
172
  return snapshot_download(repo_id=REPO_ID, allow_patterns="tmr*",
173
  token=access_token_smpl)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  import gradio as gr
175
 
176
  def clear():
177
  return ""
178
 
179
- def random_number():
180
- return "Random text"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
 
182
- with gr.Blocks() as demo:
183
  gr.Markdown(WEBSITE)
184
  gr.Markdown(WEB_source)
185
  # TODO load TMR text-encoder
@@ -188,18 +237,44 @@ with gr.Blocks() as demo:
188
  # edit that motion!
189
  with gr.Row():
190
  with gr.Column(scale=10):
191
- retrieve_text = gr.Textbox(placeholder="Type the text for the motion you want to Retrieve:",
 
 
192
  show_label=True, label="Retrieval Text", value=DEFAULT_TEXT)
193
- xxx = 'https://motion-editing.s3.eu-central-1.amazonaws.com/collection_wo_walks_runs/rendered_pairs/011327_120_240-002682_120_240.mp4'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
  with gr.Column(scale=8):
196
  retrieved_video_output = gr.Video(label="Retrieved Motion",
197
- value=xxx,
198
- height=360, width=480)
199
- with gr.Row():
200
- clear_button_retrieval = gr.Button("Clear Retrieval Text")
201
- retrieve_button = gr.Button("TMRetrieve")
202
- random_button = gr.Button("Random")
203
 
204
  gr.Markdown(WEB_target)
205
  with gr.Row():
@@ -223,14 +298,18 @@ with gr.Blocks() as demo:
223
  return fname
224
 
225
  from retrieval_loader import get_tmr_model
 
226
  # load the dataset and splits
227
-
 
228
  edit_button.click(process_and_show_video, inputs=input_text, outputs=video_output)
229
- retrieve_button.click(process_and_retrieve_video, inputs=retrieve_text, outputs=retrieved_video_output)
 
 
 
230
  # import ipdb;ipdb.set_trace()
231
 
232
  clear_button_edit.click(clear, outputs=input_text)
233
  clear_button_retrieval.click(clear, outputs=retrieve_text)
234
- random_button.click(random_number, outputs=input_text)
235
 
236
  demo.launch()
 
17
  print(zero.device) # <-- 'cuda:0' 🤗
18
 
19
  DEFAULT_TEXT = "A person is "
20
+
21
  from aitviewer.models.smpl import SMPLLayer
22
  def get_smpl_models():
23
  REPO_ID = 'athn-nik/smpl_models'
 
36
  return HeadlessRenderer()
37
 
38
 
 
39
  WEBSITE = ("""<div class="embed_hidden" style="text-align: center;">
40
  <h1>MotionFix: Text-Driven 3D Human Motion Editing</h1>
41
  <h3>
 
65
  <h1>Pick a motion to edit!</h1>
66
  <h3>
67
  Here you should pick a source motion
68
+ <hr class="double">
69
  </h3>
 
70
  </div>
71
  """)
72
 
 
74
  <h1>Now type the text to edit that motion!</h1>
75
  <h3>
76
  Here you should get the generated motion!
77
+ <hr class="double">
78
  </h3>
79
  </div>
80
  """)
 
153
  smpl_layer=SMPL_LAYER)
154
  return fname
155
 
 
 
 
 
 
 
 
156
  from huggingface_hub import hf_hub_download, hf_hub_url, cached_download
157
 
158
  def download_models():
 
165
  from huggingface_hub import snapshot_download
166
  return snapshot_download(repo_id=REPO_ID, allow_patterns="tmr*",
167
  token=access_token_smpl)
168
+
169
+ def download_motionfix():
170
+ REPO_ID = 'athn-nik/example-model'
171
+ # return hf_hub_download(REPO_ID, filename="min_checkpoint.ckpt")
172
+ from huggingface_hub import snapshot_download
173
+ return snapshot_download(repo_id=REPO_ID, allow_patterns="motionfix*",
174
+ token=access_token_smpl)
175
+
176
+ def download_embeddings():
177
+ REPO_ID = 'athn-nik/example-model'
178
+ # return hf_hub_download(REPO_ID, filename="min_checkpoint.ckpt")
179
+ from huggingface_hub import snapshot_download
180
+ return snapshot_download(repo_id=REPO_ID, allow_patterns="embeddings*",
181
+ token=access_token_smpl)
182
+
183
+
184
+ MFIX_p = download_motionfix() + '/motionfix'
185
+ SOURCE_MOTS_p = download_embeddings() + '/embeddings'
186
+
187
  import gradio as gr
188
 
189
  def clear():
190
  return ""
191
 
192
+ def random_source_motion(set_to_pick):
193
+ # import ipdb;ipdb.set_trace()
194
+ mfix_train, mfix_test = load_motionfix(MFIX_p)
195
+ if set_to_pick == 'all':
196
+ current_set = mfix_test | mfix_train
197
+ elif set_to_pick == 'train':
198
+ current_set = mfix_train
199
+ elif set_to_pick == 'test':
200
+ current_set = mfix_test
201
+ import random
202
+ random_key = random.choice(list(current_set.keys()))
203
+ curvid = current_set[random_key]['motion_a']
204
+ text_annot = current_set[random_key]['annotation']
205
+ return curvid, text_annot
206
+
207
+ def retrieve_video(retrieve_text):
208
+ tmr_text_encoder = get_tmr_model(download_tmr())
209
+ # import ipdb;ipdb.set_trace()
210
+ # text_encoded = tmr_text_encoder([retrieve_text])
211
+ motion_embeds = None
212
+ from gen_utils import read_json
213
+ import numpy as np
214
+
215
+ motion_embeds = torch.load(SOURCE_MOTS_p+'/source_motions_embeddings.pt')
216
+ motion_keyids =np.array(read_json(SOURCE_MOTS_p+'/keyids_embeddings.json'))
217
+
218
+ mfix_train, mfix_test = load_motionfix(MFIX_p)
219
+ all_mots = mfix_test | mfix_train
220
+ scores = tmr_text_encoder.compute_scores(retrieve_text, embs=motion_embeds)
221
+ sorted_idxs = np.argsort(-scores)
222
+ best_keyids = motion_keyids[sorted_idxs]
223
+ # best_scores = scores[sorted_idxs]
224
+
225
+ top_mot = best_keyids[0]
226
+ curvid = all_mots[top_mot]['motion_b']
227
+ text_annot = all_mots[top_mot]['annotation']
228
+ return curvid, text_annot
229
+
230
 
231
+ with gr.Blocks(css="style.css") as demo:
232
  gr.Markdown(WEBSITE)
233
  gr.Markdown(WEB_source)
234
  # TODO load TMR text-encoder
 
237
  # edit that motion!
238
  with gr.Row():
239
  with gr.Column(scale=10):
240
+ with gr.Column(scale=5):
241
+
242
+ retrieve_text = gr.Textbox(placeholder="Type the text for the motion you want to Retrieve:",
243
  show_label=True, label="Retrieval Text", value=DEFAULT_TEXT)
244
+ suggested_edit_text = gr.Textbox(placeholder="Texts likely to edit the motion:",
245
+ show_label=True, label="Suggested Edit Text",
246
+ value='')
247
+ xxx = 'https://motion-editing.s3.eu-central-1.amazonaws.com/collection_wo_walks_runs/rendered_pairs/011327_120_240-002682_120_240.mp4'
248
+
249
+ with gr.Column(scale=5):
250
+ set_to_pick = gr.Radio(['all', 'train', 'test'],
251
+ value='all',
252
+ label="Set to pick from",
253
+ info="Motion will be picked from whole dataset or test or train data.")
254
+
255
+ with gr.Row():
256
+ with gr.Column(scale=10):
257
+ retrieve_button = gr.Button("TMRetrieve")
258
+ random_button = gr.Button("Random")
259
+ with gr.Column(scale=10):
260
+ how_many_videos = gr.Radio([1, 3, 5, 7],
261
+ value=3,
262
+ label="# Videos",
263
+ info="# Videos to be retrieved in each case."),
264
+
265
+ # temp_slider = gr.Slider(minimum=1,
266
+ # maximum=5,
267
+ # value=1,
268
+ # step=2,
269
+ # interactive=True,
270
+ # label="Slide me")
271
+ with gr.Column(scale=10,elem_id="center-column"):
272
+ clear_button_retrieval = gr.Button("Clear Retrieval Text")
273
 
274
  with gr.Column(scale=8):
275
  retrieved_video_output = gr.Video(label="Retrieved Motion",
276
+ value=xxx,
277
+ height=360, width=480)
 
 
 
 
278
 
279
  gr.Markdown(WEB_target)
280
  with gr.Row():
 
298
  return fname
299
 
300
  from retrieval_loader import get_tmr_model
301
+ from dataset_utils import load_motionfix
302
  # load the dataset and splits
303
+ # import ipdb;ipdb.set_trace()
304
+
305
  edit_button.click(process_and_show_video, inputs=input_text, outputs=video_output)
306
+ retrieve_button.click(process_and_retrieve_video, inputs=retrieve_text, outputs=[retrieved_video_output, suggested_edit_text])
307
+ random_button.click(random_source_motion, inputs=set_to_pick, outputs=[retrieved_video_output, suggested_edit_text])
308
+ # import ipdb;ipdb.set_trace()
309
+
310
  # import ipdb;ipdb.set_trace()
311
 
312
  clear_button_edit.click(clear, outputs=input_text)
313
  clear_button_retrieval.click(clear, outputs=retrieve_text)
 
314
 
315
  demo.launch()
gen_utils.py CHANGED
@@ -61,6 +61,16 @@ def save_config(cfg: DictConfig) -> str:
61
  f.write(string)
62
  return path
63
 
 
 
 
 
 
 
 
 
 
 
64
 
65
  def read_config(run_dir: str, return_json=False) -> DictConfig:
66
  path = os.path.join(run_dir, "config.json")
 
61
  f.write(string)
62
  return path
63
 
64
+ def write_json(data, p):
65
+ import json
66
+ with open(p, 'w') as fp:
67
+ json.dump(data, fp, indent=2)
68
+
69
+ def read_json(p):
70
+ import json
71
+ with open(p, 'r') as fp:
72
+ json_contents = json.load(fp)
73
+ return json_contents
74
 
75
  def read_config(run_dir: str, return_json=False) -> DictConfig:
76
  path = os.path.join(run_dir, "config.json")