Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def get_macros_and_flags():
define_macros += [("WITH_HIP", None)]
nvcc_flags = []
else:
define_macros += [("WITH_CUDA", None)]
define_macros += [("WITH_CUDA", None), ("USE_CUDA", None)]
if NVCC_FLAGS is None:
nvcc_flags = []
else:
Expand Down
122 changes: 81 additions & 41 deletions torchvision/csrc/ops/cpu/nms_kernel.cpp
Original file line number Diff line number Diff line change
@@ -1,48 +1,83 @@
#include <ATen/ATen.h>
#include <torch/library.h>
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/ops.h>
#include <torch/csrc/stable/tensor.h>
#include <torch/headeronly/core/Dispatch.h>
#include <torch/headeronly/core/ScalarType.h>

namespace vision {
namespace ops {

namespace {

namespace stable = torch::stable;
namespace headeronly = torch::headeronly;
using headeronly::ScalarType;
using stable::Tensor;

inline std::pair<Tensor, Tensor> dispatch_sort(
const Tensor& self,
bool is_stable,
int64_t dim,
bool descending) {
constexpr int num_args = 4;
std::array<StableIValue, num_args> stack{
stable::detail::from(self),
stable::detail::from(std::optional<bool>(is_stable)),
stable::detail::from(dim),
stable::detail::from(descending)};
STABLE_TORCH_ERROR_CODE_CHECK(torch_call_dispatcher(
"aten::sort", "stable", stack.data(), TORCH_ABI_VERSION));
return {
stable::detail::to<Tensor>(stack[0]),
stable::detail::to<Tensor>(stack[1])};
}

inline Tensor dispatch_mul(const Tensor& self, const Tensor& other) {
constexpr int num_args = 2;
std::array<StableIValue, num_args> stack{
stable::detail::from(self), stable::detail::from(other)};
STABLE_TORCH_ERROR_CODE_CHECK(torch_call_dispatcher(
"aten::mul", "Tensor", stack.data(), TORCH_ABI_VERSION));
return stable::detail::to<Tensor>(stack[0]);
}

template <typename scalar_t>
at::Tensor nms_kernel_impl(
const at::Tensor& dets,
const at::Tensor& scores,
Tensor nms_kernel_impl(
const Tensor& dets,
const Tensor& scores,
double iou_threshold) {
TORCH_CHECK(dets.is_cpu(), "dets must be a CPU tensor");
TORCH_CHECK(scores.is_cpu(), "scores must be a CPU tensor");
TORCH_CHECK(
STD_TORCH_CHECK(dets.is_cpu(), "dets must be a CPU tensor");
STD_TORCH_CHECK(scores.is_cpu(), "scores must be a CPU tensor");
STD_TORCH_CHECK(
dets.scalar_type() == scores.scalar_type(),
"dets should have the same type as scores");

if (dets.numel() == 0) {
return at::empty({0}, dets.options().dtype(at::kLong));
return stable::new_empty(dets, {0}, ScalarType::Long);
}

auto x1_t = dets.select(1, 0).contiguous();
auto y1_t = dets.select(1, 1).contiguous();
auto x2_t = dets.select(1, 2).contiguous();
auto y2_t = dets.select(1, 3).contiguous();
auto x1_t = stable::contiguous(stable::select(dets, 1, 0));
auto y1_t = stable::contiguous(stable::select(dets, 1, 1));
auto x2_t = stable::contiguous(stable::select(dets, 1, 2));
auto y2_t = stable::contiguous(stable::select(dets, 1, 3));

at::Tensor areas_t = (x2_t - x1_t) * (y2_t - y1_t);
auto areas_t =
dispatch_mul(stable::subtract(x2_t, x1_t), stable::subtract(y2_t, y1_t));

auto order_t = std::get<1>(
scores.sort(/*stable=*/true, /*dim=*/0, /* descending=*/true));
auto [values, order_t] = dispatch_sort(scores, true, 0, true);

auto ndets = dets.size(0);
at::Tensor suppressed_t = at::zeros({ndets}, dets.options().dtype(at::kByte));
at::Tensor keep_t = at::zeros({ndets}, dets.options().dtype(at::kLong));

auto suppressed = suppressed_t.data_ptr<uint8_t>();
auto keep = keep_t.data_ptr<int64_t>();
auto order = order_t.data_ptr<int64_t>();
auto x1 = x1_t.data_ptr<scalar_t>();
auto y1 = y1_t.data_ptr<scalar_t>();
auto x2 = x2_t.data_ptr<scalar_t>();
auto y2 = y2_t.data_ptr<scalar_t>();
auto areas = areas_t.data_ptr<scalar_t>();
auto suppressed_t = stable::new_zeros(dets, {ndets}, ScalarType::Byte);
auto keep_t = stable::new_zeros(dets, {ndets}, ScalarType::Long);

auto suppressed = suppressed_t.mutable_data_ptr<uint8_t>();
auto keep = keep_t.mutable_data_ptr<int64_t>();
auto order = order_t.const_data_ptr<int64_t>();
auto x1 = x1_t.const_data_ptr<scalar_t>();
auto y1 = y1_t.const_data_ptr<scalar_t>();
auto x2 = x2_t.const_data_ptr<scalar_t>();
auto y2 = y2_t.const_data_ptr<scalar_t>();
auto areas = areas_t.const_data_ptr<scalar_t>();

int64_t num_to_keep = 0;

Expand Down Expand Up @@ -77,44 +112,49 @@ at::Tensor nms_kernel_impl(
}
}
}
return keep_t.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep);
return stable::narrow(keep_t, 0, 0, num_to_keep);
}

at::Tensor nms_kernel(
const at::Tensor& dets,
const at::Tensor& scores,
Tensor nms_kernel(
const Tensor& dets,
const Tensor& scores,
double iou_threshold) {
TORCH_CHECK(
STD_TORCH_CHECK(
dets.dim() == 2, "boxes should be a 2d tensor, got ", dets.dim(), "D");
TORCH_CHECK(
STD_TORCH_CHECK(
dets.size(1) == 4,
"boxes should have 4 elements in dimension 1, got ",
dets.size(1));
TORCH_CHECK(
STD_TORCH_CHECK(
scores.dim() == 1,
"scores should be a 1d tensor, got ",
scores.dim(),
"D");
TORCH_CHECK(
STD_TORCH_CHECK(
dets.size(0) == scores.size(0),
"boxes and scores should have same number of elements in ",
"dimension 0, got ",
dets.size(0),
" and ",
scores.size(0));

auto result = at::empty({0}, dets.options());
auto result = stable::new_empty(dets, {0});

AT_DISPATCH_FLOATING_TYPES(dets.scalar_type(), "nms_kernel", [&] {
result = nms_kernel_impl<scalar_t>(dets, scores, iou_threshold);
});
THO_DISPATCH_SWITCH(
dets.scalar_type(),
"nms_kernel",
THO_DISPATCH_CASE(ScalarType::Float, [&] {
result = nms_kernel_impl<scalar_t>(dets, scores, iou_threshold);
}) THO_DISPATCH_CASE(ScalarType::Double, [&] {
result = nms_kernel_impl<scalar_t>(dets, scores, iou_threshold);
}));
return result;
}

} // namespace

TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN(nms_kernel));
STABLE_TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
m.impl("nms", TORCH_BOX(&nms_kernel));
}

} // namespace ops
Expand Down
Loading
Loading