diff --git a/setup.py b/setup.py index 23cd26dd7ee..fa3c061d64f 100644 --- a/setup.py +++ b/setup.py @@ -118,7 +118,7 @@ def get_macros_and_flags(): define_macros += [("WITH_HIP", None)] nvcc_flags = [] else: - define_macros += [("WITH_CUDA", None)] + define_macros += [("WITH_CUDA", None), ("USE_CUDA", None)] if NVCC_FLAGS is None: nvcc_flags = [] else: diff --git a/torchvision/csrc/ops/cpu/nms_kernel.cpp b/torchvision/csrc/ops/cpu/nms_kernel.cpp index 454ce118a6d..1f4fe0e78bc 100644 --- a/torchvision/csrc/ops/cpu/nms_kernel.cpp +++ b/torchvision/csrc/ops/cpu/nms_kernel.cpp @@ -1,48 +1,83 @@ -#include -#include +#include +#include +#include +#include +#include namespace vision { namespace ops { namespace { +namespace stable = torch::stable; +namespace headeronly = torch::headeronly; +using headeronly::ScalarType; +using stable::Tensor; + +inline std::pair dispatch_sort( + const Tensor& self, + bool is_stable, + int64_t dim, + bool descending) { + constexpr int num_args = 4; + std::array stack{ + stable::detail::from(self), + stable::detail::from(std::optional(is_stable)), + stable::detail::from(dim), + stable::detail::from(descending)}; + STABLE_TORCH_ERROR_CODE_CHECK(torch_call_dispatcher( + "aten::sort", "stable", stack.data(), TORCH_ABI_VERSION)); + return { + stable::detail::to(stack[0]), + stable::detail::to(stack[1])}; +} + +inline Tensor dispatch_mul(const Tensor& self, const Tensor& other) { + constexpr int num_args = 2; + std::array stack{ + stable::detail::from(self), stable::detail::from(other)}; + STABLE_TORCH_ERROR_CODE_CHECK(torch_call_dispatcher( + "aten::mul", "Tensor", stack.data(), TORCH_ABI_VERSION)); + return stable::detail::to(stack[0]); +} + template -at::Tensor nms_kernel_impl( - const at::Tensor& dets, - const at::Tensor& scores, +Tensor nms_kernel_impl( + const Tensor& dets, + const Tensor& scores, double iou_threshold) { - TORCH_CHECK(dets.is_cpu(), "dets must be a CPU tensor"); - TORCH_CHECK(scores.is_cpu(), "scores must be a CPU tensor"); - TORCH_CHECK( + STD_TORCH_CHECK(dets.is_cpu(), "dets must be a CPU tensor"); + STD_TORCH_CHECK(scores.is_cpu(), "scores must be a CPU tensor"); + STD_TORCH_CHECK( dets.scalar_type() == scores.scalar_type(), "dets should have the same type as scores"); if (dets.numel() == 0) { - return at::empty({0}, dets.options().dtype(at::kLong)); + return stable::new_empty(dets, {0}, ScalarType::Long); } - auto x1_t = dets.select(1, 0).contiguous(); - auto y1_t = dets.select(1, 1).contiguous(); - auto x2_t = dets.select(1, 2).contiguous(); - auto y2_t = dets.select(1, 3).contiguous(); + auto x1_t = stable::contiguous(stable::select(dets, 1, 0)); + auto y1_t = stable::contiguous(stable::select(dets, 1, 1)); + auto x2_t = stable::contiguous(stable::select(dets, 1, 2)); + auto y2_t = stable::contiguous(stable::select(dets, 1, 3)); - at::Tensor areas_t = (x2_t - x1_t) * (y2_t - y1_t); + auto areas_t = + dispatch_mul(stable::subtract(x2_t, x1_t), stable::subtract(y2_t, y1_t)); - auto order_t = std::get<1>( - scores.sort(/*stable=*/true, /*dim=*/0, /* descending=*/true)); + auto [values, order_t] = dispatch_sort(scores, true, 0, true); auto ndets = dets.size(0); - at::Tensor suppressed_t = at::zeros({ndets}, dets.options().dtype(at::kByte)); - at::Tensor keep_t = at::zeros({ndets}, dets.options().dtype(at::kLong)); - - auto suppressed = suppressed_t.data_ptr(); - auto keep = keep_t.data_ptr(); - auto order = order_t.data_ptr(); - auto x1 = x1_t.data_ptr(); - auto y1 = y1_t.data_ptr(); - auto x2 = x2_t.data_ptr(); - auto y2 = y2_t.data_ptr(); - auto areas = areas_t.data_ptr(); + auto suppressed_t = stable::new_zeros(dets, {ndets}, ScalarType::Byte); + auto keep_t = stable::new_zeros(dets, {ndets}, ScalarType::Long); + + auto suppressed = suppressed_t.mutable_data_ptr(); + auto keep = keep_t.mutable_data_ptr(); + auto order = order_t.const_data_ptr(); + auto x1 = x1_t.const_data_ptr(); + auto y1 = y1_t.const_data_ptr(); + auto x2 = x2_t.const_data_ptr(); + auto y2 = y2_t.const_data_ptr(); + auto areas = areas_t.const_data_ptr(); int64_t num_to_keep = 0; @@ -77,25 +112,25 @@ at::Tensor nms_kernel_impl( } } } - return keep_t.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep); + return stable::narrow(keep_t, 0, 0, num_to_keep); } -at::Tensor nms_kernel( - const at::Tensor& dets, - const at::Tensor& scores, +Tensor nms_kernel( + const Tensor& dets, + const Tensor& scores, double iou_threshold) { - TORCH_CHECK( + STD_TORCH_CHECK( dets.dim() == 2, "boxes should be a 2d tensor, got ", dets.dim(), "D"); - TORCH_CHECK( + STD_TORCH_CHECK( dets.size(1) == 4, "boxes should have 4 elements in dimension 1, got ", dets.size(1)); - TORCH_CHECK( + STD_TORCH_CHECK( scores.dim() == 1, "scores should be a 1d tensor, got ", scores.dim(), "D"); - TORCH_CHECK( + STD_TORCH_CHECK( dets.size(0) == scores.size(0), "boxes and scores should have same number of elements in ", "dimension 0, got ", @@ -103,18 +138,23 @@ at::Tensor nms_kernel( " and ", scores.size(0)); - auto result = at::empty({0}, dets.options()); + auto result = stable::new_empty(dets, {0}); - AT_DISPATCH_FLOATING_TYPES(dets.scalar_type(), "nms_kernel", [&] { - result = nms_kernel_impl(dets, scores, iou_threshold); - }); + THO_DISPATCH_SWITCH( + dets.scalar_type(), + "nms_kernel", + THO_DISPATCH_CASE(ScalarType::Float, [&] { + result = nms_kernel_impl(dets, scores, iou_threshold); + }) THO_DISPATCH_CASE(ScalarType::Double, [&] { + result = nms_kernel_impl(dets, scores, iou_threshold); + })); return result; } } // namespace -TORCH_LIBRARY_IMPL(torchvision, CPU, m) { - m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN(nms_kernel)); +STABLE_TORCH_LIBRARY_IMPL(torchvision, CPU, m) { + m.impl("nms", TORCH_BOX(&nms_kernel)); } } // namespace ops diff --git a/torchvision/csrc/ops/cuda/nms_kernel.cu b/torchvision/csrc/ops/cuda/nms_kernel.cu index 44ce8db6b8e..f13b80a6a4a 100644 --- a/torchvision/csrc/ops/cuda/nms_kernel.cu +++ b/torchvision/csrc/ops/cuda/nms_kernel.cu @@ -1,8 +1,13 @@ -#include -#include -#include -#include -#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include #include "cuda_helpers.h" @@ -11,6 +16,21 @@ namespace ops { namespace { +namespace stable = torch::stable; +namespace headeronly = torch::headeronly; +using headeronly::ScalarType; +using stable::Tensor; + +template +struct CudaAccType { + using type = T; +}; + +template <> +struct CudaAccType { + using type = float; +}; + int const threadsPerBlock = sizeof(unsigned long long) * 8; template @@ -21,7 +41,7 @@ __device__ inline bool devIoU( T left = max(a[0], b[0]), right = min(a[2], b[2]); T top = max(a[1], b[1]), bottom = min(a[3], b[3]); T width = max(right - left, (T)0), height = max(bottom - top, (T)0); - using acc_T = at::acc_type; + using acc_T = typename CudaAccType::type; acc_T interS = (acc_T)width * height; acc_T Sa = ((acc_T)a[2] - a[0]) * (a[3] - a[1]); acc_T Sb = ((acc_T)b[2] - b[0]) * (b[3] - b[1]); @@ -122,25 +142,66 @@ __global__ static void gather_keep_from_mask( } } -at::Tensor nms_kernel( - const at::Tensor& dets, - const at::Tensor& scores, +inline std::pair dispatch_sort( + const Tensor& self, + bool is_stable, + int64_t dim, + bool descending) { + constexpr int num_args = 4; + std::array stack{ + stable::detail::from(self), + stable::detail::from(std::optional(is_stable)), + stable::detail::from(dim), + stable::detail::from(descending)}; + STABLE_TORCH_ERROR_CODE_CHECK(torch_call_dispatcher( + "aten::sort", "stable", stack.data(), TORCH_ABI_VERSION)); + return { + stable::detail::to(stack[0]), + stable::detail::to(stack[1])}; +} + +inline Tensor dispatch_index_select( + const Tensor& self, + int64_t dim, + const Tensor& index) { + constexpr int num_args = 3; + std::array stack{ + stable::detail::from(self), + stable::detail::from(dim), + stable::detail::from(index)}; + STABLE_TORCH_ERROR_CODE_CHECK(torch_call_dispatcher( + "aten::index_select", "", stack.data(), TORCH_ABI_VERSION)); + return stable::detail::to(stack[0]); +} + +inline Tensor dispatch_masked_select(const Tensor& self, const Tensor& mask) { + constexpr int num_args = 2; + std::array stack{ + stable::detail::from(self), stable::detail::from(mask)}; + STABLE_TORCH_ERROR_CODE_CHECK(torch_call_dispatcher( + "aten::masked_select", "", stack.data(), TORCH_ABI_VERSION)); + return stable::detail::to(stack[0]); +} + +Tensor nms_kernel( + const Tensor& dets, + const Tensor& scores, double iou_threshold) { - TORCH_CHECK(dets.is_cuda(), "dets must be a CUDA tensor"); - TORCH_CHECK(scores.is_cuda(), "scores must be a CUDA tensor"); + STD_TORCH_CHECK(dets.is_cuda(), "dets must be a CUDA tensor"); + STD_TORCH_CHECK(scores.is_cuda(), "scores must be a CUDA tensor"); - TORCH_CHECK( + STD_TORCH_CHECK( dets.dim() == 2, "boxes should be a 2d tensor, got ", dets.dim(), "D"); - TORCH_CHECK( + STD_TORCH_CHECK( dets.size(1) == 4, "boxes should have 4 elements in dimension 1, got ", dets.size(1)); - TORCH_CHECK( + STD_TORCH_CHECK( scores.dim() == 1, "scores should be a 1d tensor, got ", scores.dim(), "D"); - TORCH_CHECK( + STD_TORCH_CHECK( dets.size(0) == scores.size(0), "boxes and scores should have same number of elements in ", "dimension 0, got ", @@ -148,40 +209,59 @@ at::Tensor nms_kernel( " and ", scores.size(0)) - at::cuda::CUDAGuard device_guard(dets.device()); + torch::stable::accelerator::DeviceGuard device_guard(dets.get_device_index()); if (dets.numel() == 0) { - return at::empty({0}, dets.options().dtype(at::kLong)); + return stable::new_empty(dets, {0}, ScalarType::Long); } - auto order_t = std::get<1>( - scores.sort(/*stable=*/true, /*dim=*/0, /* descending=*/true)); - auto dets_sorted = dets.index_select(0, order_t).contiguous(); + auto [values, order_t] = dispatch_sort(scores, true, 0, true); + auto dets_sorted = + stable::contiguous(dispatch_index_select(dets, 0, order_t)); - int dets_num = dets.size(0); + int64_t dets_num = dets.size(0); - const int col_blocks = ceil_div(dets_num, threadsPerBlock); + const int col_blocks = ceil_div(static_cast(dets_num), threadsPerBlock); - at::Tensor mask = - at::empty({dets_num * col_blocks}, dets.options().dtype(at::kLong)); + auto mask = + stable::new_empty(dets, {dets_num * col_blocks}, ScalarType::Long); dim3 blocks(col_blocks, col_blocks); dim3 threads(threadsPerBlock); - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - dets_sorted.scalar_type(), "nms_kernel", [&] { + void* stream_ptr = nullptr; + STABLE_TORCH_ERROR_CODE_CHECK( + aoti_torch_get_current_cuda_stream(dets.get_device_index(), &stream_ptr)); + cudaStream_t stream = static_cast(stream_ptr); + + THO_DISPATCH_SWITCH( + dets_sorted.scalar_type(), + "nms_kernel", + THO_DISPATCH_CASE(ScalarType::Float, [&] { + nms_kernel_impl<<>>( + dets_num, + iou_threshold, + dets_sorted.const_data_ptr(), + reinterpret_cast( + mask.mutable_data_ptr())); + }) THO_DISPATCH_CASE(ScalarType::Double, [&] { + nms_kernel_impl<<>>( + dets_num, + iou_threshold, + dets_sorted.const_data_ptr(), + reinterpret_cast( + mask.mutable_data_ptr())); + }) THO_DISPATCH_CASE(ScalarType::Half, [&] { nms_kernel_impl<<>>( dets_num, iou_threshold, - dets_sorted.data_ptr(), - (unsigned long long*)mask.data_ptr()); - }); + dets_sorted.const_data_ptr(), + reinterpret_cast( + mask.mutable_data_ptr())); + })); - at::Tensor keep = - at::zeros({dets_num}, dets.options().dtype(at::kBool).device(at::kCUDA)); + auto keep = stable::new_zeros(dets, {dets_num}, ScalarType::Bool); - // Unwrap the mask to fill keep with proper values // Keeping the unwrap on device instead of applying iterative for loops on cpu // prevents the device -> cpu -> device transfer that could be bottleneck for // large number of boxes. @@ -191,18 +271,18 @@ at::Tensor nms_kernel( min(col_blocks, threadsPerBlock), col_blocks * sizeof(unsigned long long), stream>>>( - keep.data_ptr(), - (unsigned long long*)mask.data_ptr(), + keep.mutable_data_ptr(), + reinterpret_cast(mask.mutable_data_ptr()), dets_num); - AT_CUDA_CHECK(cudaGetLastError()); - return order_t.masked_select(keep); + STD_CUDA_KERNEL_LAUNCH_CHECK(); + return dispatch_masked_select(order_t, keep); } } // namespace -TORCH_LIBRARY_IMPL(torchvision, CUDA, m) { - m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN(nms_kernel)); +STABLE_TORCH_LIBRARY_IMPL(torchvision, CUDA, m) { + m.impl("nms", TORCH_BOX(&nms_kernel)); } } // namespace ops diff --git a/torchvision/csrc/ops/nms.cpp b/torchvision/csrc/ops/nms.cpp index 5ecf8812f1b..b8078f6d0c0 100644 --- a/torchvision/csrc/ops/nms.cpp +++ b/torchvision/csrc/ops/nms.cpp @@ -1,7 +1,7 @@ #include "nms.h" #include -#include +#include #include namespace vision { @@ -18,10 +18,9 @@ at::Tensor nms( return op.call(dets, scores, iou_threshold); } -TORCH_LIBRARY_FRAGMENT(torchvision, m) { +STABLE_TORCH_LIBRARY_FRAGMENT(torchvision, m) { m.set_python_module("torchvision._meta_registrations"); - m.def(TORCH_SELECTIVE_SCHEMA( - "torchvision::nms(Tensor dets, Tensor scores, float iou_threshold) -> Tensor")); + m.def("nms(Tensor dets, Tensor scores, float iou_threshold) -> Tensor"); } } // namespace ops