Skip to content
This repository was archived by the owner on Nov 17, 2025. It is now read-only.

Commit ac64604

Browse files
Add a non-one shape constraint to TensorType
1 parent 6a003e5 commit ac64604

File tree

17 files changed

+305
-181
lines changed

17 files changed

+305
-181
lines changed

aesara/gradient.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1802,13 +1802,11 @@ def verify_grad(
18021802
mode=mode,
18031803
)
18041804

1805-
tensor_pt = [
1806-
aesara.tensor.type.TensorType(
1807-
aesara.tensor.as_tensor_variable(p).dtype,
1808-
aesara.tensor.as_tensor_variable(p).broadcastable,
1809-
)(name=f"input {i}")
1810-
for i, p in enumerate(pt)
1811-
]
1805+
tensor_pt = []
1806+
for i, p in enumerate(pt):
1807+
p_t = aesara.tensor.as_tensor_variable(p).type()
1808+
p_t.name = f"input {i}"
1809+
tensor_pt.append(p_t)
18121810

18131811
# fun can be either a function or an actual Op instance
18141812
o_output = fun(*tensor_pt)

aesara/link/c/params_type.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ class ParamsType(CType):
324324
`ParamsType` constructor takes key-value args. Key will be the name of the
325325
attribute in the struct. Value is the Aesara type of this attribute,
326326
ie. an instance of (a subclass of) :class:`CType`
327-
(eg. ``TensorType('int64', (False,))``).
327+
(eg. ``TensorType('int64', (None,))``).
328328
329329
In a Python code any attribute named ``key`` will be available via::
330330

aesara/sandbox/multinomial.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,9 @@ def make_node(self, pvals, unis, n=1):
4444
odtype = pvals.dtype
4545
else:
4646
odtype = self.odtype
47-
out = at.tensor(dtype=odtype, shape=pvals.type.broadcastable)
47+
out = at.tensor(
48+
dtype=odtype, shape=tuple(1 if s == 1 else None for s in pvals.type.shape)
49+
)
4850
return Apply(self, [pvals, unis, as_scalar(n)], [out])
4951

5052
def grad(self, ins, outgrads):

aesara/sandbox/rng_mrg.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -379,10 +379,13 @@ def make_node(self, rstate, size):
379379
# this op should not be called directly.
380380
#
381381
# call through MRG_RandomStream instead.
382-
broad = []
382+
out_shape = []
383383
for i in range(self.output_type.ndim):
384-
broad.append(at.extract_constant(size[i]) == 1)
385-
output_type = self.output_type.clone(shape=broad)()
384+
if at.extract_constant(size[i]) == 1:
385+
out_shape.append(1)
386+
else:
387+
out_shape.append(None)
388+
output_type = self.output_type.clone(shape=out_shape)()
386389
rstate = as_tensor_variable(rstate)
387390
size = as_tensor_variable(size)
388391
return Apply(self, [rstate, size], [rstate.type(), output_type])
@@ -392,7 +395,7 @@ def new(cls, rstate, ndim, dtype, size):
392395
v_size = as_tensor_variable(size)
393396
if ndim is None:
394397
ndim = get_vector_length(v_size)
395-
op = cls(TensorType(dtype, (False,) * ndim))
398+
op = cls(TensorType(dtype, shape=(None,) * ndim))
396399
return op(rstate, v_size)
397400

398401
def perform(self, node, inp, out, params):

aesara/sparse/type.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,8 @@ def __init__(
7070
dtype: Union[str, np.dtype],
7171
shape: Optional[Iterable[Optional[Union[bool, int]]]] = None,
7272
name: Optional[str] = None,
73-
broadcastable: Optional[Iterable[bool]] = None,
7473
):
75-
if shape is None and broadcastable is None:
74+
if shape is None:
7675
shape = (None, None)
7776

