Skip to content
This repository was archived by the owner on Nov 17, 2025. It is now read-only.

Commit 92b3309

Browse files
Add a non-one shape constraint to TensorType
1 parent 454f8ae commit 92b3309

File tree

8 files changed

+196
-81
lines changed

8 files changed

+196
-81
lines changed

aesara/sandbox/rng_mrg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,7 @@ def new(cls, rstate, ndim, dtype, size):
395395
v_size = as_tensor_variable(size)
396396
if ndim is None:
397397
ndim = get_vector_length(v_size)
398-
op = cls(TensorType(dtype, (False,) * ndim))
398+
op = cls(TensorType(dtype, shape=(None,) * ndim))
399399
return op(rstate, v_size)
400400

401401
def perform(self, node, inp, out, params):

aesara/sparse/type.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,8 @@ def __init__(
7070
dtype: Union[str, np.dtype],
7171
shape: Optional[Iterable[Optional[Union[bool, int]]]] = None,
7272
name: Optional[str] = None,
73-
broadcastable: Optional[Iterable[bool]] = None,
7473
):
75-
if shape is None and broadcastable is None:
74+
if shape is None:
7675
shape = (None, None)
7776

7877
if format not in self.format_cls:
@@ -82,13 +81,12 @@ def __init__(
8281

8382
self.format = format
8483

85-
super().__init__(dtype, shape=shape, name=name, broadcastable=broadcastable)
84+
super().__init__(dtype, shape=shape, name=name)
8685

8786
def clone(
8887
self,
8988
dtype=None,
9089
shape=None,
91-
broadcastable=None,
9290
**kwargs,
9391
):
9492
format: Optional[SparsityTypes] = kwargs.pop("format", self.format)

aesara/tensor/shape.py

Lines changed: 58 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import warnings
22
from numbers import Number
33
from textwrap import dedent
4-
from typing import Dict, List, Tuple, Union
4+
from typing import Dict, List, Sequence, Tuple, Union
55

66
import numpy as np
77

@@ -16,11 +16,65 @@
1616
from aesara.tensor import basic as at
1717
from aesara.tensor import get_vector_length
1818
from aesara.tensor.exceptions import NotScalarConstantError
19-
from aesara.tensor.type import DenseTensorType, TensorType, int_dtypes, tensor
19+
from aesara.tensor.type import (
20+
DenseTensorType,
21+
TensorType,
22+
int_dtypes,
23+
shape_key,
24+
tensor,
25+
)
2026
from aesara.tensor.type_other import NoneConst
2127
from aesara.tensor.var import TensorConstant, TensorVariable
2228

2329

30+
def filter_shape_vars(
31+
ref_shape: Tuple[int, ...], shape: Sequence[Variable], shape_is_encoded: bool = True
32+
) -> Tuple[int, ...]:
33+
r"""Compute the most \"informative\" shape based on a static reference.
34+
35+
Parameters
36+
----------
37+
ref_shape
38+
A static shape reference using static shape constraint encoding.
39+
shape
40+
A symbolic shape.
41+
shape_is_encoded
42+
If ``True``, `shape` is assumed to be static shape constraint encoded.
43+
44+
Returns
45+
-------
46+
The most specific, and compatible (with `ref_shape`), static shape
47+
constraint encoded values.
48+
"""
49+
shape_bottom = shape_key(None)
50+
type_shape = ()
51+
for i, (xts, s) in enumerate(zip(ref_shape, shape)):
52+
53+
try:
54+
# TODO FIXME: We shouldn't need to do this; let a rewrite
55+
# do constant folding and update the `TensorType`s.
56+
s_val = at.get_scalar_constant_value(s)
57+
58+
if isinstance(s_val, np.ndarray):
59+
s_val = s_val.item()
60+
61+
if shape_is_encoded or s_val is not None and s_val > 0:
62+
type_s = shape_key(s_val)
63+
else:
64+
type_s = shape_bottom
65+
except NotScalarConstantError:
66+
type_s = shape_bottom
67+
68+
if not (xts <= -1 or type_s <= -1 or type_s == xts):
69+
raise AssertionError(
70+
f"SpecifyShape: Got shape {xts} at index {i}, expected {type_s}."
71+
)
72+
73+
type_shape += (max(type_s, xts),)
74+
75+
return type_shape
76+
77+
2478
def register_shape_c_code(type, code, version=()):
2579
"""
2680
Tell Shape Op how to generate C code for an Aesara Type.
@@ -383,7 +437,6 @@ class SpecifyShape(COp):
383437
_f16_ok = True
384438

385439
def make_node(self, x, *shape):
386-
from aesara.tensor.basic import get_scalar_constant_value
387440

388441
x = at.as_tensor_variable(x)
389442

@@ -406,18 +459,7 @@ def make_node(self, x, *shape):
406459
f"Input `x` is {x.type.ndim}-dimensional and will never match a shape of length {len(shape)}."
407460
)
408461

409-
type_shape = [None] * x.ndim
410-
for i, (xts, s) in enumerate(zip(x.type.shape, shape)):
411-
if xts is not None:
412-
type_shape[i] = xts
413-
else:
414-
try:
415-
type_s = get_scalar_constant_value(s)
416-
if type_s is not None:
417-
type_shape[i] = int(type_s)
418-
except NotScalarConstantError:
419-
pass
420-
462+
type_shape = filter_shape_vars(x.type.shape_encoded, shape)
421463
out_var = x.type.clone(shape=type_shape)()
422464

423465
return Apply(self, [x, *shape], [out_var])
@@ -601,6 +643,7 @@ def make_node(self, x, shp):
601643
x = at.as_tensor_variable(x)
602644
shp_orig = shp
603645
shp = at.as_tensor_variable(shp, ndim=1)
646+
604647
if not (
605648
shp.dtype in int_dtypes
606649
or (isinstance(shp, TensorConstant) and shp.data.size == 0)

aesara/tensor/subtensor.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1549,7 +1549,10 @@ def make_node(self, x, y, *inputs):
15491549
f"Wrong type for Subtensor template. Expected {input.type}, got {expected_type}."
15501550
)
15511551

1552-
return Apply(self, (x, y) + inputs, [x.type()])
1552+
out_var = x.type.clone(
1553+
shape=tuple(1 if s == 1 else None for s in x.type.shape)
1554+
)()
1555+
return Apply(self, (x, y) + inputs, [out_var])
15531556

15541557
def decl_view(self):
15551558
return "PyArrayObject * zview = NULL;"
@@ -2180,7 +2183,10 @@ def make_node(self, x, y, ilist):
21802183
% (opname, x_.type.ndim, y_.type.ndim)
21812184
)
21822185

2183-
return Apply(self, [x_, y_, ilist_], [x_.type()])
2186+
out_var = x_.type.clone(
2187+
shape=tuple(1 if s == 1 else None for s in x_.type.shape)
2188+
)()
2189+
return Apply(self, [x_, y_, ilist_], [out_var])
21842190

21852191
def copy_of_x(self, x):
21862192
"""

aesara/tensor/type.py

Lines changed: 92 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,26 @@
5454
}
5555

