This code:
# pip install cloudpickle jaxtyping typeguard
import cloudpickle
from jaxtyping import jaxtyped
import typeguard
def typed(function):
return jaxtyped(function, typechecker=typeguard.typechecked)
@typed
def f():
return 1
def unwrapped_f():
return 2
# This will succeed
pickled_unwrapped = cloudpickle.dumps(unwrapped_f)
print("Successfully pickled unwrapped_f")
# This will fail
try:
pickled_f = cloudpickle.dumps(f)
except Exception as e:
print(f"\nFailed to pickle decorated function f:\n{type(e).__name__}: {e}")
Prints out this:
[...]/lib/python3.13/site-packages/jaxtyping/_decorator.py:71: InstrumentationWarning: instrumentor did not find the target function -- not typechecking __main__.f
return typechecker(fn)
Successfully pickled unwrapped_f
Failed to pickle decorated function f:
TypeError: cannot pickle 'weakref.ReferenceType' object
Showing that jaxtyping doesn't interact well with cloudpickle :(
CC @JoshEngels @jkramar
This code:
Prints out this:
Showing that jaxtyping doesn't interact well with cloudpickle :(
CC @JoshEngels @jkramar