7877
if format not in self.format_cls:
@@ -82,13 +81,12 @@ def __init__(
8281

8382
self.format = format
8483

85-
super().__init__(dtype, shape=shape, name=name, broadcastable=broadcastable)
84+
super().__init__(dtype, shape=shape, name=name)
8685

8786
def clone(
8887
self,
8988
dtype=None,
9089
shape=None,
91-
broadcastable=None,
9290
**kwargs,
9391
):
9492
format: Optional[SparsityTypes] = kwargs.pop("format", self.format)

aesara/tensor/basic.py

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -110,10 +110,12 @@ def _as_tensor_Variable(x, name, ndim, **kwargs):
110110

111111
if x.type.ndim > ndim:
112112
# Strip off leading broadcastable dimensions
113-
non_broadcastables = [idx for idx in range(x.ndim) if not x.broadcastable[idx]]
113+
non_broadcastables = [
114+
idx for idx in range(x.type.ndim) if x.type.shape_encoded[idx] != 1
115+
]
114116

115117
if non_broadcastables:
116-
x = x.dimshuffle(list(range(x.ndim))[non_broadcastables[0] :])
118+
x = x.dimshuffle(list(range(x.type.ndim))[non_broadcastables[0] :])
117119
else:
118120
x = x.dimshuffle()
119121

@@ -2210,18 +2212,18 @@ def make_node(self, axis, *tensors):
22102212
"Join cannot handle arguments of dimension 0."
22112213
" Use `stack` to join scalar values."
22122214
)
2213-
# Handle single-tensor joins immediately.
2215+
22142216
if len(tensors) == 1:
2215-
bcastable = list(tensors[0].type.broadcastable)
2217+
out_shape = tensors[0].type.shape_encoded
22162218
else:
22172219
# When the axis is fixed, a dimension should be
22182220
# broadcastable if at least one of the inputs is
22192221
# broadcastable on that dimension (see justification below),
22202222
# except for the axis dimension.
22212223
# Initialize bcastable all false, and then fill in some trues with
22222224
# the loops.
2223-
bcastable = [False] * len(tensors[0].type.broadcastable)
2224-
ndim = len(bcastable)
2225+
ndim = tensors[0].type.ndim
2226+
out_shape = [None] * ndim
22252227

22262228
if not isinstance(axis, int):
22272229
try:
@@ -2246,25 +2248,25 @@ def make_node(self, axis, *tensors):
22462248
axis += ndim
22472249

22482250
for x in tensors:
2249-
for current_axis, bflag in enumerate(x.type.broadcastable):
2251+
for current_axis, s in enumerate(x.type.shape_encoded):
22502252
# Constant negative axis can no longer be negative at
22512253
# this point. It safe to compare this way.
22522254
if current_axis == axis:
22532255
continue
2254-
if bflag:
2255-
bcastable[current_axis] = True
2256+
if s == 1:
2257+
out_shape[current_axis] = 1
22562258
try:
2257-
bcastable[axis] = False
2259+
out_shape[axis] = None
22582260
except IndexError:
22592261
raise ValueError(
22602262
f"Axis value {axis} is out of range for the given input dimensions"
22612263
)
22622264
else:
22632265
# When the axis may vary, no dimension can be guaranteed to be
22642266
# broadcastable.
2265-
bcastable = [False] * len(tensors[0].type.broadcastable)
2267+
out_shape = [None] * tensors[0].type.ndim
22662268

2267-
if not builtins.all(x.ndim == len(bcastable) for x in tensors):
2269+
if not builtins.all(x.ndim == len(out_shape) for x in tensors):
22682270
raise TypeError(
22692271
"Only tensors with the same number of dimensions can be joined"
22702272
)
@@ -2274,7 +2276,7 @@ def make_node(self, axis, *tensors):
22742276
if inputs[0].type.dtype not in int_dtypes:
22752277
raise TypeError(f"Axis value {inputs[0]} must be an integer type")
22762278

2277-
return Apply(self, inputs, [tensor(dtype=out_dtype, shape=bcastable)])
2279+
return Apply(self, inputs, [tensor(dtype=out_dtype, shape=out_shape)])
22782280

