11import contextlib
22import inspect
3+ from copy import copy
34from typing import TYPE_CHECKING , Callable , Optional , Sequence , Tuple , Union
45from unittest import mock
56
1617from aesara .compile .mode import Mode
1718from aesara .compile .ops import ViewOp
1819from aesara .compile .sharedvalue import SharedVariable
19- from aesara .graph .basic import Apply , Constant
20+ from aesara .graph .basic import Apply , Constant , vars_between
2021from aesara .graph .fg import FunctionGraph
2122from aesara .graph .op import Op , get_test_value
2223from aesara .graph .rewriting .db import RewriteDatabaseQuery
@@ -218,6 +219,11 @@ def assert_fn(x, y):
218219
219220 fn_inputs = [i for i in fn_inputs if not isinstance (i , SharedVariable )]
220221
222+ shared_vars_to_init_vals = {}
223+ for v in vars_between (fn_inputs , fn_outputs ):
224+ if isinstance (v , SharedVariable ):
225+ shared_vars_to_init_vals [v ] = copy (v .get_value (borrow = True ))
226+
221227 aesara_py_fn = function (
222228 fn_inputs , fn_outputs , mode = py_mode , accept_inplace = True , updates = updates
223229 )
@@ -230,6 +236,12 @@ def assert_fn(x, y):
230236 accept_inplace = True ,
231237 updates = updates ,
232238 )
239+
240+ # Reset shared variables so that the results will match between the two
241+ # runs
242+ for v , val in shared_vars_to_init_vals .items ():
243+ v .set_value (val )
244+
233245 numba_res = aesara_numba_fn (* inputs )
234246
235247 # Get some coverage
0 commit comments