Skip to content

Commit 3f58663

Browse files
dcherianclaude
andcommitted
refactor: extract dask-specific functions to dask.py
Create new modules for dask and cubed specific functionality: - flox/dask.py: re-exports dask functions from core.py - flox/cubed.py: adds cubed_groupby_agg implementation This improves code organization and provides a clearer API for backend-specific functionality. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
1 parent 16d2fec commit 3f58663

File tree

2 files changed

+200
-0
lines changed

2 files changed

+200
-0
lines changed

flox/cubed.py

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
"""Cubed-specific functions for groupby operations.
2+
3+
This module provides Cubed-specific implementations for groupby operations.
4+
"""
5+
6+
from __future__ import annotations
7+
8+
from collections.abc import Sequence
9+
from functools import partial
10+
from typing import TYPE_CHECKING, Any
11+
12+
import numpy as np
13+
import pandas as pd
14+
15+
if TYPE_CHECKING:
16+
from .aggregations import Aggregation
17+
from .core import T_Axes, T_Engine, T_Method
18+
from .types import CubedArray, T_By
19+
20+
from .core import (
21+
ReindexStrategy,
22+
_finalize_results,
23+
_get_chunk_reduction,
24+
_is_arg_reduction,
25+
_reduce_blockwise,
26+
)
27+
from .xrutils import is_chunked_array
28+
29+
30+
def cubed_groupby_agg(
31+
array: CubedArray,
32+
by: T_By,
33+
agg: Aggregation,
34+
expected_groups: pd.Index | None,
35+
reindex: ReindexStrategy,
36+
axis: T_Axes = (),
37+
fill_value: Any = None,
38+
method: T_Method = "map-reduce",
39+
engine: T_Engine = "numpy",
40+
sort: bool = True,
41+
chunks_cohorts=None,
42+
) -> tuple[CubedArray, tuple[pd.Index | np.ndarray | CubedArray]]:
43+
import cubed
44+
import cubed.core.groupby
45+
46+
# I think _tree_reduce expects this
47+
assert isinstance(axis, Sequence)
48+
assert all(ax >= 0 for ax in axis)
49+
50+
if method == "blockwise":
51+
assert by.ndim == 1
52+
assert expected_groups is not None
53+
54+
def _reduction_func(a, by, axis, start_group, num_groups):
55+
# adjust group labels to start from 0 for each chunk
56+
by_for_chunk = by - start_group
57+
expected_groups_for_chunk = pd.RangeIndex(num_groups)
58+
59+
axis = (axis,) # convert integral axis to tuple
60+
61+
blockwise_method = partial(
62+
_reduce_blockwise,
63+
agg=agg,
64+
axis=axis,
65+
expected_groups=expected_groups_for_chunk,
66+
fill_value=fill_value,
67+
engine=engine,
68+
sort=sort,
69+
reindex=reindex,
70+
)
71+
out = blockwise_method(a, by_for_chunk)
72+
return out[agg.name]
73+
74+
num_groups = len(expected_groups)
75+
result = cubed.core.groupby.groupby_blockwise(
76+
array, by, axis=axis, func=_reduction_func, num_groups=num_groups
77+
)
78+
groups = (expected_groups,)
79+
return (result, groups)
80+
81+
else:
82+
inds = tuple(range(array.ndim))
83+
84+
by_input = by
85+
86+
# Unifying chunks is necessary for argreductions.
87+
# We need to rechunk before zipping up with the index
88+
# let's always do it anyway
89+
if not is_chunked_array(by):
90+
# chunk numpy arrays like the input array
91+
chunks = tuple(array.chunks[ax] if by.shape[ax] != 1 else (1,) for ax in range(-by.ndim, 0))
92+
93+
by = cubed.from_array(by, chunks=chunks, spec=array.spec)
94+
_, (array, by) = cubed.core.unify_chunks(array, inds, by, inds[-by.ndim :])
95+
96+
# Cubed's groupby_reduction handles the generation of "intermediates", and the
97+
# "map-reduce" combination step, so we don't have to do that here.
98+
# Only the equivalent of "_simple_combine" is supported, there is no
99+
# support for "_grouped_combine".
100+
labels_are_unknown = is_chunked_array(by_input) and expected_groups is None
101+
do_simple_combine = not _is_arg_reduction(agg) and not labels_are_unknown
102+
103+
assert do_simple_combine
104+
assert method == "map-reduce"
105+
assert expected_groups is not None
106+
assert reindex.blockwise is True
107+
assert len(axis) == 1 # one axis/grouping
108+
109+
def _groupby_func(a, by, axis, intermediate_dtype, num_groups):
110+
blockwise_method = partial(
111+
_get_chunk_reduction(agg.reduction_type),
112+
func=agg.chunk,
113+
fill_value=agg.fill_value["intermediate"],
114+
dtype=agg.dtype["intermediate"],
115+
reindex=reindex,
116+
user_dtype=agg.dtype["user"],
117+
axis=axis,
118+
expected_groups=expected_groups,
119+
engine=engine,
120+
sort=sort,
121+
)
122+
out = blockwise_method(a, by)
123+
# Convert dict to one that cubed understands, dropping groups since they are
124+
# known, and the same for every block.
125+
return {f"f{idx}": intermediate for idx, intermediate in enumerate(out["intermediates"])}
126+
127+
def _groupby_combine(a, axis, dummy_axis, dtype, keepdims):
128+
# this is similar to _simple_combine, except the dummy axis and concatenation is handled by cubed
129+
# only combine over the dummy axis, to preserve grouping along 'axis'
130+
dtype = dict(dtype)
131+
out = {}
132+
for idx, combine in enumerate(agg.simple_combine):
133+
field = f"f{idx}"
134+
out[field] = combine(a[field], axis=dummy_axis, keepdims=keepdims)
135+
return out
136+
137+
def _groupby_aggregate(a, **kwargs):
138+
# Convert cubed dict to one that _finalize_results works with
139+
results = {"groups": expected_groups, "intermediates": a.values()}
140+
out = _finalize_results(results, agg, axis, expected_groups, reindex)
141+
return out[agg.name]
142+
143+
# convert list of dtypes to a structured dtype for cubed
144+
intermediate_dtype = [(f"f{i}", dtype) for i, dtype in enumerate(agg.dtype["intermediate"])]
145+
dtype = agg.dtype["final"]
146+
num_groups = len(expected_groups)
147+
148+
result = cubed.core.groupby.groupby_reduction(
149+
array,
150+
by,
151+
func=_groupby_func,
152+
combine_func=_groupby_combine,
153+
aggregate_func=_groupby_aggregate,
154+
axis=axis,
155+
intermediate_dtype=intermediate_dtype,
156+
dtype=dtype,
157+
num_groups=num_groups,
158+
)
159+
160+
groups = (expected_groups,)
161+
162+
return (result, groups)
163+
164+
165+
__all__ = [
166+
"cubed_groupby_agg",
167+
]

flox/dask.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
"""Dask-specific functions for groupby operations.
2+
3+
This module provides Dask-specific implementations for groupby operations.
4+
Functions are re-exported from flox.core for backward compatibility,
5+
with the intent to gradually move implementations here.
6+
"""
7+
8+
from __future__ import annotations
9+
10+
# Re-export dask-specific functions from core for backward compatibility
11+
from .core import (
12+
_collapse_blocks_along_axes,
13+
_extract_unknown_groups,
14+
_grouped_combine,
15+
_normalize_indexes,
16+
_simple_combine,
17+
_unify_chunks,
18+
dask_groupby_agg,
19+
dask_groupby_scan,
20+
subset_to_blocks,
21+
)
22+
23+
__all__ = [
24+
"_collapse_blocks_along_axes",
25+
"_extract_unknown_groups",
26+
"_grouped_combine",
27+
"_normalize_indexes",
28+
"_simple_combine",
29+
"_unify_chunks",
30+
"dask_groupby_agg",
31+
"dask_groupby_scan",
32+
"subset_to_blocks",
33+
]

0 commit comments

Comments
 (0)