Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
42c0c18
Add SAM2Generic class
cjaverliat May 3, 2025
3a4ce6e
Variable renaming + docstring
cjaverliat May 3, 2025
6ae45ca
Add device transfer for empty prompt embeddings
cjaverliat May 3, 2025
85255d7
Add generic video predictor
cjaverliat May 3, 2025
7431eb9
Fix formatting
cjaverliat May 3, 2025
54bfe74
Add build_sam2_generic
cjaverliat May 3, 2025
df361a0
Add autoscale when encoding uint8 images
cjaverliat May 3, 2025
86e16e0
Update assertion in condition_image_embeddings_on_memories
cjaverliat May 3, 2025
ea32c69
Add SAM2Result containing masks_logits, ious, obj_ptrs and obj_score_…
cjaverliat May 4, 2025
66436f2
Add SAM2Prompt containing points, boxes and masks prompts
cjaverliat May 4, 2025
bceef89
Add ObjectMemory, ObjectMemorySelection and ObjectMemoryBank
cjaverliat May 4, 2025
4a9d730
Implement SAM2 existing memory bank using new classes
cjaverliat May 4, 2025
0ce42f9
Remove unnecessary SAM2ObjectMemory
cjaverliat May 4, 2025
e3db8ee
Add SAM2Result concatenation method
cjaverliat May 4, 2025
c65a20f
Add device property to SAM2Result
cjaverliat May 4, 2025
b97448f
Rename obj_score_logits to obj_scores_logits + add __getitem__ and ca…
cjaverliat May 4, 2025
6657bed
Add missing implementation for data transfer in ObjectMemory
cjaverliat May 4, 2025
d197aa1
Rename object_memories to ptr_memories
cjaverliat May 4, 2025
7b94cf9
Update ObjectMemoryBank try_add_memories function
cjaverliat May 4, 2025
7dfdb9e
Add method to count the number of stored memories in the memory bank
cjaverliat May 4, 2025
3c4dcd8
Modify sam2_generic and its video predictor to reflect changes on the…
cjaverliat May 4, 2025
1cb6a77
Fix best_mask_logits indexing issue
cjaverliat May 4, 2025
53f5292
Reshape binarize to be broadcastable
cjaverliat May 4, 2025
5b0ed20
Update notebook
cjaverliat May 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
564 changes: 564 additions & 0 deletions notebooks/generic_video_predictor_example.ipynb

Large diffs are not rendered by default.

80 changes: 80 additions & 0 deletions sam2/build_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@

import sam2

from sam2.modeling.sam2_generic import SAM2Generic
from sam2.modeling.sam2_memory import SAM2ObjectMemoryBank
from sam2.sam2_generic_video_predictor import SAM2GenericVideoPredictor

