11import numpy as np
2+ from packaging .version import Version
23
34from .utils import (
45 aggregate_common_doc ,
1314 check_fill_value ,
1415 input_validation ,
1516 iscomplexobj ,
17+ maxval ,
1618 minimum_dtype ,
1719 minimum_dtype_scalar ,
1820 minval ,
19- maxval ,
2021)
2122
2223
24+ def _full (size , fill_value , * , dtype = None , like = None ):
25+ """Backcompat for numpy < 1.20.0 which does not support the `like` kwarg"""
26+ if (
27+ like is not None # numpy bug?
28+ and not np .isscalar (like ) # scalars don't work
29+ and Version (np .__version__ ) >= Version ("1.20.0" )
30+ ):
31+ kwargs = {"like" : like }
32+ else :
33+ kwargs = {}
34+
35+ return np .full (size , fill_value = fill_value , dtype = dtype , ** kwargs )
36+
37+
2338def _sum (group_idx , a , size , fill_value , dtype = None ):
2439 dtype = minimum_dtype_scalar (fill_value , dtype , a )
2540
@@ -44,7 +59,7 @@ def _sum(group_idx, a, size, fill_value, dtype=None):
4459
4560def _prod (group_idx , a , size , fill_value , dtype = None ):
4661 dtype = minimum_dtype_scalar (fill_value , dtype , a )
47- ret = np . full (size , fill_value , dtype = dtype , like = a )
62+ ret = _full (size , fill_value , dtype = dtype , like = a )
4863 if fill_value != 1 :
4964 ret [group_idx ] = 1 # product starts from 1
5065 np .multiply .at (ret , group_idx , a )
@@ -57,7 +72,7 @@ def _len(group_idx, a, size, fill_value, dtype=None):
5772
5873def _last (group_idx , a , size , fill_value , dtype = None ):
5974 dtype = minimum_dtype (fill_value , dtype or a .dtype )
60- ret = np . full (size , fill_value , dtype = dtype , like = a )
75+ ret = _full (size , fill_value , dtype = dtype , like = a )
6176 # repeated indexing gives last value, see:
6277 # the phrase "leaving behind the last value" on this page:
6378 # http://wiki.scipy.org/Tentative_NumPy_Tutorial
@@ -67,14 +82,14 @@ def _last(group_idx, a, size, fill_value, dtype=None):
6782
6883def _first (group_idx , a , size , fill_value , dtype = None ):
6984 dtype = minimum_dtype (fill_value , dtype or a .dtype )
70- ret = np . full (size , fill_value , dtype = dtype , like = a )
85+ ret = _full (size , fill_value , dtype = dtype , like = a )
7186 ret [group_idx [::- 1 ]] = a [::- 1 ] # same trick as _last, but in reverse
7287 return ret
7388
7489
7590def _all (group_idx , a , size , fill_value , dtype = None ):
7691 check_boolean (fill_value )
77- ret = np . full (size , fill_value , dtype = bool , like = a )
92+ ret = _full (size , fill_value , dtype = bool , like = a )
7893 if not fill_value :
7994 ret [group_idx ] = True
8095 ret [group_idx .compress (np .logical_not (a ))] = False
@@ -83,7 +98,7 @@ def _all(group_idx, a, size, fill_value, dtype=None):
8398
8499def _any (group_idx , a , size , fill_value , dtype = None ):
85100 check_boolean (fill_value )
86- ret = np . full (size , fill_value , dtype = bool , like = a )
101+ ret = _full (size , fill_value , dtype = bool , like = a )
87102 if fill_value :
88103 ret [group_idx ] = False
89104 ret [group_idx .compress (a )] = True
@@ -93,7 +108,7 @@ def _any(group_idx, a, size, fill_value, dtype=None):
93108def _min (group_idx , a , size , fill_value , dtype = None ):
94109 dtype = minimum_dtype (fill_value , dtype or a .dtype )
95110 dmax = maxval (fill_value , dtype )
96- ret = np . full (size , fill_value , dtype = dtype , like = a )
111+ ret = _full (size , fill_value , dtype = dtype , like = a )
97112 if fill_value != dmax :
98113 ret [group_idx ] = dmax # min starts from maximum
99114 np .minimum .at (ret , group_idx , a )
@@ -103,7 +118,7 @@ def _min(group_idx, a, size, fill_value, dtype=None):
103118def _max (group_idx , a , size , fill_value , dtype = None ):
104119 dtype = minimum_dtype (fill_value , dtype or a .dtype )
105120 dmin = minval (fill_value , dtype )
106- ret = np . full (size , fill_value , dtype = dtype , like = a )
121+ ret = _full (size , fill_value , dtype = dtype , like = a )
107122 if fill_value != dmin :
108123 ret [group_idx ] = dmin # max starts from minimum
109124 np .maximum .at (ret , group_idx , a )
@@ -115,7 +130,7 @@ def _argmax(group_idx, a, size, fill_value, dtype=int, _nansqueeze=False):
115130 group_max = _max (group_idx , a_ , size , np .nan )
116131 # nan should never be maximum, so use a and not a_
117132 is_max = a == group_max [group_idx ]
118- ret = np . full (size , fill_value , dtype = dtype , like = a )
133+ ret = _full (size , fill_value , dtype = dtype , like = a )
119134 group_idx_max = group_idx [is_max ]
120135 (argmax ,) = is_max .nonzero ()
121136 ret [group_idx_max [::- 1 ]] = argmax [
@@ -129,7 +144,7 @@ def _argmin(group_idx, a, size, fill_value, dtype=int, _nansqueeze=False):
129144 group_min = _min (group_idx , a_ , size , np .nan )
130145 # nan should never be minimum, so use a and not a_
131146 is_min = a == group_min [group_idx ]
132- ret = np . full (size , fill_value , dtype = dtype , like = a )
147+ ret = _full (size , fill_value , dtype = dtype , like = a )
133148 group_idx_min = group_idx [is_min ]
134149 (argmin ,) = is_min .nonzero ()
135150 ret [group_idx_min [::- 1 ]] = argmin [
@@ -148,7 +163,9 @@ def _mean(group_idx, a, size, fill_value, dtype=np.dtype(np.float64)):
148163 sums .real = np .bincount (group_idx , weights = a .real , minlength = size )
149164 sums .imag = np .bincount (group_idx , weights = a .imag , minlength = size )
150165 else :
151- sums = np .bincount (group_idx , weights = a , minlength = size ).astype (dtype , copy = False )
166+ sums = np .bincount (group_idx , weights = a , minlength = size ).astype (
167+ dtype , copy = False
168+ )
152169
153170 with np .errstate (divide = "ignore" , invalid = "ignore" ):
154171 ret = sums .astype (dtype , copy = False ) / counts
@@ -223,7 +240,7 @@ def _generic_callable(
223240 """groups a by inds, and then applies foo to each group in turn, placing
224241 the results in an array."""
225242 groups = _array (group_idx , a , size , ())
226- ret = np . full (size , fill_value , dtype = dtype or np .float64 )
243+ ret = _full (size , fill_value , dtype = dtype or np .float64 )
227244
228245 for i , grp in enumerate (groups ):
229246 if np .ndim (grp ) == 1 and len (grp ) > 0 :
0 commit comments