diff --git a/jaxtyping/__init__.py b/jaxtyping/__init__.py index 8d13017..31a4c92 100644 --- a/jaxtyping/__init__.py +++ b/jaxtyping/__init__.py @@ -198,7 +198,9 @@ def __getattr__(item): # `uint32`. import jax - return Union[Key[jax.Array, ""], UInt32[jax.Array, "2"]] + return Union[ + Key[jax.Array, ""], UInt32[jax.Array, "2"], UInt32[jax.Array, "4"] + ] elif item == "DTypeLike": import jax.typing