5656

57+
def parse_bcast_and_shape(s):
58+
if s is None:
59+
return (None, False)
60+
elif s == 1:
61+
return (1, False)
62+
elif s >= 0:
63+
return (s, True)
64+
elif s < 0:
65+
# The second flag states that this dimension's size cannot be
66+
# equal to 1
67+
return (None, True)
68+
69+
70+
def shape_key(s):
71+
if s is None:
72+
return -2
73+
74+
return s
75+
76+
5777
class TensorType(CType[np.ndarray], HasDataType, HasShape):
5878
r"""Symbolic `Type` representing `numpy.ndarray`\s."""
5979

@@ -72,7 +92,6 @@ def __init__(
7292
dtype: Union[str, np.dtype],
7393
shape: Optional[Iterable[Optional[Union[bool, int]]]] = None,
7494
name: Optional[str] = None,
75-
broadcastable: Optional[Iterable[bool]] = None,
7695
):
7796
r"""
7897
@@ -82,7 +101,8 @@ def __init__(
82101
A NumPy dtype (e.g. ``"int64"``).
83102
shape
84103
The static shape information. ``None``\s are used to indicate
85-
unknown shape values for their respective dimensions.
104+
unknown shape values for their respective dimensions and ``-1`` to
105+
indicate the constraint ``shape != 1``.
86106
If `shape` is a list of ``bool``\s, the ``True`` elements of are
87107
converted to ``1``\s and the ``False`` values are converted to
88108
``None``\s.
@@ -91,12 +111,7 @@ def __init__(
91111
92112
"""
93113

