22import inspect
33import itertools
44from collections import ChainMap
5+ from contextlib import suppress
56from typing import Callable , List , Mapping , MutableMapping , Optional , Set , Tuple , Union
67
78import xarray as xr
@@ -226,6 +227,13 @@ def _get_measure(
226227 else :
227228 return default
228229 measures = dict (zip (strings [slice (0 , None , 2 )], strings [slice (1 , None , 2 )]))
230+ if key not in measures :
231+ if error :
232+ raise KeyError (
233+ f"Cell measure { key !r} not found. Please use .cf.describe() to see a list of key names that can be interpreted."
234+ )
235+ else :
236+ return default
229237 return measures [key ]
230238
231239
@@ -244,6 +252,7 @@ def _getattr(
244252 accessor : "CFAccessor" ,
245253 key_mappers : Mapping [str , Mapper ],
246254 wrap_classes : bool = False ,
255+ extra_decorator : Callable = None ,
247256):
248257 """
249258 Common getattr functionality.
@@ -261,13 +270,17 @@ def _getattr(
261270 Only True for the high level CFAccessor.
262271 Facilitates code reuse for _CFWrappedClass and _CFWrapppedPlotMethods
263272 For both of those, wrap_classes is False.
273+ extra_decorator: Callable (optional)
274+ An extra decorator, if necessary. This is used by _CFPlotMethods to set default
275+ kwargs based on CF attributes.
264276 """
265- func = getattr (obj , attr )
277+ func : Callable = getattr (obj , attr )
266278
267279 @functools .wraps (func )
268280 def wrapper (* args , ** kwargs ):
269281 arguments = accessor ._process_signature (func , args , kwargs , key_mappers )
270- result = func (** arguments )
282+ final_func = extra_decorator (func ) if extra_decorator else func
283+ result = final_func (** arguments )
271284 if wrap_classes and isinstance (result , _WRAPPED_CLASSES ):
272285 result = _CFWrappedClass (result , accessor )
273286
@@ -312,21 +325,58 @@ def __init__(self, obj, accessor):
312325 self .accessor = accessor
313326 self ._keys = ("x" , "y" , "hue" , "col" , "row" )
314327
328+ def _plot_decorator (self , func ):
329+ """
330+ This decorator is used to set kwargs on plotting functions.
331+ """
332+ valid_keys = self .accessor .get_valid_keys ()
333+
334+ @functools .wraps (func )
335+ def _plot_wrapper (* args , ** kwargs ):
336+ if "x" in kwargs :
337+ if kwargs ["x" ] in valid_keys :
338+ xvar = self .accessor [kwargs ["x" ]]
339+ else :
340+ xvar = self ._obj [kwargs ["x" ]]
341+ if "positive" in xvar .attrs :
342+ if xvar .attrs ["positive" ] == "down" :
343+ kwargs .setdefault ("xincrease" , False )
344+ else :
345+ kwargs .setdefault ("xincrease" , True )
346+
347+ if "y" in kwargs :
348+ if kwargs ["y" ] in valid_keys :
349+ yvar = self .accessor [kwargs ["y" ]]
350+ else :
351+ yvar = self ._obj [kwargs ["y" ]]
352+ if "positive" in yvar .attrs :
353+ if yvar .attrs ["positive" ] == "down" :
354+ kwargs .setdefault ("yincrease" , False )
355+ else :
356+ kwargs .setdefault ("yincrease" , True )
357+
358+ return func (* args , ** kwargs )
359+
360+ return _plot_wrapper
361+
315362 def __call__ (self , * args , ** kwargs ):
316363 plot = _getattr (
317364 obj = self ._obj ,
318365 attr = "plot" ,
319366 accessor = self .accessor ,
320367 key_mappers = dict .fromkeys (self ._keys , _get_axis_coord_single ),
321368 )
322- return plot (* args , ** kwargs )
369+ return self . _plot_decorator ( plot ) (* args , ** kwargs )
323370
324371 def __getattr__ (self , attr ):
325372 return _getattr (
326373 obj = self ._obj .plot ,
327374 attr = attr ,
328375 accessor = self .accessor ,
329376 key_mappers = dict .fromkeys (self ._keys , _get_axis_coord_single ),
377+ # TODO: "extra_decorator" is more complex than I would like it to be.
378+ # Not sure if there is a better way though
379+ extra_decorator = self ._plot_decorator ,
330380 )
331381
332382
@@ -458,6 +508,29 @@ def _describe(self):
458508 def describe (self ):
459509 print (self ._describe ())
460510
511+ def get_valid_keys (self ) -> Set [str ]:
512+ """
513+ Returns valid keys for .cf[]
514+
515+ Returns
516+ -------
517+ Set of valid key names that can be used with __getitem__ or .cf[key].
518+ """
519+ varnames = [
520+ key
521+ for key in _AXIS_NAMES + _COORD_NAMES
522+ if _get_axis_coord (self ._obj , key , error = False , default = None ) != [None ]
523+ ]
524+ with suppress (NotImplementedError ):
525+ measures = [
526+ key
527+ for key in _CELL_MEASURES
528+ if _get_measure (self ._obj , key , error = False ) is not None
529+ ]
530+ if measures :
531+ varnames .append (* measures )
532+ return set (varnames )
533+
461534
462535@xr .register_dataset_accessor ("cf" )
463536class CFDatasetAccessor (CFAccessor ):
0 commit comments