Skip to content
This repository was archived by the owner on Nov 17, 2025. It is now read-only.

Commit e146896

Browse files
Reset shared values between runs in compare_numba_and_py
1 parent b720544 commit e146896

File tree

1 file changed

+13
-1
lines changed

1 file changed

+13
-1
lines changed

tests/link/numba/test_basic.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import contextlib
22
import inspect
3+
from copy import copy
34
from typing import TYPE_CHECKING, Callable, Optional, Sequence, Tuple, Union
45
from unittest import mock
56

@@ -16,7 +17,7 @@
1617
from aesara.compile.mode import Mode
1718
from aesara.compile.ops import ViewOp
1819
from aesara.compile.sharedvalue import SharedVariable
19-
from aesara.graph.basic import Apply, Constant
20+
from aesara.graph.basic import Apply, Constant, vars_between
2021
from aesara.graph.fg import FunctionGraph
2122
from aesara.graph.op import Op, get_test_value
2223
from aesara.graph.rewriting.db import RewriteDatabaseQuery
@@ -218,6 +219,11 @@ def assert_fn(x, y):
218219

219220
fn_inputs = [i for i in fn_inputs if not isinstance(i, SharedVariable)]
220221

222+
shared_vars_to_init_vals = {}
223+
for v in vars_between(fn_inputs, fn_outputs):
224+
if isinstance(v, SharedVariable):
225+
shared_vars_to_init_vals[v] = copy(v.get_value(borrow=True))
226+
221227
aesara_py_fn = function(
222228
fn_inputs, fn_outputs, mode=py_mode, accept_inplace=True, updates=updates
223229
)
@@ -230,6 +236,12 @@ def assert_fn(x, y):
230236
accept_inplace=True,
231237
updates=updates,
232238
)
239+
240+
# Reset shared variables so that the results will match between the two
241+
# runs
242+
for v, val in shared_vars_to_init_vals.items():
243+
v.set_value(val)
244+
233245
numba_res = aesara_numba_fn(*inputs)
234246

235247
# Get some coverage

0 commit comments

Comments
 (0)