diff --git a/glue/sample/src/sinter/_data/_anon_task_stats.py b/glue/sample/src/sinter/_data/_anon_task_stats.py index 42281c07d..a303d5ea1 100644 --- a/glue/sample/src/sinter/_data/_anon_task_stats.py +++ b/glue/sample/src/sinter/_data/_anon_task_stats.py @@ -40,6 +40,16 @@ def __post_init__(self): assert isinstance(self.discards, (int, np.integer)) assert isinstance(self.seconds, (int, float, np.integer, np.floating)) assert isinstance(self.custom_counts, collections.Counter) + + if isinstance(self.errors, np.integer): + object.__setattr__(self, 'errors', int(self.errors)) + if isinstance(self.shots, np.integer): + object.__setattr__(self, 'shots', int(self.shots)) + if isinstance(self.discards, np.integer): + object.__setattr__(self, 'discards', int(self.discards)) + if isinstance(self.seconds, (np.integer, np.floating)): + object.__setattr__(self, 'seconds', float(self.seconds)) + assert self.errors >= 0 assert self.discards >= 0 assert self.seconds >= 0 diff --git a/glue/sample/src/sinter/_data/_anon_task_stats_test.py b/glue/sample/src/sinter/_data/_anon_task_stats_test.py index 3b2e0b0b7..1145745dc 100644 --- a/glue/sample/src/sinter/_data/_anon_task_stats_test.py +++ b/glue/sample/src/sinter/_data/_anon_task_stats_test.py @@ -1,4 +1,5 @@ import collections +import numpy as np import sinter @@ -33,3 +34,11 @@ def test_add(): assert a + b0 == sinter.AnonTaskStats(shots=270, errors=34, discards=43, seconds=52, custom_counts=collections.Counter({'a': 10, 'b': 20})) assert a0 + b == sinter.AnonTaskStats(shots=270, errors=34, discards=43, seconds=52, custom_counts=collections.Counter({'a': 1, 'c': 3})) + +def test_init_handles_np(): + a0 = sinter.AnonTaskStats(shots=np.int64(220), errors=np.int64(0), discards=np.int64(0), seconds=np.float64(0.0), custom_counts=collections.Counter({'a': 10, 'b': 20})) + + assert isinstance(a0.shots, int) + assert isinstance(a0.errors, int) + assert isinstance(a0.discards, int) + assert isinstance(a0.seconds, float) diff --git a/glue/sample/src/sinter/_data/_task_stats.py b/glue/sample/src/sinter/_data/_task_stats.py index f184e94d6..3512798b9 100644 --- a/glue/sample/src/sinter/_data/_task_stats.py +++ b/glue/sample/src/sinter/_data/_task_stats.py @@ -79,6 +79,16 @@ def __post_init__(self): assert isinstance(self.decoder, str) assert isinstance(self.strong_id, str) assert self.json_metadata is None or isinstance(self.json_metadata, (int, float, str, dict, list, tuple)) + + if isinstance(self.errors, np.integer): + object.__setattr__(self, 'errors', int(self.errors)) + if isinstance(self.shots, np.integer): + object.__setattr__(self, 'shots', int(self.shots)) + if isinstance(self.discards, np.integer): + object.__setattr__(self, 'discards', int(self.discards)) + if isinstance(self.seconds, (np.integer, np.floating)): + object.__setattr__(self, 'seconds', float(self.seconds)) + assert self.errors >= 0 assert self.discards >= 0 assert self.seconds >= 0 diff --git a/glue/sample/src/sinter/_data/_task_stats_test.py b/glue/sample/src/sinter/_data/_task_stats_test.py index 79f456f2a..1b928d77c 100644 --- a/glue/sample/src/sinter/_data/_task_stats_test.py +++ b/glue/sample/src/sinter/_data/_task_stats_test.py @@ -1,4 +1,5 @@ import collections +import numpy as np import pytest @@ -138,3 +139,19 @@ def test_is_equal_json_values(): assert not _is_equal_json_values({'x': (1, 2)}, {'x': (1, 3)}) assert not _is_equal_json_values(1, 2) assert _is_equal_json_values(1, 1) + +def test_init_handles_np(): + v = sinter.TaskStats( + strong_id='test', + json_metadata={'a': [1, 2, 3]}, + decoder='pymatching', + shots=np.int64(22), + errors=np.int64(3), + discards=np.int64(4), + seconds=np.float64(5), + ) + + assert isinstance(v.shots, int) + assert isinstance(v.errors, int) + assert isinstance(v.discards, int) + assert isinstance(v.seconds, float)