Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ dependencies = [
"numpy",
"jsonschema>=2.6.0",
"pydantic>=2.7.1",
"mat3ra-esse>=2025.7.1-0",
"mat3ra-utils>=2024.5.15.post0",
"mat3ra-esse",
"mat3ra-utils"
]

[project.optional-dependencies]
Expand Down
28 changes: 21 additions & 7 deletions src/py/mat3ra/code/entity.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
from typing import Any, Dict, List, Optional, Type, TypeVar

from pydantic import BaseModel, ConfigDict
from mat3ra.utils.object import filter_out_none_values
from pydantic import AliasGenerator, BaseModel, ConfigDict
from pydantic.alias_generators import to_snake
from typing_extensions import Self

Expand Down Expand Up @@ -53,11 +54,19 @@ def get_data_model(self) -> Type[B]:
def get_cls_name(self) -> str:
return self.__class__.__name__

def to_dict(self, exclude: Optional[List[str]] = None) -> Dict[str, Any]:
return self.model_dump(mode="json", exclude=set(exclude) if exclude else None)
def to_dict(
self, exclude: Optional[List[str]] = None, keep_as_none: Optional[List[str]] = None
) -> Dict[str, Any]:
data = self.model_dump(
mode="json",
exclude=set(exclude) if exclude else None,
by_alias=True,
exclude_none=False,
)
return filter_out_none_values(data, keep_as_none=keep_as_none)

def to_json(self, exclude: Optional[List[str]] = None) -> str:
return self.model_dump_json(exclude=set(exclude) if exclude else None)
def to_json(self, exclude: Optional[List[str]] = None, keep_as_none: Optional[List[str]] = None) -> str:
return json.dumps(self.to_dict(exclude=exclude, keep_as_none=keep_as_none))

def clone(self: T, extra_context: Optional[Dict[str, Any]] = None, deep=True) -> T:
return self.model_copy(update=extra_context or {}, deep=deep)
Expand All @@ -66,12 +75,17 @@ def clone(self: T, extra_context: Optional[Dict[str, Any]] = None, deep=True) ->
class InMemoryEntitySnakeCase(InMemoryEntityPydantic):
model_config = ConfigDict(
arbitrary_types_allowed=True,
extra='allow',
# Generate snake_case aliases for all fields (e.g. myField -> my_field)
alias_generator=to_snake,
alias_generator=AliasGenerator(validation_alias=to_snake, serialization_alias=lambda field_name: field_name),
# Allow populating fields using either the original name or the snake_case alias
populate_by_name=True,
)

def __init__(self, **data: Any) -> None:
"""Initialize with explicit **data to avoid parameter ordering issues in multiple inheritance."""
super().__init__(**data)

@staticmethod
def _create_property_from_camel_case(camel_name: str):
def getter(self):
Expand Down
118 changes: 98 additions & 20 deletions tests/py/unit/test_entity.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json

import pytest
from pydantic import BaseModel, Field

from . import (
CAMEL_CASE_CONFIG,
Expand Down Expand Up @@ -28,6 +29,81 @@
SampleEntityWithEnum,
SnakeCaseEntity,
)
from mat3ra.code.entity import InMemoryEntitySnakeCase

ID_ALIAS = "_id"
ID_VALUE = "workflow_1"
EXPECTED_ID_OUTPUT = {ID_ALIAS: ID_VALUE}

BASE_APPLICATION_NAME = "espresso"
KEEP_AS_NONE_APPLICATION_VERSION = ["applicationVersion"]
EXPECTED_DEFAULT_NONE_OUTPUT = {"applicationName": BASE_APPLICATION_NAME}
EXPECTED_KEEP_AS_NONE_OUTPUT = {
"applicationName": BASE_APPLICATION_NAME,
"applicationVersion": None,
}

EXAMPLE_ENTITY_EXCLUDE_KEY2 = {"exclude": ["key2"]}
EXAMPLE_ENTITY_EXCLUDE_KEY2_OUTPUT = {"key1": "value1"}

SAMPLE_ENUM_ENTITY_OUTPUT = {"type": "value1", "name": "example"}
TYPE_KEY = "type"

KEEP_AS_NONE_KWARGS = {"keep_as_none": KEEP_AS_NONE_APPLICATION_VERSION}


class BaseIdSchema(BaseModel):
id: str = Field(alias=ID_ALIAS)


class BaseIdEntity(BaseIdSchema, InMemoryEntitySnakeCase):
pass


class ChildIdEntity(BaseIdEntity):
id: str = Field(alias=ID_ALIAS)


def _create_example_entity() -> ExampleClass:
return ExampleClass.create(REFERENCE_OBJECT_VALID)