22792281
def perform(self, node, axis_and_tensors, out_):
22802282
(out,) = out_
@@ -2385,6 +2387,7 @@ def grad(self, axis_and_tensors, grads):
23852387
# Split.make_node isn't always able to infer the right
23862388
# broadcast. As the grad need to keep the information,
23872389
# read it if needed.
2390+
# TODO FIXME: Remove all this broadcastable stuff.
23882391
split_gz = [
23892392
g
23902393
if g.type.broadcastable == t.type.broadcastable
@@ -2771,6 +2774,8 @@ def flatten(x, ndim=1):
27712774
else:
27722775
dims = (-1,)
27732776
x_reshaped = _x.reshape(dims)
2777+
2778+
# TODO FIXME: Remove all this broadcastable stuff.
27742779
bcast_kept_dims = _x.broadcastable[: ndim - 1]
27752780
bcast_new_dim = builtins.all(_x.broadcastable[ndim - 1 :])
27762781
broadcastable = bcast_kept_dims + (bcast_new_dim,)
@@ -2882,7 +2887,7 @@ def make_node(self, start, stop, step):
28822887
assert step.ndim == 0
28832888

28842889
inputs = [start, stop, step]
2885-
outputs = [tensor(self.dtype, (False,))]
2890+
outputs = [tensor(self.dtype, shape=(None,))]
28862891

28872892
return Apply(self, inputs, outputs)
28882893

@@ -3158,11 +3163,11 @@ def make_node(self, x, y, inverse):
31583163
elif x_dim < y_dim:
31593164
x = shape_padleft(x, n_ones=(y_dim - x_dim))
31603165

3161-
# Compute the broadcastable pattern of the output
3162-
out_broadcastable = [
3163-
xb and yb for xb, yb in zip(x.type.broadcastable, y.type.broadcastable)
3166+
out_shape = [
3167+
1 if xb == 1 and yb == 1 else None
3168+
for xb, yb in zip(x.type.shape_encoded, y.type.shape_encoded)
31643169
]
3165-
out_type = tensor(dtype=x.type.dtype, shape=out_broadcastable)
3170+
out_type = tensor(dtype=x.type.dtype, shape=out_shape)
31663171

31673172
inputlist = [x, y, inverse]
31683173
outputlist = [out_type]

aesara/tensor/blas.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@
167167
from aesara.tensor.shape import specify_broadcastable
168168
from aesara.tensor.type import (
169169
DenseTensorType,
170+
TensorType,
170171
integer_dtypes,
171172
tensor,
172173
values_eq_approx_remove_inf_nan,
@@ -1204,11 +1205,11 @@ def _as_scalar(res, dtype=None):
12041205
"""Return ``None`` or a `TensorVariable` of float type"""
12051206
if dtype is None:
12061207
dtype = config.floatX
1207-
if all(res.type.broadcastable):
1208+
if all(s == 1 for s in res.type.shape):
12081209
while res.owner and isinstance(res.owner.op, DimShuffle):
12091210
res = res.owner.inputs[0]
12101211
# may still have some number of True's
1211-
if res.type.broadcastable:
1212+
if res.type.ndim > 0:
12121213
rval = res.dimshuffle()
12131214
else:
12141215
rval = res
@@ -1230,16 +1231,16 @@ def _is_real_matrix(res):
12301231
return (
12311232
res.type.dtype in ("float16", "float32", "float64")
12321233
and res.type.ndim == 2
1233-
and res.type.broadcastable[0] is False
1234-
and res.type.broadcastable[1] is False
1234+
and res.type.shape[0] != 1
1235+
and res.type.shape[1] != 1
12351236
) # cope with tuple vs. list
12361237

12371238

12381239
def _is_real_vector(res):
12391240
return (
12401241
res.type.dtype in ("float16", "float32", "float64")
12411242
and res.type.ndim == 1
1242-
and res.type.broadcastable[0] is False
1243+
and res.type.shape[0] != 1
12431244
)
12441245

12451246

@@ -1298,9 +1299,7 @@ def scaled(thing):
12981299
else:
12991300
return scale * thing
13001301

1301-
try:
1302-
r.type.broadcastable
1303-
except Exception:
1302+
if not isinstance(r.type, TensorType):
13041303
return None
13051304

13061305
if (r.type.ndim not in (1, 2)) or r.type.dtype not in (
@@ -1333,10 +1332,10 @@ def scaled(thing):
13331332
vectors = []
13341333
matrices = []
13351334
for i in r.owner.inputs:
1336-
if all(i.type.broadcastable):
1335+
if all(s == 1 for s in i.type.shape):
13371336
while i.owner and isinstance(i.owner.op, DimShuffle):
13381337
i = i.owner.inputs[0]
1339-
if i.type.broadcastable:
1338+
if i.type.ndim > 0:
13401339
scalars.append(i.dimshuffle())
13411340
else:
13421341
scalars.append(i)
@@ -1681,8 +1680,7 @@ def make_node(self, x, y):
16811680
raise TypeError(y)
16821681
if y.type.dtype != x.type.dtype:
16831682
raise TypeError("dtype mismatch to Dot22")
1684-
bz = (x.type.broadcastable[0], y.type.broadcastable[1])
1685-
outputs = [tensor(x.type.dtype, bz)]
1683+
outputs = [tensor(x.type.dtype, (x.type.shape[0], y.type.shape[1]))]
16861684
return Apply(self, [x, y], outputs)
16871685

16881686
def perform(self, node, inp, out):
@@ -1986,8 +1984,8 @@ def make_node(self, x, y, a):
19861984
if not a.dtype.startswith("float") and not a.dtype.startswith("complex"):
19871985
raise TypeError("Dot22Scalar requires float or complex args", a.dtype)
19881986

1989-
bz = [x.type.broadcastable[0], y.type.broadcastable[1]]
1990-
outputs = [tensor(x.type.dtype, bz)]
1987+
sz = [x.type.shape[0], y.type.shape[1]]
1988+
outputs = [tensor(x.type.dtype, shape=sz)]
19911989
return Apply(self, [x, y, a], outputs)
19921990

19931991
def perform(self, node, inp, out):
@@ -2213,12 +2211,16 @@ def make_node(self, *inputs):
22132211
dtype = aesara.scalar.upcast(*[input.type.dtype for input in inputs])
22142212
# upcast inputs to common dtype if needed
22152213
upcasted_inputs = [at.cast(input, dtype) for input in inputs]
2216-
broadcastable = (
2217-
(inputs[0].type.broadcastable[0] or inputs[1].type.broadcastable[0],)
2218-
+ inputs[0].type.broadcastable[1:-1]
2219-
+ inputs[1].type.broadcastable[2:]
2214+
out_shape = (
2215+
(
2216+
1
2217+
if inputs[0].type.shape[0] == 1 or inputs[1].type.shape[0] == 1
2218+
else None,
2219+
)
2220+
+ inputs[0].type.shape[1:-1]
2221+
+ inputs[1].type.shape[2:]
22202222
)
2221-
return Apply(self, upcasted_inputs, [tensor(dtype, broadcastable)])
2223+
return Apply(self, upcasted_inputs, [tensor(dtype, out_shape)])
22222224

22232225
def perform(self, node, inp, out):
22242226
x, y = inp

aesara/tensor/math.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1911,15 +1911,14 @@ def make_node(self, *inputs):
19111911
"aesara.tensor.dot instead."
19121912
)
19131913

1914-
i_broadcastables = [input.type.broadcastable for input in inputs]
1915-
bx, by = i_broadcastables
1916-
if len(by) == 2: # y is a matrix
1917-
bz = bx[:-1] + by[-1:]
1918-
elif len(by) == 1: # y is vector
1919-
bz = bx[:-1]
1914+
sx, sy = [input.type.shape for input in inputs]
1915+
if len(sy) == 2: # y is a matrix
1916+
sz = sx[:-1] + sy[-1:]
1917+
elif len(sy) == 1: # y is vector
1918+
sz = sx[:-1]
19201919

19211920
i_dtypes = [input.type.dtype for input in inputs]
1922-
outputs = [tensor(aes.upcast(*i_dtypes), bz)]
1921+
outputs = [tensor(aes.upcast(*i_dtypes), shape=sz)]
19231922
return Apply(self, inputs, outputs)
19241923

19251924
def perform(self, node, inp, out):

0 commit comments

Comments
 (0)