Skip to content

Commit 34384de

Browse files
authored
Set default plotting kwargs (#38)
* Add get_valid_keys * Decorate plotting functions. This allows us to set default plotting kwargs. 1. xincrease, yincrease deepending on .attrs["positive"] * Add test * fix test
1 parent baf204a commit 34384de

File tree

2 files changed

+105
-3
lines changed

2 files changed

+105
-3
lines changed

cf_xarray/accessor.py

Lines changed: 76 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import inspect
33
import itertools
44
from collections import ChainMap
5+
from contextlib import suppress
56
from typing import Callable, List, Mapping, MutableMapping, Optional, Set, Tuple, Union
67

78
import xarray as xr
@@ -226,6 +227,13 @@ def _get_measure(
226227
else:
227228
return default
228229
measures = dict(zip(strings[slice(0, None, 2)], strings[slice(1, None, 2)]))
230+
if key not in measures:
231+
if error:
232+
raise KeyError(
233+
f"Cell measure {key!r} not found. Please use .cf.describe() to see a list of key names that can be interpreted."
234+
)
235+
else:
236+
return default
229237
return measures[key]
230238

231239

@@ -244,6 +252,7 @@ def _getattr(
244252
accessor: "CFAccessor",
245253
key_mappers: Mapping[str, Mapper],
246254
wrap_classes: bool = False,
255+
extra_decorator: Callable = None,
247256
):
248257
"""
249258
Common getattr functionality.
@@ -261,13 +270,17 @@ def _getattr(
261270
Only True for the high level CFAccessor.
262271
Facilitates code reuse for _CFWrappedClass and _CFWrapppedPlotMethods
263272
For both of those, wrap_classes is False.
273+
extra_decorator: Callable (optional)
274+
An extra decorator, if necessary. This is used by _CFPlotMethods to set default
275+
kwargs based on CF attributes.
264276
"""
265-
func = getattr(obj, attr)
277+
func: Callable = getattr(obj, attr)
266278

267279
@functools.wraps(func)
268280
def wrapper(*args, **kwargs):
269281
arguments = accessor._process_signature(func, args, kwargs, key_mappers)
270-
result = func(**arguments)
282+
final_func = extra_decorator(func) if extra_decorator else func
283+
result = final_func(**arguments)
271284
if wrap_classes and isinstance(result, _WRAPPED_CLASSES):
272285
result = _CFWrappedClass(result, accessor)
273286

@@ -312,21 +325,58 @@ def __init__(self, obj, accessor):
312325
self.accessor = accessor
313326
self._keys = ("x", "y", "hue", "col", "row")
314327

328+
def _plot_decorator(self, func):
329+
"""
330+
This decorator is used to set kwargs on plotting functions.
331+
"""
332+
valid_keys = self.accessor.get_valid_keys()
333+
334+
@functools.wraps(func)
335+
def _plot_wrapper(*args, **kwargs):
336+
if "x" in kwargs:
337+
if kwargs["x"] in valid_keys:
338+
xvar = self.accessor[kwargs["x"]]
339+
else:
340+
xvar = self._obj[kwargs["x"]]
341+
if "positive" in xvar.attrs:
342+
if xvar.attrs["positive"] == "down":
343+
kwargs.setdefault("xincrease", False)
344+
else:
345+
kwargs.setdefault("xincrease", True)
346+
347+
if "y" in kwargs:
348+
if kwargs["y"] in valid_keys:
349+
yvar = self.accessor[kwargs["y"]]
350+
else:
351+
yvar = self._obj[kwargs["y"]]
352+
if "positive" in yvar.attrs:
353+
if yvar.attrs["positive"] == "down":
354+
kwargs.setdefault("yincrease", False)
355+
else:
356+
kwargs.setdefault("yincrease", True)
357+
358+
return func(*args, **kwargs)
359+
360+
return _plot_wrapper
361+
315362
def __call__(self, *args, **kwargs):
316363
plot = _getattr(
317364
obj=self._obj,
318365
attr="plot",
319366
accessor=self.accessor,
320367
key_mappers=dict.fromkeys(self._keys, _get_axis_coord_single),
321368
)
322-
return plot(*args, **kwargs)
369+
return self._plot_decorator(plot)(*args, **kwargs)
323370

324371
def __getattr__(self, attr):
325372
return _getattr(
326373
obj=self._obj.plot,
327374
attr=attr,
328375
accessor=self.accessor,
329376
key_mappers=dict.fromkeys(self._keys, _get_axis_coord_single),
377+
# TODO: "extra_decorator" is more complex than I would like it to be.
378+
# Not sure if there is a better way though
379+
extra_decorator=self._plot_decorator,
330380
)
331381

332382

@@ -458,6 +508,29 @@ def _describe(self):
458508
def describe(self):
459509
print(self._describe())
460510

511+
def get_valid_keys(self) -> Set[str]:
512+
"""
513+
Returns valid keys for .cf[]
514+
515+
Returns
516+
-------
517+
Set of valid key names that can be used with __getitem__ or .cf[key].
518+
"""
519+
varnames = [
520+
key
521+
for key in _AXIS_NAMES + _COORD_NAMES
522+
if _get_axis_coord(self._obj, key, error=False, default=None) != [None]
523+
]
524+
with suppress(NotImplementedError):
525+
measures = [
526+
key
527+
for key in _CELL_MEASURES
528+
if _get_measure(self._obj, key, error=False) is not None
529+
]
530+
if measures:
531+
varnames.append(*measures)
532+
return set(varnames)
533+
461534

462535
@xr.register_dataset_accessor("cf")
463536
class CFDatasetAccessor(CFAccessor):

cf_xarray/tests/test_accessor.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@
1515
ds.coords["cell_area"] = (
1616
xr.DataArray(np.cos(ds.lat * np.pi / 180)) * xr.ones_like(ds.lon) * 105e3 * 110e3
1717
)
18+
ds_no_attrs = ds.copy(deep=True)
19+
for variable in ds_no_attrs.variables:
20+
ds_no_attrs[variable].attrs = {}
21+
1822
datasets = [ds, ds.chunk({"lat": 5})]
1923
dataarrays = [ds.air, ds.air.chunk({"lat": 5})]
2024
objects = datasets + dataarrays
@@ -121,6 +125,19 @@ def test_kwargs_expand_key_to_multiple_keys():
121125
assert_identical(actual.mean(), expected.mean())
122126

123127

128+
@pytest.mark.parametrize(
129+
"obj, expected",
130+
[
131+
(ds, set(("latitude", "longitude", "time", "X", "Y", "T"))),
132+
(ds.air, set(("latitude", "longitude", "time", "X", "Y", "T", "area"))),
133+
(ds_no_attrs.air, set()),
134+
],
135+
)
136+
def test_get_valid_keys(obj, expected):
137+
actual = obj.cf.get_valid_keys()
138+
assert actual == expected
139+
140+
124141
@pytest.mark.parametrize("obj", objects)
125142
def test_args_methods(obj):
126143
with raise_if_dask_computes():
@@ -234,3 +251,15 @@ def test_getitem_uses_coordinates():
234251
)
235252
assert_identical(ds.UVEL.cf["X"], ds["ULONG"].reset_coords(drop=True))
236253
assert_identical(ds.TEMP.cf["X"], ds["TLONG"].reset_coords(drop=True))
254+
255+
256+
def test_plot_xincrease_yincrease():
257+
ds = xr.tutorial.open_dataset("air_temperature").isel(time=slice(4), lon=slice(50))
258+
ds.lon.attrs["positive"] = "down"
259+
ds.lat.attrs["positive"] = "down"
260+
261+
f, ax = plt.subplots(1, 1)
262+
ds.air.isel(time=1).cf.plot(ax=ax, x="X", y="Y")
263+
264+
for lim in [ax.get_xlim(), ax.get_ylim()]:
265+
assert lim[0] > lim[1]

0 commit comments

Comments
 (0)