Skip to content

Commit 91544e0

Browse files
cached the algorithms in pytest
For speed, I've cached the Pytest algorithms such that they are not re-initiated every run. Testing is faster now. Especially good for Matlab-based testing where the Matlab engine needs initiating (different PR)
1 parent 0785232 commit 91544e0

File tree

1 file changed

+34
-6
lines changed

1 file changed

+34
-6
lines changed

tests/IVIMmodels/unit_tests/test_ivim_fit.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def tolerances_helper(tolerances, data):
2929
tolerances["atol"] = tolerances.get("atol", {"f": 2e-1, "D": 5e-4, "Dp": 4e-2})
3030
return tolerances
3131

32+
3233
def data_ivim_fit_saved():
3334
# Find the algorithms from algorithms.json
3435
file = pathlib.Path(__file__)
@@ -45,8 +46,8 @@ def data_ivim_fit_saved():
4546
bvals = all_data.pop('config')
4647
bvals = bvals['bvalues']
4748
first=True
48-
for name, data in all_data.items():
49-
for algorithm in algorithms:
49+
for algorithm in algorithms:
50+
for name, data in all_data.items():
5051
algorithm_dict = algorithm_information.get(algorithm, {})
5152
xfail = {"xfail": name in algorithm_dict.get("xfail_names", {}),
5253
"strict": algorithm_dict.get("xfail_names", {}).get(name, True)}
@@ -59,15 +60,38 @@ def data_ivim_fit_saved():
5960
first = False
6061
yield name, bvals, data, algorithm, xfail, kwargs, tolerances, skiptime
6162

63+
64+
def make_hashable(obj):
65+
if isinstance(obj, dict):
66+
return tuple(sorted((k, make_hashable(v)) for k, v in obj.items()))
67+
elif isinstance(obj, (list, tuple)):
68+
return tuple(make_hashable(i) for i in obj)
69+
else:
70+
return obj
71+
72+
73+
@pytest.fixture(scope="module")
74+
def algorithm_cache():
75+
cache = {}
76+
77+
def get_instance(algorithm, kwargs):
78+
hashable_key = (algorithm, make_hashable(kwargs))
79+
if hashable_key not in cache:
80+
cache[hashable_key] = OsipiBase(algorithm=algorithm, **kwargs)
81+
return cache[hashable_key]
82+
83+
return get_instance
84+
85+
6286
@pytest.mark.parametrize("name, bvals, data, algorithm, xfail, kwargs, tolerances, skiptime", data_ivim_fit_saved())
63-
def test_ivim_fit_saved(name, bvals, data, algorithm, xfail, kwargs, tolerances,skiptime, request, record_property):
87+
def test_ivim_fit_saved(name, bvals, data, algorithm, xfail, kwargs, tolerances, skiptime, request, record_property, algorithm_cache):
6488
if xfail["xfail"]:
6589
mark = pytest.mark.xfail(reason="xfail", strict=xfail["strict"])
6690
request.node.add_marker(mark)
6791
signal = signal_helper(data["data"])
6892
tolerances = tolerances_helper(tolerances, data)
93+
fit = algorithm_cache(algorithm, kwargs)
6994
start_time = time.time() # Record the start time
70-
fit = OsipiBase(algorithm=algorithm, **kwargs)
7195
fit_result = fit.osipi_fit(signal, bvals)
7296
elapsed_time = time.time() - start_time # Calculate elapsed time
7397
def to_list_if_needed(value):
@@ -153,10 +177,14 @@ def bound_input():
153177

154178

155179
@pytest.mark.parametrize("name, bvals, data, algorithm, xfail, kwargs, tolerances", bound_input())
156-
def test_bounds(name, bvals, data, algorithm, xfail, kwargs, tolerances, request):
180+
def test_bounds(name, bvals, data, algorithm, xfail, kwargs, tolerances, request, algorithm_cache):
181+
if xfail["xfail"]:
182+
mark = pytest.mark.xfail(reason="xfail", strict=xfail["strict"])
183+
request.node.add_marker(mark)
157184
bounds = ([0.0008, 0.2, 0.01, 1.1], [0.0012, 0.3, 0.02, 1.3])
158185
# deliberately have silly bounds to see whether they are used
159-
fit = OsipiBase(algorithm=algorithm, bounds=bounds, initial_guess = [0.001, 0.25, 0.015, 1.2], **kwargs)
186+
extended_kwargs = {**kwargs, "bounds": bounds, "initial_guess": [0.001, 0.25, 0.015, 1.2]}
187+
fit = algorithm_cache(algorithm, extended_kwargs)
160188
if fit.use_bounds:
161189
signal = signal_helper(data["data"])
162190
fit_result = fit.osipi_fit(signal, bvals)

0 commit comments

Comments
 (0)