Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 57 additions & 3 deletions src/mlwp_data_specs/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import TypeVar

import xarray as xr
from loguru import logger

from mlwp_data_specs.specs.reporting import ValidationReport
from mlwp_data_specs.specs.traits.spatial_coordinate import Space
Expand All @@ -21,6 +22,12 @@
validate_dataset as validate_uncertainty,
)

_TRAIT_ATTR_FORMAT = "mlwp_{}_trait"

TIME_TRAIT_ATTR = _TRAIT_ATTR_FORMAT.format("time")
SPACE_TRAIT_ATTR = _TRAIT_ATTR_FORMAT.format("space")
UNCERTAINTY_TRAIT_ATTR = _TRAIT_ATTR_FORMAT.format("uncertainty")

EnumType = TypeVar("EnumType", bound=Enum)


Expand Down Expand Up @@ -61,6 +68,53 @@ def _coerce_enum(
) from exc


def _resolve_trait(
ds: xr.Dataset,
arg_value: EnumType | str | None,
enum_cls: type[EnumType],
) -> EnumType | None:
"""Resolve a trait value from argument or dataset attributes.

Parameters
----------
ds : xr.Dataset
The dataset.
arg_value : EnumType | str | None
The trait value passed as an argument.
enum_cls : type[EnumType]
The enum class for the trait.

Returns
-------
EnumType | None
The resolved trait value or ``None``.
"""
trait_name = enum_cls.__name__.lower()
attr_name = _TRAIT_ATTR_FORMAT.format(trait_name)

arg_trait = _coerce_enum(arg_value, enum_cls, trait_name)
attr_value = ds.attrs.get(attr_name)

if attr_value is None:
return arg_trait

try:
attr_trait = _coerce_enum(attr_value, enum_cls, f"attribute {attr_name}")
except ValueError as exc:
logger.warning(f"Invalid trait value in attribute '{attr_name}': {exc}")
return arg_trait

if arg_trait is not None and arg_trait != attr_trait:
logger.warning(
f"Provided {trait_name} trait '{arg_trait.value}' differs from "
f"dataset attribute '{attr_name}' ('{attr_trait.value}'). "
f"Using provided trait value '{arg_trait.value}'."
)
return arg_trait

return arg_trait if arg_trait is not None else attr_trait


def validate_dataset(
ds: xr.Dataset,
*,
Expand Down Expand Up @@ -100,9 +154,9 @@ def validate_dataset(

uncertainty_value = uncertainty if uncertainty is not None else uncertaity

time_trait = _coerce_enum(time, Time, "time")
space_trait = _coerce_enum(space, Space, "space")
uncertainty_trait = _coerce_enum(uncertainty_value, Uncertainty, "uncertainty")
time_trait = _resolve_trait(ds, time, Time)
space_trait = _resolve_trait(ds, space, Space)
uncertainty_trait = _resolve_trait(ds, uncertainty_value, Uncertainty)

if not any([time_trait, space_trait, uncertainty_trait]):
raise ValueError("At least one trait must be selected")
Expand Down
47 changes: 47 additions & 0 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@

from __future__ import annotations

from unittest.mock import patch

import pytest
import xarray as xr

from mlwp_data_specs import validate_dataset
from mlwp_data_specs.api import SPACE_TRAIT_ATTR, TIME_TRAIT_ATTR


def _forecast_grid_ds() -> xr.Dataset:
Expand Down Expand Up @@ -55,3 +58,47 @@ def test_validate_dataset_requires_trait() -> None:
"""API raises when no traits are selected."""
with pytest.raises(ValueError, match="At least one trait"):
validate_dataset(_forecast_grid_ds())


def test_validate_dataset_from_attributes() -> None:
"""Traits can be loaded from global dataset attributes."""
ds = _forecast_grid_ds()
ds.attrs[TIME_TRAIT_ATTR] = "forecast"
ds.attrs[SPACE_TRAIT_ATTR] = "grid"

report = validate_dataset(ds)
assert not report.has_fails()


@patch("mlwp_data_specs.api.logger")
def test_validate_dataset_attribute_mismatch_warning(mock_logger) -> None:
"""Mismatch between provided argument and attribute logs a warning."""
ds = _forecast_grid_ds()
# The dataset has forecast coords, but we put "observation" in the attribute
ds.attrs[TIME_TRAIT_ATTR] = "observation"

# We pass time="forecast" to override the attribute
report = validate_dataset(ds, time="forecast", space="grid")

# Should not fail because "forecast" is used for validation
assert not report.has_fails()

# Check that a warning was emitted
mock_logger.warning.assert_called()
warning_msg = mock_logger.warning.call_args[0][0]
assert "Provided time trait 'forecast' differs" in warning_msg
assert "attribute 'mlwp_time_trait' ('observation')" in warning_msg


@patch("mlwp_data_specs.api.logger")
def test_validate_dataset_invalid_attribute_warning(mock_logger) -> None:
"""Invalid attribute value logs a warning and is ignored if valid arg provided."""
ds = _forecast_grid_ds()
ds.attrs[TIME_TRAIT_ATTR] = "invalid_time_trait"

report = validate_dataset(ds, time="forecast", space="grid")
assert not report.has_fails()

mock_logger.warning.assert_called()
warning_msg = mock_logger.warning.call_args[0][0]
assert "Invalid trait value in attribute 'mlwp_time_trait'" in warning_msg
Loading