diff --git a/jaxtyping/_typeguard/__init__.py b/jaxtyping/_typeguard/__init__.py index da8d672..77d248c 100644 --- a/jaxtyping/_typeguard/__init__.py +++ b/jaxtyping/_typeguard/__init__.py @@ -14,7 +14,7 @@ from inspect import Parameter, isclass, isfunction, isgeneratorfunction from io import BufferedIOBase, IOBase, RawIOBase, TextIOBase from traceback import extract_stack, print_stack -from types import CodeType, FunctionType +from types import CodeType, FunctionType, UnionType from typing import ( IO, TYPE_CHECKING, AbstractSet, Any, AsyncIterable, AsyncIterator, BinaryIO, Callable, Dict, Generator, Iterable, Iterator, List, NewType, Optional, Sequence, Set, TextIO, Tuple, Type, @@ -752,6 +752,10 @@ def check_type(argname: str, value, expected_type, memo: Optional[_TypeCheckMemo expected_type = resolve_forwardref(expected_type, memo) origin_type = getattr(expected_type, '__origin__', None) + # jaxtyping fix to typeguard, see https://github.com/patrick-kidger/jaxtyping/issues/73#issuecomment-3888576066 + if isinstance(expected_type, UnionType): + origin_type = Union + # ~fix if origin_type is not None: checker_func = origin_type_checkers.get(origin_type) if checker_func: diff --git a/test/test_pytree.py b/test/test_pytree.py index cefc0ba..bfb2ea0 100644 --- a/test/test_pytree.py +++ b/test/test_pytree.py @@ -393,3 +393,15 @@ def test_pdoc(): wl.pformat(PyTree[None | Callable[[PyTree[int, " T"]], str]], width=2).strip() == expected ) + + +# https://github.com/patrick-kidger/jaxtyping/issues/73#issuecomment-3888576066 +def test_new_union(): + from jaxtyping._typeguard import typechecked + + @typechecked + def f(x: int | bool): + pass + + with pytest.raises(TypeError): + f(object())