rollback sam2_base and sam2_video_predictor
Browse files- sam2/modeling/sam2_base.py +2 -2
- sam2/sam2_video_predictor.py +11 -19
- trimm_examples.py +30 -0
sam2/modeling/sam2_base.py
CHANGED
|
@@ -617,7 +617,7 @@ class SAM2Base(torch.nn.Module):
|
|
| 617 |
if self.use_signed_tpos_enc_to_obj_ptrs
|
| 618 |
else abs(frame_idx - t)
|
| 619 |
),
|
| 620 |
-
out["obj_ptr"]
|
| 621 |
)
|
| 622 |
for t, out in ptr_cond_outputs.items()
|
| 623 |
]
|
|
@@ -630,7 +630,7 @@ class SAM2Base(torch.nn.Module):
|
|
| 630 |
t, unselected_cond_outputs.get(t, None)
|
| 631 |
)
|
| 632 |
if out is not None:
|
| 633 |
-
pos_and_ptrs.append((t_diff, out["obj_ptr"]
|
| 634 |
# If we have at least one object pointer, add them to the across attention
|
| 635 |
if len(pos_and_ptrs) > 0:
|
| 636 |
pos_list, ptrs_list = zip(*pos_and_ptrs)
|
|
|
|
| 617 |
if self.use_signed_tpos_enc_to_obj_ptrs
|
| 618 |
else abs(frame_idx - t)
|
| 619 |
),
|
| 620 |
+
out["obj_ptr"],
|
| 621 |
)
|
| 622 |
for t, out in ptr_cond_outputs.items()
|
| 623 |
]
|
|
|
|
| 630 |
t, unselected_cond_outputs.get(t, None)
|
| 631 |
)
|
| 632 |
if out is not None:
|
| 633 |
+
pos_and_ptrs.append((t_diff, out["obj_ptr"]))
|
| 634 |
# If we have at least one object pointer, add them to the across attention
|
| 635 |
if len(pos_and_ptrs) > 0:
|
| 636 |
pos_list, ptrs_list = zip(*pos_and_ptrs)
|
sam2/sam2_video_predictor.py
CHANGED
|
@@ -470,7 +470,7 @@ class SAM2VideoPredictor(SAM2Base):
|
|
| 470 |
size=(batch_size, self.hidden_dim),
|
| 471 |
fill_value=NO_OBJ_SCORE,
|
| 472 |
dtype=torch.float32,
|
| 473 |
-
device=inference_state["
|
| 474 |
),
|
| 475 |
"object_score_logits": torch.full(
|
| 476 |
size=(batch_size, 1),
|
|
@@ -478,7 +478,7 @@ class SAM2VideoPredictor(SAM2Base):
|
|
| 478 |
# present as sigmoid(10)=1, same as in `predict_masks` of `MaskDecoder`
|
| 479 |
fill_value=10.0,
|
| 480 |
dtype=torch.float32,
|
| 481 |
-
device=inference_state["
|
| 482 |
),
|
| 483 |
}
|
| 484 |
empty_mask_ptr = None
|
|
@@ -545,9 +545,7 @@ class SAM2VideoPredictor(SAM2Base):
|
|
| 545 |
frame_idx=frame_idx,
|
| 546 |
batch_size=batch_size,
|
| 547 |
high_res_masks=high_res_masks,
|
| 548 |
-
object_score_logits=consolidated_out["object_score_logits"]
|
| 549 |
-
device, non_blocking=True
|
| 550 |
-
),
|
| 551 |
is_mask_from_pts=True, # these frames are what the user interacted with
|
| 552 |
)
|
| 553 |
consolidated_out["maskmem_features"] = maskmem_features
|
|
@@ -881,10 +879,9 @@ class SAM2VideoPredictor(SAM2Base):
|
|
| 881 |
def _get_image_feature(self, inference_state, frame_idx, batch_size):
|
| 882 |
"""Compute the image features on a given frame."""
|
| 883 |
# Look up in the cache first
|
| 884 |
-
|
| 885 |
-
|
| 886 |
-
|
| 887 |
-
image, backbone_out = None, None
|
| 888 |
if backbone_out is None:
|
| 889 |
# Cache miss -- we will run inference on a single image
|
| 890 |
device = inference_state["device"]
|
|
@@ -892,7 +889,7 @@ class SAM2VideoPredictor(SAM2Base):
|
|
| 892 |
backbone_out = self.forward_image(image)
|
| 893 |
# Cache the most recent frame's feature (for repeated interactions with
|
| 894 |
# a frame; we can use an LRU cache for more frames in the future).
|
| 895 |
-
|
| 896 |
|
| 897 |
# expand the features to have the same dimension as the number of objects
|
| 898 |
expanded_image = image.expand(batch_size, -1, -1, -1)
|
|
@@ -967,11 +964,9 @@ class SAM2VideoPredictor(SAM2Base):
|
|
| 967 |
pred_masks = pred_masks_gpu.to(storage_device, non_blocking=True)
|
| 968 |
# "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
|
| 969 |
maskmem_pos_enc = self._get_maskmem_pos_enc(inference_state, current_out)
|
| 970 |
-
# object pointer is a small tensor, so we always keep it on GPU memory for fast access
|
| 971 |
-
obj_ptr = current_out["obj_ptr"]
|
| 972 |
-
object_score_logits = current_out["object_score_logits"]
|
| 973 |
-
storage_device, non_blocking=True
|
| 974 |
-
)
|
| 975 |
# make a compact version of this frame's output to reduce the state size
|
| 976 |
compact_current_out = {
|
| 977 |
"maskmem_features": maskmem_features,
|
|
@@ -1023,7 +1018,6 @@ class SAM2VideoPredictor(SAM2Base):
|
|
| 1023 |
`maskmem_pos_enc` is the same across frames and objects, so we cache it as
|
| 1024 |
a constant in the inference session to reduce session storage size.
|
| 1025 |
"""
|
| 1026 |
-
storage_device = inference_state["storage_device"]
|
| 1027 |
model_constants = inference_state["constants"]
|
| 1028 |
# "out_maskmem_pos_enc" should be either a list of tensors or None
|
| 1029 |
out_maskmem_pos_enc = current_out["maskmem_pos_enc"]
|
|
@@ -1032,9 +1026,7 @@ class SAM2VideoPredictor(SAM2Base):
|
|
| 1032 |
assert isinstance(out_maskmem_pos_enc, list)
|
| 1033 |
# only take the slice for one object, since it's same across objects
|
| 1034 |
maskmem_pos_enc = [x[0:1].clone() for x in out_maskmem_pos_enc]
|
| 1035 |
-
model_constants["maskmem_pos_enc"] = maskmem_pos_enc
|
| 1036 |
-
storage_device, non_blocking=True
|
| 1037 |
-
)
|
| 1038 |
else:
|
| 1039 |
maskmem_pos_enc = model_constants["maskmem_pos_enc"]
|
| 1040 |
# expand the cached maskmem_pos_enc to the actual batch size
|
|
|
|
| 470 |
size=(batch_size, self.hidden_dim),
|
| 471 |
fill_value=NO_OBJ_SCORE,
|
| 472 |
dtype=torch.float32,
|
| 473 |
+
device=inference_state["device"],
|
| 474 |
),
|
| 475 |
"object_score_logits": torch.full(
|
| 476 |
size=(batch_size, 1),
|
|
|
|
| 478 |
# present as sigmoid(10)=1, same as in `predict_masks` of `MaskDecoder`
|
| 479 |
fill_value=10.0,
|
| 480 |
dtype=torch.float32,
|
| 481 |
+
device=inference_state["device"],
|
| 482 |
),
|
| 483 |
}
|
| 484 |
empty_mask_ptr = None
|
|
|
|
| 545 |
frame_idx=frame_idx,
|
| 546 |
batch_size=batch_size,
|
| 547 |
high_res_masks=high_res_masks,
|
| 548 |
+
object_score_logits=consolidated_out["object_score_logits"],
|
|
|
|
|
|
|
| 549 |
is_mask_from_pts=True, # these frames are what the user interacted with
|
| 550 |
)
|
| 551 |
consolidated_out["maskmem_features"] = maskmem_features
|
|
|
|
| 879 |
def _get_image_feature(self, inference_state, frame_idx, batch_size):
|
| 880 |
"""Compute the image features on a given frame."""
|
| 881 |
# Look up in the cache first
|
| 882 |
+
image, backbone_out = inference_state["cached_features"].get(
|
| 883 |
+
frame_idx, (None, None)
|
| 884 |
+
)
|
|
|
|
| 885 |
if backbone_out is None:
|
| 886 |
# Cache miss -- we will run inference on a single image
|
| 887 |
device = inference_state["device"]
|
|
|
|
| 889 |
backbone_out = self.forward_image(image)
|
| 890 |
# Cache the most recent frame's feature (for repeated interactions with
|
| 891 |
# a frame; we can use an LRU cache for more frames in the future).
|
| 892 |
+
inference_state["cached_features"] = {frame_idx: (image, backbone_out)}
|
| 893 |
|
| 894 |
# expand the features to have the same dimension as the number of objects
|
| 895 |
expanded_image = image.expand(batch_size, -1, -1, -1)
|
|
|
|
| 964 |
pred_masks = pred_masks_gpu.to(storage_device, non_blocking=True)
|
| 965 |
# "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
|
| 966 |
maskmem_pos_enc = self._get_maskmem_pos_enc(inference_state, current_out)
|
| 967 |
+
# object pointer is a small tensor, so we always keep it on GPU memory for fast access
|
| 968 |
+
obj_ptr = current_out["obj_ptr"]
|
| 969 |
+
object_score_logits = current_out["object_score_logits"]
|
|
|
|
|
|
|
| 970 |
# make a compact version of this frame's output to reduce the state size
|
| 971 |
compact_current_out = {
|
| 972 |
"maskmem_features": maskmem_features,
|
|
|
|
| 1018 |
`maskmem_pos_enc` is the same across frames and objects, so we cache it as
|
| 1019 |
a constant in the inference session to reduce session storage size.
|
| 1020 |
"""
|
|
|
|
| 1021 |
model_constants = inference_state["constants"]
|
| 1022 |
# "out_maskmem_pos_enc" should be either a list of tensors or None
|
| 1023 |
out_maskmem_pos_enc = current_out["maskmem_pos_enc"]
|
|
|
|
| 1026 |
assert isinstance(out_maskmem_pos_enc, list)
|
| 1027 |
# only take the slice for one object, since it's same across objects
|
| 1028 |
maskmem_pos_enc = [x[0:1].clone() for x in out_maskmem_pos_enc]
|
| 1029 |
+
model_constants["maskmem_pos_enc"] = maskmem_pos_enc
|
|
|
|
|
|
|
| 1030 |
else:
|
| 1031 |
maskmem_pos_enc = model_constants["maskmem_pos_enc"]
|
| 1032 |
# expand the cached maskmem_pos_enc to the actual batch size
|
trimm_examples.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
from moviepy.editor import VideoFileClip
|
| 4 |
+
|
| 5 |
+
# Define the folder and duration
|
| 6 |
+
input_folder = "examples"
|
| 7 |
+
output_folder = "examples/trimmed"
|
| 8 |
+
trim_duration = 3 # seconds
|
| 9 |
+
|
| 10 |
+
# Create output folder if it doesn't exist
|
| 11 |
+
os.makedirs(output_folder, exist_ok=True)
|
| 12 |
+
|
| 13 |
+
# Process each .mp4 file
|
| 14 |
+
for filename in os.listdir(input_folder):
|
| 15 |
+
if filename.lower().endswith(".mp4"):
|
| 16 |
+
input_path = os.path.join(input_folder, filename)
|
| 17 |
+
output_path = os.path.join(output_folder, filename)
|
| 18 |
+
|
| 19 |
+
print(f"Trimming: {input_path} -> {output_path}")
|
| 20 |
+
try:
|
| 21 |
+
clip = VideoFileClip(input_path).subclip(0, trim_duration)
|
| 22 |
+
clip.write_videofile(
|
| 23 |
+
output_path,
|
| 24 |
+
codec="libx264",
|
| 25 |
+
audio_codec="aac",
|
| 26 |
+
verbose=False,
|
| 27 |
+
logger=None,
|
| 28 |
+
)
|
| 29 |
+
except Exception as e:
|
| 30 |
+
print(f"Failed to process {filename}: {e}")
|