Skip to content

Commit f47e533

Browse files
✨ add typed accessors for inference fields
1 parent 424c942 commit f47e533

3 files changed

Lines changed: 88 additions & 36 deletions

File tree

mindee/v2/parsing/inference/field/inference_fields.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
11
from collections.abc import Callable
2+
from typing import TYPE_CHECKING, cast
23

34
from mindee.parsing.common import StringDict
45
from mindee.v2.parsing.inference.field.base_field import BaseField, FieldType
56

7+
if TYPE_CHECKING:
8+
from mindee.v2.parsing.inference.field.list_field import ListField
9+
from mindee.v2.parsing.inference.field.object_field import ObjectField
10+
from mindee.v2.parsing.inference.field.simple_field import SimpleField
11+
612

713
class InferenceFields(dict[str, BaseField]):
814
"""Inference fields dict."""
@@ -34,3 +40,24 @@ def __str__(self) -> str:
3440
else:
3541
str_fields += f"\n:{field_key}:{field_value}"
3642
return str_fields
43+
44+
def get_simple_field(self, field_name: str) -> "SimpleField":
45+
"""Retrieve the value of a simple field by its name."""
46+
field = self.get(field_name)
47+
if field and field.field_type == FieldType.SIMPLE:
48+
return cast("SimpleField", field)
49+
raise ValueError(f"Field {field_name} is not a SimpleField.")
50+
51+
def get_object_field(self, field_name: str) -> "ObjectField":
52+
"""Retrieve the value of an object field by its name."""
53+
field = self.get(field_name)
54+
if field and field.field_type == FieldType.OBJECT:
55+
return cast("ObjectField", field)
56+
raise ValueError(f"Field {field_name} is not an ObjectField.")
57+
58+
def get_list_field(self, field_name: str) -> "ListField":
59+
"""Retrieve the value of a list field by its name."""
60+
field = self.get(field_name)
61+
if field and field.field_type == FieldType.LIST:
62+
return cast("ListField", field)
63+
raise ValueError(f"Field {field_name} is not a ListField.")

tests/v2/file_operations/test_crop_operation_integration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def test_image_should_extract_crops():
5454
crop1size = os.path.getsize(OUTPUT_DIR / "crop_001.jpg")
5555
crop2size = os.path.getsize(OUTPUT_DIR / "crop_002.jpg")
5656
assert 187484 <= crop1size <= 199685
57-
assert 197978 <= crop2size <= 199433
57+
assert 194103 <= crop2size <= 199433
5858

5959

6060
@pytest.fixture(scope="module", autouse=True)

tests/v2/product/extraction/test_extraction_response.py

