Skip to content

Commit 324c557

Browse files
authored
Use mean and median accessors (#374)
* use mean and median accessors * postrelease changes * fix error when dist is false
1 parent 8371e8b commit 324c557

File tree

8 files changed

+21
-19
lines changed

8 files changed

+21
-19
lines changed

docs/source/tutorials/overview.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1799,7 +1799,7 @@
17991799
],
18001800
"metadata": {
18011801
"kernelspec": {
1802-
"display_name": "Python 3 (ipykernel)",
1802+
"display_name": "arviz_1",
18031803
"language": "python",
18041804
"name": "python3"
18051805
},

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ classifiers = [
2626
]
2727
dynamic = ["version", "description"]
2828
dependencies = [
29-
"arviz-base >=0.7.0",
30-
"arviz-stats[xarray] >=0.7.0",
29+
"arviz-base @ git+https://github.com/arviz-devs/arviz-base",
30+
"arviz-stats[xarray] @ git+https://github.com/arviz-devs/arviz-stats",
3131
]
3232

3333
[tool.flit.module]

src/arviz_plots/_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
"""Base ArviZ version."""
2-
__version__ = "0.7.0"
2+
__version__ = "0.8.0dev0"

src/arviz_plots/plots/dist_plot.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ def plot_dist(
379379
)
380380

381381
if (
382-
(density_kwargs is not None)
382+
(density_kwargs is not False or face_kwargs is not False)
383383
and ("model" in distribution)
384384
and (plot_collection.coords is None)
385385
):
@@ -417,11 +417,13 @@ def plot_dist(
417417
plot_collection, aes_by_visuals, "point_estimate", sample_dims
418418
)
419419
if point_estimate == "median":
420-
point = distribution.median(dim=pe_dims, **stats.get("point_estimate", {}))
420+
point = distribution.azstats.median(dim=pe_dims, **stats.get("point_estimate", {}))
421421
elif point_estimate == "mean":
422-
point = distribution.mean(dim=pe_dims, **stats.get("point_estimate", {}))
422+
point = distribution.azstats.mean(dim=pe_dims, **stats.get("point_estimate", {}))
423+
elif point_estimate == "mode":
424+
point = distribution.azstats.mode(dim=pe_dims, **stats.get("point_estimate", {}))
423425
else:
424-
raise NotImplementedError("coming soon")
426+
raise ValueError("point_estimate must be either 'mean', 'median' or 'mode'")
425427

426428
if pe_kwargs is not False:
427429
if "color" not in pe_aes:
@@ -437,7 +439,7 @@ def plot_dist(
437439
# point estimate text
438440
if pet_kwargs is not False:
439441
if density_kwargs is False and face_kwargs is False:
440-
point_y = xr.full_like(point, 0.02)
442+
point_y = xr.full_like(point, 0.05)
441443
elif kind == "kde":
442444
point_density_diff = [
443445
dim for dim in density.sel(plot_axis="y").dims if dim not in point.dims

src/arviz_plots/plots/forest_plot.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -385,13 +385,13 @@ def plot_forest(
385385
if isinstance(pe_stats, xr.Dataset):
386386
point = pe_stats
387387
elif point_estimate == "median":
388-
point = distribution.median(dim=pe_dims, **pe_stats)
388+
point = distribution.azstats.median(dim=pe_dims, **pe_stats)
389389
elif point_estimate == "mean":
390-
point = distribution.mean(dim=pe_dims, **pe_stats)
390+
point = distribution.azstats.mean(dim=pe_dims, **pe_stats)
391391
elif point_estimate == "mode":
392392
point = distribution.azstats.mode(dim=pe_dims, **pe_stats)
393393
else:
394-
raise NotImplementedError(f"Point estimate '{point_estimate}' not implemented")
394+
raise ValueError("point_estimate must be either 'mean', 'median' or 'mode'")
395395

396396
if twig_kwargs is not False:
397397
x_range = ci_twig

src/arviz_plots/plots/lm_plot.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -377,11 +377,11 @@ def plot_lm(
377377
plot_collection, aes_by_visuals, "pe_line", sample_dims
378378
)
379379
if point_estimate == "mean":
380-
pe_value = y_pred.mean(dim=pe_line_dims, **stats.get("point_estimate", {}))
380+
pe_value = y_pred.azstats.mean(dim=pe_line_dims, **stats.get("point_estimate", {}))
381381
elif point_estimate == "median":
382-
pe_value = y_pred.median(dim=pe_line_dims, **stats.get("point_estimate", {}))
382+
pe_value = y_pred.azstats.median(dim=pe_line_dims, **stats.get("point_estimate", {}))
383383
elif point_estimate == "mode":
384-
pe_value = azs.mode(y_pred, dim=pe_line_dims, **stats.get("point_estimate", {}))
384+
pe_value = y_pred.azstats.mode(dim=pe_line_dims, **stats.get("point_estimate", {}))
385385
else:
386386
raise ValueError(
387387
f"'{point_estimate}' is not a valid value for `point_estimate`. "

src/arviz_plots/plots/ppc_interval_plot.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,9 +210,9 @@ def plot_ppc_interval(
210210
ci_twig = ci_fun(prob=ci_probs[0], dim=sample_dims, **stats.get("twig", {}))
211211

212212
if point_estimate == "median":
213-
point = ds_predictive.median(dim=sample_dims, **stats.get("point_estimate", {}))
213+
point = ds_predictive.azstats.median(dim=sample_dims, **stats.get("point_estimate", {}))
214214
elif point_estimate == "mean":
215-
point = ds_predictive.mean(dim=sample_dims, **stats.get("point_estimate", {}))
215+
point = ds_predictive.azstats.mean(dim=sample_dims, **stats.get("point_estimate", {}))
216216
elif point_estimate == "mode":
217217
point = ds_predictive.azstats.mode(dim=sample_dims, **stats.get("point_estimate", {}))
218218
else:

src/arviz_plots/plots/psense_quantities_plot.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ def plot_psense_quantities(
225225
name_quantities = []
226226

227227
if "mean" in quantities:
228-
to_concat_quantities.append(distribution.mean(sample_dims))
228+
to_concat_quantities.append(distribution.azstats.mean(sample_dims))
229229
if mcse:
230230
to_concat_mcse.append(ds_posterior.azstats.mcse(method="mean"))
231231
name_quantities.append("mean")
@@ -235,7 +235,7 @@ def plot_psense_quantities(
235235
to_concat_mcse.append(ds_posterior.azstats.mcse(method="sd"))
236236
name_quantities.append("sd")
237237
if "median" in quantities:
238-
to_concat_quantities.append(distribution.median(sample_dims))
238+
to_concat_quantities.append(distribution.azstats.median(sample_dims))
239239
if mcse:
240240
to_concat_mcse.append(ds_posterior.azstats.mcse(method="median"))
241241
name_quantities.append("median")

0 commit comments

Comments
 (0)