Skip to content
This repository was archived by the owner on Nov 17, 2025. It is now read-only.

Commit 8088741

Browse files
committed
Add an option to disable numba JIT
1 parent 984ee55 commit 8088741

File tree

4 files changed

+21
-70
lines changed

4 files changed

+21
-70
lines changed

aesara/configdefaults.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,13 @@ def add_compile_configvars():
497497
in_c_key=False,
498498
)
499499

500+
config.add(
501+
"disable_numba_jit",
502+
"Disable numba JIT",
503+
BoolParam(False),
504+
in_c_key=False,
505+
)
506+
500507
# Keep the default optimizer the same as the one for the mode FAST_RUN
501508
config.add(
502509
"optimizer",

aesara/link/numba/linker.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import numpy as np
44

55
import aesara
6+
from aesara.configdefaults import config
67
from aesara.link.basic import JITLinker
78

89

@@ -27,6 +28,8 @@ def fgraph_convert(self, fgraph, **kwargs):
2728
return numba_funcify(fgraph, **kwargs)
2829

2930
def jit_compile(self, fn):
31+
if config.disable_numba_jit:
32+
return fn
3033
from aesara.link.numba.dispatch import numba_njit
3134

3235
jitted_fn = numba_njit(fn)

aesara/link/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -738,7 +738,8 @@ def fgraph_to_python(
738738
compiled_func = op_conversion_fn(
739739
node.op, node=node, storage_map=storage_map, **kwargs
740740
)
741-
741+
if config.disable_numba_jit:
742+
compiled_func = compiled_func.py_func
742743
# Create a local alias with a unique name
743744
local_compiled_func_name = unique_name(compiled_func)
744745
global_env[local_compiled_func_name] = compiled_func

tests/link/numba/test_basic.py

Lines changed: 9 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
import contextlib
2-
import inspect
32
from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence, Tuple, Union
4-
from unittest import mock
53

64
import numba
75
import numpy as np
@@ -108,73 +106,15 @@ def compare_shape_dtype(x, y):
108106
def eval_python_only(fn_inputs, fn_outputs, inputs, mode=numba_mode):
109107
"""Evaluate the Numba implementation in pure Python for coverage purposes."""
110108

111-
def py_tuple_setitem(t, i, v):
112-
ll = list(t)
113-
ll[i] = v
114-
return tuple(ll)
115-
116-
def py_to_scalar(x):
117-
if isinstance(x, np.ndarray):
118-
return x.item()
119-
else:
120-
return x
121-
122-
def njit_noop(*args, **kwargs):
123-
if len(args) == 1 and callable(args[0]):
124-
return args[0]
125-
else:
126-
return lambda x: x
127-
128-
def vectorize_noop(*args, **kwargs):
129-
def wrap(fn):
130-
# `numba.vectorize` allows an `out` positional argument. We need
131-
# to account for that
132-
sig = inspect.signature(fn)
133-
nparams = len(sig.parameters)
134-
135-
def inner_vec(*args):
136-
if len(args) > nparams:
137-
# An `out` argument has been specified for an in-place
138-
# operation
139-
out = args[-1]
140-
out[...] = np.vectorize(fn)(*args[:nparams])
141-
return out
142-
else:
143-
return np.vectorize(fn)(*args)
144-
145-
return inner_vec
146-
147-
if len(args) == 1 and callable(args[0]):
148-
return wrap(args[0], **kwargs)
149-
else:
150-
return wrap
151-
152-
mocks = [
153-
mock.patch("numba.njit", njit_noop),
154-
mock.patch("numba.vectorize", vectorize_noop),
155-
mock.patch("aesara.link.numba.dispatch.basic.tuple_setitem", py_tuple_setitem),
156-
mock.patch("aesara.link.numba.dispatch.basic.numba_njit", njit_noop),
157-
mock.patch("aesara.link.numba.dispatch.basic.numba_vectorize", vectorize_noop),
158-
mock.patch("aesara.link.numba.dispatch.basic.direct_cast", lambda x, dtype: x),
159-
mock.patch("aesara.link.numba.dispatch.basic.to_scalar", py_to_scalar),
160-
mock.patch(
161-
"aesara.link.numba.dispatch.basic.numba.np.numpy_support.from_dtype",
162-
lambda dtype: dtype,
163-
),
164-
mock.patch("numba.np.unsafe.ndarray.to_fixed_tuple", lambda x, n: tuple(x)),
165-
]
166-
167-
with contextlib.ExitStack() as stack:
168-
for ctx in mocks:
169-
stack.enter_context(ctx)
170-
171-
aesara_numba_fn = function(
172-
fn_inputs,
173-
fn_outputs,
174-
mode=mode,
175-
accept_inplace=True,
176-
)
177-
_ = aesara_numba_fn(*inputs)
109+
config.disable_numba_jit = True
110+
aesara_numba_fn = function(
111+
fn_inputs,
112+
fn_outputs,
113+
mode=mode,
114+
accept_inplace=True,
115+
)
116+
_ = aesara_numba_fn(*inputs)
117+
config.disable_numba_jit = False
178118

179119

180120
def compare_numba_and_py(

0 commit comments

Comments
 (0)