import warnings from collections import OrderedDict import os import torch import torch.distributed from torch import nn import torch.nn.functional as F from torch.nn.init import trunc_normal_ from tqdm import tqdm from hydra import compose from hydra.utils import instantiate from omegaconf import OmegaConf from .sam2.sam2_video_predictor import SAM2VideoPredictor as _SAM2VideoPredictor from .sam2.modeling.sam2_base import NO_OBJ_SCORE from .sam2.utils.misc import concat_points, fill_holes_in_mask_scores, load_video_frames from .sam2.modeling.sam2_utils import get_1d_sine_pe, MLP, select_closest_cond_frames def build_sam2_video_predictor( config_file, num_maskmem=7, hydra_overrides_extra=[], apply_postprocessing=True, **kwargs, ): hydra_overrides = [ "++model._target_=inference.sam2_video_predictor.SAM2VideoPredictor", ] if apply_postprocessing: hydra_overrides_extra = hydra_overrides_extra.copy() hydra_overrides_extra += [ # dynamically fall back to multi-mask if the single mask is not stable "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true", "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05", "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98", # the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking "++model.binarize_mask_from_pts_for_mem_enc=true", # fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution) "++model.fill_hole_area=8", ] hydra_overrides_extra.append( f"model.num_maskmem={num_maskmem}" ) hydra_overrides.extend(hydra_overrides_extra) # Read config and init model cfg = compose(config_name=config_file, overrides=hydra_overrides) OmegaConf.resolve(cfg) model = instantiate(cfg.model, _recursive_=True) return model class SAM2VideoPredictor(_SAM2VideoPredictor): def init_state(self, video_path, **kwargs): inference_state = super().init_state(video_path=video_path, **kwargs) frame_names = [ os.path.splitext(p)[0] for p in os.listdir(video_path) if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"] ] frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) inference_state["video_paths"] = [ os.path.join(video_path, f"{frame_name}.jpg") for frame_name in frame_names ] return inference_state def _prepare_memory_conditioned_features( self, frame_idx, is_init_cond_frame, current_vision_feats, current_vision_pos_embeds, feat_sizes, output_dict, num_frames, track_in_reverse=False, # tracking in reverse time order (for demo usage) start_frame_idx=0, iou_thre=0.3, ): """Fuse the current frame's visual feature map with previous memory.""" B = current_vision_feats[-1].size(1) # batch size on this frame C = self.hidden_dim H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size device = current_vision_feats[-1].device # The case of `self.num_maskmem == 0` below is primarily used for reproducing SAM on images. # In this case, we skip the fusion with any memory. if self.num_maskmem == 0: # Disable memory and skip fusion pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W) return pix_feat num_obj_ptr_tokens = 0 tpos_sign_mul = -1 if track_in_reverse else 1 # Step 1: condition the visual features of the current frame on previous memories if not is_init_cond_frame: # Retrieve the memories encoded with the maskmem backbone to_cat_memory, to_cat_memory_pos_embed = [], [] # Add conditioning frames's output first (all cond frames have t_pos=0 for # when getting temporal positional embedding below) assert len(output_dict["cond_frame_outputs"]) > 0 # Select a maximum number of temporally closest cond frames for cross attention cond_outputs = output_dict["cond_frame_outputs"] selected_cond_outputs, unselected_cond_outputs = select_closest_cond_frames( frame_idx, cond_outputs, self.max_cond_frames_in_attn ) t_pos_and_prevs = [(0, out) for out in selected_cond_outputs.values()] # Add last (self.num_maskmem - 1) frames before current frame for non-conditioning memory # the earliest one has t_pos=1 and the latest one has t_pos=self.num_maskmem-1 # We also allow taking the memory frame non-consecutively (with stride>1), in which case # we take (self.num_maskmem - 2) frames among every stride-th frames plus the last frame. stride = 1 if self.training else self.memory_temporal_stride_for_eval _memory_frames = max(self.max_obj_ptrs_in_encoder, self.num_maskmem) max_obj_ptrs_in_encoder = min(num_frames, _memory_frames) valid_indices = [ i for i in range(frame_idx - 1, start_frame_idx, -1) if (output_dict["non_cond_frame_outputs"][i]['ious'].max().item() > iou_thre and output_dict["non_cond_frame_outputs"][i]['object_score_logits'].item() > 0) ][:max_obj_ptrs_in_encoder - 1] valid_indices.sort() if frame_idx - 1 not in valid_indices and (frame_idx - 1) != start_frame_idx: valid_indices.append(frame_idx - 1) for t_pos in range(1, self.num_maskmem): t_rel = t_pos - self.num_maskmem # how many frames before current frame if t_rel < -len(valid_indices): continue prev_frame_idx = valid_indices[t_rel] out = output_dict["non_cond_frame_outputs"].get(prev_frame_idx, None) if out is None: # If an unselected conditioning frame is among the last (self.num_maskmem - 1) # frames, we still attend to it as if it's a non-conditioning frame. out = unselected_cond_outputs.get(prev_frame_idx, None) t_pos_and_prevs.append((t_pos, out)) for t_pos, prev in t_pos_and_prevs: if prev is None: continue # skip padding frames # "maskmem_features" might have been offloaded to CPU in demo use cases, # so we load it back to GPU (it's a no-op if it's already on GPU). feats = prev["maskmem_features"].to(device, non_blocking=True) to_cat_memory.append(feats.flatten(2).permute(2, 0, 1)) # Spatial positional encoding (it might have been offloaded to CPU in eval) maskmem_enc = prev["maskmem_pos_enc"][-1].to(device) maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1) # Temporal positional encoding maskmem_enc = ( maskmem_enc + self.maskmem_tpos_enc[self.num_maskmem - t_pos - 1] ) to_cat_memory_pos_embed.append(maskmem_enc) # Construct the list of past object pointers if self.use_obj_ptrs_in_encoder: max_obj_ptrs_in_encoder = min(num_frames, self.max_obj_ptrs_in_encoder) # First add those object pointers from selected conditioning frames # (optionally, only include object pointers in the past during evaluation) if not self.training and self.only_obj_ptrs_in_the_past_for_eval: ptr_cond_outputs = { t: out for t, out in selected_cond_outputs.items() if (t >= frame_idx if track_in_reverse else t <= frame_idx) } else: ptr_cond_outputs = selected_cond_outputs pos_and_ptrs = [ # Temporal pos encoding contains how far away each pointer is from current frame ( ( (frame_idx - t) * tpos_sign_mul if self.use_signed_tpos_enc_to_obj_ptrs else abs(frame_idx - t) ), out["obj_ptr"].to(device, non_blocking=True), ) for t, out in ptr_cond_outputs.items() ] # Add up to (max_obj_ptrs_in_encoder - 1) non-conditioning frames before current frame for t_diff in range(1, max_obj_ptrs_in_encoder): # t = frame_idx + t_diff if track_in_reverse else frame_idx - t_diff if -t_diff <= -len(valid_indices): break # if t < 0 or (num_frames is not None and t >= num_frames): # break out = output_dict["non_cond_frame_outputs"].get( valid_indices[-t_diff], unselected_cond_outputs.get(valid_indices[-t_diff], None) ) if out is not None: pos_and_ptrs.append((t_diff, out["obj_ptr"].to(device, non_blocking=True))) # If we have at least one object pointer, add them to the across attention if len(pos_and_ptrs) > 0: pos_list, ptrs_list = zip(*pos_and_ptrs) # stack object pointers along dim=0 into [ptr_seq_len, B, C] shape obj_ptrs = torch.stack(ptrs_list, dim=0) # a temporal positional embedding based on how far each object pointer is from # the current frame (sine embedding normalized by the max pointer num). if self.add_tpos_enc_to_obj_ptrs: t_diff_max = max_obj_ptrs_in_encoder - 1 tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim obj_pos = torch.tensor(pos_list).to( device=device, non_blocking=True ) obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim) obj_pos = self.obj_ptr_tpos_proj(obj_pos) obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim) else: obj_pos = obj_ptrs.new_zeros(len(pos_list), B, self.mem_dim) if self.mem_dim < C: # split a pointer into (C // self.mem_dim) tokens for self.mem_dim < C obj_ptrs = obj_ptrs.reshape( -1, B, C // self.mem_dim, self.mem_dim ) obj_ptrs = obj_ptrs.permute(0, 2, 1, 3).flatten(0, 1) obj_pos = obj_pos.repeat_interleave(C // self.mem_dim, dim=0) to_cat_memory.append(obj_ptrs) to_cat_memory_pos_embed.append(obj_pos) num_obj_ptr_tokens = obj_ptrs.shape[0] else: num_obj_ptr_tokens = 0 else: # for initial conditioning frames, encode them without using any previous memory if self.directly_add_no_mem_embed: # directly add no-mem embedding (instead of using the transformer encoder) pix_feat_with_mem = current_vision_feats[-1] + self.no_mem_embed pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W) return pix_feat_with_mem # Use a dummy token on the first frame (to avoid empty memory input to tranformer encoder) to_cat_memory = [self.no_mem_embed.expand(1, B, self.mem_dim)] to_cat_memory_pos_embed = [self.no_mem_pos_enc.expand(1, B, self.mem_dim)] # Step 2: Concatenate the memories and forward through the transformer encoder memory = torch.cat(to_cat_memory, dim=0) memory_pos_embed = torch.cat(to_cat_memory_pos_embed, dim=0) pix_feat_with_mem = self.memory_attention( curr=current_vision_feats, curr_pos=current_vision_pos_embeds, memory=memory, memory_pos=memory_pos_embed, num_obj_ptr_tokens=num_obj_ptr_tokens, ) # reshape the output (HW)BC => BCHW pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W) return pix_feat_with_mem def _track_step( self, frame_idx, is_init_cond_frame, current_vision_feats, current_vision_pos_embeds, feat_sizes, point_inputs, mask_inputs, output_dict, num_frames, track_in_reverse, prev_sam_mask_logits, ## Extension: LLM prompt start_frame_idx=0, language_embd=None, ): current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs} # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW if len(current_vision_feats) > 1: high_res_features = [ x.permute(1, 2, 0).view(x.size(1), x.size(2), *s) for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1]) ] else: high_res_features = None if mask_inputs is not None and self.use_mask_input_as_output_without_sam: # When use_mask_input_as_output_without_sam=True, we directly output the mask input # (see it as a GT mask) without using a SAM prompt encoder + mask decoder. pix_feat = current_vision_feats[-1].permute(1, 2, 0) pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1]) sam_outputs = self._use_mask_as_output( pix_feat, high_res_features, mask_inputs ) else: # fused the visual feature with previous memory features in the memory bank pix_feat = self._prepare_memory_conditioned_features( frame_idx=frame_idx, is_init_cond_frame=is_init_cond_frame, current_vision_feats=current_vision_feats[-1:], current_vision_pos_embeds=current_vision_pos_embeds[-1:], feat_sizes=feat_sizes[-1:], output_dict=output_dict, num_frames=num_frames, track_in_reverse=track_in_reverse, start_frame_idx=start_frame_idx, ) if language_embd is not None: _language_embd = language_embd.reshape(-1, 1, 256 // self.mem_dim, self.mem_dim) _language_embd = _language_embd.permute(0, 2, 1, 3).flatten(0, 1) _language_embd_pos = _language_embd.new_zeros(_language_embd.size(0), 1, self.mem_dim) pix_feat_with_language = self.token_attn( curr=current_vision_feats[-1:], curr_pos=current_vision_pos_embeds[-1:], memory=_language_embd, memory_pos=_language_embd_pos, num_obj_ptr_tokens=_language_embd.shape[0], ) pix_feat_with_language = pix_feat_with_language.permute(1, 2, 0).view(1, 256, 64, 64) pix_feat = (pix_feat_with_language + pix_feat) / 2 # apply SAM-style segmentation head # here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder, # e.g. in demo where such logits come from earlier interaction instead of correction sampling # (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead) if prev_sam_mask_logits is not None: assert point_inputs is not None and mask_inputs is None mask_inputs = prev_sam_mask_logits multimask_output = self._use_multimask(is_init_cond_frame, point_inputs) sam_outputs = self._forward_sam_heads( backbone_features=pix_feat, point_inputs=point_inputs, mask_inputs=mask_inputs, high_res_features=high_res_features, multimask_output=multimask_output, ) return current_out, sam_outputs, high_res_features, pix_feat def track_step( self, frame_idx, is_init_cond_frame, current_vision_feats, current_vision_pos_embeds, feat_sizes, point_inputs, mask_inputs, output_dict, num_frames, track_in_reverse=False, # tracking in reverse time order (for demo usage) # Whether to run the memory encoder on the predicted masks. Sometimes we might want # to skip the memory encoder with `run_mem_encoder=False`. For example, # in demo we might call `track_step` multiple times for each user click, # and only encode the memory when the user finalizes their clicks. And in ablation # settings like SAM training on static images, we don't need the memory encoder. run_mem_encoder=True, # The previously predicted SAM mask logits (which can be fed together with new clicks in demo). prev_sam_mask_logits=None, ## Extension: LLM prompt start_frame_idx=0, language_embd=None, ): current_out, sam_outputs, _, _ = self._track_step( frame_idx, is_init_cond_frame, current_vision_feats, current_vision_pos_embeds, feat_sizes, point_inputs, mask_inputs, output_dict, num_frames, track_in_reverse, prev_sam_mask_logits, start_frame_idx, language_embd, ) ( low_res_multimasks, _, ious, low_res_masks, high_res_masks, obj_ptr, object_score_logits, ) = sam_outputs current_out["pred_masks"] = low_res_masks current_out["pred_masks_high_res"] = high_res_masks current_out["obj_ptr"] = obj_ptr current_out["low_res_multimasks"] = low_res_multimasks if not self.training: # Only add this in inference (to avoid unused param in activation checkpointing; # it's mainly used in the demo to encode spatial memories w/ consolidated masks) current_out["object_score_logits"] = object_score_logits current_out["ious"] = ious # Finally run the memory encoder on the predicted mask to encode # it into a new memory feature (that can be used in future frames) self._encode_memory_in_output( current_vision_feats, feat_sizes, point_inputs, run_mem_encoder, high_res_masks, object_score_logits, current_out, ) return current_out def _run_single_frame_inference( self, inference_state, output_dict, frame_idx, batch_size, is_init_cond_frame, point_inputs, mask_inputs, reverse, run_mem_encoder, prev_sam_mask_logits=None, ## Extension: LLM prompt start_frame_idx=0, language_embd=None, ): """Run tracking on a single frame based on current inputs and previous memory.""" # Retrieve correct image features ( _, _, current_vision_feats, current_vision_pos_embeds, feat_sizes, ) = self._get_image_feature(inference_state, frame_idx, batch_size) # point and mask should not appear as input simultaneously on the same frame assert point_inputs is None or mask_inputs is None current_out = self.track_step( frame_idx=frame_idx, is_init_cond_frame=is_init_cond_frame, current_vision_feats=current_vision_feats, current_vision_pos_embeds=current_vision_pos_embeds, feat_sizes=feat_sizes, point_inputs=point_inputs, mask_inputs=mask_inputs, output_dict=output_dict, num_frames=inference_state["num_frames"], track_in_reverse=reverse, run_mem_encoder=run_mem_encoder, prev_sam_mask_logits=prev_sam_mask_logits, start_frame_idx=start_frame_idx, language_embd=language_embd, ) # optionally offload the output to CPU memory to save GPU space storage_device = inference_state["storage_device"] maskmem_features = current_out["maskmem_features"] if maskmem_features is not None: maskmem_features = maskmem_features.to(torch.bfloat16) maskmem_features = maskmem_features.to(storage_device, non_blocking=True) pred_masks_gpu = current_out["pred_masks"] # potentially fill holes in the predicted masks if self.fill_hole_area > 0: pred_masks_gpu = fill_holes_in_mask_scores( pred_masks_gpu, self.fill_hole_area ) pred_masks = pred_masks_gpu.to(storage_device, non_blocking=True) # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it maskmem_pos_enc = self._get_maskmem_pos_enc(inference_state, current_out) # object pointer is a small tensor, so we always keep it on GPU memory for fast access obj_ptr = current_out["obj_ptr"] ious = current_out["ious"] object_score_logits = current_out["object_score_logits"] low_res_multimasks = current_out["low_res_multimasks"] # make a compact version of this frame's output to reduce the state size compact_current_out = { "maskmem_features": maskmem_features, "maskmem_pos_enc": maskmem_pos_enc, "pred_masks": pred_masks, "obj_ptr": obj_ptr, "ious": ious, "object_score_logits": object_score_logits, "low_res_multimasks": low_res_multimasks, } return compact_current_out, pred_masks_gpu