diff --git a/pyproject.toml b/pyproject.toml index 314349dc..2979efed 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ classifiers = [ dynamic = ["version", "description"] dependencies = [ "arviz-base @ git+https://github.com/arviz-devs/arviz-base", - "arviz-stats[xarray] @ git+https://github.com/arviz-devs/arviz-stats", + "arviz-stats[xarray] @ git+https://github.com/arviz-devs/arviz-stats@top_level_viz", ] [tool.flit.module] diff --git a/src/arviz_plots/plots/utils.py b/src/arviz_plots/plots/utils.py index 4e4c90d8..fc632419 100644 --- a/src/arviz_plots/plots/utils.py +++ b/src/arviz_plots/plots/utils.py @@ -4,8 +4,9 @@ import numpy as np import xarray as xr -from arviz_base import references_to_dataset +from arviz_base import rcParams, references_to_dataset from arviz_base.utils import _var_names +from arviz_stats import ecdf, histogram, kde from arviz_plots.plot_collection import concat_model_dict, process_facet_dims from arviz_plots.visuals import hline, hspan, vline, vspan @@ -74,16 +75,23 @@ def process_group_variables_coords(dt, group, var_names, filter_vars, coords, al distribution = distribution.sel(coords) return distribution - def filter_aes(pc, aes_by_visuals, visual, sample_dims): + reduce_dims, _, artist_aes, ignore_aes = filter_aes_new(pc, aes_by_visuals, visual, sample_dims) + return reduce_dims, artist_aes, ignore_aes + +def filter_aes_new(pc, aes_by_visuals, visual, sample_dims): """Split aesthetics and get relevant dimensions. Returns ------- - artist_dims : list + reduce_dims : list Dimensions that should be reduced for this visual. That is, all dimensions in `sample_dims` that are not mapped to any aesthetic. + active_dims : list + Dimensions that have either faceting or aesthetic mappings + active for that visual. Should not be reduced and should have + a groupby performed on them if computing summaries. artist_aes : iterable ignore_aes : set """ @@ -91,8 +99,9 @@ def filter_aes(pc, aes_by_visuals, visual, sample_dims): pc_aes = pc.aes_set ignore_aes = set(pc_aes).difference(artist_aes) _, all_loop_dims = pc.update_aes(ignore_aes=ignore_aes) - artist_dims = [dim for dim in sample_dims if dim not in all_loop_dims] - return artist_dims, artist_aes, ignore_aes + reduce_dims = [dim for dim in sample_dims if dim not in all_loop_dims] + active_dims = [dim for dim in all_loop_dims if dim not in sample_dims] + return reduce_dims, active_dims, artist_aes, ignore_aes def set_wrap_layout(pc_kwargs, plot_bknd, ds): @@ -167,6 +176,67 @@ def set_grid_layout(pc_kwargs, plot_bknd, ds, num_rows=None, num_cols=None): pc_kwargs["figure_kwargs"]["figsize_units"] = figsize_units return pc_kwargs +def compute_dist(data, reduce_dims, active_dims, kind=None, stats=None): + if stats is None: + stats = {} + # quick exit if pre-computed elements in `stats` + if any(isinstance(stats.get(viz, None), xr.Dataset) for viz in ("ecdf", "hist", "kde")): + return (stats.get(viz, xr.Dataset()) for viz in ("ecdf", "hist", "kde")) + if kind is None: + kind = rcParams["plot.density_kind"] + if set(reduce_dims).intersection(active_dims): + raise ValueError("'reduce_dims' and 'active_dims' can't share elements") + ecdf_vars = [] + hist_vars = [] + kde_vars = [] + if kind == "auto": + for var_name, da in data.items(): + reduced_size = np.prod([da.sizes[dim] for dim in reduce_dims if dim in da.dims]) + groupby_dims = [dim for dim in active_dims if dim in da.dims] + if groupby_dims: + reduced_size *= np.prod([np.min(np.unique(da.coords[dim], return_counts=True)[1]) for dim in groupby_dims]) + if reduced_size < 100: + ecdf_vars.append(var_name) + elif da.dtype.kind == "f": + kde_vars.append(var_name) + else: + hist_vars.append(var_name) + elif kind == "ecdf": + ecdf_vars == list(data.data_vars) + elif kind == "hist": + hist_vars == list(data.data_vars) + elif kind == "kde": + kde_vars = list(data.data_vars) + + if ecdf_vars: + ecdf_data = data[ecdf_vars] + groupby_dims = [dim for dim in active_dims if dim in ecdf_data.dims] + if groupby_dims: + ecdf_data = ecdf_data.groupby(groupby_dims) + ecdf_out = ecdf(ecdf_data, dim=reduce_dims, **stats.get("ecdf", {})) + else: + ecdf_out = xr.Dataset() + + if hist_vars: + hist_data = data[hist_vars] + groupby_dims = [dim for dim in active_dims if dim in hist_data.dims] + if groupby_dims: + hist_data = hist_data.groupby(groupby_dims) + hist_out = histogram(hist_data, dim=reduce_dims, **stats.get("hist", {})) + else: + hist_out = xr.Dataset() + + if kde_vars: + kde_data = data[kde_vars] + groupby_dims = [dim for dim in active_dims if dim in kde_data.dims] + if groupby_dims: + kde_data = kde_data.groupby(groupby_dims) + kde_out = kde(kde_data, dim=reduce_dims, **stats.get("kde", {})) + else: + kde_out = xr.Dataset() + + return ecdf_out, hist_out, kde_out + def add_lines( plot_collection,