Skip to content

Commit 287f4ee

Browse files
dcherianclaude
andcommitted
Move _postprocess_numbagg to lib.py
Move the numbagg postprocessing function to lib.py for better organization. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
1 parent 52691c7 commit 287f4ee

File tree

2 files changed

+25
-22
lines changed

2 files changed

+25
-22
lines changed

flox/core.py

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
quantile_new_dims_func,
4949
)
5050
from .cache import memoize
51-
from .lib import ArrayLayer, dask_array_type, sparse_array_type
51+
from .lib import ArrayLayer, _postprocess_numbagg, dask_array_type, sparse_array_type
5252
from .options import OPTIONS
5353
from .xrutils import (
5454
_contains_cftime_datetimes,
@@ -196,27 +196,6 @@ class FactorizeKwargs(TypedDict, total=False):
196196
sort: bool
197197

198198

199-
def _postprocess_numbagg(result, *, func, fill_value, size, seen_groups):
200-
"""Account for numbagg not providing a fill_value kwarg."""
201-
from .aggregate_numbagg import DEFAULT_FILL_VALUE
202-
203-
if not isinstance(func, str) or func not in DEFAULT_FILL_VALUE:
204-
return result
205-
# The condition needs to be
206-
# len(found_groups) < size; if so we mask with fill_value (?)
207-
default_fv = DEFAULT_FILL_VALUE[func]
208-
needs_masking = fill_value is not None and not np.array_equal(fill_value, default_fv, equal_nan=True)
209-
groups = np.arange(size)
210-
if needs_masking:
211-
mask = np.isin(groups, seen_groups, assume_unique=True, invert=True)
212-
if mask.any():
213-
if isinstance(result, sparse_array_type):
214-
result.fill_value = fill_value
215-
else:
216-
result[..., groups[mask]] = fill_value
217-
return result
218-
219-
220199
def identity(x: T) -> T:
221200
return x
222201

flox/lib.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from dataclasses import dataclass
22

3+
import numpy as np
4+
35
from .types import DaskArray, Graph
46

57
try:
@@ -29,3 +31,25 @@ def to_array(self, dep: DaskArray) -> DaskArray:
2931

3032
graph = HighLevelGraph.from_collections(self.name, self.layer, dependencies=[dep])
3133
return Array(graph, self.name, self.chunks, meta=dep._meta)
34+
35+
36+
def _postprocess_numbagg(result, *, func, fill_value, size, seen_groups):
37+
"""Account for numbagg not providing a fill_value kwarg."""
38+
39+
from .aggregate_numbagg import DEFAULT_FILL_VALUE
40+
41+
if not isinstance(func, str) or func not in DEFAULT_FILL_VALUE:
42+
return result
43+
# The condition needs to be
44+
# len(found_groups) < size; if so we mask with fill_value (?)
45+
default_fv = DEFAULT_FILL_VALUE[func]
46+
needs_masking = fill_value is not None and not np.array_equal(fill_value, default_fv, equal_nan=True)
47+
groups = np.arange(size)
48+
if needs_masking:
49+
mask = np.isin(groups, seen_groups, assume_unique=True, invert=True)
50+
if mask.any():
51+
if isinstance(result, sparse_array_type):
52+
result.fill_value = fill_value
53+
else:
54+
result[..., groups[mask]] = fill_value
55+
return result

0 commit comments

Comments
 (0)