Skip to content

Commit d454cb4

Browse files
authored
getitem support for standard_name (#39)
* Update pre-commit * Rework getitem for standard_name support * Add datasets.py
1 parent 34384de commit d454cb4

File tree

4 files changed

+226
-70
lines changed

4 files changed

+226
-70
lines changed

.pre-commit-config.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,15 @@ repos:
88
files: .+\.py$
99
# https://github.com/python/black#version-control-integration
1010
- repo: https://github.com/python/black
11-
rev: stable
11+
rev: 19.10b0
1212
hooks:
1313
- id: black
1414
- repo: https://gitlab.com/pycqa/flake8
15-
rev: 3.7.9
15+
rev: 3.8.3
1616
hooks:
1717
- id: flake8
1818
- repo: https://github.com/pre-commit/mirrors-mypy
19-
rev: v0.761 # Must match ci/requirements/*.yml
19+
rev: v0.781 # Must match ci/requirements/*.yml
2020
hooks:
2121
- id: mypy
2222
# run this occasionally, ref discussion https://github.com/pydata/xarray/pull/3194

cf_xarray/accessor.py

Lines changed: 138 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,20 @@
11
import functools
22
import inspect
33
import itertools
4+
import textwrap
45
from collections import ChainMap
56
from contextlib import suppress
6-
from typing import Callable, List, Mapping, MutableMapping, Optional, Set, Tuple, Union
7+
from typing import (
8+
Callable,
9+
Hashable,
10+
List,
11+
Mapping,
12+
MutableMapping,
13+
Optional,
14+
Set,
15+
Tuple,
16+
Union,
17+
)
718

819
import xarray as xr
920
from xarray import DataArray, Dataset
@@ -106,6 +117,11 @@
106117
]
107118

108119

