55from .xrutils import isnull
66
77
8+ def _prepare_for_flox (group_idx , array ):
9+ """
10+ Sort the input array once to save time.
11+ """
12+ assert array .shape [- 1 ] == group_idx .shape [0 ]
13+ issorted = (group_idx [:- 1 ] <= group_idx [1 :]).all ()
14+ if issorted :
15+ ordered_array = array
16+ else :
17+ perm = group_idx .argsort (kind = "stable" )
18+ group_idx = group_idx [..., perm ]
19+ ordered_array = array [..., perm ]
20+ return group_idx , ordered_array
21+
22+
823def _np_grouped_op (group_idx , array , op , axis = - 1 , size = None , fill_value = None , dtype = None , out = None ):
924 """
1025 most of this code is from shoyer's gist
@@ -13,7 +28,7 @@ def _np_grouped_op(group_idx, array, op, axis=-1, size=None, fill_value=None, dt
1328 # assumes input is sorted, which I do in core._prepare_for_flox
1429 aux = group_idx
1530
16- flag = np .concatenate (([True ], aux [1 :] != aux [:- 1 ]))
31+ flag = np .concatenate ((np . array ( [True ], like = array ) , aux [1 :] != aux [:- 1 ]))
1732 uniques = aux [flag ]
1833 (inv_idx ,) = flag .nonzero ()
1934
@@ -25,11 +40,11 @@ def _np_grouped_op(group_idx, array, op, axis=-1, size=None, fill_value=None, dt
2540 if out is None :
2641 out = np .full (array .shape [:- 1 ] + (size ,), fill_value = fill_value , dtype = dtype )
2742
28- if (len (uniques ) == size ) and (uniques == np .arange (size )).all ():
43+ if (len (uniques ) == size ) and (uniques == np .arange (size , like = array )).all ():
2944 # The previous version of this if condition
3045 # ((uniques[1:] - uniques[:-1]) == 1).all():
3146 # does not work when group_idx is [1, 2] for e.g.
32- # This happens during binning
47+ # This happens during binning
3348 op .reduceat (array , inv_idx , axis = axis , dtype = dtype , out = out )
3449 else :
3550 out [..., uniques ] = op .reduceat (array , inv_idx , axis = axis , dtype = dtype )
@@ -91,16 +106,14 @@ def nanlen(group_idx, array, *args, **kwargs):
91106def mean (group_idx , array , * , axis = - 1 , size = None , fill_value = None , dtype = None ):
92107 if fill_value is None :
93108 fill_value = 0
94- out = np .full (array .shape [:- 1 ] + (size ,), fill_value = fill_value , dtype = dtype )
95- sum (group_idx , array , axis = axis , size = size , dtype = dtype , out = out )
109+ out = sum (group_idx , array , axis = axis , size = size , dtype = dtype , fill_value = fill_value )
96110 out /= nanlen (group_idx , array , size = size , axis = axis , fill_value = 0 )
97111 return out
98112
99113
100114def nanmean (group_idx , array , * , axis = - 1 , size = None , fill_value = None , dtype = None ):
101115 if fill_value is None :
102116 fill_value = 0
103- out = np .full (array .shape [:- 1 ] + (size ,), fill_value = fill_value , dtype = dtype )
104- nansum (group_idx , array , size = size , axis = axis , dtype = dtype , out = out )
117+ out = nansum (group_idx , array , size = size , axis = axis , dtype = dtype , fill_value = fill_value )
105118 out /= nanlen (group_idx , array , size = size , axis = axis , fill_value = 0 )
106119 return out
0 commit comments