94-
if broadcastable is not None:
95-
warnings.warn(
96-
"The `broadcastable` keyword is deprecated; use `shape`.",
97-
DeprecationWarning,
98-
)
99-
shape = broadcastable
114+
self.name = name
100115

101116
if str(dtype) == "floatX":
102117
self.dtype = config.floatX
@@ -106,30 +121,21 @@ def __init__(
106121

107122
self.dtype = np.dtype(dtype).name
108123

109-
def parse_bcast_and_shape(s):
110-
if isinstance(s, (bool, np.bool_)):
111-
return 1 if s else None
112-
else:
113-
return s
114-
115-
self.shape = tuple(parse_bcast_and_shape(s) for s in shape)
116124
self.dtype_specs() # error checking is done there
117-
self.name = name
118125
self.numpy_dtype = np.dtype(self.dtype)
119126

120-
def clone(
121-
self, dtype=None, shape=None, broadcastable=None, **kwargs
122-
) -> "TensorType":
123-
if broadcastable is not None:
124-
warnings.warn(
125-
"The `broadcastable` keyword is deprecated; use `shape`.",
126-
DeprecationWarning,
127-
)
128-
shape = broadcastable
127+
self.shape_encoded = tuple(shape_key(s) for s in shape)
128+
129+
assert isinstance(self.shape_encoded, tuple)
130+
assert all(
131+
isinstance(s, int) and not isinstance(s, bool) for s in self.shape_encoded
132+
)
133+
134+
def clone(self, dtype=None, shape=None, **kwargs) -> "TensorType":
129135
if dtype is None:
130136
dtype = self.dtype
131137
if shape is None:
132-
shape = self.shape
138+
shape = self.shape_encoded
133139
return type(self)(dtype, shape, name=self.name)
134140

135141
def filter(self, data, strict=False, allow_downcast=None):
@@ -243,16 +249,24 @@ def filter(self, data, strict=False, allow_downcast=None):
243249
" Aesara C code does not support that.",
244250
)
245251

246-
if not all(
247-
ds == ts if ts is not None else True
248-
for ds, ts in zip(data.shape, self.shape)
249-
):
250-
raise TypeError(
251-
f"The type's shape ({self.shape}) is not compatible with the data's ({data.shape})"
252-
)
252+
def check_shape_info(i, s_val, s_info):
253+
if s_info == -1 and s_val == 1:
254+
raise ValueError(
255+
f"Value's shape in dimension {i} is not compatible "
256+
f"with the constraint: {s_val} != 1"
257+
)
258+
if s_info > -1 and s_val != s_info:
259+
raise ValueError(
260+
f"Value's shape in dimension {i} is not compatible "
261+
f"with the constraint: {s_val} == {s_info}"
262+
)
263+
264+
for i, (s_val, s_info) in enumerate(zip(np.shape(data), self.shape_encoded)):
265+
check_shape_info(i, s_val, s_info)
253266

254267
if self.filter_checks_isfinite and not np.all(np.isfinite(data)):
255268
raise ValueError("Non-finite elements not allowed")
269+
256270
return data
257271