# Check if the user is running Python from the parent directory of the sam2 repo
# (i.e. the directory where this repo is cloned into) -- this is not supported since
# it could shadow the sam2 package and cause issues.
Expand Down Expand Up @@ -97,6 +101,82 @@ def build_sam2(
return model


def build_sam2_generic(
config_file,
ckpt_path=None,
device="cuda",
mode="eval",
hydra_overrides_extra=[],
apply_postprocessing=True,
) -> SAM2Generic:
hydra_overrides = [
"++model._target_=sam2.sam2_generic_video_predictor.SAM2Generic",
]
if apply_postprocessing:
hydra_overrides_extra = hydra_overrides_extra.copy()
hydra_overrides_extra += [
# dynamically fall back to multi-mask if the single mask is not stable
"++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
# the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking
"++model.binarize_mask_from_pts_for_mem_enc=true",
# fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution)
"++model.max_hole_area=8",
]
hydra_overrides.extend(hydra_overrides_extra)

# Read config and init model
cfg = compose(config_name=config_file, overrides=hydra_overrides)
OmegaConf.resolve(cfg)
model = instantiate(cfg.model, _recursive_=True)
_load_checkpoint(model, ckpt_path)
model = model.to(device)
if mode == "eval":
model.eval()
return model


def build_sam2_generic_video_predictor(
config_file,
ckpt_path=None,
device="cuda",
mode="eval",
hydra_overrides_extra=[],
apply_postprocessing=True
) -> SAM2GenericVideoPredictor:
hydra_overrides = [
"++model._target_=sam2.sam2_generic_video_predictor.SAM2GenericVideoPredictor",
]
if apply_postprocessing:
hydra_overrides_extra = hydra_overrides_extra.copy()
hydra_overrides_extra += [
# dynamically fall back to multi-mask if the single mask is not stable
"++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
# the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking
"++model.binarize_mask_from_pts_for_mem_enc=true",
# fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution)
"++model.max_hole_area=8",
]
hydra_overrides.extend(hydra_overrides_extra)

# Read config and init model
cfg = compose(config_name=config_file, overrides=hydra_overrides)
OmegaConf.resolve(cfg)
model = instantiate(
cfg.model,
_recursive_=True,
memory_bank=SAM2ObjectMemoryBank(),
)
_load_checkpoint(model, ckpt_path)
model = model.to(device)
if mode == "eval":
model.eval()
return model


def build_sam2_video_predictor(
config_file,
ckpt_path=None,
Expand Down
158 changes: 158 additions & 0 deletions sam2/modeling/memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
from __future__ import annotations
import torch

from abc import ABC, abstractmethod
from sam2.modeling.sam2_prompt import SAM2Prompt
from sam2.modeling.sam2_result import SAM2Result

class ObjectMemory:

def __init__(
self,
obj_id: int,
frame_idx: int,
memory_embeddings: torch.Tensor,
memory_pos_embeddings: torch.Tensor,
ptr: torch.Tensor,
is_conditional: bool = False,
):
self.obj_id = obj_id
self.frame_idx = frame_idx
self.memory_embeddings = memory_embeddings
self.memory_pos_embeddings = memory_pos_embeddings
self.ptr = ptr
self.is_conditional = is_conditional

@abstractmethod
def to(self, device: torch.device) -> ObjectMemory:
return ObjectMemory(
obj_id=self.obj_id,
frame_idx=self.frame_idx,
memory_embeddings=self.memory_embeddings.to(device),
memory_pos_embeddings=self.memory_pos_embeddings.to(device),
ptr=self.ptr.to(device),
)

class ObjectMemorySelection:

def __init__(
self,
conditional_memories: list[ObjectMemory],
non_conditional_memories: list[ObjectMemory],
ptr_memories: list[ObjectMemory],
):
self.conditional_memories = conditional_memories
self.non_conditional_memories = non_conditional_memories
self.ptr_memories = ptr_memories

def to(self, device: torch.device) -> ObjectMemorySelection:
return ObjectMemorySelection(
conditional_memories=[
memory.to(device) for memory in self.conditional_memories
],
non_conditional_memories=[
memory.to(device) for memory in self.non_conditional_memories
],
ptr_memories=[memory.to(device) for memory in self.ptr_memories],
)


class ObjectMemoryBank(ABC):

def __init__(self):
self.known_obj_ids = set()

@abstractmethod
def count_stored_conditional_memories(self, obj_id: int) -> int:
raise NotImplementedError

@abstractmethod
def count_stored_non_conditional_memories(self, obj_id: int) -> int:
raise NotImplementedError

def clear_known_obj_ids(self):
self.known_obj_ids = set()

@abstractmethod
def try_add_memories(
self,
frame_idx: int,
obj_ids: list[int],
memory_embeddings: torch.Tensor,
memory_pos_embeddings: torch.Tensor,
results: SAM2Result,
prompts: list[SAM2Prompt],
) -> list[tuple[bool, ObjectMemory]]:
"""
Try to add memories to the memory bank.

Args:
frame_idx: The frame index.
obj_ids: The object IDs of shape (B,).
memory_embeddings: The memory embeddings of shape (B, N, H, W).
memory_pos_embeddings: The memory positional embeddings of shape (B, N, H, W).
results: The SAM2Result for all the objects. Expected to have batch size B.
prompts: The list of prompts. Can be of any length between 0 and B.

Returns:
A list of tuples containing a boolean indicating whether the memory was added and the memory itself.
"""
raise NotImplementedError

@abstractmethod
def prune_memories(self, obj_ids: list[int], current_frame_idx: int) -> dict[int, list[ObjectMemory]]:
"""
Remove memories that are no longer needed for a list of objects and return the list of pruned memories.

Args:
obj_ids: The object IDs.
current_frame_idx: The current frame index.

Returns:
A dictionary mapping object IDs to lists of pruned memories.
"""
raise NotImplementedError

@abstractmethod
def select_memories(
self,
obj_ids: list[int],
current_frame_idx: int,
max_conditional_memories: int,
max_non_conditional_memories: int,
max_ptr_memories: int,
only_include_pointers_in_past: bool = False,
reverse_tracking: bool = False,
) -> dict[int, ObjectMemorySelection]:
"""
Select the memories for each object for conditioning the frame at current_frame_idx.

Args:
obj_ids: The object IDs to select memories for.
max_conditional_memories: The maximum number of conditional memories to select.
max_non_conditional_memories: The maximum number of non-conditional memories to select.
max_object_memories: The maximum number of object memories (obj_ptrs) to select.
current_frame_idx: The current frame index.
reverse_tracking: Whether the tracking direction is reversed.

Returns:
A dictionary mapping object IDs to memory selections.
"""
raise NotImplementedError

@abstractmethod
def clear_object_non_conditional_memories_in_frame_range(
self, obj_id: int, frame_idx_start: int, frame_idx_end: int
) -> list[ObjectMemory]:
"""
Clear the non-conditional memories for an object in a given frame range (inclusive).

Args:
obj_id: The object ID.
frame_idx_start: The start frame index (inclusive).
frame_idx_end: The end frame index (inclusive).

Returns:
A list of removed memories.
"""
raise NotImplementedError
Loading