Lines changed: 60 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -29,65 +29,74 @@ def test_deep_nested_fields():
2929
assert isinstance(response.inference.result.fields["field_simple"], SimpleField)
3030
assert isinstance(response.inference.result.fields["field_object"], ObjectField)
3131
assert isinstance(
32-
response.inference.result.fields["field_object"].fields["sub_object_list"],
32+
response.inference.result.fields.get_object_field(
33+
"field_object"
34+
).get_list_field("sub_object_list"),
3335
ListField,
3436
)
3537
assert isinstance(
36-
response.inference.result.fields["field_object"].fields["sub_object_object"],
38+
response.inference.result.fields.get_object_field(
39+
"field_object"
40+
).get_object_field("sub_object_object"),
3741
ObjectField,
3842
)
3943
fields = response.inference.result.fields
40-
assert isinstance(fields.get("field_object"), ObjectField)
44+
assert isinstance(fields.get_object_field("field_object"), ObjectField)
4145
assert isinstance(
42-
fields.get("field_object").get_simple_field("sub_object_simple"), SimpleField
46+
fields.get_object_field("field_object").get_simple_field("sub_object_simple"),
47+
SimpleField,
4348
)
4449
assert isinstance(
45-
fields.get("field_object").get_list_field("sub_object_list"), ListField
50+
fields.get_object_field("field_object").get_list_field("sub_object_list"),
51+
ListField,
4652
)
4753
assert isinstance(
48-
fields.get("field_object").get_object_field("sub_object_object"), ObjectField
54+
fields.get_object_field("field_object").get_object_field("sub_object_object"),
55+
ObjectField,
4956
)
50-
assert len(fields.get("field_object").simple_fields) == 1
51-
assert len(fields.get("field_object").list_fields) == 1
52-
assert len(fields.get("field_object").object_fields) == 1
57+
assert len(fields.get_object_field("field_object").simple_fields) == 1
58+
assert len(fields.get_object_field("field_object").list_fields) == 1
59+
assert len(fields.get_object_field("field_object").object_fields) == 1
5360
assert isinstance(
54-
fields["field_object"].fields["sub_object_object"].fields,
61+
fields.get_object_field("field_object")
62+
.fields.get_object_field("sub_object_object")
63+
.fields,
5564
dict,
5665
)
5766
assert isinstance(
58-
fields["field_object"]
59-
.fields["sub_object_object"]
60-
.fields["sub_object_object_sub_object_list"],
67+
fields.get_object_field("field_object")
68+
.fields.get_object_field("sub_object_object")
69+
.fields.get_list_field("sub_object_object_sub_object_list"),
6170
ListField,
6271
)
6372
assert isinstance(
64-
fields["field_object"]
65-
.fields["sub_object_object"]
66-
.fields["sub_object_object_sub_object_list"]
73+
fields.get_object_field("field_object")
74+
.fields.get_object_field("sub_object_object")
75+
.fields.get_list_field("sub_object_object_sub_object_list")
6776
.items,
6877
list,
6978
)
7079
assert isinstance(
71-
fields["field_object"]
72-
.fields["sub_object_object"]
73-
.fields["sub_object_object_sub_object_list"]
80+
fields.get_object_field("field_object")
81+
.fields.get_object_field("sub_object_object")
82+
.fields.get_list_field("sub_object_object_sub_object_list")
7483
.items[0],
7584
ObjectField,
7685
)
7786
assert isinstance(
78-
fields["field_object"]
79-
.fields["sub_object_object"]
80-
.fields["sub_object_object_sub_object_list"]
87+
fields.get_object_field("field_object")
88+
.fields.get_object_field("sub_object_object")
89+
.fields.get_list_field("sub_object_object_sub_object_list")
8190
.items[0]
82-
.fields["sub_object_object_sub_object_list_simple"],
91+
.fields.get_simple_field("sub_object_object_sub_object_list_simple"),
8392
SimpleField,
8493
)
8594
assert (
86-
fields["field_object"]
87-
.fields["sub_object_object"]
88-
.fields["sub_object_object_sub_object_list"]
95+
fields.get_object_field("field_object")
96+
.fields.get_object_field("sub_object_object")
97+
.fields.get_list_field("sub_object_object_sub_object_list")
8998
.items[0]
90-
.fields["sub_object_object_sub_object_list_simple"]
99+
.fields.get_simple_field("sub_object_object_sub_object_list_simple")
91100
.value
92101
== "value_9"
93102
)
@@ -101,7 +110,9 @@ def test_standard_field_types():
101110
response = ExtractionResponse(json_sample)
102111
assert isinstance(response.inference, ExtractionInference)
103112

104-
field_simple_string = response.inference.result.fields["field_simple_string"]
113+
field_simple_string = response.inference.result.fields.get_simple_field(
114+
"field_simple_string"
115+
)
105116
assert isinstance(field_simple_string, SimpleField)
106117
assert field_simple_string.value == "field_simple_string-value"
107118
assert field_simple_string.confidence == FieldConfidence.CERTAIN
@@ -228,16 +239,30 @@ def test_full_inference_response():
228239

229240
assert isinstance(response.inference, ExtractionInference)
230241
assert response.inference.id == "12345678-1234-1234-1234-123456789abc"
231-
assert isinstance(response.inference.result.fields["date"], SimpleField)
232-
assert response.inference.result.fields["date"].value == "2019-11-02"
233-
assert isinstance(response.inference.result.fields["taxes"], ListField)
234-
assert isinstance(response.inference.result.fields["taxes"].items[0], ObjectField)
242+
assert isinstance(
243+
response.inference.result.fields.get_simple_field("date"), SimpleField
244+
)
235245
assert (
236-
response.inference.result.fields["customer_address"].fields["city"].value
246+
response.inference.result.fields.get_simple_field("date").value == "2019-11-02"
247+
)
248+
assert isinstance(
249+
response.inference.result.fields.get_list_field("taxes"), ListField
250+
)
251+
assert isinstance(
252+
response.inference.result.fields.get_list_field("taxes").items[0], ObjectField
253+
)
254+
assert (
255+
response.inference.result.fields.get_object_field("customer_address")
256+
.fields.get_simple_field("city")
257+
.value
237258
== "New York"
238259
)
239260
assert (
240-
response.inference.result.fields["taxes"].items[0].fields["base"].value == 31.5
261+
response.inference.result.fields.get_list_field("taxes")
262+
.items[0]
263+
.fields.get_simple_field("base")
264+
.value
265+
== 31.5
241266
)
242267

243268
assert isinstance(response.inference.model, InferenceModel)
@@ -263,7 +288,7 @@ def test_field_locations_and_confidence() -> None:
263288

264289
response = ExtractionResponse(json_sample)
265290

266-
date_field: SimpleField = response.inference.result.fields["date"]
291+
date_field: SimpleField = response.inference.result.fields.get_simple_field("date")
267292

268293
assert date_field.locations, "date field should expose locations"
269294
location = date_field.locations[0]

0 commit comments

Comments
 (0)