Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 23 additions & 88 deletions src/py/mat3ra/code/entity.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,15 @@
from typing import Any, Dict, List, Optional, Type, TypeVar

import jsonschema
from mat3ra.utils import object as object_utils
from pydantic import BaseModel, ConfigDict
from pydantic.alias_generators import to_snake
from typing_extensions import Self

from . import BaseUnderscoreJsonPropsHandler
from .mixins import DefaultableMixin, HasDescriptionMixin, HasMetadataMixin, NamedMixin

T = TypeVar("T", bound="InMemoryEntityPydantic")
B = TypeVar("B", bound="BaseModel")


# TODO: remove in the next PR
class ValidationErrorCode:
IN_MEMORY_ENTITY_DATA_INVALID = "IN_MEMORY_ENTITY_DATA_INVALID"


# TODO: remove in the next PR
class ErrorDetails:
def __init__(self, error: Optional[Dict[str, Any]], json: Dict[str, Any], schema: Dict):
self.error = error
self.json = json
self.schema = schema


# TODO: remove in the next PR
class EntityError(Exception):
def __init__(self, code: ValidationErrorCode, details: Optional[ErrorDetails] = None):
super().__init__(code)
self.code = code
self.details = details


class InMemoryEntityPydantic(BaseModel):
model_config = {"arbitrary_types_allowed": True}

Expand Down Expand Up @@ -90,82 +66,41 @@ def clone(self: T, extra_context: Optional[Dict[str, Any]] = None, deep=True) ->
class InMemoryEntitySnakeCase(InMemoryEntityPydantic):
model_config = ConfigDict(
arbitrary_types_allowed=True,
# Generate snake_case aliases for all fields (e.g. myField -> my_field)
alias_generator=to_snake,
# Allow populating fields using either the original name or the snake_case alias
populate_by_name=True,
)

@staticmethod
def _create_property_from_camel_case(camel_name: str):
def getter(self):
return getattr(self, camel_name)

# TODO: remove in the next PR
class InMemoryEntity(BaseUnderscoreJsonPropsHandler):
jsonSchema: Optional[Dict] = None

@classmethod
def get_cls(cls) -> str:
return cls.__name__

@property
def cls(self) -> str:
return self.__class__.__name__

def get_cls_name(self) -> str:
return self.__class__.__name__

@classmethod
def create(cls, config: Dict[str, Any]) -> Any:
return cls(config)
def setter(self, value: Any):
setattr(self, camel_name, value)

def to_json(self, exclude: List[str] = []) -> Dict[str, Any]:
return self.clean(object_utils.clone_deep(object_utils.omit(self._json, exclude)))
return property(getter, setter)

def clone(self, extra_context: Dict[str, Any] = {}) -> Any:
config = self.to_json()
config.update(extra_context)
# To avoid:
# Argument 1 to "__init__" of "BaseUnderscoreJsonPropsHandler" has incompatible type "Dict[str, Any]";
# expected "BaseUnderscoreJsonPropsHandler"
return self.__class__(config)
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if not issubclass(cls, BaseModel):
return

@staticmethod
def validate_data(data: Dict[str, Any], clean: bool = False):
if clean:
print("Error: clean is not supported for InMemoryEntity.validateData")
if InMemoryEntity.jsonSchema:
jsonschema.validate(data, InMemoryEntity.jsonSchema)

def validate(self) -> None:
if self._json:
self.__class__.validate_data(self._json)

def clean(self, config: Dict[str, Any]) -> Dict[str, Any]:
# Not implemented, consider the below for the implementation
# https://stackoverflow.com/questions/44694835/remove-properties-from-json-object-not-present-in-schema
return config

def is_valid(self) -> bool:
try:
self.validate()
return True
except EntityError:
return False

# Properties
@property
def id(self) -> str:
return self.prop("_id", "")
model_fields = cls.model_fields
except Exception:
return

@id.setter
def id(self, id: str) -> None:
self.set_prop("_id", id)
for field_name, field_info in model_fields.items():
if field_name == to_snake(field_name):
continue

@property
def slug(self) -> str:
return self.prop("slug", "")
snake_case_name = to_snake(field_name)
if hasattr(cls, snake_case_name):
continue

def get_as_entity_reference(self, by_id_only: bool = False) -> Dict[str, str]:
if by_id_only:
return {"_id": self.id}
else:
return {"_id": self.id, "slug": self.slug, "cls": self.get_cls_name()}
setattr(cls, snake_case_name, cls._create_property_from_camel_case(field_name))


