Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 69e6c04

Browse files
bgawrychszhabartekkuncer
authored
Optimize 'take' operator for CPU (#20745)
* Improve performance of take operator * remove comment * Fix build * fix sanity * Add comment * review * Update src/operator/tensor/indexing_op.h Co-authored-by: bartekkuncer <[email protected]> Co-authored-by: Sheng Zha <[email protected]> Co-authored-by: bartekkuncer <[email protected]>
1 parent 7d84b59 commit 69e6c04

File tree

2 files changed

+72
-32
lines changed

2 files changed

+72
-32
lines changed

src/operator/tensor/indexing_op.cc

Lines changed: 69 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,51 @@ struct TakeZeroAxisCPU {
6060
}
6161
};
6262

63+
template <bool clip = true>
64+
struct TakeNonzeroAxisCPU {
65+
/*!
66+
* \brief Map function for take operator
67+
* \param i global thread id
68+
* \param out_data ptr to output buffer
69+
* \param in_data ptr to input buffer
70+
* \param indices ptr to indices buffer
71+
* \param outer_dim_stride stride of dimension before axis
72+
* \param axis_dim_stride stride of axis dimension
73+
* \param idx_size size of the indices tensor
74+
* \param axis_dim dim size of the axis dimension
75+
* \param axis axis id
76+
*/
77+
template <typename DType, typename IType>
78+
MSHADOW_XINLINE static void Map(index_t i,
79+
DType* out_data,
80+
const DType* in_data,
81+
const IType* indices,
82+
const index_t outer_dim_stride,
83+
const index_t axis_dim_stride,
84+
const int idx_size,
85+
const int axis_dim,
86+
const int axis) {
87+
for (index_t j = 0; j < static_cast<index_t>(idx_size); ++j) {
88+
int index = indices[j];
89+
if (clip) {
90+
index = std::max(index, 0);
91+
index = std::min(axis_dim - 1, index);
92+
} else {
93+
index %= axis_dim;
94+
index += (index < 0) ? axis_dim : 0;
95+
}
96+
size_t in_offset = i * outer_dim_stride + index * axis_dim_stride;
97+
size_t out_offset = (i * idx_size + j) * axis_dim_stride;
98+
#pragma GCC diagnostic push
99+
#if __GNUC__ >= 8
100+
#pragma GCC diagnostic ignored "-Wclass-memaccess"
101+
#endif
102+
std::memcpy(out_data + out_offset, in_data + in_offset, axis_dim_stride * sizeof(DType));
103+
#pragma GCC diagnostic pop
104+
}
105+
}
106+
};
107+
63108
/*
64109
* \brief returns true if all indices are between [min, max]
65110
* \param data_ptr the indices to check
@@ -323,6 +368,7 @@ void TakeOpForward<cpu>(const nnvm::NodeAttrs& attrs,
323368
const std::vector<OpReqType>& req,
324369
const std::vector<TBlob>& outputs) {
325370
using namespace mxnet_op;
371+
326372
if (req[take_::kOut] == kNullOp)
327373
return;
328374
const TakeParam& param = nnvm::get<TakeParam>(attrs.parsed);
@@ -375,39 +421,32 @@ void TakeOpForward<cpu>(const nnvm::NodeAttrs& attrs,
375421
for (int i = arrshape.ndim() - 1; i >= 0; stride *= arrshape[i], --i) {
376422
in_strides[i] = stride;
377423
}
378-
mshadow::Shape<10> out_strides;
379-
stride = 1;
380-
for (int i = oshape.ndim() - 1; i >= 0; stride *= oshape[i], --i) {
381-
out_strides[i] = stride;
424+
int outer_dimensions = 1;
425+
for (int i = 0; i < actual_axis; i++) {
426+
outer_dimensions *= oshape[i];
382427
}
383428
if (param.mode == take_::kClip) {
384-
Kernel<TakeNonzeroAxis<true>, cpu>::Launch(s,
385-
oshape.Size(),
386-
outputs[take_::kOut].dptr<DType>(),
387-
inputs[take_::kArr].dptr<DType>(),
388-
inputs[take_::kIdx].dptr<IType>(),
389-
out_strides[actual_axis - 1],
390-
in_strides[actual_axis - 1],
391-
in_strides[actual_axis],
392-
arrshape.ndim(),
393-
oshape.ndim(),
394-
idxshape.ndim(),
395-
arrshape[actual_axis],
396-
actual_axis);
429+
Kernel<TakeNonzeroAxisCPU<true>, cpu>::Launch(s,
430+
outer_dimensions,
431+
outputs[take_::kOut].dptr<DType>(),
432+
inputs[take_::kArr].dptr<DType>(),
433+
inputs[take_::kIdx].dptr<IType>(),
434+
in_strides[actual_axis - 1],
435+
in_strides[actual_axis],
436+
idxshape.Size(),
437+
arrshape[actual_axis],
438+
actual_axis);
397439
} else {
398-
Kernel<TakeNonzeroAxis<false>, cpu>::Launch(s,
399-
oshape.Size(),
400-
outputs[take_::kOut].dptr<DType>(),
401-
inputs[take_::kArr].dptr<DType>(),
402-
inputs[take_::kIdx].dptr<IType>(),
403-
out_strides[actual_axis - 1],
404-
in_strides[actual_axis - 1],
405-
in_strides[actual_axis],
406-
arrshape.ndim(),
407-
oshape.ndim(),
408-
idxshape.ndim(),
409-
arrshape[actual_axis],
410-
actual_axis);
440+
Kernel<TakeNonzeroAxisCPU<false>, cpu>::Launch(s,
441+
outer_dimensions,
442+
outputs[take_::kOut].dptr<DType>(),
443+
inputs[take_::kArr].dptr<DType>(),
444+
inputs[take_::kIdx].dptr<IType>(),
445+
in_strides[actual_axis - 1],
446+
in_strides[actual_axis],
447+
idxshape.Size(),
448+
arrshape[actual_axis],
449+
actual_axis);
411450
}
412451
}
413452
});

src/operator/tensor/indexing_op.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,8 +215,9 @@ inline bool EmbeddingOpBackwardStorageType(const nnvm::NodeAttrs& attrs,
215215
return dispatched;
216216
}
217217

218-
/*! \brief name the struct TakeNonzeroAxis for general take when
219-
* axis is not zero, use TakeZeroAxisGPU or TakeZeroAxisCPU for axis zero
218+
/*! \brief TakeNonzeroAxis is designated for general take when
219+
* axis is not zero (for CPU optimized version use TakeNonZeroAxisCPU and
220+
for axis zero use TakeZeroAxisGPU or TakeZeroAxisCPU)
220221
*/
221222
template <bool clip = true>
222223
struct TakeNonzeroAxis {

0 commit comments

Comments
 (0)