Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion jaxtyping/_typeguard/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from inspect import Parameter, isclass, isfunction, isgeneratorfunction
from io import BufferedIOBase, IOBase, RawIOBase, TextIOBase
from traceback import extract_stack, print_stack
from types import CodeType, FunctionType
from types import CodeType, FunctionType, UnionType
from typing import (
IO, TYPE_CHECKING, AbstractSet, Any, AsyncIterable, AsyncIterator, BinaryIO, Callable, Dict,
Generator, Iterable, Iterator, List, NewType, Optional, Sequence, Set, TextIO, Tuple, Type,
Expand Down Expand Up @@ -752,6 +752,10 @@ def check_type(argname: str, value, expected_type, memo: Optional[_TypeCheckMemo

expected_type = resolve_forwardref(expected_type, memo)
origin_type = getattr(expected_type, '__origin__', None)
# jaxtyping fix to typeguard, see https://github.com/patrick-kidger/jaxtyping/issues/73#issuecomment-3888576066
if isinstance(expected_type, UnionType):
origin_type = Union
# ~fix
if origin_type is not None:
checker_func = origin_type_checkers.get(origin_type)
if checker_func:
Expand Down
12 changes: 12 additions & 0 deletions test/test_pytree.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,3 +393,15 @@ def test_pdoc():
wl.pformat(PyTree[None | Callable[[PyTree[int, " T"]], str]], width=2).strip()
== expected
)


# https://github.com/patrick-kidger/jaxtyping/issues/73#issuecomment-3888576066
def test_new_union():
from jaxtyping._typeguard import typechecked

@typechecked
def f(x: int | bool):
pass

with pytest.raises(TypeError):
f(object())