@@ -60,6 +60,51 @@ struct TakeZeroAxisCPU {
6060 }
6161};
6262
63+ template <bool clip = true >
64+ struct TakeNonzeroAxisCPU {
65+ /* !
66+ * \brief Map function for take operator
67+ * \param i global thread id
68+ * \param out_data ptr to output buffer
69+ * \param in_data ptr to input buffer
70+ * \param indices ptr to indices buffer
71+ * \param outer_dim_stride stride of dimension before axis
72+ * \param axis_dim_stride stride of axis dimension
73+ * \param idx_size size of the indices tensor
74+ * \param axis_dim dim size of the axis dimension
75+ * \param axis axis id
76+ */
77+ template <typename DType, typename IType>
78+ MSHADOW_XINLINE static void Map (index_t i,
79+ DType* out_data,
80+ const DType* in_data,
81+ const IType* indices,
82+ const index_t outer_dim_stride,
83+ const index_t axis_dim_stride,
84+ const int idx_size,
85+ const int axis_dim,
86+ const int axis) {
87+ for (index_t j = 0 ; j < static_cast <index_t >(idx_size); ++j) {
88+ int index = indices[j];
89+ if (clip) {
90+ index = std::max (index, 0 );
91+ index = std::min (axis_dim - 1 , index);
92+ } else {
93+ index %= axis_dim;
94+ index += (index < 0 ) ? axis_dim : 0 ;
95+ }
96+ size_t in_offset = i * outer_dim_stride + index * axis_dim_stride;
97+ size_t out_offset = (i * idx_size + j) * axis_dim_stride;
98+ #pragma GCC diagnostic push
99+ #if __GNUC__ >= 8
100+ #pragma GCC diagnostic ignored "-Wclass-memaccess"
101+ #endif
102+ std::memcpy (out_data + out_offset, in_data + in_offset, axis_dim_stride * sizeof (DType));
103+ #pragma GCC diagnostic pop
104+ }
105+ }
106+ };
107+
63108/*
64109 * \brief returns true if all indices are between [min, max]
65110 * \param data_ptr the indices to check
@@ -323,6 +368,7 @@ void TakeOpForward<cpu>(const nnvm::NodeAttrs& attrs,
323368 const std::vector<OpReqType>& req,
324369 const std::vector<TBlob>& outputs) {
325370 using namespace mxnet_op ;
371+
326372 if (req[take_::kOut ] == kNullOp )
327373 return ;
328374 const TakeParam& param = nnvm::get<TakeParam>(attrs.parsed );
@@ -375,39 +421,32 @@ void TakeOpForward<cpu>(const nnvm::NodeAttrs& attrs,
375421 for (int i = arrshape.ndim () - 1 ; i >= 0 ; stride *= arrshape[i], --i) {
376422 in_strides[i] = stride;
377423 }
378- mshadow::Shape<10 > out_strides;
379- stride = 1 ;
380- for (int i = oshape.ndim () - 1 ; i >= 0 ; stride *= oshape[i], --i) {
381- out_strides[i] = stride;
424+ int outer_dimensions = 1 ;
425+ for (int i = 0 ; i < actual_axis; i++) {
426+ outer_dimensions *= oshape[i];
382427 }
383428 if (param.mode == take_::kClip ) {
384- Kernel<TakeNonzeroAxis<true >, cpu>::Launch (s,
385- oshape.Size (),
386- outputs[take_::kOut ].dptr <DType>(),
387- inputs[take_::kArr ].dptr <DType>(),
388- inputs[take_::kIdx ].dptr <IType>(),
389- out_strides[actual_axis - 1 ],
390- in_strides[actual_axis - 1 ],
391- in_strides[actual_axis],
392- arrshape.ndim (),
393- oshape.ndim (),
394- idxshape.ndim (),
395- arrshape[actual_axis],
396- actual_axis);
429+ Kernel<TakeNonzeroAxisCPU<true >, cpu>::Launch (s,
430+ outer_dimensions,
431+ outputs[take_::kOut ].dptr <DType>(),
432+ inputs[take_::kArr ].dptr <DType>(),
433+ inputs[take_::kIdx ].dptr <IType>(),
434+ in_strides[actual_axis - 1 ],
435+ in_strides[actual_axis],
436+ idxshape.Size (),
437+ arrshape[actual_axis],
438+ actual_axis);
397439 } else {
398- Kernel<TakeNonzeroAxis<false >, cpu>::Launch (s,
399- oshape.Size (),
400- outputs[take_::kOut ].dptr <DType>(),
401- inputs[take_::kArr ].dptr <DType>(),
402- inputs[take_::kIdx ].dptr <IType>(),
403- out_strides[actual_axis - 1 ],
404- in_strides[actual_axis - 1 ],
405- in_strides[actual_axis],
406- arrshape.ndim (),
407- oshape.ndim (),
408- idxshape.ndim (),
409- arrshape[actual_axis],
410- actual_axis);
440+ Kernel<TakeNonzeroAxisCPU<false >, cpu>::Launch (s,
441+ outer_dimensions,
442+ outputs[take_::kOut ].dptr <DType>(),
443+ inputs[take_::kArr ].dptr <DType>(),
444+ inputs[take_::kIdx ].dptr <IType>(),
445+ in_strides[actual_axis - 1 ],
446+ in_strides[actual_axis],
447+ idxshape.Size (),
448+ arrshape[actual_axis],
449+ actual_axis);
411450 }
412451 }
413452 });
0 commit comments