class HasDescriptionHasMetadataNamedDefaultableInMemoryEntityPydantic(
Expand Down
11 changes: 11 additions & 0 deletions tests/py/unit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,14 @@ class SnakeCaseEntity(CamelCaseSchema, InMemoryEntitySnakeCase):
"applicationVersion": "7.2",
"executable_name": "pw.x",
}


class AutoSnakeCaseTestSchema(BaseModel):
contextProviders: list = []
applicationName: str
applicationVersion: Optional[str] = None
executableName: Optional[str] = None


class AutoSnakeCaseTestEntity(AutoSnakeCaseTestSchema, InMemoryEntitySnakeCase):
pass
1 change: 0 additions & 1 deletion tests/py/unit/test_entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,4 +244,3 @@ def test_create_entity_snake_case(config, expected_output):

entity_from_create = SnakeCaseEntity.create(config)
assert entity_from_create.to_dict() == expected_output

91 changes: 91 additions & 0 deletions tests/py/unit/test_entity_snake_case.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import pytest
from mat3ra.utils import assertion
from . import AutoSnakeCaseTestEntity

BASE = {
"applicationName": "camelCasedValue",
"applicationVersion": "camelCasedVersion",
"executableName": "camelCasedExecutable",
"contextProviders": [],
}

INSTANTIATION = [
{"application_name": BASE["applicationName"], "application_version": BASE["applicationVersion"],
"executable_name": BASE["executableName"]},
{"applicationName": BASE["applicationName"], "applicationVersion": BASE["applicationVersion"],
"executableName": BASE["executableName"]},
{"application_name": BASE["applicationName"], "applicationVersion": BASE["applicationVersion"],
"executable_name": BASE["executableName"]},
]

UPDATES = [
(
{"application_name": "new_value", "context_providers": ["item_snake"]},
{"applicationName": "new_value", "contextProviders": ["item_snake"]},
{"application_name": "new_value", "context_providers": ["item_snake"]},
),
(
{"applicationName": "newValueCamel", "contextProviders": ["itemCamel"]},
{"applicationName": "newValueCamel", "contextProviders": ["itemCamel"]},
{"application_name": "newValueCamel", "context_providers": ["itemCamel"]},
),
(
{"application_name": "new_value_snake", "applicationVersion": "newVersionCamel"},
{"applicationName": "new_value_snake", "applicationVersion": "newVersionCamel"},
{"application_name": "new_value_snake", "application_version": "newVersionCamel"},
),
(
{"application_name": "new_val", "application_version": "new_version",
"executable_name": "new_exec", "context_providers": ["a", "b"]},
{"applicationName": "new_val", "applicationVersion": "new_version",
"executableName": "new_exec", "contextProviders": ["a", "b"]},
{"application_name": "new_val", "application_version": "new_version",
"executable_name": "new_exec", "context_providers": ["a", "b"]},
),
]


def camel(entity):
return dict(
applicationName=entity.applicationName,
applicationVersion=entity.applicationVersion,
executableName=entity.executableName,
contextProviders=entity.contextProviders,
)


def snake(entity):
return dict(
application_name=entity.application_name,
application_version=entity.application_version,
executable_name=entity.executable_name,
context_providers=entity.context_providers,
)


@pytest.mark.parametrize("cfg", INSTANTIATION)
def test_instantiation(cfg):
entity = AutoSnakeCaseTestEntity(**cfg)
assertion.assert_deep_almost_equal(BASE, camel(entity))
assertion.assert_deep_almost_equal(
dict(application_name=BASE["applicationName"],
application_version=BASE["applicationVersion"],
executable_name=BASE["executableName"],
context_providers=[]),
snake(entity),
)


@pytest.mark.parametrize("updates, exp_camel, exp_snake", UPDATES)
def test_updates(updates, exp_camel, exp_snake):
entity = AutoSnakeCaseTestEntity(**BASE)
for k, v in updates.items():
setattr(entity, k, v)
assertion.assert_deep_almost_equal({**BASE, **exp_camel}, camel(entity))
assertion.assert_deep_almost_equal(
{**snake(AutoSnakeCaseTestEntity(**BASE)), **exp_snake},
snake(entity),
)
out = entity.to_dict()
assertion.assert_deep_almost_equal({**BASE, **exp_camel}, out)
assert "application_name" not in out and "context_providers" not in out
Loading