258272
def filter_variable(self, other, allow_convert=True):
@@ -308,7 +322,10 @@ def in_same_class(self, otype):
308322
if (
309323
isinstance(otype, TensorType)
310324
and otype.dtype == self.dtype
311-
and otype.broadcastable == self.broadcastable
325+
and all(
326+
s == o_s if s == 1 or o_s == 1 else True
327+
for s, o_s in zip(self.shape, otype.shape)
328+
)
312329
):
313330
return True
314331
return False
@@ -320,7 +337,11 @@ def is_super(self, otype):
320337
and otype.ndim == self.ndim
321338
# `otype` is allowed to be as or more shape-specific than `self`,
322339
# but not less
323-
and all(sb == ob or sb is None for sb, ob in zip(self.shape, otype.shape))
340+
and all(
341+
s == o_s if s > -1 and o_s > -1 else s <= o_s
342+
# not (s is not None and s >= 0 and max(s, o_s, key=shape_key) != s)
343+
for s, o_s in zip(self.shape_encoded, otype.shape_encoded)
344+
)
324345
):
325346
return True
326347

@@ -334,13 +355,22 @@ def convert_variable(self, var):
334355
if (self.ndim == var.type.ndim) and (self.dtype == var.type.dtype):
335356
# `var.type` only differs from `self` in that its shape is (at least partially)
336357
# less specific than `self`, so we convert `var` to `self`'s `Type`.
337-
# `specify_shape` will combine the more precise shapes of the two types
338-
return aesara.tensor.specify_shape(var, self.shape)
358+
# `specify_shape` will combine the more precise shapes of the two types.
359+
360+
new_shape_encoded = ()
361+
for s, o_s in zip(self.shape_encoded, var.type.shape_encoded):
362+
363+
if s > -1 and o_s > -1 and s != o_s:
364+
raise ValueError(
365+
f"Incompatible shapes: {self.shape_encoded}, {var.type.shape_encoded}"
366+
)
367+
368+
new_shape_encoded += (max(s, o_s),)
369+
370+
return aesara.tensor.specify_shape(var, new_shape_encoded)
339371

340372
@staticmethod
341373
def values_eq(a, b, force_same_dtype=True):
342-
# TODO: check to see if the shapes must match; for now, we err on safe
343-
# side...
344374
if a.shape != b.shape:
345375
return False
346376
if force_same_dtype and a.dtype != b.dtype:
@@ -367,14 +397,23 @@ def __eq__(self, other):
367397
if type(self) != type(other):
368398
return NotImplemented
369399

370-
return other.dtype == self.dtype and other.shape == self.shape
400+
return other.dtype == self.dtype and other.shape_encoded == self.shape_encoded
371401

372402
def __hash__(self):
373-
return hash((type(self), self.dtype, self.shape))
403+
return hash((type(self), self.dtype, self.shape_encoded))
404+
405+
@property
406+
def shape(self) -> Tuple[Optional[Union[int]]]:
407+
"""Return a static shape tuple with unknown values equal to ``None``."""
408+
return tuple(s if s > -1 else None for s in self.shape_encoded)
374409

375410
@property
376411
def broadcastable(self):
377412
"""A boolean tuple indicating which dimensions have a shape equal to one."""
413+
warnings.warn(
414+
"TensorType.broadcastable is deprecated; use TensorType.shape",
415+
DeprecationWarning,
416+
)
378417
return tuple(s == 1 for s in self.shape)
379418

380419
@property
@@ -386,7 +425,18 @@ def __str__(self):
386425
if self.name:
387426
return self.name
388427
else:
389-
return f"TensorType({self.dtype}, {self.shape})"
428+
429+
def shape_str(s):
430+
if s == -1:
431+
return ">1"
432+
elif s < -1:
433+
return "?"
434+
else:
435+
return str(s)
436+
437+
formatted_shape = ", ".join([shape_str(s) for s in self.shape_encoded])
438+
439+
return f"TensorType({self.dtype}, ({formatted_shape}))"
390440

391441
def __repr__(self):
392442
return str(self)

0 commit comments

Comments
 (0)