diff --git a/jaxtyping/__init__.py b/jaxtyping/__init__.py index 8d13017..ee2d9bc 100644 --- a/jaxtyping/__init__.py +++ b/jaxtyping/__init__.py @@ -38,6 +38,7 @@ ) from ._import_hook import install_import_hook as install_import_hook from ._ipython_extension import load_ipython_extension as load_ipython_extension +from ._ipython_extension import unload_ipython_extension as unload_ipython_extension from ._storage import print_bindings as print_bindings diff --git a/jaxtyping/_ipython_extension.py b/jaxtyping/_ipython_extension.py index ff5219b..0dfc632 100644 --- a/jaxtyping/_ipython_extension.py +++ b/jaxtyping/_ipython_extension.py @@ -17,6 +17,7 @@ # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +from ._config import config from ._import_hook import JaxtypingTransformer, Typechecker @@ -54,3 +55,27 @@ def load_ipython_extension(ipython): raise RuntimeError("Failed to define jaxtyping.typechecker magic") from e ipython.register_magics(ChooseTypecheckerMagics) + + +def unload_ipython_extension(ipython): + """ + Support `%unload_ext jaxtyping` to remove the jaxtyping AST transformer + and unregister the `%jaxtyping.typechecker` magic. + """ + # Names registered via @line_magic use the explicit string we provided. + extension_name = "jaxtyping.typechecker" + try: + # 0) Disable runtime typechecking globally (covers already-decorated functions). + config.jaxtyping_disable = True + + # 1) Remove any JaxtypingTransformer from the AST transformers. + ipython.ast_transformers = [ + t for t in getattr(ipython, "ast_transformers", None) + if not isinstance(t, JaxtypingTransformer)] + + # 2) Unregister the `%jaxtyping.typechecker` magic. + mm = getattr(ipython, "magics_manager", None) + magics = getattr(mm, "magics", None) + magics["line"].pop(extension_name) + except Exception as e: + RuntimeError("Failed to unload jaxtyping.typechecker magic") \ No newline at end of file diff --git a/test/test_ipython_extension.py b/test/test_ipython_extension.py index 1c3c7d8..9f4ab66 100644 --- a/test/test_ipython_extension.py +++ b/test/test_ipython_extension.py @@ -74,6 +74,23 @@ def g(x: Float[Array, "1"]): ip.run_cell(raw_cell='g("string")').raise_error() +def test_unload_extension_disables_typechecking(ip): + ip.run_cell( + raw_cell=""" + from jaxtyping import Float, Array + import jax + + def g(x: Float[Array, "1"]): + return x + 1 + + int_arr = jax.numpy.array([1]) + """ + ).raise_error() + + ip.run_cell(raw_cell="%unload_ext jaxtyping").raise_error() + ip.run_cell(raw_cell="g(int_arr)").raise_error() + + def test_function_jaxtyped_and_jitted(ip): ip.run_cell( raw_cell="""