diff --git a/sam2/csrc/connected_components.cu b/sam2/csrc/connected_components.cu index ced21eb32..5585d0776 100644 --- a/sam2/csrc/connected_components.cu +++ b/sam2/csrc/connected_components.cu @@ -226,6 +226,8 @@ std::vector get_connected_componnets( AT_ASSERTM((H % 2) == 0, "height must be an even number"); AT_ASSERTM((W % 2) == 0, "width must be an even number"); + at::cuda::CUDAGuard guard(inputs_.device()); + // label must be uint32_t auto label_options = torch::TensorOptions().dtype(torch::kInt32).device(inputs.device());