Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 25c25da

Browse files
Zha0q1haojin2
andauthored
port & update pr16744 numpy gcd (#19547)
* numpy-compatible gcd operator * use BinaryScalarRTCCompute * Update _op.py * Update np_elemwise_broadcast_op_extended.cc * fix * Update operator_tune.cc * fix kernel * add large tensor test * add gcd interoperability workload * Update test_numpy_interoperability.py * Update np_elemwise_broadcast_op_extended.cc * Update np_elemwise_broadcast_op_extended.cc * avoid ci linspce issue Co-authored-by: Hao Jin <[email protected]>
1 parent c723ae2 commit 25c25da

File tree

15 files changed

+278
-5
lines changed

15 files changed

+278
-5
lines changed

ci/docker/install/requirements

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
# the whole docker cache for the image
2020

2121
# Required dependencies
22-
numpy<1.20.0
22+
numpy>=1.17,<1.20.0
2323
requests>=2.20.0,<3
2424
graphviz<0.9.0,>=0.8.1
2525
contextvars;python_version<"3.7"

python/mxnet/amp/lists/symbol_fp16.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,8 @@
245245
'_npi_logistic',
246246
'_npi_lcm',
247247
'_npi_lcm_scalar',
248+
'_npi_gcd',
249+
'_npi_gcd_scalar',
248250
'_npi_linspace',
249251
'_npi_logical_not',
250252
'_npi_logical_and_scalar',

python/mxnet/ndarray/numpy/_op.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
'max', 'min', 'amax', 'amin', 'logical_and', 'logical_or', 'logical_xor',
4545
'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'unravel_index',
4646
'diag_indices_from', 'hanning', 'hamming', 'blackman', 'flip', 'flipud', 'fliplr',
47-
'hypot', 'bitwise_and', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm',
47+
'hypot', 'bitwise_and', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'gcd',
4848
'tril', 'triu', 'tri', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'cross', 'kron',
4949
'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'roll', 'rot90', 'einsum',
5050
'true_divide', 'nonzero', 'quantile', 'percentile', 'shares_memory', 'may_share_memory', 'interp',
@@ -2081,6 +2081,46 @@ def expand_dims(a, axis):
20812081
return _api_internal.expand_dims(a, axis)
20822082

20832083

2084+
@set_module('mxnet.ndarray.numpy')
2085+
@wrap_np_binary_func
2086+
def gcd(x1, x2, out=None, **kwargs):
2087+
"""
2088+
Returns the greatest common divisor of ``|x1|`` and ``|x2|``
2089+
2090+
Parameters
2091+
----------
2092+
x1, x2 : ndarrays or scalar values
2093+
The arrays for computing greatest common divisor. If x1.shape != x2.shape,
2094+
they must be broadcastable to a common shape (which may be the shape of
2095+
one or the other).
2096+
2097+
out : ndarray or None, optional
2098+
A location into which the result is stored. If provided, it must have a shape
2099+
that the inputs broadcast to. If not provided or None, a freshly-allocated array
2100+
is returned.
2101+
2102+
Returns
2103+
-------
2104+
y : ndarray or scalar
2105+
The greatest common divisor of the absolute value of the inputs
2106+
This is a scalar if both `x1` and `x2` are scalars.
2107+
2108+
See Also
2109+
--------
2110+
lcm : The lowest common multiple
2111+
2112+
Examples
2113+
--------
2114+
>>> np.gcd(12, 20)
2115+
4
2116+
>>> np.gcd(np.arange(6, dtype=int), 20)
2117+
array([20, 1, 2, 1, 4, 5], dtype=int64)
2118+
"""
2119+
if isinstance(x1, numeric_types) and isinstance(x2, numeric_types):
2120+
return _np.gcd(x1, x2, out=out)
2121+
return _api_internal.gcd(x1, x2, out)
2122+
2123+
20842124
@set_module('mxnet.ndarray.numpy')
20852125
@wrap_np_binary_func
20862126
def lcm(x1, x2, out=None, **kwargs):

python/mxnet/numpy/multiarray.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@
7474
'flip', 'flipud', 'fliplr', 'around', 'round', 'round_', 'arctan2', 'hypot',
7575
'triu_indices_from', 'triu_indices', 'tri',
7676
'bitwise_and', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad',
77-
'unique', 'lcm', 'tril', 'triu', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer',
77+
'unique', 'lcm', 'gcd', 'tril', 'triu', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer',
7878
'cross', 'kron', 'equal', 'not_equal', 'interp',
7979
'greater', 'less', 'greater_equal', 'less_equal', 'roll', 'rot90', 'einsum', 'true_divide', 'nonzero',
8080
'quantile', 'percentile', 'shares_memory', 'may_share_memory', 'diff', 'ediff1d', 'resize', 'matmul',
@@ -3620,6 +3620,44 @@ def power(x1, x2, out=None, **kwargs):
36203620
return _mx_nd_np.power(x1, x2, out=out)
36213621

36223622

3623+
@set_module('mxnet.numpy')
3624+
@wrap_np_binary_func
3625+
def gcd(x1, x2, out=None, **kwargs):
3626+
"""
3627+
Returns the greatest common divisor of ``|x1|`` and ``|x2|``
3628+
3629+
Parameters
3630+
----------
3631+
x1, x2 : ndarrays or scalar values
3632+
The arrays for computing greatest common divisor. If x1.shape != x2.shape,
3633+
they must be broadcastable to a common shape (which may be the shape of
3634+
one or the other).
3635+
3636+
out : ndarray or None, optional
3637+
A location into which the result is stored. If provided, it must have a shape
3638+
that the inputs broadcast to. If not provided or None, a freshly-allocated array
3639+
is returned.
3640+
3641+
Returns
3642+
-------
3643+
y : ndarray or scalar
3644+
The greatest common divisor of the absolute value of the inputs
3645+
This is a scalar if both `x1` and `x2` are scalars.
3646+
3647+
See Also
3648+
--------
3649+
gcd : The lowest common multiple
3650+
3651+
Examples
3652+
--------
3653+
>>> np.gcd(12, 20)
3654+
4
3655+
>>> np.gcd(np.arange(6, dtype=int), 20)
3656+
array([20, 1, 2, 1, 4, 5], dtype=int64)
3657+
"""
3658+
return _mx_nd_np.gcd(x1, x2, out=out)
3659+
3660+
36233661
@set_module('mxnet.numpy')
36243662
@wrap_np_binary_func
36253663
def lcm(x1, x2, out=None, **kwargs):

python/mxnet/numpy_dispatch_protocol.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,7 @@ def _register_array_function():
249249
'degrees',
250250
'hypot',
251251
'lcm',
252+
'gcd',
252253
# 'ldexp',
253254
'subtract',
254255
'multiply',

python/mxnet/symbol/numpy/_symbol.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
'flatnonzero', 'tril_indices', 'amax', 'amin', 'max', 'min', 'logical_and', 'logical_or', 'logical_xor',
5050
'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'unravel_index',
5151
'diag_indices_from', 'hanning', 'hamming', 'blackman', 'flip', 'flipud', 'fliplr',
52-
'hypot', 'bitwise_and', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'interp',
52+
'hypot', 'bitwise_and', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'gcd', 'interp',
5353
'tril', 'triu', 'tri', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'cross', 'kron',
5454
'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'roll', 'rot90', 'einsum',
5555
'true_divide', 'quantile', 'percentile', 'shares_memory', 'may_share_memory', 'diff', 'ediff1d',
@@ -1678,6 +1678,37 @@ def power(x1, x2, out=None, **kwargs):
16781678
return _ufunc_helper(x1, x2, _npi.power, _np.power, _npi.power_scalar, _npi.rpower_scalar, out)
16791679

16801680

1681+
@set_module('mxnet.symbol.numpy')
1682+
@wrap_np_binary_func
1683+
def gcd(x1, x2, out=None, **kwargs):
1684+
"""
1685+
Returns the greatest common divisor of ``|x1|`` and ``|x2|``
1686+
1687+
Parameters
1688+
----------
1689+
x1, x2 : ndarrays or scalar values
1690+
The arrays for computing greatest common divisor. If x1.shape != x2.shape,
1691+
they must be broadcastable to a common shape (which may be the shape of
1692+
one or the other).
1693+
1694+
out : ndarray or None, optional
1695+
A location into which the result is stored. If provided, it must have a shape
1696+
that the inputs broadcast to. If not provided or None, a freshly-allocated array
1697+
is returned.
1698+
1699+
Returns
1700+
-------
1701+
y : ndarray or scalar
1702+
The greatest common divisor of the absolute value of the inputs
1703+
This is a scalar if both `x1` and `x2` are scalars.
1704+
1705+
See Also
1706+
--------
1707+
lcm : The lowest common multiple
1708+
"""
1709+
return _ufunc_helper(x1, x2, _npi.gcd, _np.gcd, _npi.gcd_scalar, None, out)
1710+
1711+
16811712
@set_module('mxnet.symbol.numpy')
16821713
@wrap_np_binary_func
16831714
def matmul(a, b, out=None, **kwargs):

src/api/operator/numpy/np_elemwise_broadcast_op.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,14 @@ MXNET_REGISTER_API("_npi.lcm")
8888
UFuncHelper(args, ret, op, op_scalar, nullptr);
8989
});
9090

