From 0aa5c63c4183988aaedcf14b2c354d264842243f Mon Sep 17 00:00:00 2001 From: Nikolas Claussen Date: Wed, 4 Feb 2026 18:14:40 -0500 Subject: [PATCH 1/4] added support for unloading the jaxtyping ipython extension --- jaxtyping/__init__.py | 1 + jaxtyping/_ipython_extension.py | 43 +++++++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+) diff --git a/jaxtyping/__init__.py b/jaxtyping/__init__.py index 8d13017..ee2d9bc 100644 --- a/jaxtyping/__init__.py +++ b/jaxtyping/__init__.py @@ -38,6 +38,7 @@ ) from ._import_hook import install_import_hook as install_import_hook from ._ipython_extension import load_ipython_extension as load_ipython_extension +from ._ipython_extension import unload_ipython_extension as unload_ipython_extension from ._storage import print_bindings as print_bindings diff --git a/jaxtyping/_ipython_extension.py b/jaxtyping/_ipython_extension.py index ff5219b..6eccbce 100644 --- a/jaxtyping/_ipython_extension.py +++ b/jaxtyping/_ipython_extension.py @@ -17,6 +17,7 @@ # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +from ._config import config from ._import_hook import JaxtypingTransformer, Typechecker @@ -54,3 +55,45 @@ def load_ipython_extension(ipython): raise RuntimeError("Failed to define jaxtyping.typechecker magic") from e ipython.register_magics(ChooseTypecheckerMagics) + + +def unload_ipython_extension(ipython): + """ + Support `%unload_ext jaxtyping` to remove the jaxtyping AST transformer + and unregister the `%jaxtyping.typechecker` magic. + """ + if ipython is None: + return + + # Disable runtime typechecking globally (covers already-decorated functions). + try: + config.jaxtyping_disable = True + except Exception: + pass + + # 1) Remove any JaxtypingTransformer from the AST transformers. + try: + ipython.ast_transformers = [ + t for t in getattr(ipython, "ast_transformers", []) + if not isinstance(t, JaxtypingTransformer) + ] + except Exception: + # Be permissive: if IPython internals change, don't hard-fail. + pass + + # 2) Unregister the `%jaxtyping.typechecker` magic. + try: + mm = getattr(ipython, "magics_manager", None) + if mm is not None: + for kind in ("line", "cell", "line_cell"): + d = mm.magics.get(kind, {}) + # Names registered via @line_magic use the explicit string we provided. + for name in ("jaxtyping.typechecker",): + if name in d: + try: + del d[name] + except Exception: + pass + except Exception: + # Also permissive here. + pass \ No newline at end of file From ed4115d12766fb1660ff8071f570c398e2422a92 Mon Sep 17 00:00:00 2001 From: Nikolas Claussen Date: Wed, 4 Feb 2026 18:32:21 -0500 Subject: [PATCH 2/4] added test to check unloading disables type checking --- test/test_ipython_extension.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/test/test_ipython_extension.py b/test/test_ipython_extension.py index 1c3c7d8..9f4ab66 100644 --- a/test/test_ipython_extension.py +++ b/test/test_ipython_extension.py @@ -74,6 +74,23 @@ def g(x: Float[Array, "1"]): ip.run_cell(raw_cell='g("string")').raise_error() +def test_unload_extension_disables_typechecking(ip): + ip.run_cell( + raw_cell=""" + from jaxtyping import Float, Array + import jax + + def g(x: Float[Array, "1"]): + return x + 1 + + int_arr = jax.numpy.array([1]) + """ + ).raise_error() + + ip.run_cell(raw_cell="%unload_ext jaxtyping").raise_error() + ip.run_cell(raw_cell="g(int_arr)").raise_error() + + def test_function_jaxtyped_and_jitted(ip): ip.run_cell( raw_cell=""" From 63b49511205458063bafdf21453fc72a07ffa82d Mon Sep 17 00:00:00 2001 From: Nikolas Claussen Date: Sun, 8 Feb 2026 15:25:38 -0500 Subject: [PATCH 3/4] streamlined unload_ipython_extension --- jaxtyping/_ipython_extension.py | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/jaxtyping/_ipython_extension.py b/jaxtyping/_ipython_extension.py index 6eccbce..427e3d6 100644 --- a/jaxtyping/_ipython_extension.py +++ b/jaxtyping/_ipython_extension.py @@ -61,11 +61,14 @@ def unload_ipython_extension(ipython): """ Support `%unload_ext jaxtyping` to remove the jaxtyping AST transformer and unregister the `%jaxtyping.typechecker` magic. + Permissive; does not raise errors if, e.g., the magic is not found. """ + # Names registered via @line_magic use the explicit string we provided. + extension_name = "jaxtyping.typechecker" if ipython is None: return - # Disable runtime typechecking globally (covers already-decorated functions). + # 0) Disable runtime typechecking globally (covers already-decorated functions). try: config.jaxtyping_disable = True except Exception: @@ -74,26 +77,20 @@ def unload_ipython_extension(ipython): # 1) Remove any JaxtypingTransformer from the AST transformers. try: ipython.ast_transformers = [ - t for t in getattr(ipython, "ast_transformers", []) + t for t in getattr(ipython, "ast_transformers", None) if not isinstance(t, JaxtypingTransformer) ] except Exception: - # Be permissive: if IPython internals change, don't hard-fail. pass # 2) Unregister the `%jaxtyping.typechecker` magic. try: mm = getattr(ipython, "magics_manager", None) if mm is not None: - for kind in ("line", "cell", "line_cell"): - d = mm.magics.get(kind, {}) - # Names registered via @line_magic use the explicit string we provided. - for name in ("jaxtyping.typechecker",): - if name in d: - try: - del d[name] - except Exception: - pass + magics = getattr(mm, "magics", None) + if isinstance(magics, dict): + for registry in magics.values(): + if isinstance(registry, dict): + registry.pop(extension_name, None) except Exception: - # Also permissive here. pass \ No newline at end of file From 9493bbc604973d0e6f2db00e643a568520833d5c Mon Sep 17 00:00:00 2001 From: Nikolas Claussen Date: Sun, 8 Feb 2026 15:36:17 -0500 Subject: [PATCH 4/4] streamlined unload_ipython_extension --- jaxtyping/_ipython_extension.py | 33 +++++++++------------------------ 1 file changed, 9 insertions(+), 24 deletions(-) diff --git a/jaxtyping/_ipython_extension.py b/jaxtyping/_ipython_extension.py index 427e3d6..0dfc632 100644 --- a/jaxtyping/_ipython_extension.py +++ b/jaxtyping/_ipython_extension.py @@ -61,36 +61,21 @@ def unload_ipython_extension(ipython): """ Support `%unload_ext jaxtyping` to remove the jaxtyping AST transformer and unregister the `%jaxtyping.typechecker` magic. - Permissive; does not raise errors if, e.g., the magic is not found. """ # Names registered via @line_magic use the explicit string we provided. extension_name = "jaxtyping.typechecker" - if ipython is None: - return - - # 0) Disable runtime typechecking globally (covers already-decorated functions). try: + # 0) Disable runtime typechecking globally (covers already-decorated functions). config.jaxtyping_disable = True - except Exception: - pass - - # 1) Remove any JaxtypingTransformer from the AST transformers. - try: + + # 1) Remove any JaxtypingTransformer from the AST transformers. ipython.ast_transformers = [ t for t in getattr(ipython, "ast_transformers", None) - if not isinstance(t, JaxtypingTransformer) - ] - except Exception: - pass + if not isinstance(t, JaxtypingTransformer)] - # 2) Unregister the `%jaxtyping.typechecker` magic. - try: + # 2) Unregister the `%jaxtyping.typechecker` magic. mm = getattr(ipython, "magics_manager", None) - if mm is not None: - magics = getattr(mm, "magics", None) - if isinstance(magics, dict): - for registry in magics.values(): - if isinstance(registry, dict): - registry.pop(extension_name, None) - except Exception: - pass \ No newline at end of file + magics = getattr(mm, "magics", None) + magics["line"].pop(extension_name) + except Exception as e: + RuntimeError("Failed to unload jaxtyping.typechecker magic") \ No newline at end of file