gloriforge commited on
Commit
1139535
·
verified ·
1 Parent(s): 3611b6a

Upload folder using huggingface_hub

Browse files
__pycache__/miner.cpython-312.pyc CHANGED
Binary files a/__pycache__/miner.cpython-312.pyc and b/__pycache__/miner.cpython-312.pyc differ
 
__pycache__/pitch.cpython-312.pyc CHANGED
Binary files a/__pycache__/pitch.cpython-312.pyc and b/__pycache__/pitch.cpython-312.pyc differ
 
pitch.py CHANGED
@@ -16,6 +16,8 @@ import torchvision.transforms as T
16
  import torchvision.transforms.functional as f
17
  from pydantic import BaseModel
18
 
 
 
19
  import logging
20
  logger = logging.getLogger(__name__)
21
 
@@ -614,6 +616,9 @@ def inference_batch(frames, model, kp_threshold, device, batch_size=8):
614
  model_device = next(model.parameters()).device
615
  print(model_device)
616
 
 
 
 
617
  # Process all frames in optimally-sized batches
618
  for i in range(0, num_frames, batch_size):
619
  current_batch_size = min(batch_size, num_frames - i)
@@ -626,8 +631,8 @@ def inference_batch(frames, model, kp_threshold, device, batch_size=8):
626
  # Move batch to model device
627
  batch = batch.to(model_device)
628
 
629
- with torch.no_grad():
630
- heatmaps = model(batch)
631
 
632
  # Ultra-fast keypoint extraction
633
  kp_coords = extract_keypoints_from_heatmap_fast(heatmaps[:,:-1,:,:], scale=2, max_keypoints=1)
 
16
  import torchvision.transforms.functional as f
17
  from pydantic import BaseModel
18
 
19
+ import tensorflow as tf
20
+
21
  import logging
22
  logger = logging.getLogger(__name__)
23
 
 
616
  model_device = next(model.parameters()).device
617
  print(model_device)
618
 
619
+ @tf.function
620
+ def infer(x):
621
+ return model(x)
622
  # Process all frames in optimally-sized batches
623
  for i in range(0, num_frames, batch_size):
624
  current_batch_size = min(batch_size, num_frames - i)
 
631
  # Move batch to model device
632
  batch = batch.to(model_device)
633
 
634
+ # with torch.no_grad():
635
+ heatmaps = infer(batch)
636
 
637
  # Ultra-fast keypoint extraction
638
  kp_coords = extract_keypoints_from_heatmap_fast(heatmaps[:,:-1,:,:], scale=2, max_keypoints=1)