diff --git a/.github/workflows/check-pr.yml b/.github/workflows/check-pr.yml index 5f87e6b88..17f79942c 100644 --- a/.github/workflows/check-pr.yml +++ b/.github/workflows/check-pr.yml @@ -63,7 +63,7 @@ jobs: id: changes with: filters: | # this is intentionally a string - relnotes: 'docs/release-notes/${{ github.event.pull_request.number }}.${{ needs.check-milestone.outputs.type }}.md' + relnotes: 'docs/release-notes/${{ github.event.pull_request.number }}.${{ (contains(github.event.pull_request.title, '!') && 'breaking') || needs.check-milestone.outputs.type }}.md' - name: Check if a relevant release fragment is added uses: flying-sheep/check@v1 with: diff --git a/docs/api.md b/docs/api.md index 279070e50..5bdac918e 100644 --- a/docs/api.md +++ b/docs/api.md @@ -264,6 +264,13 @@ Types used by the former: abc.CSCDataset ``` +```{eval-rst} +.. autosummary:: + :toctree: generated/ + + types.ReduceFunc +``` + ```{eval-rst} diff --git a/docs/release-notes/0.6.0.md b/docs/release-notes/0.6.0.md index dc9f2e981..127f4d72c 100644 --- a/docs/release-notes/0.6.0.md +++ b/docs/release-notes/0.6.0.md @@ -2,7 +2,7 @@ ### 0.6.0 {small}`1 May, 2018` - compatibility with Seurat converter -- tremendous speedup for {meth}`~anndata.AnnData.concatenate` +- tremendous speedup for `~anndata.AnnData.concatenate` - bug fix for deep copy of unstructured annotation after slicing - bug fix for reading HDF5 stored single-category annotations - `'outer join'` concatenation: adds zeros for concatenation of sparse data and nans for dense data diff --git a/docs/release-notes/0.9.0.md b/docs/release-notes/0.9.0.md index 3481ade4c..6f1af32b9 100644 --- a/docs/release-notes/0.9.0.md +++ b/docs/release-notes/0.9.0.md @@ -39,7 +39,7 @@ #### Deprecations -- {meth}`AnnData.concatenate() ` is now deprecated in favour of {func}`anndata.concat` {pr}`845` {user}`ivirshup` +- `AnnData.concatenate()` is now deprecated in favour of {func}`anndata.concat` {pr}`845` {user}`ivirshup` #### Bug fixes diff --git a/docs/release-notes/2367.breaking.md b/docs/release-notes/2367.breaking.md new file mode 100644 index 000000000..2541b0fc5 --- /dev/null +++ b/docs/release-notes/2367.breaking.md @@ -0,0 +1 @@ +Remove `Anndata.__{set,del}item__` {user}`ilan-gold` diff --git a/docs/release-notes/2370.breaking.md b/docs/release-notes/2370.breaking.md new file mode 100644 index 000000000..a3bdee43c --- /dev/null +++ b/docs/release-notes/2370.breaking.md @@ -0,0 +1 @@ +Remove `AnnData.concatenate` {user}`ilan-gold` diff --git a/docs/release-notes/2372.feat.md b/docs/release-notes/2372.feat.md new file mode 100644 index 000000000..0a0feef86 --- /dev/null +++ b/docs/release-notes/2372.feat.md @@ -0,0 +1 @@ +New {meth}`AnnData.reduce` for crawling the "elems" and accumulating a value over these, and then {meth}`AnnData.can_write` built on top {user}`ilan-gold` diff --git a/src/anndata/_core/anndata.py b/src/anndata/_core/anndata.py index b4e7fb3c2..15fb447ae 100644 --- a/src/anndata/_core/anndata.py +++ b/src/anndata/_core/anndata.py @@ -4,10 +4,10 @@ from __future__ import annotations -from collections import OrderedDict +from collections import OrderedDict, defaultdict from collections.abc import Mapping, MutableMapping, Sequence from copy import copy, deepcopy -from functools import partial, singledispatchmethod +from functools import singledispatchmethod from pathlib import Path from textwrap import dedent from typing import TYPE_CHECKING, cast, overload @@ -18,15 +18,14 @@ from natsort import natsorted from numpy import ma from pandas.api.types import infer_dtype -from scipy import sparse from scipy.sparse import issparse from anndata._warnings import ImplicitModificationWarning +from anndata.acc import A, AdAcc, AdRef, GraphAcc, LayerAcc, MultiAcc from .. import utils from .._settings import settings from ..compat import ( - CSArray, DaskArray, IndexManager, ZarrArray, @@ -61,10 +60,14 @@ from os import PathLike from typing import Any, ClassVar, Literal + from scipy import sparse from zarr.storage import StoreLike - from ..acc import AdRef, Array, MapAcc, RefAcc - from ..compat import XDataset + from anndata.types import ReduceFunc + from anndata.typing import RWAble + + from ..acc import Array, MapAcc, RefAcc + from ..compat import CSArray, CSMatrix, XDataset from ..typing import Index, Index1D, _Index1DNorm, _XDataType from .aligned_mapping import AxisArraysView, LayersView, PairwiseArraysView @@ -513,36 +516,57 @@ def _init_as_actual( # noqa: PLR0912, PLR0913, PLR0915 def __sizeof__( self, *, show_stratified: bool = False, with_disk: bool = False ) -> int: - def get_size(X) -> int: - def cs_to_bytes(X) -> int: - return int(X.data.nbytes + X.indptr.nbytes + X.indices.nbytes) + def cs_to_bytes(X: CSArray | CSMatrix) -> int: + return int(X.data.nbytes + X.indptr.nbytes + X.indices.nbytes) + def get_size(X: RWAble) -> int: if isinstance(X, h5py.Dataset) and with_disk: return int(np.array(X.shape).prod() * X.dtype.itemsize) elif isinstance(X, BaseCompressedSparseDataset) and with_disk: return cs_to_bytes(X._to_backed()) elif issparse(X): return cs_to_bytes(X) + elif isinstance(X, dict): + return sum(get_size(v) for v in X.values()) else: return X.__sizeof__() - sizes = {} - attrs = ["X", "_obs", "_var"] - attrs_multi = ["_uns", "_obsm", "_varm", "varp", "_obsp", "_layers"] - for attr in attrs + attrs_multi: - if attr in attrs_multi: - keys = getattr(self, attr).keys() - s = sum(get_size(getattr(self, attr)[k]) for k in keys) - else: - s = get_size(getattr(self, attr)) - if s > 0 and show_stratified: - from tqdm import tqdm + def fold_size[R: dict[type[RefAcc | MapAcc | AdAcc | Raw] | None, int]]( + X: RWAble, + *, + accumulate: R, + ref_acc: RefAcc | AdRef | MapAcc | None, + ) -> R: + if isinstance(X, Raw): + ref_acc = X # type: ignore[assignment] + accumulate[Raw] += get_size(X.X) + accumulate[Raw] += get_size(X.var) + for key in X.varm: + accumulate[Raw] += get_size(X.varm[key]) + elif ref_acc is None: # "None but not Raw" is uns + accumulate[None] = get_size(self.uns) + if is_elem := ( + # an array of some sort i.e., from AdRef (from obs/var) or a reference to one + (is_ad_ref := isinstance(ref_acc, AdRef)) + or isinstance(ref_acc, LayerAcc | MultiAcc | GraphAcc) + ): + key = type(ref_acc.acc) if is_ad_ref else ref_acc.parent_type + accumulate[key] += get_size(X) + # if this is X or a parent elem maybe print it out. + if (is_x := ref_acc is A.X) or not is_elem: + if ref_acc is not None: + s = accumulate[AdAcc if is_x else type(ref_acc)] # type: ignore[assignment] + else: + s = accumulate[None] + if s > 0 and show_stratified: + from tqdm import tqdm - print( - f"Size of {attr.replace('_', '.'):<7}: {tqdm.format_sizeof(s, 'B')}" - ) - sizes[attr] = s - return sum(sizes.values()) + print( + f"Size of {repr(ref_acc).replace('A.', '') if ref_acc is not None else 'uns'}: {tqdm.format_sizeof(s, 'B')}" + ) + return accumulate + + return sum(self.reduce(fold_size, init=defaultdict(int)).values()) def _gen_repr(self, n_obs, n_vars) -> str: backed_at = f" backed at {str(self.filename)!r}" if self.isbacked else "" @@ -1081,21 +1105,6 @@ def _normalize_indices( ) -> tuple[_Index1DNorm | int | np.integer, _Index1DNorm | int | np.integer]: return _normalize_indices(index, self.obs_names, self.var_names) - # TODO: this is not quite complete... - def __delitem__(self, index: Index) -> None: - obs, var = self._normalize_indices(index) - # TODO: does this really work? - if not self.isbacked: - del self._X[obs, var] - else: - X = self.file["X"] - del X[obs, var] - self._set_backed("X", X) - if var == slice(None): - del self._obs.iloc[obs, :] - if obs == slice(None): - del self._var.iloc[var, :] - @overload def __getitem__(self, index: AdRef) -> Array: ... @overload @@ -1264,19 +1273,6 @@ def _inplace_subset_obs(self, index: Index1D): self._init_as_actual(adata_subset) - # TODO: Update, possibly remove - def __setitem__(self, index: Index, val: float | _XDataType): - if self.is_view: - msg = "Object is view and cannot be accessed with `[]`." - raise ValueError(msg) - obs, var = self._normalize_indices(index) - if not self.isbacked: - self._X[obs, var] = val - else: - X = self.file["X"] - X[obs, var] = val - self._set_backed("X", X) - def __len__(self) -> int: return self.shape[0] @@ -1504,293 +1500,109 @@ def copy(self, filename: PathLike[str] | str | None = None) -> AnnData: write_h5ad(filename, self) return read_h5ad(filename, backed=mode) - @deprecated( - deprecation_msg( - *("AnnData.concatenate", "anndata.concat"), - "See the tutorial for concat at: " - "https://anndata.readthedocs.io/en/latest/concatenation.html", - ) - ) - def concatenate( + def reduce[T]( self, - *adatas: AnnData, - join: str = "inner", - batch_key: str = "batch", - batch_categories: Sequence[Any] | None = None, - uns_merge: str | None = None, - index_unique: str | None = "-", - fill_value=None, - ) -> AnnData: - """\ - Concatenate along the observations axis. - - The :attr:`uns`, :attr:`varm` and :attr:`obsm` attributes are ignored. - - Currently, this works only in `'memory'` mode. - - .. note:: - - For more flexible and efficient concatenation, see: :func:`~anndata.concat`. + func: ReduceFunc[T], + *, + init: T, + order: Literal["DFS-pre", "DFS-post"] = "DFS-post", + ) -> T: + """Accumulate a value starting from init by iterating over the "elems"/leaf nodes of the AnnData object. + + All visits inside the user-defined `func` (see :func:`types.ReduceFunc`) are distinguishable via the `ref_acc` + `elem` args. + Visits to {attr}`~AnnData.raw` pass `ref_acc is None` and `isinstance(elem, Raw)` to the :func:`types.ReduceFunc`. + Visits to {attr}`~AnnData.uns` pass `ref_acc is None` and `isinstance(elem, dict)` to the :func:`types.ReduceFunc`. + Furthermore, neither element is descended into. + This behavior could change where a new `ref_acc` type will be available, in which case we could start descending in these cases. + All other elements will have a non-`None` `ref_acc` argument indicating the path at which `elem` was created in the `AnnData`. Parameters ---------- - adatas - AnnData matrices to concatenate with. Each matrix is referred to as - a “batch”. - join - Use intersection (`'inner'`) or union (`'outer'`) of variables. - batch_key - Add the batch annotation to :attr:`obs` using this key. - batch_categories - Use these as categories for the batch annotation. By default, use increasing numbers. - uns_merge - Strategy to use for merging entries of uns. These strategies are applied recusivley. - Currently implemented strategies include: - - * `None`: The default. The concatenated object will just have an empty dict for `uns`. - * `"same"`: Only entries which have the same value in all AnnData objects are kept. - * `"unique"`: Only entries which have one unique value in all AnnData objects are kept. - * `"first"`: The first non-missing value is used. - * `"only"`: A value is included if only one of the AnnData objects has a value at this - path. - index_unique - Make the index unique by joining the existing index names with the - batch category, using `index_unique='-'`, for instance. Provide - `None` to keep existing indices. - fill_value - Scalar value to fill newly missing values in arrays with. Note: only applies to arrays - and sparse matrices (not dataframes) and will only be used if `join="outer"`. + func + The function that performs the accumulation. + init + The starting value + order + How to visit the items in the reduce. + "DFS-pre" indicates that parent-elements like layers, obs, and varp get visited first. + "DFS-post" means they get visited afterwards. + The `AnnData` itself is not visited. + - .. note:: - If not provided, the default value is `0` for sparse matrices and `np.nan` - for numpy arrays. See the examples below for more information. Returns ------- - :class:`~anndata.AnnData` - The concatenated :class:`~anndata.AnnData`, where `adata.obs[batch_key]` - stores a categorical variable labeling the batch. - - Notes - ----- + An accumulated value + """ + accumulate = init + for attr_name in [ + "X", + "obs", + "var", + "obsm", + "varm", + "obsp", + "varp", + "layers", + ]: + attr = getattr(self, attr_name) + acc = getattr(A, attr_name) + if order == "DFS-pre": + accumulate = func(attr, accumulate=accumulate, ref_acc=acc) + if attr_name != "X": + for elem_name in attr: + ref = acc[elem_name] if acc is not None else None + accumulate = func( + attr[elem_name], accumulate=accumulate, ref_acc=ref + ) + if order == "DFS-post": + accumulate = func(attr, accumulate=accumulate, ref_acc=acc) + accumulate = func(self.uns, accumulate=accumulate, ref_acc=None) + accumulate = func(self.raw, accumulate=accumulate, ref_acc=None) + return accumulate - .. warning:: + def can_write(self, *, store_type: Literal["h5", "zarr"] | None) -> bool: + """Whether or not an `AnnData` object can be written to disk for a given store type. - If you use `join='outer'` this fills 0s for sparse data when - variables are absent in a batch. Use this with care. Dense data is - filled with `NaN`. See the examples. + Parameters + ---------- + store_type + Which backing store - `None` indicates that it can be writeable to either. - Examples - -------- - Joining on intersection of variables. - - >>> adata1 = AnnData( - ... np.array([[1, 2, 3], [4, 5, 6]]), - ... dict(obs_names=['s1', 's2'], anno1=['c1', 'c2']), - ... dict(var_names=['a', 'b', 'c'], annoA=[0, 1, 2]), - ... ) - >>> adata2 = AnnData( - ... np.array([[1, 2, 3], [4, 5, 6]]), - ... dict(obs_names=['s3', 's4'], anno1=['c3', 'c4']), - ... dict(var_names=['d', 'c', 'b'], annoA=[0, 1, 2]), - ... ) - >>> adata3 = AnnData( - ... np.array([[1, 2, 3], [4, 5, 6]]), - ... dict(obs_names=['s1', 's2'], anno2=['d3', 'd4']), - ... dict(var_names=['d', 'c', 'b'], annoA=[0, 2, 3], annoB=[0, 1, 2]), - ... ) - >>> adata = adata1.concatenate(adata2, adata3) - >>> adata - AnnData object with n_obs × n_vars = 6 × 2 - obs: 'anno1', 'anno2', 'batch' - var: 'annoA-0', 'annoA-1', 'annoA-2', 'annoB-2' - >>> adata.X - array([[2, 3], - [5, 6], - [3, 2], - [6, 5], - [3, 2], - [6, 5]]) - >>> adata.obs - anno1 anno2 batch - s1-0 c1 NaN 0 - s2-0 c2 NaN 0 - s3-1 c3 NaN 1 - s4-1 c4 NaN 1 - s1-2 NaN d3 2 - s2-2 NaN d4 2 - >>> adata.var.T - b c - annoA-0 1 2 - annoA-1 2 1 - annoA-2 3 2 - annoB-2 2 1 - - Joining on the union of variables. - - >>> outer = adata1.concatenate(adata2, adata3, join='outer') - >>> outer - AnnData object with n_obs × n_vars = 6 × 4 - obs: 'anno1', 'anno2', 'batch' - var: 'annoA-0', 'annoA-1', 'annoA-2', 'annoB-2' - >>> outer.var.T - a b c d - annoA-0 0.0 1.0 2.0 NaN - annoA-1 NaN 2.0 1.0 0.0 - annoA-2 NaN 3.0 2.0 0.0 - annoB-2 NaN 2.0 1.0 0.0 - >>> outer.var_names.astype("string") - Index(['a', 'b', 'c', 'd'], dtype='string') - >>> outer.X - array([[ 1., 2., 3., nan], - [ 4., 5., 6., nan], - [nan, 3., 2., 1.], - [nan, 6., 5., 4.], - [nan, 3., 2., 1.], - [nan, 6., 5., 4.]]) - >>> outer.X.sum(axis=0) - array([nan, 25., 23., nan]) - >>> import pandas as pd - >>> Xdf = pd.DataFrame(outer.X, columns=outer.var_names) - >>> Xdf - a b c d - 0 1.0 2.0 3.0 NaN - 1 4.0 5.0 6.0 NaN - 2 NaN 3.0 2.0 1.0 - 3 NaN 6.0 5.0 4.0 - 4 NaN 3.0 2.0 1.0 - 5 NaN 6.0 5.0 4.0 - >>> Xdf.sum() - a 5.0 - b 25.0 - c 23.0 - d 10.0 - dtype: float64 - - One way to deal with missing values is to use masked arrays: - - >>> from numpy import ma - >>> outer.X = ma.masked_invalid(outer.X) - >>> outer.X - masked_array( - data=[[1.0, 2.0, 3.0, --], - [4.0, 5.0, 6.0, --], - [--, 3.0, 2.0, 1.0], - [--, 6.0, 5.0, 4.0], - [--, 3.0, 2.0, 1.0], - [--, 6.0, 5.0, 4.0]], - mask=[[False, False, False, True], - [False, False, False, True], - [ True, False, False, False], - [ True, False, False, False], - [ True, False, False, False], - [ True, False, False, False]], - fill_value=1e+20) - >>> outer.X.sum(axis=0).data - array([ 5., 25., 23., 10.]) - - The masked array is not saved but has to be reinstantiated after saving. - - >>> outer.write('./test.h5ad') - >>> from anndata import read_h5ad - >>> outer = read_h5ad('./test.h5ad') - >>> outer.X - array([[ 1., 2., 3., nan], - [ 4., 5., 6., nan], - [nan, 3., 2., 1.], - [nan, 6., 5., 4.], - [nan, 3., 2., 1.], - [nan, 6., 5., 4.]]) - - For sparse data, everything behaves similarly, - except that for `join='outer'`, zeros are added. - - >>> from scipy.sparse import csr_matrix - >>> adata1 = AnnData( - ... csr_matrix([[0, 2, 3], [0, 5, 6]], dtype=np.float32), - ... dict(obs_names=['s1', 's2'], anno1=['c1', 'c2']), - ... dict(var_names=['a', 'b', 'c']), - ... ) - >>> adata2 = AnnData( - ... csr_matrix([[0, 2, 3], [0, 5, 6]], dtype=np.float32), - ... dict(obs_names=['s3', 's4'], anno1=['c3', 'c4']), - ... dict(var_names=['d', 'c', 'b']), - ... ) - >>> adata3 = AnnData( - ... csr_matrix([[1, 2, 0], [0, 5, 6]], dtype=np.float32), - ... dict(obs_names=['s5', 's6'], anno2=['d3', 'd4']), - ... dict(var_names=['d', 'c', 'b']), - ... ) - >>> adata = adata1.concatenate(adata2, adata3, join='outer') - >>> adata.var_names.astype("string") - Index(['a', 'b', 'c', 'd'], dtype='string') - >>> adata.X.toarray() - array([[0., 2., 3., 0.], - [0., 5., 6., 0.], - [0., 3., 2., 0.], - [0., 6., 5., 0.], - [0., 0., 2., 1.], - [0., 6., 5., 0.]], dtype=float32) + Returns + ------- + Whether or not this object is writable. """ - from .merge import concat, merge_dataframes, merge_outer, merge_same + from anndata._io.specs.registry import _REGISTRY - if self.isbacked: - msg = "Currently, concatenate only works in memory mode." - raise ValueError(msg) + writeable_elems = _REGISTRY.get_writeable_types(store_type) - if len(adatas) == 0: - return self.copy() - elif len(adatas) == 1 and not isinstance(adatas[0], AnnData): - adatas = adatas[0] # backwards compatibility - all_adatas = (self, *adatas) - - out = concat( - all_adatas, - axis=0, - join=join, - label=batch_key, - keys=batch_categories, - uns_merge=uns_merge, - fill_value=fill_value, - index_unique=index_unique, - pairwise=False, - ) - - # Backwards compat (some of this could be more efficient) - # obs used to always be an outer join - sparse_class = sparse.csr_matrix - if any(isinstance(a.X, CSArray) for a in all_adatas): - sparse_class = sparse.csr_array - out.obs = concat( - [AnnData(sparse_class(a.shape), obs=a.obs) for a in all_adatas], - axis=0, - join="outer", - label=batch_key, - keys=batch_categories, - index_unique=index_unique, - ).obs - # Removing varm - del out.varm - # Implementing old-style merging of var - if batch_categories is None: - batch_categories = np.arange(len(all_adatas)).astype(str) - pat = rf"-({'|'.join(batch_categories)})$" - out.var = merge_dataframes( - [a.var for a in all_adatas], - out.var_names, - partial(merge_outer, batch_keys=batch_categories, merge=merge_same), - ) - out.var = out.var.iloc[ - :, - ( - out.var.columns.str - .extract(pat, expand=False) - .fillna("") - .argsort(kind="stable") - ), - ] - - return out + def predicate( + elem: RWAble, + *, + accumulate: bool, + ref_acc: AdAcc | RefAcc | AdRef | MapAcc | None, + ): + if isinstance(elem, Raw): + accumulate = accumulate and type(elem.X) in writeable_elems + return accumulate and all( + type(e[attr]) in writeable_elems + for e in [elem.var, elem.varm] + for attr in e + ) + if ref_acc is None and isinstance(elem, dict): + return accumulate and all( + predicate(e, accumulate=accumulate, ref_acc=None) + for e in elem.values() + ) + if isinstance(ref_acc, AdRef) or ref_acc is None: + if isinstance(elem, pd.Series): + # matches behavior in methods.py + elem = elem._values + return accumulate and type(elem) in writeable_elems + return accumulate + + return self.reduce(predicate, init=True) def var_names_make_unique(self, join: str = "-") -> None: # Important to go through the setter so obsm dataframes are updated too diff --git a/src/anndata/_io/specs/registry.py b/src/anndata/_io/specs/registry.py index 51726e4e2..975a4fbbf 100644 --- a/src/anndata/_io/specs/registry.py +++ b/src/anndata/_io/specs/registry.py @@ -18,7 +18,7 @@ if TYPE_CHECKING: from collections.abc import Callable, Generator, Iterable - from typing import Any + from typing import Any, Literal from anndata._types import ( ReadCallback, @@ -221,6 +221,31 @@ def get_partial_read( name = "read_partial" raise IORegistryError._from_read_parts(name, self.read_partial, src_type, spec) + def get_writeable_types( + self, + store_type: Literal["h5", "zarr"] | None = None, + ) -> set[type]: + """Get the set of source types that have a registered writer. + + Parameters + ---------- + store_type + Filter by storage backend. ``None`` means any backend. + """ + return { + src_type + for (dest_type, src_type, _modifiers) in self.write + if store_type is None or store_type in dest_type.__module__ + } + + def has_spec(self, elem: Any) -> bool: + """Check whether *elem*'s type has a registered write spec.""" + try: + self.get_spec(elem) + return True + except (KeyError, TypeError): + return False + def get_spec(self, elem: Any) -> IOSpec: if isinstance(elem, DaskArray): if (typ_meta := (DaskArray, type(elem._meta))) in self.write_specs: diff --git a/src/anndata/_repr/formatters.py b/src/anndata/_repr/formatters.py index 8407455f3..421710568 100644 --- a/src/anndata/_repr/formatters.py +++ b/src/anndata/_repr/formatters.py @@ -84,13 +84,9 @@ def _check_array_has_writer(array: object) -> bool: This uses the actual IO registry, making it future-proof: if a writer is registered for a new type (e.g., datetime64), this will detect it. """ - try: - from .._io.specs.registry import _REGISTRY + from .._io.specs.registry import _REGISTRY - _REGISTRY.get_spec(array) - return True - except (KeyError, TypeError): - return False + return _REGISTRY.has_spec(array) def _check_series_backing_array(series: pd.Series) -> tuple[bool, str]: diff --git a/src/anndata/_repr/html.py b/src/anndata/_repr/html.py index 5b9c7ded0..95860da6c 100644 --- a/src/anndata/_repr/html.py +++ b/src/anndata/_repr/html.py @@ -44,6 +44,7 @@ render_search_box, ) from .core import ( + render_empty_section, render_formatted_entry, render_section, render_truncation_indicator, @@ -58,9 +59,8 @@ ) from .sections import ( _detect_unknown_sections, - _render_dataframe_section, + _render_entry_row, _render_error_entry, - _render_mapping_section, _render_raw_section, _render_unknown_sections, _render_uns_section, @@ -357,56 +357,272 @@ def _render_all_sections( adata: AnnData, context: FormatterContext, ) -> list[str]: - """Render all standard and custom sections.""" - parts: list[str] = [] + """Render all standard and custom sections using AnnData.reduce(). + + Uses ``reduce`` with ``DFS-pre`` order to traverse the AnnData element tree. + Standard sections are rendered inside the reduce callback; custom sections + and unknown sections are handled separately since ``reduce`` only visits + the hardcoded AnnData attributes and has no extension mechanism. + """ + from dataclasses import replace as dc_replace + + from anndata._core.raw import Raw + from anndata.acc import ( + AdRef, + GraphAcc, + GraphMapAcc, + LayerAcc, + LayerMapAcc, + MapAcc, + MetaAcc, + MultiAcc, + MultiMapAcc, + RefAcc, + ) + custom_sections_after = _get_custom_sections_by_position(adata) - for section in SECTION_ORDER: - parts.append(_render_section(adata, section, context)) + # ── accumulator type ────────────────────────────────────────────── + # Without enter/exit events on reduce, we must carry open-section + # state through the accumulator so we can finalize the previous + # section when the next one starts (and after reduce returns). + # + # This 4-tuple is the "pain point" described in the PR discussion: + # reduce already knows section boundaries internally (it's the + # `for attr_name in [...]` loop), but doesn't surface them to the + # callback, so we have to re-derive them here. + # + # accumulator = ( + # finished_sections: list[str], # completed section HTML strings + # current_rows: list[str] | None, # entry rows being collected + # current_section: str | None, # section name (for header/metadata) + # current_n_items: int, # total items in section (for count) + # ) + + def _section_name_from_acc(ref_acc) -> str: + """Derive section name from accessor. + + PAIN POINT: reduce doesn't pass the section name, only the + accessor object. We have to repr() it and strip the 'A.' prefix, + or isinstance-check to figure out what section we're in. + """ + return repr(ref_acc).replace("A.", "") + + def _is_section_visit(ref_acc) -> bool: + """Detect whether this is a section-level (parent) visit. + + PAIN POINT: reduce visits both sections and their children through + the same callback. The only way to distinguish them is by + isinstance-checking the accessor type hierarchy. This is fragile: + LayerAcc is used for both X (k=None) and individual layer entries + (k='counts'), so we must also check the key to avoid treating + layer entries as new sections. + """ + if isinstance(ref_acc, AdRef | MultiAcc | GraphAcc): + return False + if isinstance(ref_acc, LayerAcc): + # LayerAcc with k=None is X (handled by _is_x), k!=None is a leaf + return False + return isinstance( + ref_acc, + MetaAcc | LayerMapAcc | MultiMapAcc | GraphMapAcc, + ) - # Render custom sections after this section - if section in custom_sections_after: - parts.extend( - _render_custom_section(adata, section_formatter, context) - for section_formatter in custom_sections_after[section] + def _is_x(ref_acc) -> bool: + """Check if this is the X section (LayerAcc with k=None).""" + return isinstance(ref_acc, LayerAcc) and ref_acc.k is None + + def _finalize_section( + finished: list[str], + section_name: str, + rows: list[str], + n_items: int, + ) -> None: + """Wrap up collected rows into a complete section HTML string. + + PAIN POINT: this logic must be called in two places: + 1. Inside the callback, when the next section starts + 2. After reduce returns, for the last section + A proper enter/exit API would eliminate this duplication. + """ + from . import get_section_doc_url + from .core import get_section_tooltip + + doc_url = get_section_doc_url(section_name) + tooltip = get_section_tooltip(section_name) + + if n_items == 0: + finished.append(render_empty_section(section_name, doc_url, tooltip)) + else: + count_str = ( + f"({n_items} columns)" + if section_name in ("obs", "var") + else None + ) + finished.append( + render_section( + section_name, + "\n".join(rows), + n_items=n_items, + doc_url=doc_url, + tooltip=tooltip, + should_collapse=n_items > context.fold_threshold, + count_str=count_str, + ) + ) + + # Insert custom sections after this standard section + if section_name in custom_sections_after: + finished.extend( + _render_custom_section(adata, sf, context) + for sf in custom_sections_after[section_name] ) - # Custom sections at end (no specific position) + def _render_via_reduce(elem, *, accumulate, ref_acc): + finished, current_rows, current_section, current_n_items = accumulate + + # ── uns (ref_acc=None, dict) ────────────────────────────────── + # PAIN POINT: uns and raw are passed with ref_acc=None and are + # not descended into by reduce. We must handle them as opaque + # blobs and fall back to the dedicated renderers. + if ref_acc is None and isinstance(elem, dict): + if current_rows is not None: + _finalize_section(finished, current_section, current_rows, current_n_items) + try: + finished.append(_render_uns_section(adata, context)) + except Exception as e: # noqa: BLE001 + finished.append(_render_error_entry("uns", str(e))) + if "uns" in custom_sections_after: + finished.extend( + _render_custom_section(adata, sf, context) + for sf in custom_sections_after["uns"] + ) + return finished, None, None, 0 + + # ── raw (ref_acc=None, Raw or None) ─────────────────────────── + # Same issue as uns: reduce passes raw as an opaque blob. + if ref_acc is None: + if isinstance(elem, Raw): + try: + finished.append(_render_raw_section(adata, context)) + except Exception as e: # noqa: BLE001 + finished.append(_render_error_entry("raw", str(e))) + if "raw" in custom_sections_after: + finished.extend( + _render_custom_section(adata, sf, context) + for sf in custom_sections_after["raw"] + ) + # raw is None (no raw data) — nothing to render + return finished, None, None, 0 + + # ── X section (special rendering) ───────────────────────────── + if _is_x(ref_acc): + try: + finished.append(render_x_entry(adata, context)) + except Exception as e: # noqa: BLE001 + finished.append(_render_error_entry("X", str(e))) + if "X" in custom_sections_after: + finished.extend( + _render_custom_section(adata, sf, context) + for sf in custom_sections_after["X"] + ) + return finished, None, None, 0 + + # ── Section entry (parent visit in DFS-pre) ────────────────── + if _is_section_visit(ref_acc): + # Finalize previous section if one was open + # PAIN POINT: previous section's exit is detected here, + # inside the *next* section's entry. Mixed concerns. + if current_rows is not None: + _finalize_section(finished, current_section, current_rows, current_n_items) + + section_name = _section_name_from_acc(ref_acc) + # PAIN POINT: reduce passes the container elem but we need the + # *entry count*, which differs by section type: + # DataFrame → len(df.columns), not len(df) (which is n_rows) + # Mapping → len(mapping) (number of keys) + # A dedicated enter event could pass this directly. + if isinstance(ref_acc, MetaAcc): + import pandas as pd + n_items = len(elem.columns) if isinstance(elem, pd.DataFrame) else 0 + else: + try: + n_items = len(elem) + except TypeError: + n_items = 0 + return finished, [], section_name, n_items + + # ── Leaf visit (entry within a section) ────────────────────── + # PAIN POINT: we must derive the section name and key from + # ref_acc. For MetaAcc children (obs/var columns), the key is + # in the AdRef. For mapping children, it's in the RefAcc. + if current_rows is not None and current_section is not None: + # Determine the entry key + if isinstance(ref_acc, AdRef): + key = ref_acc.idx + elif isinstance(ref_acc, (MultiAcc, GraphAcc, LayerAcc)): + key = ref_acc.k + else: + key = str(ref_acc) + + # Truncation: skip entries beyond max_items + entry_index = len(current_rows) + if entry_index >= context.max_items: + if entry_index == context.max_items: + remaining = current_n_items - context.max_items + current_rows.append(render_truncation_indicator(remaining)) + return accumulate + + # Format via the registry (same as the dedicated renderers) + try: + section_context = dc_replace(context, section=current_section) + key_context = dc_replace(section_context, key=key) + output = formatter_registry.format_value(elem, key_context) + append_type = current_section not in ("obs", "var") + current_rows.append( + _render_entry_row(str(key), output, append_type_html=append_type) + ) + except Exception as e: # noqa: BLE001 + current_rows.append( + f'
Error: {escape_html(str(e))}
' + ) + + return finished, current_rows, current_section, current_n_items + + # ── Run reduce ──────────────────────────────────────────────────── + try: + finished, leftover_rows, last_section, last_n_items = adata.reduce( + _render_via_reduce, + init=([], None, None, 0), + order="DFS-pre", + ) + except Exception as e: # noqa: BLE001 + # If reduce itself fails, fall back to an error message + return [f'
reduce failed: {escape_html(str(e))}
'] + + # PAIN POINT: the *last* section never gets a "next section entry" + # to trigger finalization. We must duplicate the finalize call here. + if leftover_rows is not None and last_section is not None: + _finalize_section(finished, last_section, leftover_rows, last_n_items) + + # ── Custom sections at end (no specific position) ───────────────── + # LIMITATION: reduce has no concept of custom/unknown sections. + # These must be handled entirely outside reduce. if None in custom_sections_after: - parts.extend( - _render_custom_section(adata, section_formatter, context) - for section_formatter in custom_sections_after[None] + finished.extend( + _render_custom_section(adata, sf, context) + for sf in custom_sections_after[None] ) - # Detect and show unknown sections (mapping-like attributes not in SECTION_ORDER) + # ── Unknown sections (extension attributes) ────────────────────── + # LIMITATION: reduce hardcodes its attribute list. Extension packages + # (TreeData adding .obst/.vart) are invisible to reduce. We must + # detect and render them separately, same as before. unknown_sections = _detect_unknown_sections(adata) if unknown_sections: - parts.append(_render_unknown_sections(unknown_sections)) + finished.append(_render_unknown_sections(unknown_sections)) - return parts - - -def _render_section( - adata: AnnData, - section: str, - context: FormatterContext, -) -> str: - """Render a single standard section.""" - from .._repr_constants import SECTION_RAW, SECTION_UNS - - try: - if section == SECTION_X: - return render_x_entry(adata, context) - if section == SECTION_RAW: - return _render_raw_section(adata, context) - if section in (SECTION_OBS, SECTION_VAR): - return _render_dataframe_section(adata, section, context) - if section == SECTION_UNS: - return _render_uns_section(adata, context) - return _render_mapping_section(adata, section, context) - except Exception as e: # noqa: BLE001 - # Show error instead of hiding the section - return _render_error_entry(section, str(e)) + return finished def _get_custom_sections_by_position( diff --git a/src/anndata/_repr/utils.py b/src/anndata/_repr/utils.py index 5704497ce..b2402bd20 100644 --- a/src/anndata/_repr/utils.py +++ b/src/anndata/_repr/utils.py @@ -41,14 +41,11 @@ def _check_serializable_single(obj: object) -> tuple[bool, str]: if obj is None: return True, "" - # Use the actual IO registry - try: - from .._io.specs.registry import _REGISTRY + # Check the IO write registry + from .._io.specs.registry import _REGISTRY - _REGISTRY.get_spec(obj) + if _REGISTRY.has_spec(obj): return True, "" - except (KeyError, TypeError): - pass # Check for basic Python types that are serializable if isinstance(obj, (bool, int, float, str, bytes)): diff --git a/src/anndata/acc/__init__.py b/src/anndata/acc/__init__.py index ff6b80b6d..86215827b 100644 --- a/src/anndata/acc/__init__.py +++ b/src/anndata/acc/__init__.py @@ -192,6 +192,11 @@ def _maybe_flatten(self, idx: I, a: Array) -> Array: return a.__array_namespace__().reshape(a, (a.size,)) return a.ravel() + @property + @abc.abstractmethod + def parent_type(self) -> type[MapAcc | AdAcc]: + """Get the parent to this reference accessor""" + @dataclass(frozen=True) class LayerAcc[R: AdRef[Idx2D]](RefAcc[R, Idx2D]): @@ -209,6 +214,10 @@ class LayerAcc[R: AdRef[Idx2D]](RefAcc[R, Idx2D]): k: str | None """Key this accessor refers to, e.g. `A.layers['counts'].k == 'counts'`.""" + @property + def parent_type(self) -> type[MapAcc | AdAcc]: + return LayerMapAcc if self.k is not None else AdAcc + @overload def __getitem__(self, idx: Idx2D, /) -> R: ... @overload @@ -298,6 +307,10 @@ class MetaAcc[R: AdRef[str | None]](RefAcc[R, str | None]): dim: Literal["obs", "var"] """Axis this accessor refers to, e.g. `A.obs.dim == 'obs'`.""" + @property + def parent_type(self) -> type[MapAcc | AdAcc]: + return AdAcc + @property def index(self) -> R: """Index :class:`AdRef`, i.e. `A.obs.index` or `A.var.index`.""" @@ -380,6 +393,10 @@ class MultiAcc[R: AdRef[int]](RefAcc[R, int]): k: str """Key this accessor refers to, e.g. `A.varm['x'].k == 'x'`.""" + @property + def parent_type(self) -> type[MapAcc | AdAcc]: + return MultiMapAcc + @staticmethod def process_idx(i: object, /) -> int | list[int] | pd.Index[int]: if isinstance(i, tuple): @@ -463,6 +480,10 @@ class GraphAcc[R: AdRef[Idx2D]](RefAcc[R, Idx2D]): k: str """Key this accessor refers to, e.g. `A.obsp['x'].k == 'x'`.""" + @property + def parent_type(self) -> type[MapAcc | AdAcc]: + return GraphMapAcc + def process_idx(self, idx: Idx2D, /) -> Idx2D: if not all(isinstance(i, str | slice) for i in idx): msg = f"Unsupported index {idx!r}" diff --git a/src/anndata/types.py b/src/anndata/types.py index aa23d10f2..712c09d51 100644 --- a/src/anndata/types.py +++ b/src/anndata/types.py @@ -8,6 +8,9 @@ from array_api.latest import ArrayNamespace + from anndata.acc import AdAcc, AdRef, MapAcc, RefAcc + from anndata.typing import RWAble + from ._core.anndata import AnnData @@ -48,3 +51,29 @@ def __dlpack__( copy: bool | None = None, ) -> Any: ... def __dlpack_device__(self) -> tuple[int, int]: ... + + +class ReduceFunc[T](Protocol): + def __call__( + self, + elem: RWAble, + *, + accumulate: T, + ref_acc: AdAcc | RefAcc | AdRef | MapAcc | None, + ) -> T: + """Function to be called on each visit within :meth:`anndata.AnnData.reduce`. + + Parameters + ---------- + elem + The current element. + accumulate + The value being accumulated. + ref_acc + A reference to help uses distinguish where they are in the `AnnData` object. + + Returns + ------- + An accumulated value + """ + ... diff --git a/tests/test_base.py b/tests/test_base.py index 254e483b8..fe6c66ba3 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -186,6 +186,29 @@ def test_df_warnings(): adata.X = df +@pytest.mark.parametrize("use_raw", [True, False], ids=["raw", "no_raw"]) +@pytest.mark.parametrize("use_uns", [True, False], ids=["uns", "no_uns"]) +def test_sizeof_print_stratified(capsys, *, use_raw: bool, use_uns: bool): + adata = gen_adata((10, 20)) + if use_uns: + adata.uns = {"foo": np.arange(10), "nested": {"here": np.arange(10)}} + if use_raw: + adata.raw = adata.copy() + adata.__sizeof__(show_stratified=True) + captured = capsys.readouterr() + for attr in [ + "X", + "layers", + "obsm", + "varm", + "obsp", + "varp", + *(["uns"] if use_uns else []), + *(["raw"] if use_raw else []), + ]: + assert attr in captured.out + + @pytest.mark.parametrize("attr", ["X", "layers", "obsm", "varm", "obsp", "varp"]) @pytest.mark.parametrize("when", ["init", "assign"]) def test_convert_matrix(attr, when): diff --git a/tests/test_concatenate.py b/tests/test_concatenate.py index cfa57719c..3e9c82959 100644 --- a/tests/test_concatenate.py +++ b/tests/test_concatenate.py @@ -48,10 +48,6 @@ from anndata._types import Join_T -mark_legacy_concatenate = pytest.mark.filterwarnings( - r"ignore:.*AnnData\.concatenate is deprecated:FutureWarning" -) - @singledispatch def filled_like(a, fill_value=None): @@ -167,9 +163,7 @@ def force_lazy(request): return request.param -def fix_known_differences( - orig: AnnData, result: AnnData, *, backwards_compat: bool = True -): +def fix_known_differences(orig: AnnData, result: AnnData): """ Helper function for reducing anndata's to only the elements we expect to be equivalent after concatenation. @@ -181,14 +175,6 @@ def fix_known_differences( orig = orig.copy() result = result.copy() - if backwards_compat: - del orig.varm - del orig.varp - if isinstance(result.obs, Dataset2D): - result.obs = result.obs.ds.drop_vars(["batch"]) - else: - result.obs.drop(columns=["batch"], inplace=True) - for attrname in ("obs", "var"): if isinstance(getattr(result, attrname), Dataset2D): for adata in (orig, result): @@ -240,28 +226,11 @@ def test_concat_interface_errors(use_xdataset): concat([]) -@pytest.mark.parametrize( - ("concat_func", "backwards_compat"), - [ - pytest.param(partial(concat, merge="unique"), False, id="concat"), - pytest.param( - lambda x, **kwargs: x[0].concatenate(x[1:], **kwargs), - True, - marks=mark_legacy_concatenate, - id="concatenate", - ), - ], -) def test_concatenate_roundtrip( join_type, array_type, - concat_func, - backwards_compat, use_xdataset, - force_lazy, ): - if backwards_compat and force_lazy: - pytest.skip("unsupported") adata = gen_adata( (100, 10), X_type=array_type, @@ -277,17 +246,12 @@ def test_concatenate_roundtrip( subset_idx = np.random.choice(remaining, n, replace=False) subsets.append(adata[subset_idx]) remaining = remaining.difference(subset_idx) - result = concat_func(subsets, join=join_type, uns_merge="same", index_unique=None) - if backwards_compat and use_xdataset: - import xarray as xr - - # backwards compat always returns a dataframe - result.var = xr.Dataset.from_dataframe(result.var) + result = concat( + subsets, join=join_type, uns_merge="same", index_unique=None, merge="unique" + ) # Correcting for known differences - orig, result = fix_known_differences( - adata, result, backwards_compat=backwards_compat - ) + orig, result = fix_known_differences(adata, result) assert_equal(result[orig.obs_names].copy(), orig) base_type = type(orig.X) @@ -298,7 +262,6 @@ def test_concatenate_roundtrip( assert isinstance(result.X, base_type) -@mark_legacy_concatenate def test_concatenate_dense(): # dense data X1 = np.array([[1, 2, 3], [4, 5, 6]]) @@ -328,25 +291,24 @@ def test_concatenate_dense(): ) # inner join - adata = adata1.concatenate(adata2, adata3) - X_combined = [[2, 3], [5, 6], [3, 2], [6, 5], [3, 2], [6, 5]] - assert adata.X.astype(int).tolist() == X_combined - assert adata.layers["Xs"].astype(int).tolist() == X_combined - assert adata.obs.columns.tolist() == ["anno1", "anno2", "batch"] - assert adata.var.columns.tolist() == ["annoA-0", "annoA-1", "annoB-2"] - assert adata.var.values.tolist() == [[1, 2, 2], [2, 1, 1]] + adata = concat([adata1, adata2, adata3], merge="first", label="batch") + X_combined = np.array([[2, 3], [5, 6], [3, 2], [6, 5], [3, 2], [6, 5]]) + assert_equal(X_combined, adata.X) + assert_equal(adata.layers["Xs"], X_combined) + assert adata.obs.columns.tolist() == ["batch"] + assert adata.var.columns.tolist() == ["annoA", "annoB"] + assert adata.var.values.tolist() == [[1, 2], [2, 1]] assert adata.obsm.keys() == {"X_1", "X_2"} assert adata.obsm["X_1"].tolist() == np.concatenate([X1, X1, X1]).tolist() - # with batch_key and batch_categories - adata = adata1.concatenate(adata2, adata3, batch_key="batch1") - assert adata.obs.columns.tolist() == ["anno1", "anno2", "batch1"] - adata = adata1.concatenate(adata2, adata3, batch_categories=["a1", "a2", "a3"]) - assert adata.obs["batch"].cat.categories.tolist() == ["a1", "a2", "a3"] + adata = concat([adata1, adata2, adata3], label="batch1") + assert adata.obs.columns.tolist() == ["batch1"] + adata = concat([adata1, adata2, adata3], label="batch1", keys=["a1", "a2", "a3"]) + assert adata.obs["batch1"].cat.categories.tolist() == ["a1", "a2", "a3"] assert adata.var_names.tolist() == ["b", "c"] # outer join - adata = adata1.concatenate(adata2, adata3, join="outer") + adata = concat([adata1, adata2, adata3], join="outer", merge="first") X_ref = np.array([ [1.0, 2.0, 3.0, np.nan], @@ -360,24 +322,23 @@ def test_concatenate_dense(): var_ma = ma.masked_invalid(adata.var.values.tolist()) var_ma_ref = ma.masked_invalid( np.array([ - [0.0, np.nan, np.nan], - [1.0, 2.0, 2.0], - [2.0, 1.0, 1.0], - [np.nan, 0.0, 0.0], + [0.0, np.nan], + [1.0, 2.0], + [2.0, 1.0], + [np.nan, 0.0], ]) ) assert np.array_equal(var_ma.mask, var_ma_ref.mask) assert np.allclose(var_ma.compressed(), var_ma_ref.compressed()) -@mark_legacy_concatenate def test_concatenate_layers(array_type, join_type): adatas = [] for _ in range(5): a = array_type(sparse.random(100, 200, format="csr")) adatas.append(AnnData(X=a, layers={"a": a})) - merged = adatas[0].concatenate(adatas[1:], join=join_type) + merged = concat(adatas, join=join_type) assert_equal(merged.X, merged.layers["a"]) @@ -430,9 +391,8 @@ def gen_index(n): ] -@mark_legacy_concatenate def test_concatenate_obsm_inner(obsm_adatas): - adata = obsm_adatas[0].concatenate(obsm_adatas[1:], join="inner") + adata = concat(obsm_adatas, join="inner") assert set(adata.obsm.keys()) == {"dense", "df"} assert adata.obsm["dense"].shape == (9, 2) @@ -460,13 +420,10 @@ def test_concatenate_obsm_inner(obsm_adatas): pd.testing.assert_frame_equal(true_df, cur_df) -@mark_legacy_concatenate def test_concatenate_obsm_outer(obsm_adatas, fill_val): - outer = obsm_adatas[0].concatenate( - obsm_adatas[1:], join="outer", fill_value=fill_val - ) + outer = concat(obsm_adatas, join="outer", fill_value=fill_val) - inner = obsm_adatas[0].concatenate(obsm_adatas[1:], join="inner") + inner = concat(obsm_adatas, join="inner") for k, inner_v in inner.obsm.items(): assert np.array_equal( _subset(outer.obsm[k], (slice(None), slice(None, inner_v.shape[1]))), @@ -536,7 +493,6 @@ def test_concat_annot_join(obsm_adatas, join_type): ) -@mark_legacy_concatenate def test_concatenate_layers_misaligned(array_type, join_type): adatas = [] for _ in range(5): @@ -546,11 +502,10 @@ def test_concatenate_layers_misaligned(array_type, join_type): adata[:, np.random.choice(adata.var_names, 150, replace=False)].copy() ) - merged = adatas[0].concatenate(adatas[1:], join=join_type) + merged = concat(adatas, join=join_type) assert_equal(merged.X, merged.layers["a"]) -@mark_legacy_concatenate def test_concatenate_layers_outer(array_type, fill_val): # Testing that issue #368 is fixed a = AnnData( @@ -559,14 +514,15 @@ def test_concatenate_layers_outer(array_type, fill_val): ) b = AnnData(X=np.ones((10, 20))) - c = a.concatenate(b, join="outer", fill_value=fill_val, batch_categories=["a", "b"]) + c = concat( + [a, b], join="outer", fill_value=fill_val, label="batch", keys=["a", "b"] + ) np.testing.assert_array_equal( asarray(c[c.obs["batch"] == "b"].layers["a"]), fill_val ) -@mark_legacy_concatenate def test_concatenate_fill_value(fill_val): def get_obs_els(adata): return { @@ -598,7 +554,7 @@ def get_obs_els(adata): for k in [k for k, v in tmp_ad.varm.items() if isinstance(v, AwkArray)]: del tmp_ad.varm[k] - joined = adata1.concatenate([adata2, adata3], join="outer", fill_value=fill_val) + joined = concat([adata1, adata2, adata3], join="outer", fill_value=fill_val) ptr = 0 for orig in [adata1, adata2, adata3]: @@ -612,8 +568,19 @@ def get_obs_els(adata): ptr += orig.n_obs -@mark_legacy_concatenate -def test_concatenate_dense_duplicates(): +@pytest.mark.parametrize( + ("merge", "expected_cols"), + [ + ("first", ["annoA", "annoB", "annoC", "annoD", "annoE"]), + ("same", ["annoA", "annoB"]), + ("unique", ["annoA", "annoB", "annoC", "annoE"]), + ("only", ["annoE"]), + (None, []), + ], +) +def test_concatenate_merge( + merge: Literal["first", "unique", "same", "only"] | None, expected_cols: list[str] +): X1 = np.array([[1, 2, 3], [4, 5, 6]]) X2 = np.array([[1, 2, 3], [4, 5, 6]]) X3 = np.array([[1, 2, 3], [4, 5, 6]]) @@ -649,22 +616,14 @@ def test_concatenate_dense_duplicates(): annoA=[0, 1, 2], annoB=[1.1, 1.0, 2.0], annoD=[2.1, 2.0, 3.1], + annoE=[2.1, 2.0, 3.1], ), ) - adata = adata1.concatenate(adata2, adata3) - assert adata.var.columns.tolist() == [ - "annoA", - "annoB", - "annoC-0", - "annoD-0", - "annoC-1", - "annoD-1", - "annoD-2", - ] + adata = concat([adata1, adata2, adata3], merge=merge) + assert adata.var.columns.tolist() == expected_cols -@mark_legacy_concatenate def test_concatenate_sparse(): # sparse data from scipy.sparse import csr_matrix @@ -693,13 +652,13 @@ def test_concatenate_sparse(): ) # inner join - adata = adata1.concatenate(adata2, adata3) + adata = concat([adata1, adata2, adata3]) X_combined = [[2, 3], [5, 6], [3, 2], [6, 5], [0, 2], [6, 5]] assert adata.X.toarray().astype(int).tolist() == X_combined assert adata.layers["Xs"].toarray().astype(int).tolist() == X_combined # outer join - adata = adata1.concatenate(adata2, adata3, join="outer") + adata = concat([adata1, adata2, adata3], join="outer") assert adata.X.toarray().tolist() == [ [0.0, 2.0, 3.0, 0.0], [0.0, 5.0, 6.0, 0.0], @@ -710,7 +669,6 @@ def test_concatenate_sparse(): ] -@mark_legacy_concatenate def test_concatenate_mixed(): X1 = sparse.csr_matrix(np.array([[1, 2, 0], [4, 0, 6], [0, 0, 9]])) X2 = sparse.csr_matrix(np.array([[0, 2, 3], [4, 0, 0], [7, 0, 9]])) @@ -741,12 +699,11 @@ def test_concatenate_mixed(): layers=dict(counts=X2), # sic ) - adata_all = AnnData.concatenate(adata1, adata2, adata3, adata4) + adata_all = concat([adata1, adata2, adata3, adata4]) assert isinstance(adata_all.X, sparse.csr_matrix) assert isinstance(adata_all.layers["counts"], sparse.csr_matrix) -@mark_legacy_concatenate def test_concatenate_with_raw(): # dense data X1 = np.array([[1, 2, 3], [4, 5, 6]]) @@ -785,20 +742,20 @@ def test_concatenate_with_raw(): adata2.raw = adata2.copy() adata3.raw = adata3.copy() - adata_all = AnnData.concatenate(adata1, adata2, adata3) + adata_all = concat([adata1, adata2, adata3]) assert isinstance(adata_all.raw, Raw) assert set(adata_all.raw.var_names) == {"b", "c"} assert_equal(adata_all.raw.to_adata().obs, adata_all.obs) assert np.array_equal(adata_all.raw.X, adata_all.X) - adata_all = AnnData.concatenate(adata1, adata2, adata3, join="outer") + adata_all = concat([adata1, adata2, adata3], join="outer") assert isinstance(adata_all.raw, Raw) assert set(adata_all.raw.var_names) == set("abcd") assert_equal(adata_all.raw.to_adata().obs, adata_all.obs) assert np.array_equal(np.nan_to_num(adata_all.raw.X), np.nan_to_num(adata_all.X)) adata3.raw = adata4.copy() - adata_all = AnnData.concatenate(adata1, adata2, adata3, join="outer") + adata_all = concat([adata1, adata2, adata3], join="outer") assert isinstance(adata_all.raw, Raw) assert set(adata_all.raw.var_names) == set("abcdz") assert set(adata_all.var_names) == set("abcd") @@ -814,13 +771,13 @@ def test_concatenate_with_raw(): "not concatenating `.raw` attributes." ), ): - adata_all = AnnData.concatenate(adata1, adata2, adata3) + adata_all = concat([adata1, adata2, adata3]) assert adata_all.raw is None del adata1.raw del adata2.raw assert all(_adata.raw is None for _adata in (adata1, adata2, adata3)) - adata_all = AnnData.concatenate(adata1, adata2, adata3) + adata_all = concat([adata1, adata2, adata3]) assert adata_all.raw is None @@ -1232,11 +1189,9 @@ def test_concatenate_uns(unss, merge_strategy, result, value_gen): to `[{"a": [1, 2, 3]}, {"a": [1, 2, 3]}]`. """ # So we can see what the initial pattern was meant to be - print(merge_strategy, "\n", unss, "\n", result) result, *unss = permute_nested_values([result, *unss], value_gen) adatas = [uns_ad(uns) for uns in unss] - with pytest.warns(FutureWarning, match=r"concatenate is deprecated"): - merged = AnnData.concatenate(*adatas, uns_merge=merge_strategy).uns + merged = concat(adatas, uns_merge=merge_strategy).uns assert_equal(merged, result, elem_name="uns") @@ -1634,7 +1589,6 @@ def test_concat_outer_aligned_mapping(elem, axis, use_xdataset, force_lazy): check_filled_like(result, elem_name=f"{axis}m/{elem}") -@mark_legacy_concatenate def test_concatenate_size_0_axis(): # https://github.com/scverse/anndata/issues/526 @@ -1642,8 +1596,7 @@ def test_concatenate_size_0_axis(): b = gen_adata((5, 0)) # Mostly testing that this doesn't error - assert a.concatenate([b]).shape == (10, 0) - assert b.concatenate([a]).shape == (10, 0) + assert concat([a, b]).shape == (10, 0) def test_concat_null_X(use_xdataset): diff --git a/tests/test_readwrite.py b/tests/test_readwrite.py index 3359b2ff8..43694da77 100644 --- a/tests/test_readwrite.py +++ b/tests/test_readwrite.py @@ -99,7 +99,7 @@ def dataset_kwargs(request): @pytest.fixture -def rw(backing_h5ad): +def rw(backing_h5ad) -> tuple[ad.AnnData, ad.AnnData]: M, N = 100, 101 orig = gen_adata((M, N), **GEN_ADATA_NO_XARRAY_ARGS) orig.write(backing_h5ad) @@ -126,6 +126,35 @@ def dtype(request): # ------------------------------------------------------------------------------ +@pytest.mark.parametrize("store_type", ["h5", "zarr", None]) +def test_can_write( + rw: tuple[ad.AnnData, ad.AnnData], store_type: Literal["h5", "zarr"] | None +): + adata, _ = rw + assert adata.can_write(store_type=store_type) + + +@pytest.mark.parametrize("store_type", ["h5", "zarr", None]) +@pytest.mark.parametrize("parent_elem", ["var", "uns", "raw"]) +def test_can_not_write_with_custom_array( + rw: tuple[ad.AnnData, ad.AnnData], + store_type: Literal["h5", "zarr"] | None, + parent_elem: Literal["obs", "uns", "raw"], +): + import pyarrow as pa + + adata, _ = rw + if parent_elem == "raw": + adata.raw = adata.copy() + getter = lambda: getattr(adata, parent_elem).var + else: + getter = lambda: getattr(adata, parent_elem) + getter()["arrow_array"] = pd.arrays.ArrowExtensionArray( + pa.array([{"x": 1, "y": True}] * adata.shape[1]) + ) + assert not adata.can_write(store_type=store_type) + + @pytest.mark.parametrize("typ", ARRAY_TYPES) def test_readwrite_roundtrip(typ, tmp_path, diskfmt, diskfmt2): pth1 = tmp_path / f"first.{diskfmt}"