diff --git a/src/databricks/labs/blueprint/installation.py b/src/databricks/labs/blueprint/installation.py index 5d028a2..7a2d22c 100644 --- a/src/databricks/labs/blueprint/installation.py +++ b/src/databricks/labs/blueprint/installation.py @@ -67,6 +67,8 @@ class Installation: T = TypeVar("T") _PRIMITIVES = (int, bool, float, str) + allow_raw_types = True + allow_weak_types = True def __init__(self, ws: WorkspaceClient, product: str, *, install_folder: str | None = None): """The `Installation` class constructor creates an `Installation` object for the given product in @@ -469,19 +471,27 @@ def _get_list_type_ref(inst: T) -> type[list[T]]: item_type = type(from_list[0]) # type: ignore[misc] return list[item_type] # type: ignore[valid-type] + # pylint: disable=too-complex def _marshal(self, type_ref: type, path: list[str], inst: Any) -> tuple[Any, bool]: """The `_marshal` method is a private method that is used to serialize an object of type `type_ref` to a dictionary. This method is called by the `save` method.""" if inst is None: - return None, False + none_allowed = type_ref is types.NoneType or (isinstance(type_ref, types.UnionType) and types.NoneType in get_args(type_ref)) + return None, none_allowed if isinstance(inst, databricks.sdk.core.Config): return self._marshal_databricks_config(inst) if hasattr(inst, "as_dict"): return inst.as_dict(), True if dataclasses.is_dataclass(type_ref): return self._marshal_dataclass(type_ref, path, inst) - if type_ref == list: - return self._marshal_list(type_ref, path, inst) + if self.allow_raw_types: + if type_ref == list: + return self._marshal_raw_list(path, inst) + if type_ref == dict: + return self._marshal_raw_dict(path, inst) + if self.allow_weak_types: + if type_ref in (object, Any): + return self._marshal(type(inst), path, inst) if isinstance(type_ref, enum.EnumMeta): return self._marshal_enum(inst) if type_ref == types.NoneType: @@ -523,8 +533,8 @@ def _marshal_generic(self, type_ref: type, path: list[str], inst: Any) -> tuple[ if not type_args: raise SerdeError(f"Missing type arguments: {type_args}") if len(type_args) == 2: - return self._marshal_dict(type_args[1], path, inst) - return self._marshal_list(type_args[0], path, inst) + return self._marshal_generic_dict(type_args[1], path, inst) + return self._marshal_generic_list(type_args[0], path, inst) @staticmethod def _marshal_generic_alias(type_ref, inst): @@ -534,21 +544,34 @@ def _marshal_generic_alias(type_ref, inst): return None, False return inst, isinstance(inst, type_ref.__origin__) # type: ignore[attr-defined] - def _marshal_list(self, type_ref: type, path: list[str], inst: Any) -> tuple[Any, bool]: - """The `_marshal_list` method is a private method that is used to serialize an object of type `type_ref` to - a dictionary. This method is called by the `save` method.""" + def _marshal_generic_list(self, type_ref: type, path: list[str], inst: Any) -> tuple[Any, bool]: + """The `_marshal_generic_list` method is a private method that is used to serialize an object of type list[type_ref] to + an array. This method is called by the `save` method.""" as_list = [] if not isinstance(inst, list): return None, False for i, v in enumerate(inst): value, ok = self._marshal(type_ref, [*path, f"{i}"], v) if not ok: - raise SerdeError(self._explain_why(type_ref, [*path, f"{i}"], v)) + raise SerdeError(self._explain_why(type(v), [*path, f"{i}"], v)) + as_list.append(value) + return as_list, True + + def _marshal_raw_list(self, path: list[str], inst: Any) -> tuple[Any, bool]: + """The `_marshal_raw_list` method is a private method that is used to serialize an object of type list to + an array. This method is called by the `save` method.""" + as_list = [] + if not isinstance(inst, list): + return None, False + for i, v in enumerate(inst): + value, ok = self._marshal(type(v), [*path, f"{i}"], v) + if not ok: + raise SerdeError(self._explain_why(type(v), [*path, f"{i}"], v)) as_list.append(value) return as_list, True - def _marshal_dict(self, type_ref: type, path: list[str], inst: Any) -> tuple[Any, bool]: - """The `_marshal_dict` method is a private method that is used to serialize an object of type `type_ref` to + def _marshal_generic_dict(self, type_ref: type, path: list[str], inst: Any) -> tuple[Any, bool]: + """The `_marshal_generic_dict` method is a private method that is used to serialize an object of type dict[str, type_ref] to a dictionary. This method is called by the `save` method.""" if not isinstance(inst, dict): return None, False @@ -556,7 +579,19 @@ def _marshal_dict(self, type_ref: type, path: list[str], inst: Any) -> tuple[Any for k, v in inst.items(): as_dict[k], ok = self._marshal(type_ref, [*path, k], v) if not ok: - raise SerdeError(self._explain_why(type_ref, [*path, k], v)) + raise SerdeError(self._explain_why(type(v), [*path, k], v)) + return as_dict, True + + def _marshal_raw_dict(self, path: list[str], inst: Any) -> tuple[Any, bool]: + """The `_marshal_raw_dict` method is a private method that is used to serialize an object of type dict to + a dictionary. This method is called by the `save` method.""" + if not isinstance(inst, dict): + return None, False + as_dict = {} + for k, v in inst.items(): + as_dict[k], ok = self._marshal(type(v), [*path, k], v) + if not ok: + raise SerdeError(self._explain_why(type(v), [*path, k], v)) return as_dict, True def _marshal_dataclass(self, type_ref: type, path: list[str], inst: Any) -> tuple[Any, bool]: @@ -616,6 +651,8 @@ def from_dict(cls, raw: dict): def _unmarshal(cls, inst: Any, path: list[str], type_ref: type[T]) -> T | None: """The `_unmarshal` method is a private method that is used to deserialize a dictionary to an object of type `type_ref`. This method is called by the `load` method.""" + if type_ref == types.NoneType: + return None if dataclasses.is_dataclass(type_ref): return cls._unmarshal_dataclass(inst, path, type_ref) if isinstance(type_ref, enum.EnumMeta): @@ -624,12 +661,14 @@ def _unmarshal(cls, inst: Any, path: list[str], type_ref: type[T]) -> T | None: return type_ref(inst) if type_ref in cls._PRIMITIVES: return cls._unmarshal_primitive(inst, type_ref) + if type_ref == list: + return cls._unmarshal_list(inst, path, Any) + if type_ref == dict: + return cls._unmarshal_dict(inst, path, Any) if type_ref == databricks.sdk.core.Config: if not inst: inst = {} return databricks.sdk.core.Config(**inst) # type: ignore[return-value] - if type_ref == types.NoneType: - return None if isinstance(type_ref, cls._FromDict): return type_ref.from_dict(inst) return cls._unmarshal_generic_types(type_ref, path, inst) @@ -646,8 +685,23 @@ def _unmarshal_generic_types(cls, type_ref, path, inst): return cls._unmarshal_union(inst, path, type_ref) if isinstance(type_ref, (_GenericAlias, types.GenericAlias)): return cls._unmarshal_generic(inst, path, type_ref) + if cls.allow_weak_types and type_ref in (object, Any): + return cls._unmarshal_object(inst, path) raise SerdeError(f'{".".join(path)}: unknown: {type_ref}: {inst}') + @classmethod + def _unmarshal_object(cls, inst, path): + if inst is None: + return None + if isinstance(inst, (bool, int, float, str)): + return cls._unmarshal_primitive(inst, type(inst)) + if cls.allow_raw_types: + if isinstance(inst, list): + return cls._unmarshal_list(inst, path, object) + if isinstance(inst, dict): + return cls._unmarshal_dict(inst, path, object) + raise SerdeError(f'{".".join(path)}: unknown: {type(inst)}: {inst}') + @classmethod def _unmarshal_dataclass(cls, inst, path, type_ref): """The `_unmarshal_dataclass` method is a private method that is used to deserialize a dictionary to an object @@ -682,9 +736,14 @@ def _unmarshal_union(cls, inst, path, type_ref): """The `_unmarshal_union` method is a private method that is used to deserialize a dictionary to an object of type `type_ref`. This method is called by the `load` method.""" for variant in get_args(type_ref): - value = cls._unmarshal(inst, path, variant) - if value: - return value + if variant == type(None) and inst is None: + return None + try: + value = cls._unmarshal(inst, path, variant) + if value is not None: + return value + except SerdeError: + pass return None @classmethod @@ -706,14 +765,16 @@ def _unmarshal_generic(cls, inst, path, type_ref): return cls._unmarshal_list(inst, path, type_args[0]) @classmethod - def _unmarshal_list(cls, inst, path, hint): - """The `_unmarshal_list` method is a private method that is used to deserialize a dictionary to an object + def _unmarshal_list(cls, inst, path, type_ref): + """The `_unmarshal_list` method is a private method that is used to deserialize an array to a list of type `type_ref`. This method is called by the `load` method.""" if inst is None: return None + if not isinstance(inst, list): + raise SerdeError(cls._explain_why(type_ref, path, inst)) as_list = [] for i, v in enumerate(inst): - as_list.append(cls._unmarshal(v, [*path, f"{i}"], hint)) + as_list.append(cls._unmarshal(v, [*path, f"{i}"], type_ref or type(v))) return as_list @classmethod @@ -733,10 +794,23 @@ def _unmarshal_dict(cls, inst, path, type_ref): def _unmarshal_primitive(cls, inst, type_ref): """The `_unmarshal_primitive` method is a private method that is used to deserialize a dictionary to an object of type `type_ref`. This method is called by the `load` method.""" - if not inst: + if inst is None: + return None + if isinstance(inst, type_ref): return inst - # convert from str to int if necessary - converted = type_ref(inst) # type: ignore[call-arg] + converted = inst + # convert from str + if isinstance(inst, str): + if type_ref in (int, float): + try: + converted = type_ref(inst) # type: ignore[call-arg] + except ValueError as exc: + raise SerdeError(f"Not a number {inst}!") from exc + elif type_ref == bool: + if inst.lower() == "true": + converted = True + elif inst.lower() == "false": + converted = False return converted @staticmethod @@ -745,7 +819,8 @@ def _explain_why(type_ref: type, path: list[str], raw: Any) -> str: type. This method is called by the `_unmarshal` and `_marshal` methods.""" if raw is None: raw = "value is missing" - return f'{".".join(path)}: not a {type_ref.__name__}: {raw}' + type_name = getattr(type_ref, "__name__", str(type_ref)) + return f'{".".join(path)}: not a {type_name}: {raw}' @staticmethod def _dump_json(as_dict: Json, _: type) -> bytes: diff --git a/src/databricks/labs/blueprint/paths.py b/src/databricks/labs/blueprint/paths.py index 6923b64..5197f9f 100644 --- a/src/databricks/labs/blueprint/paths.py +++ b/src/databricks/labs/blueprint/paths.py @@ -151,6 +151,7 @@ def __new__(cls, *args, **kwargs): # Force all initialisation to go via __init__() irrespective of the (Python-specific) base version. return object.__new__(cls) + # pylint: disable=super-init-not-called def __init__(self, ws: WorkspaceClient, *args: str | bytes | os.PathLike) -> None: # We deliberately do _not_ call the super initializer because we're taking over complete responsibility for the # implementation of the public API. @@ -398,6 +399,7 @@ def with_suffix(self: P, suffix: str) -> P: raise ValueError(msg) return self.with_name(stem + suffix) + # pylint: disable=arguments-differ def relative_to(self: P, *other: str | bytes | os.PathLike, walk_up: bool = False) -> P: normalized = self.with_segments(*other) if self.anchor != normalized.anchor: diff --git a/tests/unit/test_installation.py b/tests/unit/test_installation.py index e836e7a..aebe6f2 100644 --- a/tests/unit/test_installation.py +++ b/tests/unit/test_installation.py @@ -502,6 +502,30 @@ class SampleClass: assert loaded == saved +def test_generic_dict_object(): + @dataclass + class SampleClass: + field: dict[str, object] + + installation = MockInstallation() + saved = SampleClass(field={"a": ["x", "y"], "b": [], "c": 3, "d": True, "e": {"a": "b"}}) + installation.save(saved, filename="backups/SampleClass.json") + loaded = installation.load(SampleClass, filename="backups/SampleClass.json") + assert loaded == saved + + +def test_generic_dict_any(): + @dataclass + class SampleClass: + field: dict[str, typing.Any] + + installation = MockInstallation() + saved = SampleClass(field={"a": ["x", "y"], "b": [], "c": 3, "d": True, "e": {"a": "b"}}) + installation.save(saved, filename="backups/SampleClass.json") + loaded = installation.load(SampleClass, filename="backups/SampleClass.json") + assert loaded == saved + + def test_generic_list_str() -> None: @dataclass class SampleClass: @@ -548,3 +572,69 @@ class SampleClass: installation.save(saved, filename="backups/SampleClass.json") loaded = installation.load(SampleClass, filename="backups/SampleClass.json") assert loaded == saved + + +def test_generic_list_object(): + @dataclass + class SampleClass: + field: list[object] + + installation = MockInstallation() + saved = SampleClass(field=[["x", "y"], [], 3, True, {"a": "b"}]) + installation.save(saved, filename="backups/SampleClass.json") + loaded = installation.load(SampleClass, filename="backups/SampleClass.json") + assert loaded == saved + + +def test_generic_list_any(): + @dataclass + class SampleClass: + field: list[typing.Any] + + installation = MockInstallation() + saved = SampleClass(field=[["x", "y"], [], 3, True, {"a": "b"}]) + installation.save(saved, filename="backups/SampleClass.json") + loaded = installation.load(SampleClass, filename="backups/SampleClass.json") + assert loaded == saved + + +def test_bool_in_union(): + @dataclass + class SampleClass: + field: dict[str, bool | str] + + installation = MockInstallation() + saved = SampleClass(field={"a": "b"}) + installation.save(saved, filename="backups/SampleClass.json") + loaded = installation.load(SampleClass, filename="backups/SampleClass.json") + assert loaded == saved + + +JsonType: typing.TypeAlias = None | bool | int | float | str | list["JsonType"] | dict[str, "JsonType"] + + +def test_complex_union(): + @dataclass + class SampleClass: + field: dict[str, JsonType] + + installation = MockInstallation() + saved = SampleClass(field={"a": "b"}) + installation.save(saved, filename="backups/SampleClass.json") + loaded = installation.load(SampleClass, filename="backups/SampleClass.json") + assert loaded == saved + + +JsonType2: typing.TypeAlias = dict[str, "JsonType2"] | list["JsonType2"] | str | float | int | bool | None + + +def test_complex_union2(): + @dataclass + class SampleClass: + field: dict[str, JsonType2] + + installation = MockInstallation() + saved = SampleClass(field={"a": "b"}) + installation.save(saved, filename="backups/SampleClass.json") + loaded = installation.load(SampleClass, filename="backups/SampleClass.json") + assert loaded == saved diff --git a/tests/unit/test_installer.py b/tests/unit/test_installer.py index 0532174..e4c959e 100644 --- a/tests/unit/test_installer.py +++ b/tests/unit/test_installer.py @@ -31,7 +31,7 @@ def test_jobs_state(): state = InstallState(ws, "blueprint") - assert {"foo": "123"} == state.jobs + assert {"foo": 123} == state.jobs assert {} == state.dashboards ws.workspace.download.assert_called_with("/Users/foo/.blueprint/state.json") diff --git a/tests/unit/test_marshalling_scenarios.py b/tests/unit/test_marshalling_scenarios.py new file mode 100644 index 0000000..6670145 --- /dev/null +++ b/tests/unit/test_marshalling_scenarios.py @@ -0,0 +1,224 @@ +from abc import ABC +from dataclasses import dataclass +from enum import Enum +from typing import Any + +import pytest +from mypy.metastore import abstractmethod + +from databricks.labs.blueprint.installation import Installation, MockInstallation + +class TypeSupport(Enum): + STRICT = "STRICT" + RAW_TYPES = "RAW" + WEAK_TYPES = "WEAK" + +def set_type_support(type_support: TypeSupport): + if type_support == TypeSupport.STRICT: + Installation.allow_raw_types = False + Installation.allow_weak_types = False + elif type_support == TypeSupport.RAW_TYPES: + Installation.allow_raw_types = True + Installation.allow_weak_types = False + if type_support == TypeSupport.WEAK_TYPES: + Installation.allow_raw_types = True + Installation.allow_weak_types = True + +@pytest.mark.parametrize("type_support", [s for s in TypeSupport]) +def test_weak_typing_with_list(type_support) -> None: + set_type_support(type_support) + # this example corresponds to a frequent Python coding pattern + # where users don't specify the item type of a list + + @dataclass + class SampleClass: + field: list + + installation = MockInstallation() + saved = SampleClass(field=["a", 1, True]) + installation.save(saved, filename="backups/SampleClass.json") + loaded = installation.load(SampleClass, filename="backups/SampleClass.json") + assert loaded == saved + + +@pytest.mark.parametrize("type_support", [s for s in TypeSupport]) +def test_weak_typing_with_dict(type_support) -> None: + set_type_support(type_support) + # this example corresponds to a frequent Python coding pattern + # where users don't specify the key and item types of a dict + + @dataclass + class SampleClass: + field: dict + + installation = MockInstallation() + saved = SampleClass(field={"x": "a", "y": 1, "z": True}) + installation.save(saved, filename="backups/SampleClass.json") + loaded = installation.load(SampleClass, filename="backups/SampleClass.json") + assert loaded == saved + +@pytest.mark.parametrize("type_support", [s for s in TypeSupport]) +def test_progressive_typing_with_list(type_support) -> None: + set_type_support(type_support) + + # this example corresponds to a frequent Python coding pattern + # where users only specify the item type of a list once they need it + + @dataclass + class SampleClassV1: + field: list + + @dataclass + class SampleClassV2: + field: list[str] + + installation = MockInstallation() + saved = SampleClassV1(field=["a", "b", "c"]) + installation.save(saved, filename="backups/SampleClass.json") + # problem: can't directly use untyped item values + # loaded_v1 = installation.load(SampleClassV1, filename="backups/SampleClass.json") + # stuff = loaded_v1[0][1:2] + # so they've stored weakly typed data, and they need to read it as typed data + loaded = installation.load(SampleClassV2, filename="backups/SampleClass.json") + assert loaded == SampleClassV2(field=saved.field) + + +@pytest.mark.parametrize("type_support", [s for s in TypeSupport]) +def test_progressive_typing_with_dict(type_support) -> None: + set_type_support(type_support) + + # this example corresponds to a frequent Python coding pattern + # where users only specify the item type of a dict once they need it + + @dataclass + class SampleClassV1: + field: dict + + @dataclass + class SampleClassV2: + field: dict[str, str] + + installation = MockInstallation() + saved = SampleClassV1(field={"x": "abc", "y": "def", "z": "ghi"}) + installation.save(saved, filename="backups/SampleClass.json") + # problem: can't directly use untyped item values + # loaded_v1 = installation.load(SampleClassV1, filename="backups/SampleClass.json") + # stuff = loaded_v1["x"][1:2] + # so they've stored weakly typed data, and they need to read it as typed data + loaded = installation.load(SampleClassV2, filename="backups/SampleClass.json") + assert loaded == SampleClassV2(field=saved.field) + +@pytest.mark.parametrize("type_support", [s for s in TypeSupport]) +def test_type_migration(type_support) -> None: + set_type_support(type_support) + + # this example corresponds to a frequent Python coding scenario + # where users change their mind about a type + + @dataclass + class SampleClassV1: + field: list[str] + + @dataclass + class SampleClassV2: + field: list[int | None] + + installation = MockInstallation() + saved = SampleClassV1(field=["1", "2", ""]) + installation.save(saved, filename="backups/SampleClass.json") + # problem: can't directly convert an item value + # loaded_v1 = installation.load(SampleClassV2, filename="backups/SampleClass.json") + # so they've stored strings, and they need to read ints + converted = SampleClassV2(field=[(int(val) if val else None) for val in saved.field]) + loaded = installation.load(SampleClassV2, filename="backups/SampleClass.json") + assert loaded == converted + +@pytest.mark.parametrize("type_support", [s for s in TypeSupport]) +def test_lost_code_with_list(type_support) -> None: + set_type_support(type_support) + # this example corresponds to a scenario where data was stored + # using code that is no longer available + + @dataclass + class LostSampleClass: + field: list[str] + + # we don't know the type of 'field' + # so we'll use code to restore the data + @dataclass + class RecoverySampleClass: + field: object + + installation = MockInstallation() + saved = LostSampleClass(field=["a", "b", "c"]) + installation.save(saved, filename="backups/SampleClass.json") + # problem: we don't know how SampleClass.json was stored + # so we're loading the data as weakly typed + loaded = installation.load(RecoverySampleClass, filename="backups/SampleClass.json") + assert loaded.field == saved.field + + +@pytest.mark.parametrize("type_support", [s for s in TypeSupport]) +def test_dynamic_config_data(type_support) -> None: + set_type_support(type_support) + # this example corresponds to a scenario where we store data provided + # by some object, without a schema for it + + class AbstractDriver(ABC): + + @abstractmethod + def get_config_data(self) -> object: ... + + class XDriver(AbstractDriver): + def get_config_data(self) -> object: + return "oracle:jdbc:thin://my_login:my_password@myserver:2312" + + class YDriver(AbstractDriver): + def get_config_data(self) -> object: + return { + "login": "my_login", + "password": "my_password", + "host": "myserver", + "port": 2312 + } + + @dataclass + class SampleClass: + driver_class: str + driver_config: object + + installation = MockInstallation() + saved_x = SampleClass(driver_class=type(XDriver).__name__, driver_config=XDriver().get_config_data()) + installation.save(saved_x, filename="backups/SampleDriverX.json") + saved_y = SampleClass(driver_class=type(YDriver).__name__, driver_config=YDriver().get_config_data()) + installation.save(saved_y, filename="backups/SampleDriverY.json") + loaded_x = installation.load(SampleClass, filename="backups/SampleDriverX.json") + assert loaded_x == saved_x + loaded_y = installation.load(SampleClass, filename="backups/SampleDriverY.json") + assert loaded_y == saved_y + + + +@pytest.mark.parametrize("type_support", [s for s in TypeSupport]) +def test_lost_code_with_any(type_support) -> None: + set_type_support(type_support) + # this example corresponds to a scenario where data was stored + # using code that is no longer available + + @dataclass + class LostSampleClass: + field: dict[str, str | int | None] + + # we don't know the type of 'field' + # so we'll use code to restore the data + @dataclass + class RecoverySampleClass: + field: Any + + installation = MockInstallation() + saved = LostSampleClass(field={"a": "b", "b": 2, "c": None}) + installation.save(saved, filename="backups/SampleClass.json") + # problem: we don't know how SampleClass.json was stored + # so we're loading the data as weakly typed + loaded = installation.load(RecoverySampleClass, filename="backups/SampleClass.json") + assert loaded.field == saved.field