Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions jaxtyping/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
)
from ._import_hook import install_import_hook as install_import_hook
from ._ipython_extension import load_ipython_extension as load_ipython_extension
from ._ipython_extension import unload_ipython_extension as unload_ipython_extension
from ._storage import print_bindings as print_bindings


Expand Down
25 changes: 25 additions & 0 deletions jaxtyping/_ipython_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

from ._config import config
from ._import_hook import JaxtypingTransformer, Typechecker


Expand Down Expand Up @@ -54,3 +55,27 @@ def load_ipython_extension(ipython):
raise RuntimeError("Failed to define jaxtyping.typechecker magic") from e

ipython.register_magics(ChooseTypecheckerMagics)


def unload_ipython_extension(ipython):
"""
Support `%unload_ext jaxtyping` to remove the jaxtyping AST transformer
and unregister the `%jaxtyping.typechecker` magic.
"""
# Names registered via @line_magic use the explicit string we provided.
extension_name = "jaxtyping.typechecker"
try:
# 0) Disable runtime typechecking globally (covers already-decorated functions).
config.jaxtyping_disable = True

# 1) Remove any JaxtypingTransformer from the AST transformers.
ipython.ast_transformers = [
t for t in getattr(ipython, "ast_transformers", None)
if not isinstance(t, JaxtypingTransformer)]

# 2) Unregister the `%jaxtyping.typechecker` magic.
mm = getattr(ipython, "magics_manager", None)
magics = getattr(mm, "magics", None)
magics["line"].pop(extension_name)
except Exception as e:
RuntimeError("Failed to unload jaxtyping.typechecker magic")
17 changes: 17 additions & 0 deletions test/test_ipython_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,23 @@ def g(x: Float[Array, "1"]):
ip.run_cell(raw_cell='g("string")').raise_error()


def test_unload_extension_disables_typechecking(ip):
ip.run_cell(
raw_cell="""
from jaxtyping import Float, Array
import jax

def g(x: Float[Array, "1"]):
return x + 1

int_arr = jax.numpy.array([1])
"""
).raise_error()

ip.run_cell(raw_cell="%unload_ext jaxtyping").raise_error()
ip.run_cell(raw_cell="g(int_arr)").raise_error()


def test_function_jaxtyped_and_jitted(ip):
ip.run_cell(
raw_cell="""
Expand Down