Skip to content

Conversation

@cjaverliat
Copy link

@cjaverliat cjaverliat commented May 3, 2025

Inference on video without extracting images

This PR proposes an alternative to the original SAM2Base, SAM2Generic, which provides new APIs. Additionally, I added a SAM2GenericVideoPredictor which is a re-implementation of the video predictor but with configurable strategies for memorization and removal of past memories (cf. here for an example), this solves the issue with keeping everything in the VRAM.

More importantly, this gives more flexibility as to how the images are provided (as tensors instead of giving a video url or individual images):

import cv2
import torch
from tqdm import tqdm
from sam2.sam2_generic_video_predictor import Prompt
from sam2.build_sam import build_sam2_generic_video_predictor

sam2_checkpoint = "../checkpoints/sam2.1_hiera_base_plus.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_b+.yaml"

predictor = build_sam2_generic_video_predictor(model_cfg, sam2_checkpoint, device=device)

cap = cv2.VideoCapture("./videos/bedroom.mp4")
n_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
orig_hw = (height, width)

def read_frame(cap) -> torch.Tensor:
    ret, frame = cap.read()
    if not ret:
        return None
    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    frame = torch.as_tensor(frame).permute(2, 0, 1).to(device) # HWC -> CHW
    frame = frame / 255.0
    return frame
 
# Add a prompt on the first frame
initial_frame = read_frame(cap)
points_coords = torch.tensor([400.0, 150.0], device=device).reshape((1, 1, 2))
points_labels = torch.tensor([1], device=device).reshape((1, 1))
prompt = Prompt(obj_id=0, points_coords=points_coords, points_labels=points_labels)
results = predictor.forward(frame=initial_frame, object_prompts=[prompt])

for f in tqdm(range(1, n_frames)):
    frame = read_frame(cap)

    if frame is None:
        break

    results = predictor.forward(frame=frame)
    
    # Do something with the result, for example:
    #     show_mask((results[0].best_mask_logits > 0), plt.gca(), obj_id=0)

The full usage example is available in the generic_video_predictor_example.ipynb notebook.

@cjaverliat cjaverliat marked this pull request as draft May 3, 2025 17:11
@hjj-lmx
Copy link

hjj-lmx commented Aug 11, 2025

Prompt

为什么我的gpu只占用了不到3G,推理的时候感觉没有用上gpu,使用率一直为0,速度很慢,四分钟视频七千多帧,要跑四十几分钟

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants