Skip to content

Commit 35816af

Browse files
dcherianclaude
andcommitted
Move _postprocess_numbagg to lib.py
Move the numbagg postprocessing function to lib.py for better organization. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
1 parent 3f58663 commit 35816af

File tree

2 files changed

+24
-19
lines changed

2 files changed

+24
-19
lines changed

flox/core.py

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -196,25 +196,7 @@ class FactorizeKwargs(TypedDict, total=False):
196196
sort: bool
197197

198198

199-
def _postprocess_numbagg(result, *, func, fill_value, size, seen_groups):
200-
"""Account for numbagg not providing a fill_value kwarg."""
201-
from .aggregate_numbagg import DEFAULT_FILL_VALUE
202-
203-
if not isinstance(func, str) or func not in DEFAULT_FILL_VALUE:
204-
return result
205-
# The condition needs to be
206-
# len(found_groups) < size; if so we mask with fill_value (?)
207-
default_fv = DEFAULT_FILL_VALUE[func]
208-
needs_masking = fill_value is not None and not np.array_equal(fill_value, default_fv, equal_nan=True)
209-
groups = np.arange(size)
210-
if needs_masking:
211-
mask = np.isin(groups, seen_groups, assume_unique=True, invert=True)
212-
if mask.any():
213-
if isinstance(result, sparse_array_type):
214-
result.fill_value = fill_value
215-
else:
216-
result[..., groups[mask]] = fill_value
217-
return result
199+
from .lib import _postprocess_numbagg
218200

219201

220202
def identity(x: T) -> T:

flox/lib.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,26 @@ def to_array(self, dep: DaskArray) -> DaskArray:
2929

3030
graph = HighLevelGraph.from_collections(self.name, self.layer, dependencies=[dep])
3131
return Array(graph, self.name, self.chunks, meta=dep._meta)
32+
33+
34+
def _postprocess_numbagg(result, *, func, fill_value, size, seen_groups):
35+
"""Account for numbagg not providing a fill_value kwarg."""
36+
import numpy as np
37+
38+
from .aggregate_numbagg import DEFAULT_FILL_VALUE
39+
40+
if not isinstance(func, str) or func not in DEFAULT_FILL_VALUE:
41+
return result
42+
# The condition needs to be
43+
# len(found_groups) < size; if so we mask with fill_value (?)
44+
default_fv = DEFAULT_FILL_VALUE[func]
45+
needs_masking = fill_value is not None and not np.array_equal(fill_value, default_fv, equal_nan=True)
46+
groups = np.arange(size)
47+
if needs_masking:
48+
mask = np.isin(groups, seen_groups, assume_unique=True, invert=True)
49+
if mask.any():
50+
if isinstance(result, sparse_array_type):
51+
result.fill_value = fill_value
52+
else:
53+
result[..., groups[mask]] = fill_value
54+
return result

0 commit comments

Comments
 (0)