11import functools
22import inspect
33import itertools
4+ import textwrap
45from collections import ChainMap
56from contextlib import suppress
6- from typing import Callable , List , Mapping , MutableMapping , Optional , Set , Tuple , Union
7+ from typing import (
8+ Callable ,
9+ Hashable ,
10+ List ,
11+ Mapping ,
12+ MutableMapping ,
13+ Optional ,
14+ Set ,
15+ Tuple ,
16+ Union ,
17+ )
718
819import xarray as xr
920from xarray import DataArray , Dataset
106117]
107118
108119
120+ def _strip_none_list (lst : List [Optional [str ]]) -> List [str ]:
121+ """ The mappers can return [None]. Strip that when necessary. Keeps mypy happy."""
122+ return [item for item in lst if item != [None ]] # type: ignore
123+
124+
109125def _get_axis_coord_single (
110126 var : Union [xr .DataArray , xr .Dataset ],
111127 key : str ,
@@ -176,8 +192,8 @@ def _get_axis_coord(
176192 results : Set = set ()
177193 for coord in search_in :
178194 for criterion , valid_values in coordinate_criteria .items ():
179- if key in valid_values : # type: ignore
180- expected = valid_values [key ] # type: ignore
195+ if key in valid_values :
196+ expected = valid_values [key ]
181197 if var .coords [coord ].attrs .get (criterion , None ) in expected :
182198 results .update ((coord ,))
183199
@@ -246,6 +262,31 @@ def _get_measure(
246262}
247263
248264
265+ def _filter_by_standard_names (ds : xr .Dataset , name : Union [str , List [str ]]) -> List [str ]:
266+ """ returns a list of variable names with standard names matching name. """
267+ if isinstance (name , str ):
268+ name = [name ]
269+
270+ varnames = []
271+ counts = dict .fromkeys (name , 0 )
272+ for vname , var in ds .variables .items ():
273+ stdname = var .attrs .get ("standard_name" , None )
274+ if stdname in name :
275+ varnames .append (str (vname ))
276+ counts [stdname ] += 1
277+
278+ return varnames
279+
280+
281+ def _get_list_standard_names (obj : xr .Dataset ) -> List [str ]:
282+ """ Returns a sorted list of standard names in Dataset. """
283+ names = []
284+ for k , v in obj .variables .items ():
285+ if "standard_name" in v .attrs :
286+ names .append (v .attrs ["standard_name" ])
287+ return sorted (names )
288+
289+
249290def _getattr (
250291 obj : Union [DataArray , Dataset ],
251292 attr : str ,
@@ -503,6 +544,16 @@ def _describe(self):
503544 text += f"\t { measure } : unsupported\n "
504545 else :
505546 text += f"\t { measure } : { _get_measure (self ._obj , measure , error = False , default = None )} \n "
547+
548+ text += "\n Standard Names:\n "
549+ if isinstance (self ._obj , xr .DataArray ):
550+ text += "\t unsupported\n "
551+ else :
552+ stdnames = _get_list_standard_names (self ._obj )
553+ text += "\t "
554+ text += "\n " .join (
555+ textwrap .wrap (f"{ stdnames !r} " , 70 , break_long_words = False )
556+ )
506557 return text
507558
508559 def describe (self ):
@@ -529,32 +580,96 @@ def get_valid_keys(self) -> Set[str]:
529580 ]
530581 if measures :
531582 varnames .append (* measures )
583+
584+ if not isinstance (self ._obj , xr .DataArray ):
585+ varnames .extend (_get_list_standard_names (self ._obj ))
532586 return set (varnames )
533587
588+ def __getitem__ (self , key : Union [str , List [str ]]):
589+
590+ kind = str (type (self ._obj ).__name__ )
591+ scalar_key = isinstance (key , str )
592+ if scalar_key :
593+ key = (key ,) # type: ignore
594+
595+ varnames : List [Hashable ] = []
596+ coords : List [Hashable ] = []
597+ successful = dict .fromkeys (key , False )
598+ for k in key :
599+ if k in _AXIS_NAMES + _COORD_NAMES :
600+ names = _get_axis_coord (self ._obj , k )
601+ successful [k ] = bool (names )
602+ varnames .extend (_strip_none_list (names ))
603+ coords .extend (_strip_none_list (names ))
604+ elif k in _CELL_MEASURES :
605+ if isinstance (self ._obj , xr .Dataset ):
606+ raise NotImplementedError (
607+ "Invalid key {k!r}. Cell measures not implemented for Dataset yet."
608+ )
609+ else :
610+ measure = _get_measure (self ._obj , k )
611+ successful [k ] = bool (measure )
612+ if measure :
613+ varnames .append (measure )
614+ elif not isinstance (self ._obj , xr .DataArray ):
615+ stdnames = _filter_by_standard_names (self ._obj , k )
616+ successful [k ] = bool (stdnames )
617+ varnames .extend (stdnames )
618+ coords .extend (list (set (stdnames ).intersection (set (self ._obj .coords ))))
619+
620+ # these are not special names but could be variable names in underlying object
621+ # we allow this so that we can return variables with appropriate CF auxiliary variables
622+ varnames .extend ([k for k , v in successful .items () if not v ])
623+ assert len (varnames ) > 0
624+
625+ try :
626+ # TODO: make this a get_auxiliary_variables function
627+ # make sure to set coordinate variables referred to in "coordinates" attribute
628+ for name in varnames :
629+ attrs = self ._obj [name ].attrs
630+ if "coordinates" in attrs :
631+ coords .extend (attrs .get ("coordinates" ).split (" " ))
632+
633+ if "cell_measures" in attrs :
634+ measures = [
635+ _get_measure (self ._obj [name ], measure )
636+ for measure in _CELL_MEASURES
637+ if measure in attrs ["cell_measures" ]
638+ ]
639+ coords .extend (_strip_none_list (measures ))
640+
641+ varnames .extend (coords )
642+ if isinstance (self ._obj , xr .DataArray ):
643+ ds = self ._obj ._to_temp_dataset ()
644+ else :
645+ ds = self ._obj
646+ ds = ds .reset_coords ()[varnames ]
647+ if isinstance (self ._obj , DataArray ):
648+ if scalar_key and len (ds .variables ) == 1 :
649+ # single dimension coordinates
650+ return ds [list (ds .variables .keys ())[0 ]].squeeze (drop = True )
651+ elif scalar_key and len (ds .coords ) > 1 :
652+ raise NotImplementedError (
653+ "Not sure what to return when given scalar key for DataArray and it has multiple values. "
654+ "Please open an issue."
655+ )
656+ elif not scalar_key :
657+ return ds .set_coords (coords )
658+ else :
659+ return ds .set_coords (coords )
660+
661+ except KeyError :
662+ raise KeyError (
663+ f"{ kind } .cf does not understand the key { k !r} . "
664+ f"Use { kind } .cf.describe() to see a list of key names that can be interpreted."
665+ )
666+
534667
535668@xr .register_dataset_accessor ("cf" )
536669class CFDatasetAccessor (CFAccessor ):
537- def __getitem__ (self , key ):
538- if key in _AXIS_NAMES + _COORD_NAMES :
539- varnames = _get_axis_coord (self ._obj , key )
540- return self ._obj .reset_coords ()[varnames ].set_coords (varnames )
541- elif key in _CELL_MEASURES :
542- raise NotImplementedError ("measures not implemented for Dataset yet." )
543- else :
544- raise KeyError (
545- f"Dataset.cf does not understand the key { key !r} . Use Dataset.cf.describe() to see a list of key names that can be interpreted."
546- )
670+ pass
547671
548672
549673@xr .register_dataarray_accessor ("cf" )
550674class CFDataArrayAccessor (CFAccessor ):
551- def __getitem__ (self , key ):
552- if key in _AXIS_NAMES + _COORD_NAMES :
553- varname = _get_axis_coord_single (self ._obj , key )
554- return self ._obj [varname ].reset_coords (drop = True )
555- elif key in _CELL_MEASURES :
556- return self ._obj [_get_measure (self ._obj , key )]
557- else :
558- raise KeyError (
559- f"DataArray.cf does not understand the key { key !r} . Use DataArray.cf.describe() to see a list of key names that can be interpreted."
560- )
675+ pass
0 commit comments