@@ -29,6 +29,7 @@ def tolerances_helper(tolerances, data):
2929 tolerances ["atol" ] = tolerances .get ("atol" , {"f" : 2e-1 , "D" : 5e-4 , "Dp" : 4e-2 })
3030 return tolerances
3131
32+
3233def data_ivim_fit_saved ():
3334 # Find the algorithms from algorithms.json
3435 file = pathlib .Path (__file__ )
@@ -45,8 +46,8 @@ def data_ivim_fit_saved():
4546 bvals = all_data .pop ('config' )
4647 bvals = bvals ['bvalues' ]
4748 first = True
48- for name , data in all_data . items () :
49- for algorithm in algorithms :
49+ for algorithm in algorithms :
50+ for name , data in all_data . items () :
5051 algorithm_dict = algorithm_information .get (algorithm , {})
5152 xfail = {"xfail" : name in algorithm_dict .get ("xfail_names" , {}),
5253 "strict" : algorithm_dict .get ("xfail_names" , {}).get (name , True )}
@@ -59,15 +60,38 @@ def data_ivim_fit_saved():
5960 first = False
6061 yield name , bvals , data , algorithm , xfail , kwargs , tolerances , skiptime
6162
63+
64+ def make_hashable (obj ):
65+ if isinstance (obj , dict ):
66+ return tuple (sorted ((k , make_hashable (v )) for k , v in obj .items ()))
67+ elif isinstance (obj , (list , tuple )):
68+ return tuple (make_hashable (i ) for i in obj )
69+ else :
70+ return obj
71+
72+
73+ @pytest .fixture (scope = "module" )
74+ def algorithm_cache ():
75+ cache = {}
76+
77+ def get_instance (algorithm , kwargs ):
78+ hashable_key = (algorithm , make_hashable (kwargs ))
79+ if hashable_key not in cache :
80+ cache [hashable_key ] = OsipiBase (algorithm = algorithm , ** kwargs )
81+ return cache [hashable_key ]
82+
83+ return get_instance
84+
85+
6286@pytest .mark .parametrize ("name, bvals, data, algorithm, xfail, kwargs, tolerances, skiptime" , data_ivim_fit_saved ())
63- def test_ivim_fit_saved (name , bvals , data , algorithm , xfail , kwargs , tolerances ,skiptime , request , record_property ):
87+ def test_ivim_fit_saved (name , bvals , data , algorithm , xfail , kwargs , tolerances , skiptime , request , record_property , algorithm_cache ):
6488 if xfail ["xfail" ]:
6589 mark = pytest .mark .xfail (reason = "xfail" , strict = xfail ["strict" ])
6690 request .node .add_marker (mark )
6791 signal = signal_helper (data ["data" ])
6892 tolerances = tolerances_helper (tolerances , data )
93+ fit = algorithm_cache (algorithm , kwargs )
6994 start_time = time .time () # Record the start time
70- fit = OsipiBase (algorithm = algorithm , ** kwargs )
7195 fit_result = fit .osipi_fit (signal , bvals )
7296 elapsed_time = time .time () - start_time # Calculate elapsed time
7397 def to_list_if_needed (value ):
@@ -153,10 +177,14 @@ def bound_input():
153177
154178
155179@pytest .mark .parametrize ("name, bvals, data, algorithm, xfail, kwargs, tolerances" , bound_input ())
156- def test_bounds (name , bvals , data , algorithm , xfail , kwargs , tolerances , request ):
180+ def test_bounds (name , bvals , data , algorithm , xfail , kwargs , tolerances , request , algorithm_cache ):
181+ if xfail ["xfail" ]:
182+ mark = pytest .mark .xfail (reason = "xfail" , strict = xfail ["strict" ])
183+ request .node .add_marker (mark )
157184 bounds = ([0.0008 , 0.2 , 0.01 , 1.1 ], [0.0012 , 0.3 , 0.02 , 1.3 ])
158185 # deliberately have silly bounds to see whether they are used
159- fit = OsipiBase (algorithm = algorithm , bounds = bounds , initial_guess = [0.001 , 0.25 , 0.015 , 1.2 ], ** kwargs )
186+ extended_kwargs = {** kwargs , "bounds" : bounds , "initial_guess" : [0.001 , 0.25 , 0.015 , 1.2 ]}
187+ fit = algorithm_cache (algorithm , extended_kwargs )
160188 if fit .use_bounds :
161189 signal = signal_helper (data ["data" ])
162190 fit_result = fit .osipi_fit (signal , bvals )
0 commit comments