Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/tutorials/overview.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1799,7 +1799,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "arviz_1",
"language": "python",
"name": "python3"
},
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ classifiers = [
]
dynamic = ["version", "description"]
dependencies = [
"arviz-base >=0.7.0",
"arviz-stats[xarray] >=0.7.0",
"arviz-base @ git+https://github.com/arviz-devs/arviz-base",
"arviz-stats[xarray] @ git+https://github.com/arviz-devs/arviz-stats",
]

[tool.flit.module]
Expand Down
2 changes: 1 addition & 1 deletion src/arviz_plots/_version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
"""Base ArviZ version."""
__version__ = "0.7.0"
__version__ = "0.8.0dev0"
12 changes: 7 additions & 5 deletions src/arviz_plots/plots/dist_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ def plot_dist(
)

if (
(density_kwargs is not None)
(density_kwargs is not False or face_kwargs is not False)
and ("model" in distribution)
and (plot_collection.coords is None)
):
Expand Down Expand Up @@ -417,11 +417,13 @@ def plot_dist(
plot_collection, aes_by_visuals, "point_estimate", sample_dims
)
if point_estimate == "median":
point = distribution.median(dim=pe_dims, **stats.get("point_estimate", {}))
point = distribution.azstats.median(dim=pe_dims, **stats.get("point_estimate", {}))
elif point_estimate == "mean":
point = distribution.mean(dim=pe_dims, **stats.get("point_estimate", {}))
point = distribution.azstats.mean(dim=pe_dims, **stats.get("point_estimate", {}))
elif point_estimate == "mode":
point = distribution.azstats.mode(dim=pe_dims, **stats.get("point_estimate", {}))
else:
raise NotImplementedError("coming soon")
raise ValueError("point_estimate must be either 'mean', 'median' or 'mode'")

if pe_kwargs is not False:
if "color" not in pe_aes:
Expand All @@ -437,7 +439,7 @@ def plot_dist(
# point estimate text
if pet_kwargs is not False:
if density_kwargs is False and face_kwargs is False:
point_y = xr.full_like(point, 0.02)
point_y = xr.full_like(point, 0.05)
elif kind == "kde":
point_density_diff = [
dim for dim in density.sel(plot_axis="y").dims if dim not in point.dims
Expand Down
6 changes: 3 additions & 3 deletions src/arviz_plots/plots/forest_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,13 +385,13 @@ def plot_forest(
if isinstance(pe_stats, xr.Dataset):
point = pe_stats
elif point_estimate == "median":
point = distribution.median(dim=pe_dims, **pe_stats)
point = distribution.azstats.median(dim=pe_dims, **pe_stats)
elif point_estimate == "mean":
point = distribution.mean(dim=pe_dims, **pe_stats)
point = distribution.azstats.mean(dim=pe_dims, **pe_stats)
elif point_estimate == "mode":
point = distribution.azstats.mode(dim=pe_dims, **pe_stats)
else:
raise NotImplementedError(f"Point estimate '{point_estimate}' not implemented")
raise ValueError("point_estimate must be either 'mean', 'median' or 'mode'")

if twig_kwargs is not False:
x_range = ci_twig
Expand Down
6 changes: 3 additions & 3 deletions src/arviz_plots/plots/lm_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,11 +377,11 @@ def plot_lm(
plot_collection, aes_by_visuals, "pe_line", sample_dims
)
if point_estimate == "mean":
pe_value = y_pred.mean(dim=pe_line_dims, **stats.get("point_estimate", {}))
pe_value = y_pred.azstats.mean(dim=pe_line_dims, **stats.get("point_estimate", {}))
elif point_estimate == "median":
pe_value = y_pred.median(dim=pe_line_dims, **stats.get("point_estimate", {}))
pe_value = y_pred.azstats.median(dim=pe_line_dims, **stats.get("point_estimate", {}))
elif point_estimate == "mode":
pe_value = azs.mode(y_pred, dim=pe_line_dims, **stats.get("point_estimate", {}))
pe_value = y_pred.azstats.mode(dim=pe_line_dims, **stats.get("point_estimate", {}))
else:
raise ValueError(
f"'{point_estimate}' is not a valid value for `point_estimate`. "
Expand Down
4 changes: 2 additions & 2 deletions src/arviz_plots/plots/ppc_interval_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,9 +210,9 @@ def plot_ppc_interval(
ci_twig = ci_fun(prob=ci_probs[0], dim=sample_dims, **stats.get("twig", {}))

if point_estimate == "median":
point = ds_predictive.median(dim=sample_dims, **stats.get("point_estimate", {}))
point = ds_predictive.azstats.median(dim=sample_dims, **stats.get("point_estimate", {}))
elif point_estimate == "mean":
point = ds_predictive.mean(dim=sample_dims, **stats.get("point_estimate", {}))
point = ds_predictive.azstats.mean(dim=sample_dims, **stats.get("point_estimate", {}))
elif point_estimate == "mode":
point = ds_predictive.azstats.mode(dim=sample_dims, **stats.get("point_estimate", {}))
else:
Expand Down
4 changes: 2 additions & 2 deletions src/arviz_plots/plots/psense_quantities_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def plot_psense_quantities(
name_quantities = []

if "mean" in quantities:
to_concat_quantities.append(distribution.mean(sample_dims))
to_concat_quantities.append(distribution.azstats.mean(sample_dims))
if mcse:
to_concat_mcse.append(ds_posterior.azstats.mcse(method="mean"))
name_quantities.append("mean")
Expand All @@ -235,7 +235,7 @@ def plot_psense_quantities(
to_concat_mcse.append(ds_posterior.azstats.mcse(method="sd"))
name_quantities.append("sd")
if "median" in quantities:
to_concat_quantities.append(distribution.median(sample_dims))
to_concat_quantities.append(distribution.azstats.median(sample_dims))
if mcse:
to_concat_mcse.append(ds_posterior.azstats.mcse(method="median"))
name_quantities.append("median")
Expand Down