From b907554d48326d0705289d326fa53a24487c2851 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 19 Mar 2026 14:16:06 +0100 Subject: [PATCH 01/21] feat: `AnnData.can_write` based on `AnnData.fold` --- docs/api.md | 7 +++++++ src/anndata/_core/anndata.py | 39 ++++++++++++++++++++++++++++++++++++ src/anndata/types.py | 6 ++++++ tests/test_readwrite.py | 23 ++++++++++++++++++++- 4 files changed, 74 insertions(+), 1 deletion(-) diff --git a/docs/api.md b/docs/api.md index 279070e50..0f42a381d 100644 --- a/docs/api.md +++ b/docs/api.md @@ -264,6 +264,13 @@ Types used by the former: abc.CSCDataset ``` +```{eval-rst} +.. autosummary:: + :toctree: generated/ + + types.FoldFunc +``` + ```{eval-rst} diff --git a/src/anndata/_core/anndata.py b/src/anndata/_core/anndata.py index 03ac68dad..699c9832a 100644 --- a/src/anndata/_core/anndata.py +++ b/src/anndata/_core/anndata.py @@ -63,6 +63,9 @@ from zarr.storage import StoreLike + from anndata.types import FoldFunc + from anndata.typing import RWAble + from ..acc import AdRef, Array, MapAcc, RefAcc from ..compat import XDataset from ..typing import Index, Index1D, _Index1DNorm, _XDataType @@ -1446,6 +1449,42 @@ def copy(self, filename: PathLike[str] | str | None = None) -> AnnData: write_h5ad(filename, self) return read_h5ad(filename, backed=mode) + def fold[T](self, func: FoldFunc[T], *, init: T) -> T: + acc = init + for attr_name in [ + "X", + "obs", + "var", + "obsm", + "varm", + "obsp", + "varp", + "layers", + "uns", + ]: + attr = getattr(self, attr_name) + if attr_name != "X": + for elem_name in attr: + acc = func(attr[elem_name], acc=acc) + return acc + + def can_write(self, *, store_type: Literal["h5", "zarr"] | None) -> bool: + from anndata._io.specs.registry import _REGISTRY + + writeable_elems = { + src_type + for (dest_type, src_type, __) in _REGISTRY.write + if store_type is None or store_type in dest_type.__module__ + } + + def predicate(x: RWAble, *, acc: bool): + if isinstance(x, pd.Series): + # matches behavior in methods.py + x = x._values + return acc and type(x) in writeable_elems + + return self.fold(predicate, init=True) + @deprecated( deprecation_msg( *("AnnData.concatenate", "anndata.concat"), diff --git a/src/anndata/types.py b/src/anndata/types.py index aa23d10f2..add66c1c6 100644 --- a/src/anndata/types.py +++ b/src/anndata/types.py @@ -8,6 +8,8 @@ from array_api.latest import ArrayNamespace + from anndata.typing import RWAble + from ._core.anndata import AnnData @@ -48,3 +50,7 @@ def __dlpack__( copy: bool | None = None, ) -> Any: ... def __dlpack_device__(self) -> tuple[int, int]: ... + + +class FoldFunc[T](Protocol): + def __call__(self, elem: RWAble, *, acc: T | None) -> T | None: ... diff --git a/tests/test_readwrite.py b/tests/test_readwrite.py index 3359b2ff8..c6468e86b 100644 --- a/tests/test_readwrite.py +++ b/tests/test_readwrite.py @@ -99,7 +99,7 @@ def dataset_kwargs(request): @pytest.fixture -def rw(backing_h5ad): +def rw(backing_h5ad) -> tuple[ad.AnnData, ad.AnnData]: M, N = 100, 101 orig = gen_adata((M, N), **GEN_ADATA_NO_XARRAY_ARGS) orig.write(backing_h5ad) @@ -126,6 +126,27 @@ def dtype(request): # ------------------------------------------------------------------------------ +@pytest.mark.parametrize("store_type", ["h5", "zarr", None]) +def test_can_write( + rw: tuple[ad.AnnData, ad.AnnData], store_type: Literal["h5", "zarr"] | None +): + adata, _ = rw + assert adata.can_write(store_type=store_type) + + +@pytest.mark.parametrize("store_type", ["h5", "zarr", None]) +def test_can_not_write_with_custom_array( + rw: tuple[ad.AnnData, ad.AnnData], store_type: Literal["h5", "zarr"] | None +): + import pyarrow as pa + + adata, _ = rw + adata.obs["arrow_array"] = pd.arrays.ArrowExtensionArray( + pa.array([{"x": 1, "y": True}] * adata.shape[0]) + ) + assert not adata.can_write(store_type=store_type) + + @pytest.mark.parametrize("typ", ARRAY_TYPES) def test_readwrite_roundtrip(typ, tmp_path, diskfmt, diskfmt2): pth1 = tmp_path / f"first.{diskfmt}" From 19daed554b9d6f4b22801f023dbee553caad8857 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 19 Mar 2026 14:42:22 +0100 Subject: [PATCH 02/21] chore: docs --- docs/release-notes/2327.feat.md | 1 + src/anndata/_core/anndata.py | 35 ++++++++++++++++++++++++++++----- 2 files changed, 31 insertions(+), 5 deletions(-) create mode 100644 docs/release-notes/2327.feat.md diff --git a/docs/release-notes/2327.feat.md b/docs/release-notes/2327.feat.md new file mode 100644 index 000000000..a66fef550 --- /dev/null +++ b/docs/release-notes/2327.feat.md @@ -0,0 +1 @@ +New {meth}`AnnData.fold` for crawling the "elems" and accumulating a value over these, and then {meth}`AnnData.can_write` built on top {user}`ilan-gold` diff --git a/src/anndata/_core/anndata.py b/src/anndata/_core/anndata.py index 699c9832a..6f4765c16 100644 --- a/src/anndata/_core/anndata.py +++ b/src/anndata/_core/anndata.py @@ -1450,7 +1450,21 @@ def copy(self, filename: PathLike[str] | str | None = None) -> AnnData: return read_h5ad(filename, backed=mode) def fold[T](self, func: FoldFunc[T], *, init: T) -> T: - acc = init + """Accumulate a value starting from init by iterating over the "elems"/leaf nodes of the AnnData object. + + Parameters + ---------- + func + The function that performs the accumulation + init + The starting value + + + Returns + ------- + An accumulated value + """ + accumulate = init for attr_name in [ "X", "obs", @@ -1465,10 +1479,21 @@ def fold[T](self, func: FoldFunc[T], *, init: T) -> T: attr = getattr(self, attr_name) if attr_name != "X": for elem_name in attr: - acc = func(attr[elem_name], acc=acc) - return acc + accumulate = func(attr[elem_name], accumulate=accumulate) + return accumulate def can_write(self, *, store_type: Literal["h5", "zarr"] | None) -> bool: + """Whether or not an `AnnData` object can be written to disk for a given store type. + + Parameters + ---------- + store_type + Which backing store - `None` indicates that it can be writeable to either. + + Returns + ------- + Whether or not this object is writable. + """ from anndata._io.specs.registry import _REGISTRY writeable_elems = { @@ -1477,11 +1502,11 @@ def can_write(self, *, store_type: Literal["h5", "zarr"] | None) -> bool: if store_type is None or store_type in dest_type.__module__ } - def predicate(x: RWAble, *, acc: bool): + def predicate(x: RWAble, *, accumulate: bool): if isinstance(x, pd.Series): # matches behavior in methods.py x = x._values - return acc and type(x) in writeable_elems + return accumulate and type(x) in writeable_elems return self.fold(predicate, init=True) From 4125375612ec6b702a0032526e8b5b217b1921da Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 19 Mar 2026 15:38:40 +0100 Subject: [PATCH 03/21] refactor: use accessors --- src/anndata/_core/anndata.py | 84 ++++++++++++++++++++---------------- src/anndata/types.py | 5 ++- 2 files changed, 52 insertions(+), 37 deletions(-) diff --git a/src/anndata/_core/anndata.py b/src/anndata/_core/anndata.py index 6f4765c16..6c46d5d1e 100644 --- a/src/anndata/_core/anndata.py +++ b/src/anndata/_core/anndata.py @@ -4,7 +4,7 @@ from __future__ import annotations -from collections import OrderedDict +from collections import OrderedDict, defaultdict from collections.abc import Mapping, MutableMapping, Sequence from copy import copy, deepcopy from functools import partial, singledispatchmethod @@ -22,6 +22,7 @@ from scipy.sparse import issparse from anndata._warnings import ImplicitModificationWarning +from anndata.acc import A, AdRef, GraphAcc, LayerAcc, MultiAcc from .. import utils from .._settings import settings @@ -66,7 +67,7 @@ from anndata.types import FoldFunc from anndata.typing import RWAble - from ..acc import AdRef, Array, MapAcc, RefAcc + from ..acc import Array, MapAcc, RefAcc from ..compat import XDataset from ..typing import Index, Index1D, _Index1DNorm, _XDataType from .aligned_mapping import AxisArraysView, LayersView, PairwiseArraysView @@ -516,36 +517,39 @@ def _init_as_actual( # noqa: PLR0912, PLR0913, PLR0915 def __sizeof__( self, *, show_stratified: bool = False, with_disk: bool = False ) -> int: - def get_size(X) -> int: + def get_size[R: dict[RefAcc | None, int]]( + X: RWAble, + *, + accumulate: R, + ref_acc: RefAcc | AdRef | None, + ) -> R: def cs_to_bytes(X) -> int: return int(X.data.nbytes + X.indptr.nbytes + X.indices.nbytes) - if isinstance(X, h5py.Dataset) and with_disk: - return int(np.array(X.shape).prod() * X.dtype.itemsize) - elif isinstance(X, BaseCompressedSparseDataset) and with_disk: - return cs_to_bytes(X._to_backed()) - elif issparse(X): - return cs_to_bytes(X) - else: - return X.__sizeof__() - - sizes = {} - attrs = ["X", "_obs", "_var"] - attrs_multi = ["_uns", "_obsm", "_varm", "varp", "_obsp", "_layers"] - for attr in attrs + attrs_multi: - if attr in attrs_multi: - keys = getattr(self, attr).keys() - s = sum(get_size(getattr(self, attr)[k]) for k in keys) - else: - s = get_size(getattr(self, attr)) - if s > 0 and show_stratified: - from tqdm import tqdm + if is_elem := ( + (is_ad_ref := isinstance(ref_acc, AdRef)) + or isinstance(ref_acc, LayerAcc | MultiAcc | GraphAcc) + ) or (ref_acc is None and X is not self.uns): + key = ref_acc.acc if is_ad_ref else ref_acc + if isinstance(X, h5py.Dataset) and with_disk: + accumulate[key] += int(np.array(X.shape).prod() * X.dtype.itemsize) + elif isinstance(X, BaseCompressedSparseDataset) and with_disk: + accumulate[key] += cs_to_bytes(X._to_backed()) + elif issparse(X): + accumulate[key] += cs_to_bytes(X) + else: + accumulate[key] += X.__sizeof__() + if not is_elem or ref_acc is A.X: + s = accumulate[ref_acc] + if s > 0 and show_stratified: + from tqdm import tqdm + + print( + f"Size of {repr(ref_acc).replace('A.', '') if ref_acc is not None else 'uns'}: {tqdm.format_sizeof(s, 'B')}" + ) + return accumulate - print( - f"Size of {attr.replace('_', '.'):<7}: {tqdm.format_sizeof(s, 'B')}" - ) - sizes[attr] = s - return sum(sizes.values()) + return sum(self.fold(get_size, init=defaultdict(int)).values()) def _gen_repr(self, n_obs, n_vars) -> str: backed_at = f" backed at {str(self.filename)!r}" if self.isbacked else "" @@ -1450,7 +1454,7 @@ def copy(self, filename: PathLike[str] | str | None = None) -> AnnData: return read_h5ad(filename, backed=mode) def fold[T](self, func: FoldFunc[T], *, init: T) -> T: - """Accumulate a value starting from init by iterating over the "elems"/leaf nodes of the AnnData object. + """Accumulate a value starting from init by iterating over the "elems"/leaf nodes of the AnnData object in DFS order. Parameters ---------- @@ -1474,12 +1478,18 @@ def fold[T](self, func: FoldFunc[T], *, init: T) -> T: "obsp", "varp", "layers", - "uns", ]: attr = getattr(self, attr_name) + acc = getattr(A, attr_name) if attr_name != "X": for elem_name in attr: - accumulate = func(attr[elem_name], accumulate=accumulate) + ref = acc[elem_name] + accumulate = func( + attr[elem_name], accumulate=accumulate, ref_acc=ref + ) + accumulate = func(attr, accumulate=accumulate, ref_acc=acc) + for elem in self.uns: + accumulate = func(elem, accumulate=accumulate, ref_acc=None) return accumulate def can_write(self, *, store_type: Literal["h5", "zarr"] | None) -> bool: @@ -1502,11 +1512,13 @@ def can_write(self, *, store_type: Literal["h5", "zarr"] | None) -> bool: if store_type is None or store_type in dest_type.__module__ } - def predicate(x: RWAble, *, accumulate: bool): - if isinstance(x, pd.Series): - # matches behavior in methods.py - x = x._values - return accumulate and type(x) in writeable_elems + def predicate(x: RWAble, *, accumulate: bool, ref_acc: AdRef | RefAcc | None): + if isinstance(ref_acc, AdRef) or ref_acc is None: + if isinstance(x, pd.Series): + # matches behavior in methods.py + x = x._values + return accumulate and type(x) in writeable_elems + return accumulate return self.fold(predicate, init=True) diff --git a/src/anndata/types.py b/src/anndata/types.py index add66c1c6..06a80a784 100644 --- a/src/anndata/types.py +++ b/src/anndata/types.py @@ -8,6 +8,7 @@ from array_api.latest import ArrayNamespace + from anndata.acc import AdRef, RefAcc from anndata.typing import RWAble from ._core.anndata import AnnData @@ -53,4 +54,6 @@ def __dlpack_device__(self) -> tuple[int, int]: ... class FoldFunc[T](Protocol): - def __call__(self, elem: RWAble, *, acc: T | None) -> T | None: ... + def __call__( + self, elem: RWAble, *, accumulate: T, ref_acc: RefAcc | AdRef | None + ) -> T | None: ... From 8be5ba2849f09d62818fa6fcd795227de6dcce09 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 19 Mar 2026 16:27:45 +0100 Subject: [PATCH 04/21] fix: DFS order + fixes --- src/anndata/_core/anndata.py | 64 +++++++++++++++++++++++++++--------- src/anndata/acc/__init__.py | 21 ++++++++++++ src/anndata/types.py | 10 ++++-- 3 files changed, 76 insertions(+), 19 deletions(-) diff --git a/src/anndata/_core/anndata.py b/src/anndata/_core/anndata.py index 6c46d5d1e..a6ac363ff 100644 --- a/src/anndata/_core/anndata.py +++ b/src/anndata/_core/anndata.py @@ -22,7 +22,7 @@ from scipy.sparse import issparse from anndata._warnings import ImplicitModificationWarning -from anndata.acc import A, AdRef, GraphAcc, LayerAcc, MultiAcc +from anndata.acc import A, AdAcc, AdRef, GraphAcc, LayerAcc, MultiAcc from .. import utils from .._settings import settings @@ -517,20 +517,34 @@ def _init_as_actual( # noqa: PLR0912, PLR0913, PLR0915 def __sizeof__( self, *, show_stratified: bool = False, with_disk: bool = False ) -> int: - def get_size[R: dict[RefAcc | None, int]]( + def get_size[R: dict[type[RefAcc | MapAcc | AdAcc] | None, int]]( X: RWAble, *, accumulate: R, - ref_acc: RefAcc | AdRef | None, + ref_acc: RefAcc | AdRef | MapAcc | None, ) -> R: def cs_to_bytes(X) -> int: return int(X.data.nbytes + X.indptr.nbytes + X.indices.nbytes) if is_elem := ( - (is_ad_ref := isinstance(ref_acc, AdRef)) - or isinstance(ref_acc, LayerAcc | MultiAcc | GraphAcc) - ) or (ref_acc is None and X is not self.uns): - key = ref_acc.acc if is_ad_ref else ref_acc + ( + # an array of some sort i.e., from AdRef (from obs/var) or a reference to one + (is_ad_ref := isinstance(ref_acc, AdRef)) + or ( + is_ref_acc := isinstance( + ref_acc, LayerAcc | MultiAcc | GraphAcc + ) + ) + ) + # an element of uns + or (ref_acc is None and X is not self.uns) + ): + if is_ad_ref: + key = type(ref_acc.acc) + elif is_ref_acc: + key = ref_acc.parent_type + else: + key = None if isinstance(X, h5py.Dataset) and with_disk: accumulate[key] += int(np.array(X.shape).prod() * X.dtype.itemsize) elif isinstance(X, BaseCompressedSparseDataset) and with_disk: @@ -539,8 +553,12 @@ def cs_to_bytes(X) -> int: accumulate[key] += cs_to_bytes(X) else: accumulate[key] += X.__sizeof__() - if not is_elem or ref_acc is A.X: - s = accumulate[ref_acc] + # if this is X or a parent elem maybe print it out. + if (is_x := ref_acc is A.X) or not is_elem: + if ref_acc is not None: + s = accumulate[AdAcc if is_x else type(ref_acc)] + else: + s = accumulate[None] if s > 0 and show_stratified: from tqdm import tqdm @@ -1453,8 +1471,14 @@ def copy(self, filename: PathLike[str] | str | None = None) -> AnnData: write_h5ad(filename, self) return read_h5ad(filename, backed=mode) - def fold[T](self, func: FoldFunc[T], *, init: T) -> T: - """Accumulate a value starting from init by iterating over the "elems"/leaf nodes of the AnnData object in DFS order. + def fold[T]( + self, + func: FoldFunc[T], + *, + init: T, + order: Literal["DFS-pre", "DFS-post"] = "DFS-post", + ) -> T: + """Accumulate a value starting from init by iterating over the "elems"/leaf nodes of the AnnData object. Parameters ---------- @@ -1462,6 +1486,12 @@ def fold[T](self, func: FoldFunc[T], *, init: T) -> T: The function that performs the accumulation init The starting value + order + How to visit the items in the fold. + "DFS-pre" indicates that parent-elements like uns, obs, and varp get visited first. + "DFS-post" means they get visited afterwards. + The `AnnData` itself is not visited. + Returns @@ -1478,18 +1508,20 @@ def fold[T](self, func: FoldFunc[T], *, init: T) -> T: "obsp", "varp", "layers", + "uns", ]: attr = getattr(self, attr_name) - acc = getattr(A, attr_name) + acc = getattr(A, attr_name) if attr_name != "uns" else None + if order == "DFS-pre": + accumulate = func(attr, accumulate=accumulate, ref_acc=acc) if attr_name != "X": for elem_name in attr: - ref = acc[elem_name] + ref = acc[elem_name] if acc is not None else None accumulate = func( attr[elem_name], accumulate=accumulate, ref_acc=ref ) - accumulate = func(attr, accumulate=accumulate, ref_acc=acc) - for elem in self.uns: - accumulate = func(elem, accumulate=accumulate, ref_acc=None) + if order == "DFS-post": + accumulate = func(attr, accumulate=accumulate, ref_acc=acc) return accumulate def can_write(self, *, store_type: Literal["h5", "zarr"] | None) -> bool: diff --git a/src/anndata/acc/__init__.py b/src/anndata/acc/__init__.py index ff6b80b6d..86215827b 100644 --- a/src/anndata/acc/__init__.py +++ b/src/anndata/acc/__init__.py @@ -192,6 +192,11 @@ def _maybe_flatten(self, idx: I, a: Array) -> Array: return a.__array_namespace__().reshape(a, (a.size,)) return a.ravel() + @property + @abc.abstractmethod + def parent_type(self) -> type[MapAcc | AdAcc]: + """Get the parent to this reference accessor""" + @dataclass(frozen=True) class LayerAcc[R: AdRef[Idx2D]](RefAcc[R, Idx2D]): @@ -209,6 +214,10 @@ class LayerAcc[R: AdRef[Idx2D]](RefAcc[R, Idx2D]): k: str | None """Key this accessor refers to, e.g. `A.layers['counts'].k == 'counts'`.""" + @property + def parent_type(self) -> type[MapAcc | AdAcc]: + return LayerMapAcc if self.k is not None else AdAcc + @overload def __getitem__(self, idx: Idx2D, /) -> R: ... @overload @@ -298,6 +307,10 @@ class MetaAcc[R: AdRef[str | None]](RefAcc[R, str | None]): dim: Literal["obs", "var"] """Axis this accessor refers to, e.g. `A.obs.dim == 'obs'`.""" + @property + def parent_type(self) -> type[MapAcc | AdAcc]: + return AdAcc + @property def index(self) -> R: """Index :class:`AdRef`, i.e. `A.obs.index` or `A.var.index`.""" @@ -380,6 +393,10 @@ class MultiAcc[R: AdRef[int]](RefAcc[R, int]): k: str """Key this accessor refers to, e.g. `A.varm['x'].k == 'x'`.""" + @property + def parent_type(self) -> type[MapAcc | AdAcc]: + return MultiMapAcc + @staticmethod def process_idx(i: object, /) -> int | list[int] | pd.Index[int]: if isinstance(i, tuple): @@ -463,6 +480,10 @@ class GraphAcc[R: AdRef[Idx2D]](RefAcc[R, Idx2D]): k: str """Key this accessor refers to, e.g. `A.obsp['x'].k == 'x'`.""" + @property + def parent_type(self) -> type[MapAcc | AdAcc]: + return GraphMapAcc + def process_idx(self, idx: Idx2D, /) -> Idx2D: if not all(isinstance(i, str | slice) for i in idx): msg = f"Unsupported index {idx!r}" diff --git a/src/anndata/types.py b/src/anndata/types.py index 06a80a784..126eb1d02 100644 --- a/src/anndata/types.py +++ b/src/anndata/types.py @@ -8,7 +8,7 @@ from array_api.latest import ArrayNamespace - from anndata.acc import AdRef, RefAcc + from anndata.acc import AdAcc, AdRef, MapAcc, RefAcc from anndata.typing import RWAble from ._core.anndata import AnnData @@ -55,5 +55,9 @@ def __dlpack_device__(self) -> tuple[int, int]: ... class FoldFunc[T](Protocol): def __call__( - self, elem: RWAble, *, accumulate: T, ref_acc: RefAcc | AdRef | None - ) -> T | None: ... + self, + elem: RWAble, + *, + accumulate: T, + ref_acc: AdAcc | RefAcc | AdRef | MapAcc | None, + ) -> T: ... From 0f4d1b0417b73e0e74609641a1e0c990bc5eeb06 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 19 Mar 2026 16:30:04 +0100 Subject: [PATCH 05/21] chore: add test for `uns` --- tests/test_readwrite.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/test_readwrite.py b/tests/test_readwrite.py index c6468e86b..429a093d0 100644 --- a/tests/test_readwrite.py +++ b/tests/test_readwrite.py @@ -135,13 +135,16 @@ def test_can_write( @pytest.mark.parametrize("store_type", ["h5", "zarr", None]) +@pytest.mark.parametrize("parent_elem", ["obs", "uns"]) def test_can_not_write_with_custom_array( - rw: tuple[ad.AnnData, ad.AnnData], store_type: Literal["h5", "zarr"] | None + rw: tuple[ad.AnnData, ad.AnnData], + store_type: Literal["h5", "zarr"] | None, + parent_elem: Literal["obs", "uns"], ): import pyarrow as pa adata, _ = rw - adata.obs["arrow_array"] = pd.arrays.ArrowExtensionArray( + getattr(adata, parent_elem)["arrow_array"] = pd.arrays.ArrowExtensionArray( pa.array([{"x": 1, "y": True}] * adata.shape[0]) ) assert not adata.can_write(store_type=store_type) From be98d32c9d0a16c53206c03d17f94b59fe9bc73b Mon Sep 17 00:00:00 2001 From: Ilan Gold Date: Mon, 23 Mar 2026 13:06:37 +0100 Subject: [PATCH 06/21] fix!: remove `__delitem__` and `__setitem__` from the `AnnData` object (#2367) --- .github/workflows/check-pr.yml | 2 +- docs/release-notes/2367.breaking.md | 1 + src/anndata/_core/anndata.py | 28 ---------------------------- 3 files changed, 2 insertions(+), 29 deletions(-) create mode 100644 docs/release-notes/2367.breaking.md diff --git a/.github/workflows/check-pr.yml b/.github/workflows/check-pr.yml index 5f87e6b88..17f79942c 100644 --- a/.github/workflows/check-pr.yml +++ b/.github/workflows/check-pr.yml @@ -63,7 +63,7 @@ jobs: id: changes with: filters: | # this is intentionally a string - relnotes: 'docs/release-notes/${{ github.event.pull_request.number }}.${{ needs.check-milestone.outputs.type }}.md' + relnotes: 'docs/release-notes/${{ github.event.pull_request.number }}.${{ (contains(github.event.pull_request.title, '!') && 'breaking') || needs.check-milestone.outputs.type }}.md' - name: Check if a relevant release fragment is added uses: flying-sheep/check@v1 with: diff --git a/docs/release-notes/2367.breaking.md b/docs/release-notes/2367.breaking.md new file mode 100644 index 000000000..2541b0fc5 --- /dev/null +++ b/docs/release-notes/2367.breaking.md @@ -0,0 +1 @@ +Remove `Anndata.__{set,del}item__` {user}`ilan-gold` diff --git a/src/anndata/_core/anndata.py b/src/anndata/_core/anndata.py index 03ac68dad..f11763e72 100644 --- a/src/anndata/_core/anndata.py +++ b/src/anndata/_core/anndata.py @@ -1023,21 +1023,6 @@ def _normalize_indices( ) -> tuple[_Index1DNorm | int | np.integer, _Index1DNorm | int | np.integer]: return _normalize_indices(index, self.obs_names, self.var_names) - # TODO: this is not quite complete... - def __delitem__(self, index: Index) -> None: - obs, var = self._normalize_indices(index) - # TODO: does this really work? - if not self.isbacked: - del self._X[obs, var] - else: - X = self.file["X"] - del X[obs, var] - self._set_backed("X", X) - if var == slice(None): - del self._obs.iloc[obs, :] - if obs == slice(None): - del self._var.iloc[var, :] - @overload def __getitem__(self, index: AdRef) -> Array: ... @overload @@ -1206,19 +1191,6 @@ def _inplace_subset_obs(self, index: Index1D): self._init_as_actual(adata_subset) - # TODO: Update, possibly remove - def __setitem__(self, index: Index, val: float | _XDataType): - if self.is_view: - msg = "Object is view and cannot be accessed with `[]`." - raise ValueError(msg) - obs, var = self._normalize_indices(index) - if not self.isbacked: - self._X[obs, var] = val - else: - X = self.file["X"] - X[obs, var] = val - self._set_backed("X", X) - def __len__(self) -> int: return self.shape[0] From fbc696f306bdd603e75d5d991054bc39f71b0765 Mon Sep 17 00:00:00 2001 From: Ilan Gold Date: Mon, 23 Mar 2026 13:52:41 +0100 Subject: [PATCH 07/21] chore!: remove `AnnData.concatenate` (#2370) Co-authored-by: Philipp A. --- docs/release-notes/0.6.0.md | 2 +- docs/release-notes/0.9.0.md | 2 +- docs/release-notes/2370.breaking.md | 1 + src/anndata/_core/anndata.py | 293 +--------------------------- tests/test_concatenate.py | 159 ++++++--------- 5 files changed, 61 insertions(+), 396 deletions(-) create mode 100644 docs/release-notes/2370.breaking.md diff --git a/docs/release-notes/0.6.0.md b/docs/release-notes/0.6.0.md index dc9f2e981..127f4d72c 100644 --- a/docs/release-notes/0.6.0.md +++ b/docs/release-notes/0.6.0.md @@ -2,7 +2,7 @@ ### 0.6.0 {small}`1 May, 2018` - compatibility with Seurat converter -- tremendous speedup for {meth}`~anndata.AnnData.concatenate` +- tremendous speedup for `~anndata.AnnData.concatenate` - bug fix for deep copy of unstructured annotation after slicing - bug fix for reading HDF5 stored single-category annotations - `'outer join'` concatenation: adds zeros for concatenation of sparse data and nans for dense data diff --git a/docs/release-notes/0.9.0.md b/docs/release-notes/0.9.0.md index 3481ade4c..6f1af32b9 100644 --- a/docs/release-notes/0.9.0.md +++ b/docs/release-notes/0.9.0.md @@ -39,7 +39,7 @@ #### Deprecations -- {meth}`AnnData.concatenate() ` is now deprecated in favour of {func}`anndata.concat` {pr}`845` {user}`ivirshup` +- `AnnData.concatenate()` is now deprecated in favour of {func}`anndata.concat` {pr}`845` {user}`ivirshup` #### Bug fixes diff --git a/docs/release-notes/2370.breaking.md b/docs/release-notes/2370.breaking.md new file mode 100644 index 000000000..a3bdee43c --- /dev/null +++ b/docs/release-notes/2370.breaking.md @@ -0,0 +1 @@ +Remove `AnnData.concatenate` {user}`ilan-gold` diff --git a/src/anndata/_core/anndata.py b/src/anndata/_core/anndata.py index f11763e72..34a16226e 100644 --- a/src/anndata/_core/anndata.py +++ b/src/anndata/_core/anndata.py @@ -7,7 +7,7 @@ from collections import OrderedDict from collections.abc import Mapping, MutableMapping, Sequence from copy import copy, deepcopy -from functools import partial, singledispatchmethod +from functools import singledispatchmethod from pathlib import Path from textwrap import dedent from typing import TYPE_CHECKING, cast, overload @@ -18,7 +18,6 @@ from natsort import natsorted from numpy import ma from pandas.api.types import infer_dtype -from scipy import sparse from scipy.sparse import issparse from anndata._warnings import ImplicitModificationWarning @@ -26,7 +25,6 @@ from .. import utils from .._settings import settings from ..compat import ( - CSArray, DaskArray, IndexManager, ZarrArray, @@ -61,6 +59,7 @@ from os import PathLike from typing import Any, ClassVar, Literal + from scipy import sparse from zarr.storage import StoreLike from ..acc import AdRef, Array, MapAcc, RefAcc @@ -1418,294 +1417,6 @@ def copy(self, filename: PathLike[str] | str | None = None) -> AnnData: write_h5ad(filename, self) return read_h5ad(filename, backed=mode) - @deprecated( - deprecation_msg( - *("AnnData.concatenate", "anndata.concat"), - "See the tutorial for concat at: " - "https://anndata.readthedocs.io/en/latest/concatenation.html", - ) - ) - def concatenate( - self, - *adatas: AnnData, - join: str = "inner", - batch_key: str = "batch", - batch_categories: Sequence[Any] | None = None, - uns_merge: str | None = None, - index_unique: str | None = "-", - fill_value=None, - ) -> AnnData: - """\ - Concatenate along the observations axis. - - The :attr:`uns`, :attr:`varm` and :attr:`obsm` attributes are ignored. - - Currently, this works only in `'memory'` mode. - - .. note:: - - For more flexible and efficient concatenation, see: :func:`~anndata.concat`. - - Parameters - ---------- - adatas - AnnData matrices to concatenate with. Each matrix is referred to as - a “batch”. - join - Use intersection (`'inner'`) or union (`'outer'`) of variables. - batch_key - Add the batch annotation to :attr:`obs` using this key. - batch_categories - Use these as categories for the batch annotation. By default, use increasing numbers. - uns_merge - Strategy to use for merging entries of uns. These strategies are applied recusivley. - Currently implemented strategies include: - - * `None`: The default. The concatenated object will just have an empty dict for `uns`. - * `"same"`: Only entries which have the same value in all AnnData objects are kept. - * `"unique"`: Only entries which have one unique value in all AnnData objects are kept. - * `"first"`: The first non-missing value is used. - * `"only"`: A value is included if only one of the AnnData objects has a value at this - path. - index_unique - Make the index unique by joining the existing index names with the - batch category, using `index_unique='-'`, for instance. Provide - `None` to keep existing indices. - fill_value - Scalar value to fill newly missing values in arrays with. Note: only applies to arrays - and sparse matrices (not dataframes) and will only be used if `join="outer"`. - - .. note:: - If not provided, the default value is `0` for sparse matrices and `np.nan` - for numpy arrays. See the examples below for more information. - - Returns - ------- - :class:`~anndata.AnnData` - The concatenated :class:`~anndata.AnnData`, where `adata.obs[batch_key]` - stores a categorical variable labeling the batch. - - Notes - ----- - - .. warning:: - - If you use `join='outer'` this fills 0s for sparse data when - variables are absent in a batch. Use this with care. Dense data is - filled with `NaN`. See the examples. - - Examples - -------- - Joining on intersection of variables. - - >>> adata1 = AnnData( - ... np.array([[1, 2, 3], [4, 5, 6]]), - ... dict(obs_names=['s1', 's2'], anno1=['c1', 'c2']), - ... dict(var_names=['a', 'b', 'c'], annoA=[0, 1, 2]), - ... ) - >>> adata2 = AnnData( - ... np.array([[1, 2, 3], [4, 5, 6]]), - ... dict(obs_names=['s3', 's4'], anno1=['c3', 'c4']), - ... dict(var_names=['d', 'c', 'b'], annoA=[0, 1, 2]), - ... ) - >>> adata3 = AnnData( - ... np.array([[1, 2, 3], [4, 5, 6]]), - ... dict(obs_names=['s1', 's2'], anno2=['d3', 'd4']), - ... dict(var_names=['d', 'c', 'b'], annoA=[0, 2, 3], annoB=[0, 1, 2]), - ... ) - >>> adata = adata1.concatenate(adata2, adata3) - >>> adata - AnnData object with n_obs × n_vars = 6 × 2 - obs: 'anno1', 'anno2', 'batch' - var: 'annoA-0', 'annoA-1', 'annoA-2', 'annoB-2' - >>> adata.X - array([[2, 3], - [5, 6], - [3, 2], - [6, 5], - [3, 2], - [6, 5]]) - >>> adata.obs - anno1 anno2 batch - s1-0 c1 NaN 0 - s2-0 c2 NaN 0 - s3-1 c3 NaN 1 - s4-1 c4 NaN 1 - s1-2 NaN d3 2 - s2-2 NaN d4 2 - >>> adata.var.T - b c - annoA-0 1 2 - annoA-1 2 1 - annoA-2 3 2 - annoB-2 2 1 - - Joining on the union of variables. - - >>> outer = adata1.concatenate(adata2, adata3, join='outer') - >>> outer - AnnData object with n_obs × n_vars = 6 × 4 - obs: 'anno1', 'anno2', 'batch' - var: 'annoA-0', 'annoA-1', 'annoA-2', 'annoB-2' - >>> outer.var.T - a b c d - annoA-0 0.0 1.0 2.0 NaN - annoA-1 NaN 2.0 1.0 0.0 - annoA-2 NaN 3.0 2.0 0.0 - annoB-2 NaN 2.0 1.0 0.0 - >>> outer.var_names.astype("string") - Index(['a', 'b', 'c', 'd'], dtype='string') - >>> outer.X - array([[ 1., 2., 3., nan], - [ 4., 5., 6., nan], - [nan, 3., 2., 1.], - [nan, 6., 5., 4.], - [nan, 3., 2., 1.], - [nan, 6., 5., 4.]]) - >>> outer.X.sum(axis=0) - array([nan, 25., 23., nan]) - >>> import pandas as pd - >>> Xdf = pd.DataFrame(outer.X, columns=outer.var_names) - >>> Xdf - a b c d - 0 1.0 2.0 3.0 NaN - 1 4.0 5.0 6.0 NaN - 2 NaN 3.0 2.0 1.0 - 3 NaN 6.0 5.0 4.0 - 4 NaN 3.0 2.0 1.0 - 5 NaN 6.0 5.0 4.0 - >>> Xdf.sum() - a 5.0 - b 25.0 - c 23.0 - d 10.0 - dtype: float64 - - One way to deal with missing values is to use masked arrays: - - >>> from numpy import ma - >>> outer.X = ma.masked_invalid(outer.X) - >>> outer.X - masked_array( - data=[[1.0, 2.0, 3.0, --], - [4.0, 5.0, 6.0, --], - [--, 3.0, 2.0, 1.0], - [--, 6.0, 5.0, 4.0], - [--, 3.0, 2.0, 1.0], - [--, 6.0, 5.0, 4.0]], - mask=[[False, False, False, True], - [False, False, False, True], - [ True, False, False, False], - [ True, False, False, False], - [ True, False, False, False], - [ True, False, False, False]], - fill_value=1e+20) - >>> outer.X.sum(axis=0).data - array([ 5., 25., 23., 10.]) - - The masked array is not saved but has to be reinstantiated after saving. - - >>> outer.write('./test.h5ad') - >>> from anndata import read_h5ad - >>> outer = read_h5ad('./test.h5ad') - >>> outer.X - array([[ 1., 2., 3., nan], - [ 4., 5., 6., nan], - [nan, 3., 2., 1.], - [nan, 6., 5., 4.], - [nan, 3., 2., 1.], - [nan, 6., 5., 4.]]) - - For sparse data, everything behaves similarly, - except that for `join='outer'`, zeros are added. - - >>> from scipy.sparse import csr_matrix - >>> adata1 = AnnData( - ... csr_matrix([[0, 2, 3], [0, 5, 6]], dtype=np.float32), - ... dict(obs_names=['s1', 's2'], anno1=['c1', 'c2']), - ... dict(var_names=['a', 'b', 'c']), - ... ) - >>> adata2 = AnnData( - ... csr_matrix([[0, 2, 3], [0, 5, 6]], dtype=np.float32), - ... dict(obs_names=['s3', 's4'], anno1=['c3', 'c4']), - ... dict(var_names=['d', 'c', 'b']), - ... ) - >>> adata3 = AnnData( - ... csr_matrix([[1, 2, 0], [0, 5, 6]], dtype=np.float32), - ... dict(obs_names=['s5', 's6'], anno2=['d3', 'd4']), - ... dict(var_names=['d', 'c', 'b']), - ... ) - >>> adata = adata1.concatenate(adata2, adata3, join='outer') - >>> adata.var_names.astype("string") - Index(['a', 'b', 'c', 'd'], dtype='string') - >>> adata.X.toarray() - array([[0., 2., 3., 0.], - [0., 5., 6., 0.], - [0., 3., 2., 0.], - [0., 6., 5., 0.], - [0., 0., 2., 1.], - [0., 6., 5., 0.]], dtype=float32) - """ - from .merge import concat, merge_dataframes, merge_outer, merge_same - - if self.isbacked: - msg = "Currently, concatenate only works in memory mode." - raise ValueError(msg) - - if len(adatas) == 0: - return self.copy() - elif len(adatas) == 1 and not isinstance(adatas[0], AnnData): - adatas = adatas[0] # backwards compatibility - all_adatas = (self, *adatas) - - out = concat( - all_adatas, - axis=0, - join=join, - label=batch_key, - keys=batch_categories, - uns_merge=uns_merge, - fill_value=fill_value, - index_unique=index_unique, - pairwise=False, - ) - - # Backwards compat (some of this could be more efficient) - # obs used to always be an outer join - sparse_class = sparse.csr_matrix - if any(isinstance(a.X, CSArray) for a in all_adatas): - sparse_class = sparse.csr_array - out.obs = concat( - [AnnData(sparse_class(a.shape), obs=a.obs) for a in all_adatas], - axis=0, - join="outer", - label=batch_key, - keys=batch_categories, - index_unique=index_unique, - ).obs - # Removing varm - del out.varm - # Implementing old-style merging of var - if batch_categories is None: - batch_categories = np.arange(len(all_adatas)).astype(str) - pat = rf"-({'|'.join(batch_categories)})$" - out.var = merge_dataframes( - [a.var for a in all_adatas], - out.var_names, - partial(merge_outer, batch_keys=batch_categories, merge=merge_same), - ) - out.var = out.var.iloc[ - :, - ( - out.var.columns.str - .extract(pat, expand=False) - .fillna("") - .argsort(kind="stable") - ), - ] - - return out - def var_names_make_unique(self, join: str = "-") -> None: # Important to go through the setter so obsm dataframes are updated too self.var_names = utils.make_index_unique(self.var.index, join) diff --git a/tests/test_concatenate.py b/tests/test_concatenate.py index cfa57719c..3e9c82959 100644 --- a/tests/test_concatenate.py +++ b/tests/test_concatenate.py @@ -48,10 +48,6 @@ from anndata._types import Join_T -mark_legacy_concatenate = pytest.mark.filterwarnings( - r"ignore:.*AnnData\.concatenate is deprecated:FutureWarning" -) - @singledispatch def filled_like(a, fill_value=None): @@ -167,9 +163,7 @@ def force_lazy(request): return request.param -def fix_known_differences( - orig: AnnData, result: AnnData, *, backwards_compat: bool = True -): +def fix_known_differences(orig: AnnData, result: AnnData): """ Helper function for reducing anndata's to only the elements we expect to be equivalent after concatenation. @@ -181,14 +175,6 @@ def fix_known_differences( orig = orig.copy() result = result.copy() - if backwards_compat: - del orig.varm - del orig.varp - if isinstance(result.obs, Dataset2D): - result.obs = result.obs.ds.drop_vars(["batch"]) - else: - result.obs.drop(columns=["batch"], inplace=True) - for attrname in ("obs", "var"): if isinstance(getattr(result, attrname), Dataset2D): for adata in (orig, result): @@ -240,28 +226,11 @@ def test_concat_interface_errors(use_xdataset): concat([]) -@pytest.mark.parametrize( - ("concat_func", "backwards_compat"), - [ - pytest.param(partial(concat, merge="unique"), False, id="concat"), - pytest.param( - lambda x, **kwargs: x[0].concatenate(x[1:], **kwargs), - True, - marks=mark_legacy_concatenate, - id="concatenate", - ), - ], -) def test_concatenate_roundtrip( join_type, array_type, - concat_func, - backwards_compat, use_xdataset, - force_lazy, ): - if backwards_compat and force_lazy: - pytest.skip("unsupported") adata = gen_adata( (100, 10), X_type=array_type, @@ -277,17 +246,12 @@ def test_concatenate_roundtrip( subset_idx = np.random.choice(remaining, n, replace=False) subsets.append(adata[subset_idx]) remaining = remaining.difference(subset_idx) - result = concat_func(subsets, join=join_type, uns_merge="same", index_unique=None) - if backwards_compat and use_xdataset: - import xarray as xr - - # backwards compat always returns a dataframe - result.var = xr.Dataset.from_dataframe(result.var) + result = concat( + subsets, join=join_type, uns_merge="same", index_unique=None, merge="unique" + ) # Correcting for known differences - orig, result = fix_known_differences( - adata, result, backwards_compat=backwards_compat - ) + orig, result = fix_known_differences(adata, result) assert_equal(result[orig.obs_names].copy(), orig) base_type = type(orig.X) @@ -298,7 +262,6 @@ def test_concatenate_roundtrip( assert isinstance(result.X, base_type) -@mark_legacy_concatenate def test_concatenate_dense(): # dense data X1 = np.array([[1, 2, 3], [4, 5, 6]]) @@ -328,25 +291,24 @@ def test_concatenate_dense(): ) # inner join - adata = adata1.concatenate(adata2, adata3) - X_combined = [[2, 3], [5, 6], [3, 2], [6, 5], [3, 2], [6, 5]] - assert adata.X.astype(int).tolist() == X_combined - assert adata.layers["Xs"].astype(int).tolist() == X_combined - assert adata.obs.columns.tolist() == ["anno1", "anno2", "batch"] - assert adata.var.columns.tolist() == ["annoA-0", "annoA-1", "annoB-2"] - assert adata.var.values.tolist() == [[1, 2, 2], [2, 1, 1]] + adata = concat([adata1, adata2, adata3], merge="first", label="batch") + X_combined = np.array([[2, 3], [5, 6], [3, 2], [6, 5], [3, 2], [6, 5]]) + assert_equal(X_combined, adata.X) + assert_equal(adata.layers["Xs"], X_combined) + assert adata.obs.columns.tolist() == ["batch"] + assert adata.var.columns.tolist() == ["annoA", "annoB"] + assert adata.var.values.tolist() == [[1, 2], [2, 1]] assert adata.obsm.keys() == {"X_1", "X_2"} assert adata.obsm["X_1"].tolist() == np.concatenate([X1, X1, X1]).tolist() - # with batch_key and batch_categories - adata = adata1.concatenate(adata2, adata3, batch_key="batch1") - assert adata.obs.columns.tolist() == ["anno1", "anno2", "batch1"] - adata = adata1.concatenate(adata2, adata3, batch_categories=["a1", "a2", "a3"]) - assert adata.obs["batch"].cat.categories.tolist() == ["a1", "a2", "a3"] + adata = concat([adata1, adata2, adata3], label="batch1") + assert adata.obs.columns.tolist() == ["batch1"] + adata = concat([adata1, adata2, adata3], label="batch1", keys=["a1", "a2", "a3"]) + assert adata.obs["batch1"].cat.categories.tolist() == ["a1", "a2", "a3"] assert adata.var_names.tolist() == ["b", "c"] # outer join - adata = adata1.concatenate(adata2, adata3, join="outer") + adata = concat([adata1, adata2, adata3], join="outer", merge="first") X_ref = np.array([ [1.0, 2.0, 3.0, np.nan], @@ -360,24 +322,23 @@ def test_concatenate_dense(): var_ma = ma.masked_invalid(adata.var.values.tolist()) var_ma_ref = ma.masked_invalid( np.array([ - [0.0, np.nan, np.nan], - [1.0, 2.0, 2.0], - [2.0, 1.0, 1.0], - [np.nan, 0.0, 0.0], + [0.0, np.nan], + [1.0, 2.0], + [2.0, 1.0], + [np.nan, 0.0], ]) ) assert np.array_equal(var_ma.mask, var_ma_ref.mask) assert np.allclose(var_ma.compressed(), var_ma_ref.compressed()) -@mark_legacy_concatenate def test_concatenate_layers(array_type, join_type): adatas = [] for _ in range(5): a = array_type(sparse.random(100, 200, format="csr")) adatas.append(AnnData(X=a, layers={"a": a})) - merged = adatas[0].concatenate(adatas[1:], join=join_type) + merged = concat(adatas, join=join_type) assert_equal(merged.X, merged.layers["a"]) @@ -430,9 +391,8 @@ def gen_index(n): ] -@mark_legacy_concatenate def test_concatenate_obsm_inner(obsm_adatas): - adata = obsm_adatas[0].concatenate(obsm_adatas[1:], join="inner") + adata = concat(obsm_adatas, join="inner") assert set(adata.obsm.keys()) == {"dense", "df"} assert adata.obsm["dense"].shape == (9, 2) @@ -460,13 +420,10 @@ def test_concatenate_obsm_inner(obsm_adatas): pd.testing.assert_frame_equal(true_df, cur_df) -@mark_legacy_concatenate def test_concatenate_obsm_outer(obsm_adatas, fill_val): - outer = obsm_adatas[0].concatenate( - obsm_adatas[1:], join="outer", fill_value=fill_val - ) + outer = concat(obsm_adatas, join="outer", fill_value=fill_val) - inner = obsm_adatas[0].concatenate(obsm_adatas[1:], join="inner") + inner = concat(obsm_adatas, join="inner") for k, inner_v in inner.obsm.items(): assert np.array_equal( _subset(outer.obsm[k], (slice(None), slice(None, inner_v.shape[1]))), @@ -536,7 +493,6 @@ def test_concat_annot_join(obsm_adatas, join_type): ) -@mark_legacy_concatenate def test_concatenate_layers_misaligned(array_type, join_type): adatas = [] for _ in range(5): @@ -546,11 +502,10 @@ def test_concatenate_layers_misaligned(array_type, join_type): adata[:, np.random.choice(adata.var_names, 150, replace=False)].copy() ) - merged = adatas[0].concatenate(adatas[1:], join=join_type) + merged = concat(adatas, join=join_type) assert_equal(merged.X, merged.layers["a"]) -@mark_legacy_concatenate def test_concatenate_layers_outer(array_type, fill_val): # Testing that issue #368 is fixed a = AnnData( @@ -559,14 +514,15 @@ def test_concatenate_layers_outer(array_type, fill_val): ) b = AnnData(X=np.ones((10, 20))) - c = a.concatenate(b, join="outer", fill_value=fill_val, batch_categories=["a", "b"]) + c = concat( + [a, b], join="outer", fill_value=fill_val, label="batch", keys=["a", "b"] + ) np.testing.assert_array_equal( asarray(c[c.obs["batch"] == "b"].layers["a"]), fill_val ) -@mark_legacy_concatenate def test_concatenate_fill_value(fill_val): def get_obs_els(adata): return { @@ -598,7 +554,7 @@ def get_obs_els(adata): for k in [k for k, v in tmp_ad.varm.items() if isinstance(v, AwkArray)]: del tmp_ad.varm[k] - joined = adata1.concatenate([adata2, adata3], join="outer", fill_value=fill_val) + joined = concat([adata1, adata2, adata3], join="outer", fill_value=fill_val) ptr = 0 for orig in [adata1, adata2, adata3]: @@ -612,8 +568,19 @@ def get_obs_els(adata): ptr += orig.n_obs -@mark_legacy_concatenate -def test_concatenate_dense_duplicates(): +@pytest.mark.parametrize( + ("merge", "expected_cols"), + [ + ("first", ["annoA", "annoB", "annoC", "annoD", "annoE"]), + ("same", ["annoA", "annoB"]), + ("unique", ["annoA", "annoB", "annoC", "annoE"]), + ("only", ["annoE"]), + (None, []), + ], +) +def test_concatenate_merge( + merge: Literal["first", "unique", "same", "only"] | None, expected_cols: list[str] +): X1 = np.array([[1, 2, 3], [4, 5, 6]]) X2 = np.array([[1, 2, 3], [4, 5, 6]]) X3 = np.array([[1, 2, 3], [4, 5, 6]]) @@ -649,22 +616,14 @@ def test_concatenate_dense_duplicates(): annoA=[0, 1, 2], annoB=[1.1, 1.0, 2.0], annoD=[2.1, 2.0, 3.1], + annoE=[2.1, 2.0, 3.1], ), ) - adata = adata1.concatenate(adata2, adata3) - assert adata.var.columns.tolist() == [ - "annoA", - "annoB", - "annoC-0", - "annoD-0", - "annoC-1", - "annoD-1", - "annoD-2", - ] + adata = concat([adata1, adata2, adata3], merge=merge) + assert adata.var.columns.tolist() == expected_cols -@mark_legacy_concatenate def test_concatenate_sparse(): # sparse data from scipy.sparse import csr_matrix @@ -693,13 +652,13 @@ def test_concatenate_sparse(): ) # inner join - adata = adata1.concatenate(adata2, adata3) + adata = concat([adata1, adata2, adata3]) X_combined = [[2, 3], [5, 6], [3, 2], [6, 5], [0, 2], [6, 5]] assert adata.X.toarray().astype(int).tolist() == X_combined assert adata.layers["Xs"].toarray().astype(int).tolist() == X_combined # outer join - adata = adata1.concatenate(adata2, adata3, join="outer") + adata = concat([adata1, adata2, adata3], join="outer") assert adata.X.toarray().tolist() == [ [0.0, 2.0, 3.0, 0.0], [0.0, 5.0, 6.0, 0.0], @@ -710,7 +669,6 @@ def test_concatenate_sparse(): ] -@mark_legacy_concatenate def test_concatenate_mixed(): X1 = sparse.csr_matrix(np.array([[1, 2, 0], [4, 0, 6], [0, 0, 9]])) X2 = sparse.csr_matrix(np.array([[0, 2, 3], [4, 0, 0], [7, 0, 9]])) @@ -741,12 +699,11 @@ def test_concatenate_mixed(): layers=dict(counts=X2), # sic ) - adata_all = AnnData.concatenate(adata1, adata2, adata3, adata4) + adata_all = concat([adata1, adata2, adata3, adata4]) assert isinstance(adata_all.X, sparse.csr_matrix) assert isinstance(adata_all.layers["counts"], sparse.csr_matrix) -@mark_legacy_concatenate def test_concatenate_with_raw(): # dense data X1 = np.array([[1, 2, 3], [4, 5, 6]]) @@ -785,20 +742,20 @@ def test_concatenate_with_raw(): adata2.raw = adata2.copy() adata3.raw = adata3.copy() - adata_all = AnnData.concatenate(adata1, adata2, adata3) + adata_all = concat([adata1, adata2, adata3]) assert isinstance(adata_all.raw, Raw) assert set(adata_all.raw.var_names) == {"b", "c"} assert_equal(adata_all.raw.to_adata().obs, adata_all.obs) assert np.array_equal(adata_all.raw.X, adata_all.X) - adata_all = AnnData.concatenate(adata1, adata2, adata3, join="outer") + adata_all = concat([adata1, adata2, adata3], join="outer") assert isinstance(adata_all.raw, Raw) assert set(adata_all.raw.var_names) == set("abcd") assert_equal(adata_all.raw.to_adata().obs, adata_all.obs) assert np.array_equal(np.nan_to_num(adata_all.raw.X), np.nan_to_num(adata_all.X)) adata3.raw = adata4.copy() - adata_all = AnnData.concatenate(adata1, adata2, adata3, join="outer") + adata_all = concat([adata1, adata2, adata3], join="outer") assert isinstance(adata_all.raw, Raw) assert set(adata_all.raw.var_names) == set("abcdz") assert set(adata_all.var_names) == set("abcd") @@ -814,13 +771,13 @@ def test_concatenate_with_raw(): "not concatenating `.raw` attributes." ), ): - adata_all = AnnData.concatenate(adata1, adata2, adata3) + adata_all = concat([adata1, adata2, adata3]) assert adata_all.raw is None del adata1.raw del adata2.raw assert all(_adata.raw is None for _adata in (adata1, adata2, adata3)) - adata_all = AnnData.concatenate(adata1, adata2, adata3) + adata_all = concat([adata1, adata2, adata3]) assert adata_all.raw is None @@ -1232,11 +1189,9 @@ def test_concatenate_uns(unss, merge_strategy, result, value_gen): to `[{"a": [1, 2, 3]}, {"a": [1, 2, 3]}]`. """ # So we can see what the initial pattern was meant to be - print(merge_strategy, "\n", unss, "\n", result) result, *unss = permute_nested_values([result, *unss], value_gen) adatas = [uns_ad(uns) for uns in unss] - with pytest.warns(FutureWarning, match=r"concatenate is deprecated"): - merged = AnnData.concatenate(*adatas, uns_merge=merge_strategy).uns + merged = concat(adatas, uns_merge=merge_strategy).uns assert_equal(merged, result, elem_name="uns") @@ -1634,7 +1589,6 @@ def test_concat_outer_aligned_mapping(elem, axis, use_xdataset, force_lazy): check_filled_like(result, elem_name=f"{axis}m/{elem}") -@mark_legacy_concatenate def test_concatenate_size_0_axis(): # https://github.com/scverse/anndata/issues/526 @@ -1642,8 +1596,7 @@ def test_concatenate_size_0_axis(): b = gen_adata((5, 0)) # Mostly testing that this doesn't error - assert a.concatenate([b]).shape == (10, 0) - assert b.concatenate([a]).shape == (10, 0) + assert concat([a, b]).shape == (10, 0) def test_concat_null_X(use_xdataset): From 69daf90bb0ae54a567ef3139ccd5d64d414e728d Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 23 Mar 2026 16:07:59 +0100 Subject: [PATCH 08/21] feat: `raw` + `uns` traversal --- src/anndata/_core/anndata.py | 85 ++++++++++++++++++++++-------------- tests/test_base.py | 23 ++++++++++ tests/test_readwrite.py | 13 ++++-- 3 files changed, 85 insertions(+), 36 deletions(-) diff --git a/src/anndata/_core/anndata.py b/src/anndata/_core/anndata.py index 59a802b69..929e08edb 100644 --- a/src/anndata/_core/anndata.py +++ b/src/anndata/_core/anndata.py @@ -68,7 +68,7 @@ from anndata.typing import RWAble from ..acc import Array, MapAcc, RefAcc - from ..compat import XDataset + from ..compat import CSMatrix, XDataset from ..typing import Index, Index1D, _Index1DNorm, _XDataType from .aligned_mapping import AxisArraysView, LayersView, PairwiseArraysView @@ -517,27 +517,37 @@ def _init_as_actual( # noqa: PLR0912, PLR0913, PLR0915 def __sizeof__( self, *, show_stratified: bool = False, with_disk: bool = False ) -> int: - def get_size[R: dict[type[RefAcc | MapAcc | AdAcc] | None, int]]( + def cs_to_bytes(X: CSArray | CSMatrix) -> int: + return int(X.data.nbytes + X.indptr.nbytes + X.indices.nbytes) + + def get_size(X: RWAble) -> int: + if isinstance(X, h5py.Dataset) and with_disk: + return int(np.array(X.shape).prod() * X.dtype.itemsize) + elif isinstance(X, BaseCompressedSparseDataset) and with_disk: + return cs_to_bytes(X._to_backed()) + elif issparse(X): + return cs_to_bytes(X) + else: + return X.__sizeof__() + + def fold_size[R: dict[type[RefAcc | MapAcc | AdAcc | Raw] | None, int]]( X: RWAble, *, accumulate: R, ref_acc: RefAcc | AdRef | MapAcc | None, ) -> R: - def cs_to_bytes(X) -> int: - return int(X.data.nbytes + X.indptr.nbytes + X.indices.nbytes) - + if isinstance(X, Raw): + ref_acc = X # type: ignore[assignment] + accumulate[Raw] += get_size(X.X) + accumulate[Raw] += get_size(X.var) + for key in X.varm: + accumulate[Raw] += get_size(X.varm[key]) + elif ref_acc is None: # "None but not Raw" is uns + accumulate[None] = sum(get_size(v) for v in self.uns.values()) if is_elem := ( - ( - # an array of some sort i.e., from AdRef (from obs/var) or a reference to one - (is_ad_ref := isinstance(ref_acc, AdRef)) - or ( - is_ref_acc := isinstance( - ref_acc, LayerAcc | MultiAcc | GraphAcc - ) - ) - ) - # an element of uns - or (ref_acc is None and X is not self.uns) + # an array of some sort i.e., from AdRef (from obs/var) or a reference to one + (is_ad_ref := isinstance(ref_acc, AdRef)) + or (is_ref_acc := isinstance(ref_acc, LayerAcc | MultiAcc | GraphAcc)) ): if is_ad_ref: key = type(ref_acc.acc) @@ -545,18 +555,11 @@ def cs_to_bytes(X) -> int: key = ref_acc.parent_type else: key = None - if isinstance(X, h5py.Dataset) and with_disk: - accumulate[key] += int(np.array(X.shape).prod() * X.dtype.itemsize) - elif isinstance(X, BaseCompressedSparseDataset) and with_disk: - accumulate[key] += cs_to_bytes(X._to_backed()) - elif issparse(X): - accumulate[key] += cs_to_bytes(X) - else: - accumulate[key] += X.__sizeof__() + accumulate[key] += get_size(X) # if this is X or a parent elem maybe print it out. if (is_x := ref_acc is A.X) or not is_elem: if ref_acc is not None: - s = accumulate[AdAcc if is_x else type(ref_acc)] + s = accumulate[AdAcc if is_x else type(ref_acc)] # type: ignore[assignment] else: s = accumulate[None] if s > 0 and show_stratified: @@ -567,7 +570,7 @@ def cs_to_bytes(X) -> int: ) return accumulate - return sum(self.fold(get_size, init=defaultdict(int)).values()) + return sum(self.fold(fold_size, init=defaultdict(int)).values()) def _gen_repr(self, n_obs, n_vars) -> str: backed_at = f" backed at {str(self.filename)!r}" if self.isbacked else "" @@ -1480,10 +1483,9 @@ def fold[T]( "obsp", "varp", "layers", - "uns", ]: attr = getattr(self, attr_name) - acc = getattr(A, attr_name) if attr_name != "uns" else None + acc = getattr(A, attr_name) if order == "DFS-pre": accumulate = func(attr, accumulate=accumulate, ref_acc=acc) if attr_name != "X": @@ -1494,6 +1496,8 @@ def fold[T]( ) if order == "DFS-post": accumulate = func(attr, accumulate=accumulate, ref_acc=acc) + accumulate = func(self.uns, accumulate=accumulate, ref_acc=None) + accumulate = func(self.raw, accumulate=accumulate, ref_acc=None) return accumulate def can_write(self, *, store_type: Literal["h5", "zarr"] | None) -> bool: @@ -1516,12 +1520,29 @@ def can_write(self, *, store_type: Literal["h5", "zarr"] | None) -> bool: if store_type is None or store_type in dest_type.__module__ } - def predicate(x: RWAble, *, accumulate: bool, ref_acc: AdRef | RefAcc | None): + def predicate( + elem: RWAble, + *, + accumulate: bool, + ref_acc: AdAcc | RefAcc | AdRef | MapAcc | None, + ): + if isinstance(elem, Raw): + accumulate = accumulate and type(elem.X) in writeable_elems + return accumulate and all( + type(e[attr]) in writeable_elems + for e in [elem.var, elem.varm] + for attr in e + ) + if ref_acc is None and isinstance(elem, dict): + return accumulate and all( + predicate(e, accumulate=accumulate, ref_acc=None) + for e in elem.values() + ) if isinstance(ref_acc, AdRef) or ref_acc is None: - if isinstance(x, pd.Series): + if isinstance(elem, pd.Series): # matches behavior in methods.py - x = x._values - return accumulate and type(x) in writeable_elems + elem = elem._values + return accumulate and type(elem) in writeable_elems return accumulate return self.fold(predicate, init=True) diff --git a/tests/test_base.py b/tests/test_base.py index 254e483b8..b24197956 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -186,6 +186,29 @@ def test_df_warnings(): adata.X = df +@pytest.mark.parametrize("use_raw", [True, False], ids=["raw", "no_raw"]) +@pytest.mark.parametrize("use_uns", [True, False], ids=["uns", "no_uns"]) +def test_sizeof_print_stratified(capsys, *, use_raw: bool, use_uns: bool): + adata = gen_adata((10, 20)) + if use_uns: + adata.uns = {"foo": np.arange(10)} + if use_raw: + adata.raw = adata.copy() + adata.__sizeof__(show_stratified=True) + captured = capsys.readouterr() + for attr in [ + "X", + "layers", + "obsm", + "varm", + "obsp", + "varp", + *(["uns"] if use_uns else []), + *(["raw"] if use_raw else []), + ]: + assert attr in captured.out + + @pytest.mark.parametrize("attr", ["X", "layers", "obsm", "varm", "obsp", "varp"]) @pytest.mark.parametrize("when", ["init", "assign"]) def test_convert_matrix(attr, when): diff --git a/tests/test_readwrite.py b/tests/test_readwrite.py index 429a093d0..43694da77 100644 --- a/tests/test_readwrite.py +++ b/tests/test_readwrite.py @@ -135,17 +135,22 @@ def test_can_write( @pytest.mark.parametrize("store_type", ["h5", "zarr", None]) -@pytest.mark.parametrize("parent_elem", ["obs", "uns"]) +@pytest.mark.parametrize("parent_elem", ["var", "uns", "raw"]) def test_can_not_write_with_custom_array( rw: tuple[ad.AnnData, ad.AnnData], store_type: Literal["h5", "zarr"] | None, - parent_elem: Literal["obs", "uns"], + parent_elem: Literal["obs", "uns", "raw"], ): import pyarrow as pa adata, _ = rw - getattr(adata, parent_elem)["arrow_array"] = pd.arrays.ArrowExtensionArray( - pa.array([{"x": 1, "y": True}] * adata.shape[0]) + if parent_elem == "raw": + adata.raw = adata.copy() + getter = lambda: getattr(adata, parent_elem).var + else: + getter = lambda: getattr(adata, parent_elem) + getter()["arrow_array"] = pd.arrays.ArrowExtensionArray( + pa.array([{"x": 1, "y": True}] * adata.shape[1]) ) assert not adata.can_write(store_type=store_type) From 932d766db3800c1b418857124f991d2ebe6e04e7 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 23 Mar 2026 16:08:20 +0100 Subject: [PATCH 09/21] fix: `fold` -> `reduce` --- src/anndata/_core/anndata.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/anndata/_core/anndata.py b/src/anndata/_core/anndata.py index 929e08edb..d692e242f 100644 --- a/src/anndata/_core/anndata.py +++ b/src/anndata/_core/anndata.py @@ -570,7 +570,7 @@ def fold_size[R: dict[type[RefAcc | MapAcc | AdAcc | Raw] | None, int]]( ) return accumulate - return sum(self.fold(fold_size, init=defaultdict(int)).values()) + return sum(self.reduce(fold_size, init=defaultdict(int)).values()) def _gen_repr(self, n_obs, n_vars) -> str: backed_at = f" backed at {str(self.filename)!r}" if self.isbacked else "" @@ -1446,7 +1446,7 @@ def copy(self, filename: PathLike[str] | str | None = None) -> AnnData: write_h5ad(filename, self) return read_h5ad(filename, backed=mode) - def fold[T]( + def reduce[T]( self, func: FoldFunc[T], *, @@ -1462,7 +1462,7 @@ def fold[T]( init The starting value order - How to visit the items in the fold. + How to visit the items in the reduce. "DFS-pre" indicates that parent-elements like uns, obs, and varp get visited first. "DFS-post" means they get visited afterwards. The `AnnData` itself is not visited. @@ -1545,7 +1545,7 @@ def predicate( return accumulate and type(elem) in writeable_elems return accumulate - return self.fold(predicate, init=True) + return self.reduce(predicate, init=True) @deprecated( deprecation_msg( From e0f3ee24da7f93d5ca6b8d8edfa521f29965a70c Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 23 Mar 2026 16:42:19 +0100 Subject: [PATCH 10/21] chore: docs --- docs/api.md | 2 +- src/anndata/_core/anndata.py | 15 +++++++++++---- src/anndata/types.py | 20 ++++++++++++++++++-- 3 files changed, 30 insertions(+), 7 deletions(-) diff --git a/docs/api.md b/docs/api.md index 0f42a381d..5bdac918e 100644 --- a/docs/api.md +++ b/docs/api.md @@ -268,7 +268,7 @@ Types used by the former: .. autosummary:: :toctree: generated/ - types.FoldFunc + types.ReduceFunc ``` diff --git a/src/anndata/_core/anndata.py b/src/anndata/_core/anndata.py index d692e242f..f968a0b33 100644 --- a/src/anndata/_core/anndata.py +++ b/src/anndata/_core/anndata.py @@ -64,7 +64,7 @@ from zarr.storage import StoreLike - from anndata.types import FoldFunc + from anndata.types import ReduceFunc from anndata.typing import RWAble from ..acc import Array, MapAcc, RefAcc @@ -1448,22 +1448,29 @@ def copy(self, filename: PathLike[str] | str | None = None) -> AnnData: def reduce[T]( self, - func: FoldFunc[T], + func: ReduceFunc[T], *, init: T, order: Literal["DFS-pre", "DFS-post"] = "DFS-post", ) -> T: """Accumulate a value starting from init by iterating over the "elems"/leaf nodes of the AnnData object. + All visits inside the user-defined `func` are distinguishable via the `ref_acc` + `elem` args. + Visits to {attr}`~AnnData.raw` pass `ref_acc is None` and `isinstance(elem, Raw)` to the :func:`types.ReduceFunc`. + Visits to {attr}`~AnnData.uns` pass `ref_acc is None` and `isinstance(elem, dict)` to the :func:`types.ReduceFunc`. + Furthermore, neither element is descended into. + This behavior could change where a new `ref_acc` type will be available, in which case we could start descending in these cases. + All other elements will have a non-`None` `ref_acc` argument. + Parameters ---------- func - The function that performs the accumulation + The function that performs the accumulation. init The starting value order How to visit the items in the reduce. - "DFS-pre" indicates that parent-elements like uns, obs, and varp get visited first. + "DFS-pre" indicates that parent-elements like layers, obs, and varp get visited first. "DFS-post" means they get visited afterwards. The `AnnData` itself is not visited. diff --git a/src/anndata/types.py b/src/anndata/types.py index 126eb1d02..d168d3134 100644 --- a/src/anndata/types.py +++ b/src/anndata/types.py @@ -53,11 +53,27 @@ def __dlpack__( def __dlpack_device__(self) -> tuple[int, int]: ... -class FoldFunc[T](Protocol): +class ReduceFunc[T](Protocol): def __call__( self, elem: RWAble, *, accumulate: T, ref_acc: AdAcc | RefAcc | AdRef | MapAcc | None, - ) -> T: ... + ) -> T: + """Function to be called on each visit within :func:`AnnData.reduce`. + + Parameters + ---------- + elem + The current element. + accumulate + The value being accumulated. + ref_acc + A reference to help uses distinguish where they are in the `AnnData` object. + + Returns + ------- + An accumulated value + """ + ... From ee04741f5a5d40e86fa9726dfabd93e2584c5e1d Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 23 Mar 2026 16:46:15 +0100 Subject: [PATCH 11/21] fix: `meth` not `func` --- src/anndata/types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/anndata/types.py b/src/anndata/types.py index d168d3134..4262c639c 100644 --- a/src/anndata/types.py +++ b/src/anndata/types.py @@ -61,7 +61,7 @@ def __call__( accumulate: T, ref_acc: AdAcc | RefAcc | AdRef | MapAcc | None, ) -> T: - """Function to be called on each visit within :func:`AnnData.reduce`. + """Function to be called on each visit within :meth:`AnnData.reduce`. Parameters ---------- From 6d6f4548c19c44d03099672e676a7264f5ce0d41 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 23 Mar 2026 16:46:58 +0100 Subject: [PATCH 12/21] fix: `fold` not `reduce` in relnote --- docs/release-notes/2327.feat.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/release-notes/2327.feat.md b/docs/release-notes/2327.feat.md index a66fef550..0a0feef86 100644 --- a/docs/release-notes/2327.feat.md +++ b/docs/release-notes/2327.feat.md @@ -1 +1 @@ -New {meth}`AnnData.fold` for crawling the "elems" and accumulating a value over these, and then {meth}`AnnData.can_write` built on top {user}`ilan-gold` +New {meth}`AnnData.reduce` for crawling the "elems" and accumulating a value over these, and then {meth}`AnnData.can_write` built on top {user}`ilan-gold` From 1f77a4c719a374984f6490c238c988a9cb8d5fce Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 23 Mar 2026 16:50:50 +0100 Subject: [PATCH 13/21] fix: nested --- src/anndata/_core/anndata.py | 6 ++++-- tests/test_base.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/anndata/_core/anndata.py b/src/anndata/_core/anndata.py index 2a135eb95..ec9f2e789 100644 --- a/src/anndata/_core/anndata.py +++ b/src/anndata/_core/anndata.py @@ -67,7 +67,7 @@ from anndata.typing import RWAble from ..acc import Array, MapAcc, RefAcc - from ..compat import CSMatrix, XDataset + from ..compat import CSArray, CSMatrix, XDataset from ..typing import Index, Index1D, _Index1DNorm, _XDataType from .aligned_mapping import AxisArraysView, LayersView, PairwiseArraysView @@ -526,6 +526,8 @@ def get_size(X: RWAble) -> int: return cs_to_bytes(X._to_backed()) elif issparse(X): return cs_to_bytes(X) + elif isinstance(X, dict): + return sum(get_size(v) for v in X.values()) else: return X.__sizeof__() @@ -542,7 +544,7 @@ def fold_size[R: dict[type[RefAcc | MapAcc | AdAcc | Raw] | None, int]]( for key in X.varm: accumulate[Raw] += get_size(X.varm[key]) elif ref_acc is None: # "None but not Raw" is uns - accumulate[None] = sum(get_size(v) for v in self.uns.values()) + accumulate[None] = get_size(self.uns) if is_elem := ( # an array of some sort i.e., from AdRef (from obs/var) or a reference to one (is_ad_ref := isinstance(ref_acc, AdRef)) diff --git a/tests/test_base.py b/tests/test_base.py index b24197956..fe6c66ba3 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -191,7 +191,7 @@ def test_df_warnings(): def test_sizeof_print_stratified(capsys, *, use_raw: bool, use_uns: bool): adata = gen_adata((10, 20)) if use_uns: - adata.uns = {"foo": np.arange(10)} + adata.uns = {"foo": np.arange(10), "nested": {"here": np.arange(10)}} if use_raw: adata.raw = adata.copy() adata.__sizeof__(show_stratified=True) From 91adffe29383a4a81165be9abd34c2dec7385a63 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 23 Mar 2026 16:53:04 +0100 Subject: [PATCH 14/21] chore: more `func` clarification --- src/anndata/_core/anndata.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/anndata/_core/anndata.py b/src/anndata/_core/anndata.py index ec9f2e789..90d2fcb8b 100644 --- a/src/anndata/_core/anndata.py +++ b/src/anndata/_core/anndata.py @@ -1461,7 +1461,7 @@ def reduce[T]( Visits to {attr}`~AnnData.uns` pass `ref_acc is None` and `isinstance(elem, dict)` to the :func:`types.ReduceFunc`. Furthermore, neither element is descended into. This behavior could change where a new `ref_acc` type will be available, in which case we could start descending in these cases. - All other elements will have a non-`None` `ref_acc` argument. + All other elements will have a non-`None` `ref_acc` argument indicating the path at which `elem` was created in the `AnnData`. Parameters ---------- From 928b72af17297e9a4d532a17d5d71160c98761e8 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 23 Mar 2026 16:53:36 +0100 Subject: [PATCH 15/21] fix: link --- src/anndata/_core/anndata.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/anndata/_core/anndata.py b/src/anndata/_core/anndata.py index 90d2fcb8b..c1ae4fdb5 100644 --- a/src/anndata/_core/anndata.py +++ b/src/anndata/_core/anndata.py @@ -1456,7 +1456,7 @@ def reduce[T]( ) -> T: """Accumulate a value starting from init by iterating over the "elems"/leaf nodes of the AnnData object. - All visits inside the user-defined `func` are distinguishable via the `ref_acc` + `elem` args. + All visits inside the user-defined `func` (see :func:`types.ReduceFunc`) are distinguishable via the `ref_acc` + `elem` args. Visits to {attr}`~AnnData.raw` pass `ref_acc is None` and `isinstance(elem, Raw)` to the :func:`types.ReduceFunc`. Visits to {attr}`~AnnData.uns` pass `ref_acc is None` and `isinstance(elem, dict)` to the :func:`types.ReduceFunc`. Furthermore, neither element is descended into. From 19a915d5f3bd8f201e66b91b0dac65ee8c64200e Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 23 Mar 2026 16:58:00 +0100 Subject: [PATCH 16/21] fix: link --- src/anndata/types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/anndata/types.py b/src/anndata/types.py index 4262c639c..712c09d51 100644 --- a/src/anndata/types.py +++ b/src/anndata/types.py @@ -61,7 +61,7 @@ def __call__( accumulate: T, ref_acc: AdAcc | RefAcc | AdRef | MapAcc | None, ) -> T: - """Function to be called on each visit within :meth:`AnnData.reduce`. + """Function to be called on each visit within :meth:`anndata.AnnData.reduce`. Parameters ---------- From c0886fef44448ec5cb2ffb2cf2c2e70ef17ab9ae Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 23 Mar 2026 16:59:54 +0100 Subject: [PATCH 17/21] refactor: simpler --- src/anndata/_core/anndata.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/src/anndata/_core/anndata.py b/src/anndata/_core/anndata.py index c1ae4fdb5..e059bb7fc 100644 --- a/src/anndata/_core/anndata.py +++ b/src/anndata/_core/anndata.py @@ -548,14 +548,9 @@ def fold_size[R: dict[type[RefAcc | MapAcc | AdAcc | Raw] | None, int]]( if is_elem := ( # an array of some sort i.e., from AdRef (from obs/var) or a reference to one (is_ad_ref := isinstance(ref_acc, AdRef)) - or (is_ref_acc := isinstance(ref_acc, LayerAcc | MultiAcc | GraphAcc)) + or isinstance(ref_acc, LayerAcc | MultiAcc | GraphAcc) ): - if is_ad_ref: - key = type(ref_acc.acc) - elif is_ref_acc: - key = ref_acc.parent_type - else: - key = None + key = type(ref_acc.acc) if is_ad_ref else ref_acc.parent_type accumulate[key] += get_size(X) # if this is X or a parent elem maybe print it out. if (is_x := ref_acc is A.X) or not is_elem: From 6cffc05c480914e414dd9000e51e5b51ec757195 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 23 Mar 2026 17:02:45 +0100 Subject: [PATCH 18/21] fix: relnote number --- docs/release-notes/{2327.feat.md => 2372.feat.md} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename docs/release-notes/{2327.feat.md => 2372.feat.md} (100%) diff --git a/docs/release-notes/2327.feat.md b/docs/release-notes/2372.feat.md similarity index 100% rename from docs/release-notes/2327.feat.md rename to docs/release-notes/2372.feat.md From b8dfaea72dceedcc6d0ad130b3fbe8560c560069 Mon Sep 17 00:00:00 2001 From: Dominik Date: Mon, 23 Mar 2026 13:08:58 -0700 Subject: [PATCH 19/21] refactor: use AnnData.reduce() for HTML repr section traversal Replace _render_all_sections with a reduce-based traversal to evaluate how well the new reduce API serves as a rendering backbone. Changes: - Extract IORegistry.get_writeable_types() and has_spec() utilities - Simplify can_write, _check_serializable_single, _check_array_has_writer to use the shared utilities - Rewrite _render_all_sections using adata.reduce(DFS-pre) Pain points documented inline: - 4-tuple accumulator to track open section state (no enter/exit events) - Duplicated _finalize_section calls (callback + post-reduce) - isinstance checks to distinguish section vs leaf visits - len(DataFrame) returns rows not columns, needs per-type handling - reduce crashes on non-string column names (tuple keys in obs/var) - uns/raw passed as opaque blobs, fall back to dedicated renderers - Custom/unknown sections entirely outside reduce's scope --- src/anndata/_core/anndata.py | 6 +- src/anndata/_io/specs/registry.py | 27 ++- src/anndata/_repr/formatters.py | 8 +- src/anndata/_repr/html.py | 294 +++++++++++++++++++++++++----- src/anndata/_repr/utils.py | 9 +- 5 files changed, 285 insertions(+), 59 deletions(-) diff --git a/src/anndata/_core/anndata.py b/src/anndata/_core/anndata.py index ac17af62e..15fb447ae 100644 --- a/src/anndata/_core/anndata.py +++ b/src/anndata/_core/anndata.py @@ -1575,11 +1575,7 @@ def can_write(self, *, store_type: Literal["h5", "zarr"] | None) -> bool: """ from anndata._io.specs.registry import _REGISTRY - writeable_elems = { - src_type - for (dest_type, src_type, __) in _REGISTRY.write - if store_type is None or store_type in dest_type.__module__ - } + writeable_elems = _REGISTRY.get_writeable_types(store_type) def predicate( elem: RWAble, diff --git a/src/anndata/_io/specs/registry.py b/src/anndata/_io/specs/registry.py index 51726e4e2..975a4fbbf 100644 --- a/src/anndata/_io/specs/registry.py +++ b/src/anndata/_io/specs/registry.py @@ -18,7 +18,7 @@ if TYPE_CHECKING: from collections.abc import Callable, Generator, Iterable - from typing import Any + from typing import Any, Literal from anndata._types import ( ReadCallback, @@ -221,6 +221,31 @@ def get_partial_read( name = "read_partial" raise IORegistryError._from_read_parts(name, self.read_partial, src_type, spec) + def get_writeable_types( + self, + store_type: Literal["h5", "zarr"] | None = None, + ) -> set[type]: + """Get the set of source types that have a registered writer. + + Parameters + ---------- + store_type + Filter by storage backend. ``None`` means any backend. + """ + return { + src_type + for (dest_type, src_type, _modifiers) in self.write + if store_type is None or store_type in dest_type.__module__ + } + + def has_spec(self, elem: Any) -> bool: + """Check whether *elem*'s type has a registered write spec.""" + try: + self.get_spec(elem) + return True + except (KeyError, TypeError): + return False + def get_spec(self, elem: Any) -> IOSpec: if isinstance(elem, DaskArray): if (typ_meta := (DaskArray, type(elem._meta))) in self.write_specs: diff --git a/src/anndata/_repr/formatters.py b/src/anndata/_repr/formatters.py index 8407455f3..421710568 100644 --- a/src/anndata/_repr/formatters.py +++ b/src/anndata/_repr/formatters.py @@ -84,13 +84,9 @@ def _check_array_has_writer(array: object) -> bool: This uses the actual IO registry, making it future-proof: if a writer is registered for a new type (e.g., datetime64), this will detect it. """ - try: - from .._io.specs.registry import _REGISTRY + from .._io.specs.registry import _REGISTRY - _REGISTRY.get_spec(array) - return True - except (KeyError, TypeError): - return False + return _REGISTRY.has_spec(array) def _check_series_backing_array(series: pd.Series) -> tuple[bool, str]: diff --git a/src/anndata/_repr/html.py b/src/anndata/_repr/html.py index 4b14dc580..8f16c716e 100644 --- a/src/anndata/_repr/html.py +++ b/src/anndata/_repr/html.py @@ -44,6 +44,7 @@ render_search_box, ) from .core import ( + render_empty_section, render_formatted_entry, render_section, render_truncation_indicator, @@ -58,9 +59,8 @@ ) from .sections import ( _detect_unknown_sections, - _render_dataframe_section, + _render_entry_row, _render_error_entry, - _render_mapping_section, _render_raw_section, _render_unknown_sections, _render_uns_section, @@ -352,56 +352,268 @@ def _render_all_sections( adata: AnnData, context: FormatterContext, ) -> list[str]: - """Render all standard and custom sections.""" - parts: list[str] = [] + """Render all standard and custom sections using AnnData.reduce(). + + Uses ``reduce`` with ``DFS-pre`` order to traverse the AnnData element tree. + Standard sections are rendered inside the reduce callback; custom sections + and unknown sections are handled separately since ``reduce`` only visits + the hardcoded AnnData attributes and has no extension mechanism. + """ + from dataclasses import replace as dc_replace + + from anndata._core.raw import Raw + from anndata.acc import ( + AdRef, + GraphAcc, + GraphMapAcc, + LayerAcc, + LayerMapAcc, + MapAcc, + MetaAcc, + MultiAcc, + MultiMapAcc, + RefAcc, + ) + custom_sections_after = _get_custom_sections_by_position(adata) - for section in SECTION_ORDER: - parts.append(_render_section(adata, section, context)) + # ── accumulator type ────────────────────────────────────────────── + # Without enter/exit events on reduce, we must carry open-section + # state through the accumulator so we can finalize the previous + # section when the next one starts (and after reduce returns). + # + # This 4-tuple is the "pain point" described in the PR discussion: + # reduce already knows section boundaries internally (it's the + # `for attr_name in [...]` loop), but doesn't surface them to the + # callback, so we have to re-derive them here. + # + # accumulator = ( + # finished_sections: list[str], # completed section HTML strings + # current_rows: list[str] | None, # entry rows being collected + # current_section: str | None, # section name (for header/metadata) + # current_n_items: int, # total items in section (for count) + # ) + + def _section_name_from_acc(ref_acc) -> str: + """Derive section name from accessor. + + PAIN POINT: reduce doesn't pass the section name, only the + accessor object. We have to repr() it and strip the 'A.' prefix, + or isinstance-check to figure out what section we're in. + """ + return repr(ref_acc).replace("A.", "") + + def _is_section_visit(ref_acc) -> bool: + """Detect whether this is a section-level (parent) visit. + + PAIN POINT: reduce visits both sections and their children through + the same callback. The only way to distinguish them is by + isinstance-checking the accessor type hierarchy. + """ + return isinstance( + ref_acc, + MetaAcc | LayerAcc | LayerMapAcc | MultiMapAcc | GraphMapAcc, + ) and not isinstance(ref_acc, AdRef | MultiAcc | GraphAcc) + + def _is_x(ref_acc) -> bool: + """Check if this is the X section (LayerAcc with k=None).""" + return isinstance(ref_acc, LayerAcc) and ref_acc.k is None + + def _finalize_section( + finished: list[str], + section_name: str, + rows: list[str], + n_items: int, + ) -> None: + """Wrap up collected rows into a complete section HTML string. + + PAIN POINT: this logic must be called in two places: + 1. Inside the callback, when the next section starts + 2. After reduce returns, for the last section + A proper enter/exit API would eliminate this duplication. + """ + from . import get_section_doc_url + from .core import get_section_tooltip + + doc_url = get_section_doc_url(section_name) + tooltip = get_section_tooltip(section_name) + if section_name == "obs": + tooltip = "Observation annotations" + elif section_name == "var": + tooltip = "Variable annotations" + + if n_items == 0: + finished.append(render_empty_section(section_name, doc_url, tooltip)) + else: + count_str = ( + f"({n_items} columns)" + if section_name in ("obs", "var") + else None + ) + finished.append( + render_section( + section_name, + "\n".join(rows), + n_items=n_items, + doc_url=doc_url, + tooltip=tooltip, + should_collapse=n_items > context.fold_threshold, + count_str=count_str, + ) + ) - # Render custom sections after this section - if section in custom_sections_after: - parts.extend( - _render_custom_section(adata, section_formatter, context) - for section_formatter in custom_sections_after[section] + # Insert custom sections after this standard section + if section_name in custom_sections_after: + finished.extend( + _render_custom_section(adata, sf, context) + for sf in custom_sections_after[section_name] ) - # Custom sections at end (no specific position) + def _render_via_reduce(elem, *, accumulate, ref_acc): + finished, current_rows, current_section, current_n_items = accumulate + + # ── uns (ref_acc=None, dict) ────────────────────────────────── + # PAIN POINT: uns and raw are passed with ref_acc=None and are + # not descended into by reduce. We must handle them as opaque + # blobs and fall back to the dedicated renderers. + if ref_acc is None and isinstance(elem, dict): + if current_rows is not None: + _finalize_section(finished, current_section, current_rows, current_n_items) + try: + finished.append(_render_uns_section(adata, context)) + except Exception as e: # noqa: BLE001 + finished.append(_render_error_entry("uns", str(e))) + if "uns" in custom_sections_after: + finished.extend( + _render_custom_section(adata, sf, context) + for sf in custom_sections_after["uns"] + ) + return finished, None, None, 0 + + # ── raw (ref_acc=None, Raw or None) ─────────────────────────── + # Same issue as uns: reduce passes raw as an opaque blob. + if ref_acc is None: + if isinstance(elem, Raw): + try: + finished.append(_render_raw_section(adata, context)) + except Exception as e: # noqa: BLE001 + finished.append(_render_error_entry("raw", str(e))) + if "raw" in custom_sections_after: + finished.extend( + _render_custom_section(adata, sf, context) + for sf in custom_sections_after["raw"] + ) + # raw is None (no raw data) — nothing to render + return finished, None, None, 0 + + # ── X section (special rendering) ───────────────────────────── + if _is_x(ref_acc): + try: + finished.append(render_x_entry(adata, context)) + except Exception as e: # noqa: BLE001 + finished.append(_render_error_entry("X", str(e))) + if "X" in custom_sections_after: + finished.extend( + _render_custom_section(adata, sf, context) + for sf in custom_sections_after["X"] + ) + return finished, None, None, 0 + + # ── Section entry (parent visit in DFS-pre) ────────────────── + if _is_section_visit(ref_acc): + # Finalize previous section if one was open + # PAIN POINT: previous section's exit is detected here, + # inside the *next* section's entry. Mixed concerns. + if current_rows is not None: + _finalize_section(finished, current_section, current_rows, current_n_items) + + section_name = _section_name_from_acc(ref_acc) + # PAIN POINT: reduce passes the container elem but we need the + # *entry count*, which differs by section type: + # DataFrame → len(df.columns), not len(df) (which is n_rows) + # Mapping → len(mapping) (number of keys) + # A dedicated enter event could pass this directly. + if isinstance(ref_acc, MetaAcc): + import pandas as pd + n_items = len(elem.columns) if isinstance(elem, pd.DataFrame) else 0 + else: + try: + n_items = len(elem) + except TypeError: + n_items = 0 + return finished, [], section_name, n_items + + # ── Leaf visit (entry within a section) ────────────────────── + # PAIN POINT: we must derive the section name and key from + # ref_acc. For MetaAcc children (obs/var columns), the key is + # in the AdRef. For mapping children, it's in the RefAcc. + if current_rows is not None and current_section is not None: + # Determine the entry key + if isinstance(ref_acc, AdRef): + key = ref_acc.idx + elif isinstance(ref_acc, (MultiAcc, GraphAcc, LayerAcc)): + key = ref_acc.k + else: + key = str(ref_acc) + + # Truncation: skip entries beyond max_items + entry_index = len(current_rows) + if entry_index >= context.max_items: + if entry_index == context.max_items: + remaining = current_n_items - context.max_items + current_rows.append(render_truncation_indicator(remaining)) + return accumulate + + # Format via the registry (same as the dedicated renderers) + try: + section_context = dc_replace(context, section=current_section) + key_context = dc_replace(section_context, key=key) + output = formatter_registry.format_value(elem, key_context) + append_type = current_section not in ("obs", "var") + current_rows.append( + _render_entry_row(str(key), output, append_type_html=append_type) + ) + except Exception as e: # noqa: BLE001 + current_rows.append( + f'
Error: {escape_html(str(e))}
' + ) + + return finished, current_rows, current_section, current_n_items + + # ── Run reduce ──────────────────────────────────────────────────── + try: + finished, leftover_rows, last_section, last_n_items = adata.reduce( + _render_via_reduce, + init=([], None, None, 0), + order="DFS-pre", + ) + except Exception as e: # noqa: BLE001 + # If reduce itself fails, fall back to an error message + return [f'
reduce failed: {escape_html(str(e))}
'] + + # PAIN POINT: the *last* section never gets a "next section entry" + # to trigger finalization. We must duplicate the finalize call here. + if leftover_rows is not None and last_section is not None: + _finalize_section(finished, last_section, leftover_rows, last_n_items) + + # ── Custom sections at end (no specific position) ───────────────── + # LIMITATION: reduce has no concept of custom/unknown sections. + # These must be handled entirely outside reduce. if None in custom_sections_after: - parts.extend( - _render_custom_section(adata, section_formatter, context) - for section_formatter in custom_sections_after[None] + finished.extend( + _render_custom_section(adata, sf, context) + for sf in custom_sections_after[None] ) - # Detect and show unknown sections (mapping-like attributes not in SECTION_ORDER) + # ── Unknown sections (extension attributes) ────────────────────── + # LIMITATION: reduce hardcodes its attribute list. Extension packages + # (TreeData adding .obst/.vart) are invisible to reduce. We must + # detect and render them separately, same as before. unknown_sections = _detect_unknown_sections(adata) if unknown_sections: - parts.append(_render_unknown_sections(unknown_sections)) - - return parts - - -def _render_section( - adata: AnnData, - section: str, - context: FormatterContext, -) -> str: - """Render a single standard section.""" - from .._repr_constants import SECTION_RAW, SECTION_UNS + finished.append(_render_unknown_sections(unknown_sections)) - try: - if section == SECTION_X: - return render_x_entry(adata, context) - if section == SECTION_RAW: - return _render_raw_section(adata, context) - if section in (SECTION_OBS, SECTION_VAR): - return _render_dataframe_section(adata, section, context) - if section == SECTION_UNS: - return _render_uns_section(adata, context) - return _render_mapping_section(adata, section, context) - except Exception as e: # noqa: BLE001 - # Show error instead of hiding the section - return _render_error_entry(section, str(e)) + return finished def _get_custom_sections_by_position( diff --git a/src/anndata/_repr/utils.py b/src/anndata/_repr/utils.py index 5704497ce..b2402bd20 100644 --- a/src/anndata/_repr/utils.py +++ b/src/anndata/_repr/utils.py @@ -41,14 +41,11 @@ def _check_serializable_single(obj: object) -> tuple[bool, str]: if obj is None: return True, "" - # Use the actual IO registry - try: - from .._io.specs.registry import _REGISTRY + # Check the IO write registry + from .._io.specs.registry import _REGISTRY - _REGISTRY.get_spec(obj) + if _REGISTRY.has_spec(obj): return True, "" - except (KeyError, TypeError): - pass # Check for basic Python types that are serializable if isinstance(obj, (bool, int, float, str, bytes)): From 5c31253997da09b9babfaf8da629fe1edbd98c4b Mon Sep 17 00:00:00 2001 From: Dominik Date: Mon, 23 Mar 2026 15:28:16 -0700 Subject: [PATCH 20/21] fix: LayerAcc children incorrectly detected as section visits LayerAcc is used for both X (k=None) and individual layer entries (k='counts'). The _is_section_visit check matched both, causing each layer to render as its own section instead of as an entry within the layers section. This is another instance of the isinstance fragility documented in the reduce experiment: the accessor type hierarchy doesn't cleanly separate parent visits from leaf visits. --- src/anndata/_repr/html.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/anndata/_repr/html.py b/src/anndata/_repr/html.py index 8f16c716e..25392b837 100644 --- a/src/anndata/_repr/html.py +++ b/src/anndata/_repr/html.py @@ -408,12 +408,20 @@ def _is_section_visit(ref_acc) -> bool: PAIN POINT: reduce visits both sections and their children through the same callback. The only way to distinguish them is by - isinstance-checking the accessor type hierarchy. + isinstance-checking the accessor type hierarchy. This is fragile: + LayerAcc is used for both X (k=None) and individual layer entries + (k='counts'), so we must also check the key to avoid treating + layer entries as new sections. """ + if isinstance(ref_acc, AdRef | MultiAcc | GraphAcc): + return False + if isinstance(ref_acc, LayerAcc): + # LayerAcc with k=None is X (handled by _is_x), k!=None is a leaf + return False return isinstance( ref_acc, - MetaAcc | LayerAcc | LayerMapAcc | MultiMapAcc | GraphMapAcc, - ) and not isinstance(ref_acc, AdRef | MultiAcc | GraphAcc) + MetaAcc | LayerMapAcc | MultiMapAcc | GraphMapAcc, + ) def _is_x(ref_acc) -> bool: """Check if this is the X section (LayerAcc with k=None).""" From eb84064002ccfc942c3400b5773ea7b9df629c06 Mon Sep 17 00:00:00 2001 From: Dominik Date: Tue, 24 Mar 2026 14:08:41 -0700 Subject: [PATCH 21/21] fix: remove redundant tooltip override in _finalize_section get_section_tooltip() already returns the correct tooltips for obs/var. The hardcoded override was an artifact of porting from _render_dataframe_section which didn't use get_section_tooltip. --- src/anndata/_repr/html.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/anndata/_repr/html.py b/src/anndata/_repr/html.py index 25392b837..926c174d0 100644 --- a/src/anndata/_repr/html.py +++ b/src/anndata/_repr/html.py @@ -445,10 +445,6 @@ def _finalize_section( doc_url = get_section_doc_url(section_name) tooltip = get_section_tooltip(section_name) - if section_name == "obs": - tooltip = "Observation annotations" - elif section_name == "var": - tooltip = "Variable annotations" if n_items == 0: finished.append(render_empty_section(section_name, doc_url, tooltip))