def _create_sample_enum_entity() -> SampleEntityWithEnum:
return SampleEntityWithEnum(type=SampleEnum.VALUE1, name="example")


def _create_base_id_entity() -> BaseIdEntity:
return BaseIdEntity(**{ID_ALIAS: ID_VALUE})


def _create_child_id_entity() -> ChildIdEntity:
return ChildIdEntity(**{ID_ALIAS: ID_VALUE})


def _create_snake_case_entity_with_nones() -> SnakeCaseEntity:
return SnakeCaseEntity(applicationName=BASE_APPLICATION_NAME, applicationVersion=None, executableName=None)


TO_DICT_CASES = [
(_create_example_entity, {}, {"key1": "value1", "key2": 1}, {}, {}, "example_entity"),
(_create_example_entity, EXAMPLE_ENTITY_EXCLUDE_KEY2, EXAMPLE_ENTITY_EXCLUDE_KEY2_OUTPUT, {}, {},
"example_entity_exclude"),
(_create_sample_enum_entity, {}, SAMPLE_ENUM_ENTITY_OUTPUT, {}, {TYPE_KEY: str}, "enum_entity"),
(_create_base_id_entity, {}, EXPECTED_ID_OUTPUT, {"id": ID_VALUE}, {}, "base_id_entity"),
(_create_child_id_entity, {}, EXPECTED_ID_OUTPUT, {"id": ID_VALUE}, {}, "child_id_entity"),
(_create_snake_case_entity_with_nones, {}, EXPECTED_DEFAULT_NONE_OUTPUT, {}, {}, "snake_case_default_none"),
(_create_snake_case_entity_with_nones, KEEP_AS_NONE_KWARGS, EXPECTED_KEEP_AS_NONE_OUTPUT, {}, {},
"snake_case_keep_as_none"),
]

TO_JSON_CASES = [
(_create_example_entity, {}, json.loads(REFERENCE_OBJECT_VALID_JSON), "example_entity"),
(_create_base_id_entity, {}, EXPECTED_ID_OUTPUT, "base_id_entity"),
(_create_child_id_entity, {}, EXPECTED_ID_OUTPUT, "child_id_entity"),
(_create_snake_case_entity_with_nones, {}, EXPECTED_DEFAULT_NONE_OUTPUT, "snake_case_default_none"),
(_create_snake_case_entity_with_nones, KEEP_AS_NONE_KWARGS, EXPECTED_KEEP_AS_NONE_OUTPUT,
"snake_case_keep_as_none"),
]


def test_create():
Expand Down Expand Up @@ -163,32 +239,34 @@ def test_get_cls_name():
assert ExampleClass.__name__ == "ExampleClass"


def test_to_dict():
entity = ExampleClass.create(REFERENCE_OBJECT_VALID)
# Test to_dict method
result = entity.to_dict()
@pytest.mark.parametrize(
"entity_factory,to_dict_kwargs,expected_output,expected_attrs,expected_types,_case_id",
TO_DICT_CASES,
ids=[case[-1] for case in TO_DICT_CASES],
)
def test_to_dict(entity_factory, to_dict_kwargs, expected_output, expected_attrs, expected_types, _case_id):
entity = entity_factory()
result = entity.to_dict(**to_dict_kwargs)
assert isinstance(result, dict)
assert result == {"key1": "value1", "key2": 1}
# Test with exclude
result_exclude = entity.to_dict(exclude=["key2"])
assert result_exclude == {"key1": "value1"}
assert result == expected_output

for attr_name, expected_value in expected_attrs.items():
assert getattr(entity, attr_name) == expected_value

def test_to_dict_with_enum():
entity = SampleEntityWithEnum(type=SampleEnum.VALUE1, name="example")
result = entity.to_dict()

assert isinstance(result, dict)
assert not isinstance(result["type"], SampleEnum) # Should not be an enum object
assert result == {"type": "value1", "name": "example"}
assert result["type"] == "value1" # String, not enum object
for key, expected_type in expected_types.items():
assert isinstance(result[key], expected_type)


def test_to_json():
entity = ExampleClass.create(REFERENCE_OBJECT_VALID)
result = entity.to_json()
@pytest.mark.parametrize(
"entity_factory,to_json_kwargs,expected_output,_case_id",
TO_JSON_CASES,
ids=[case[-1] for case in TO_JSON_CASES],
)
def test_to_json(entity_factory, to_json_kwargs, expected_output, _case_id):
entity = entity_factory()
result = entity.to_json(**to_json_kwargs)
assert isinstance(result, str)
assert json.loads(result) == json.loads(REFERENCE_OBJECT_VALID_JSON)
assert json.loads(result) == expected_output


def test_clone():
Expand Down
Loading