91+
MXNET_REGISTER_API("_npi.gcd")
92+
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
93+
using namespace runtime;
94+
const nnvm::Op* op = Op::Get("_npi_gcd");
95+
const nnvm::Op* op_scalar = Op::Get("_npi_gcd_scalar");
96+
UFuncHelper(args, ret, op, op_scalar, nullptr);
97+
});
98+
9199
MXNET_REGISTER_API("_npi.logical_and")
92100
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
93101
using namespace runtime;

src/common/cuda/rtc/forward_functions-inl.h

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -541,6 +541,49 @@ lcm(const DType a, const DType2 b) {
541541
}
542542
}
543543
544+
template <typename DType, typename DType2>
545+
__device__ inline typename type_util::mixed_type<DType, DType2>::type
546+
gcd(const DType a, const DType2 b) {
547+
if (type_util::is_integral<DType>::value &&
548+
type_util::is_integral<DType2>::value) {
549+
DType A = a;
550+
DType2 B = b;
551+
// minus cases.
552+
if (a < 0) {
553+
A = -a;
554+
}
555+
if (b < 0) {
556+
B = -b;
557+
}
558+
// handle zero-valued cases.
559+
DType c;
560+
if (a == 0 && b != 0) {
561+
c = B;
562+
} else if (b == 0 && a != 0) {
563+
c = A;
564+
} else if (a == 0 && b == 0) {
565+
c = 0;
566+
} else {
567+
DType tmp;
568+
if (A < B) {
569+
tmp = A;
570+
A = B;
571+
B = tmp;
572+
}
573+
while (A % B != 0) {
574+
A = A % B;
575+
tmp = A;
576+
A = B;
577+
B = tmp;
578+
}
579+
c = B;
580+
}
581+
return c;
582+
} else {
583+
return 0;
584+
}
585+
}
586+
544587
template <typename DType, typename DType2>
545588
__device__ inline typename type_util::mixed_type<DType, DType2>::type bitwise_xor(const DType a,
546589
const DType2 b) {

src/operator/mshadow_op.h

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1704,6 +1704,52 @@ struct nanprod_grad : public mxnet_op::tunable {
17041704
#pragma GCC diagnostic ignored "-Wint-in-bool-context"
17051705
#pragma GCC diagnostic ignored "-Wbool-compare"
17061706
#endif
1707+
1708+
/*! \brief used for computing binary greatest common divisor */
1709+
struct gcd : public mxnet_op::tunable {
1710+
template<typename DType>
1711+
MSHADOW_XINLINE static typename enable_if<is_integral<DType>::value, DType>::type
1712+
Map(DType a, DType b) {
1713+
// minus cases.
1714+
if (a < 0) {
1715+
a = -a;
1716+
}
1717+
if (b < 0) {
1718+
b = -b;
1719+
}
1720+
// handle zero-valued cases.
1721+
DType c;
1722+
if (a == 0 && b != 0) {
1723+
c = b;
1724+
} else if (b == 0 && a != 0) {
1725+
c = a;
1726+
} else if (a == 0 && b == 0) {
1727+
c = 0;
1728+
} else {
1729+
DType tmp;
1730+
if (a < b) {
1731+
tmp = a;
1732+
a = b;
1733+
b = tmp;
1734+
}
1735+
while (a % b != 0) {
1736+
a = a % b;
1737+
tmp = a;
1738+
a = b;
1739+
b = tmp;
1740+
}
1741+
c = b;
1742+
}
1743+
return c;
1744+
}
1745+
1746+
template<typename DType>
1747+
MSHADOW_XINLINE static typename enable_if<!is_integral<DType>::value, DType>::type
1748+
Map(DType a, DType b) {
1749+
return DType(0.0f);
1750+
}
1751+
};
1752+
17071753
/*! \brief used for computing binary lowest common multiple */
17081754
struct lcm : public mxnet_op::tunable {
17091755
template<typename DType>

src/operator/numpy/np_elemwise_broadcast_op_extended.cc

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,39 @@ NNVM_REGISTER_OP(_backward_npi_copysign)
6363
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastBackwardUseIn<cpu, mshadow_op::copysign_grad,
6464
mshadow_op::copysign_rgrad>);
6565

66+
NNVM_REGISTER_OP(_npi_gcd)
67+
.set_num_inputs(2)
68+
.set_num_outputs(1)
69+
.set_attr<nnvm::FListInputNames>("FListInputNames",
70+
[](const NodeAttrs& attrs) {
71+
return std::vector<std::string>{"lhs", "rhs"};
72+
})
73+
.set_attr<mxnet::FInferShape>("FInferShape", BinaryBroadcastShape)
74+
.set_attr<nnvm::FInferType>("FInferType", ElemwiseIntType<2, 1>)
75+
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
76+
[](const NodeAttrs& attrs){
77+
return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}};
78+
})
79+
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
80+
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastIntCompute<cpu, mshadow_op::gcd>)
81+
.add_argument("lhs", "NDArray-or-Symbol", "First input to the function")
82+
.add_argument("rhs", "NDArray-or-Symbol", "Second input to the function");
83+
84+
NNVM_REGISTER_OP(_npi_gcd_scalar)
85+
.set_num_inputs(1)
86+
.set_num_outputs(1)
87+
.set_attr_parser(ParamParser<NumpyBinaryScalarParam>)
88+
.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
89+
.set_attr<nnvm::FInferType>("FInferType", ElemwiseIntType<1, 1>)
90+
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
91+
[](const NodeAttrs& attrs){
92+
return std::vector<std::pair<int, int> >{{0, 0}};
93+
})
94+
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
95+
.add_argument("data", "NDArray-or-Symbol", "source input")
96+
.add_arguments(NumpyBinaryScalarParam::__FIELDS__())
97+
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::ComputeInt<cpu, mshadow_op::gcd>);
98+
6699
NNVM_REGISTER_OP(_npi_lcm)
67100
.set_num_inputs(2)
68101
.set_num_outputs(1)
@@ -94,7 +127,7 @@ NNVM_REGISTER_OP(_npi_lcm_scalar)
94127
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
95128
.add_argument("data", "NDArray-or-Symbol", "source input")
96129
.add_arguments(NumpyBinaryScalarParam::__FIELDS__())
97-
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::lcm>);
130+
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::ComputeInt<cpu, mshadow_op::lcm>);
98131

99132
NNVM_REGISTER_OP(_npi_bitwise_and)
100133
.set_num_inputs(2)

0 commit comments

Comments
 (0)