120+
def _strip_none_list(lst: List[Optional[str]]) -> List[str]:
121+
""" The mappers can return [None]. Strip that when necessary. Keeps mypy happy."""
122+
return [item for item in lst if item != [None]] # type: ignore
123+
124+
109125
def _get_axis_coord_single(
110126
var: Union[xr.DataArray, xr.Dataset],
111127
key: str,
@@ -176,8 +192,8 @@ def _get_axis_coord(
176192
results: Set = set()
177193
for coord in search_in:
178194
for criterion, valid_values in coordinate_criteria.items():
179-
if key in valid_values: # type: ignore
180-
expected = valid_values[key] # type: ignore
195+
if key in valid_values:
196+
expected = valid_values[key]
181197
if var.coords[coord].attrs.get(criterion, None) in expected:
182198
results.update((coord,))
183199

@@ -246,6 +262,31 @@ def _get_measure(
246262
}
247263

248264

265+
def _filter_by_standard_names(ds: xr.Dataset, name: Union[str, List[str]]) -> List[str]:
266+
""" returns a list of variable names with standard names matching name. """
267+
if isinstance(name, str):
268+
name = [name]
269+
270+
varnames = []
271+
counts = dict.fromkeys(name, 0)
272+
for vname, var in ds.variables.items():
273+
stdname = var.attrs.get("standard_name", None)
274+
if stdname in name:
275+
varnames.append(str(vname))
276+
counts[stdname] += 1
277+
278+
return varnames
279+
280+
281+
def _get_list_standard_names(obj: xr.Dataset) -> List[str]:
282+
""" Returns a sorted list of standard names in Dataset. """
283+
names = []
284+
for k, v in obj.variables.items():
285+
if "standard_name" in v.attrs:
286+
names.append(v.attrs["standard_name"])
287+
return sorted(names)
288+
289+
249290
def _getattr(
250291
obj: Union[DataArray, Dataset],
251292
attr: str,
@@ -503,6 +544,16 @@ def _describe(self):
503544
text += f"\t{measure}: unsupported\n"
504545
else:
505546
text += f"\t{measure}: {_get_measure(self._obj, measure, error=False, default=None)}\n"
547+
548+
text += "\nStandard Names:\n"
549+
if isinstance(self._obj, xr.DataArray):
550+
text += "\tunsupported\n"
551+
else:
552+
stdnames = _get_list_standard_names(self._obj)
553+
text += "\t"
554+
text += "\n".join(
555+
textwrap.wrap(f"{stdnames!r}", 70, break_long_words=False)
556+
)
506557
return text
507558

508559
def describe(self):
@@ -529,32 +580,96 @@ def get_valid_keys(self) -> Set[str]:
529580
]
530581
if measures:
531582
varnames.append(*measures)
583+
584+
if not isinstance(self._obj, xr.DataArray):
585+
varnames.extend(_get_list_standard_names(self._obj))
532586
return set(varnames)
533587

588+
def __getitem__(self, key: Union[str, List[str]]):
589+
590+
kind = str(type(self._obj).__name__)
591+
scalar_key = isinstance(key, str)
592+
if scalar_key:
593+
key = (key,) # type: ignore
594+
595+
varnames: List[Hashable] = []
596+
coords: List[Hashable] = []
597+
successful = dict.fromkeys(key, False)
598+
for k in key:
599+
if k in _AXIS_NAMES + _COORD_NAMES:
600+
names = _get_axis_coord(self._obj, k)
601+
successful[k] = bool(names)
602+
varnames.extend(_strip_none_list(names))
603+
coords.extend(_strip_none_list(names))
604+
elif k in _CELL_MEASURES:
605+
if isinstance(self._obj, xr.Dataset):
606+
raise NotImplementedError(
607+
"Invalid key {k!r}. Cell measures not implemented for Dataset yet."
608+
)
609+
else:
610+
measure = _get_measure(self._obj, k)
611+
successful[k] = bool(measure)
612+
if measure:
613+
varnames.append(measure)
614+
elif not isinstance(self._obj, xr.DataArray):
615+
stdnames = _filter_by_standard_names(self._obj, k)
616+
successful[k] = bool(stdnames)
617+
varnames.extend(stdnames)
618+
coords.extend(list(set(stdnames).intersection(set(self._obj.coords))))
619+
620+
# these are not special names but could be variable names in underlying object
621+
# we allow this so that we can return variables with appropriate CF auxiliary variables
622+
varnames.extend([k for k, v in successful.items() if not v])
623+
assert len(varnames) > 0
624+
625+
try:
626+
# TODO: make this a get_auxiliary_variables function
627+
# make sure to set coordinate variables referred to in "coordinates" attribute
628+
for name in varnames:
629+
attrs = self._obj[name].attrs
630+
if "coordinates" in attrs:
631+
coords.extend(attrs.get("coordinates").split(" "))
632+
633+
if "cell_measures" in attrs:
634+
measures = [
635+
_get_measure(self._obj[name], measure)
636+
for measure in _CELL_MEASURES
637+
if measure in attrs["cell_measures"]
638+
]
639+
coords.extend(_strip_none_list(measures))
640+
641+
varnames.extend(coords)
642+
if isinstance(self._obj, xr.DataArray):
643+
ds = self._obj._to_temp_dataset()
644+
else:
645+
ds = self._obj
646+
ds = ds.reset_coords()[varnames]
647+
if isinstance(self._obj, DataArray):
648+
if scalar_key and len(ds.variables) == 1:
649+
# single dimension coordinates
650+
return ds[list(ds.variables.keys())[0]].squeeze(drop=True)
651+
elif scalar_key and len(ds.coords) > 1:
652+
raise NotImplementedError(
653+
"Not sure what to return when given scalar key for DataArray and it has multiple values. "
654+
"Please open an issue."
655+
)
656+
elif not scalar_key:
657+
return ds.set_coords(coords)
658+
else:
659+
return ds.set_coords(coords)
660+
661+
except KeyError:
662+
raise KeyError(
663+
f"{kind}.cf does not understand the key {k!r}. "
664+
f"Use {kind}.cf.describe() to see a list of key names that can be interpreted."
665+
)
666+
534667

535668
@xr.register_dataset_accessor("cf")
536669
class CFDatasetAccessor(CFAccessor):
537-
def __getitem__(self, key):
538-
if key in _AXIS_NAMES + _COORD_NAMES:
539-
varnames = _get_axis_coord(self._obj, key)
540-
return self._obj.reset_coords()[varnames].set_coords(varnames)
541-
elif key in _CELL_MEASURES:
542-
raise NotImplementedError("measures not implemented for Dataset yet.")
543-
else:
544-
raise KeyError(
545-
f"Dataset.cf does not understand the key {key!r}. Use Dataset.cf.describe() to see a list of key names that can be interpreted."
546-
)
670+
pass
547671

548672

549673
@xr.register_dataarray_accessor("cf")
550674
class CFDataArrayAccessor(CFAccessor):
551-
def __getitem__(self, key):
552-
if key in _AXIS_NAMES + _COORD_NAMES:
553-
varname = _get_axis_coord_single(self._obj, key)
554-
return self._obj[varname].reset_coords(drop=True)
555-
elif key in _CELL_MEASURES:
556-
return self._obj[_get_measure(self._obj, key)]
557-
else:
558-
raise KeyError(
559-
f"DataArray.cf does not understand the key {key!r}. Use DataArray.cf.describe() to see a list of key names that can be interpreted."
560-
)
675+
pass

cf_xarray/tests/datasets.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import numpy as np
2+
import xarray as xr
3+
4+
airds = xr.tutorial.open_dataset("air_temperature").isel(time=slice(4), lon=slice(50))
5+
airds.air.attrs["cell_measures"] = "area: cell_area"
6+
airds.air.attrs["standard_name"] = "air_temperature"
7+
airds.coords["cell_area"] = (
8+
xr.DataArray(np.cos(airds.lat * np.pi / 180))
9+
* xr.ones_like(airds.lon)
10+
* 105e3
11+
* 110e3
12+
)
13+
14+
ds_no_attrs = airds.copy(deep=True)
15+
for variable in ds_no_attrs.variables:
16+
ds_no_attrs[variable].attrs = {}
17+
18+
19+
popds = xr.Dataset()
20+
popds.coords["TLONG"] = (
21+
("nlat", "nlon"),
22+
np.ones((20, 30)),
23+
{"axis": "X", "units": "degrees_east"},
24+
)
25+
popds.coords["TLAT"] = (
26+
("nlat", "nlon"),
27+
2 * np.ones((20, 30)),
28+
{"axis": "Y", "units": "degrees_north"},
29+
)
30+
popds.coords["ULONG"] = (
31+
("nlat", "nlon"),
32+
0.5 * np.ones((20, 30)),
33+
{"axis": "X", "units": "degrees_east"},
34+
)
35+
popds.coords["ULAT"] = (
36+
("nlat", "nlon"),
37+
2.5 * np.ones((20, 30)),
38+
{"axis": "Y", "units": "degrees_north"},
39+
)
40+
popds["UVEL"] = (
41+
("nlat", "nlon"),
42+
np.ones((20, 30)) * 15,
43+
{"coordinates": "ULONG ULAT", "standard_name": "sea_water_x_velocity"},
44+
)
45+
popds["TEMP"] = (
46+
("nlat", "nlon"),
47+
np.ones((20, 30)) * 15,
48+
{"coordinates": "TLONG TLAT", "standard_name": "sea_water_potential_temperature"},
49+
)

0 commit comments

Comments
 (0)