5454}
5555
5656
57+ def parse_bcast_and_shape (s ):
58+ if s is None :
59+ return (None , False )
60+ elif s == 1 :
61+ return (1 , False )
62+ elif s >= 0 :
63+ return (s , True )
64+ elif s < 0 :
65+ # The second flag states that this dimension's size cannot be
66+ # equal to 1
67+ return (None , True )
68+
69+
70+ def shape_key (s ):
71+ if s is None :
72+ return - 2
73+
74+ return s
75+
76+
5777class TensorType (CType [np .ndarray ], HasDataType , HasShape ):
5878 r"""Symbolic `Type` representing `numpy.ndarray`\s."""
5979
@@ -72,7 +92,6 @@ def __init__(
7292 dtype : Union [str , np .dtype ],
7393 shape : Optional [Iterable [Optional [Union [bool , int ]]]] = None ,
7494 name : Optional [str ] = None ,
75- broadcastable : Optional [Iterable [bool ]] = None ,
7695 ):
7796 r"""
7897
@@ -82,7 +101,8 @@ def __init__(
82101 A NumPy dtype (e.g. ``"int64"``).
83102 shape
84103 The static shape information. ``None``\s are used to indicate
85- unknown shape values for their respective dimensions.
104+ unknown shape values for their respective dimensions and ``-1`` to
105+ indicate the constraint ``shape != 1``.
86106 If `shape` is a list of ``bool``\s, the ``True`` elements of are
87107 converted to ``1``\s and the ``False`` values are converted to
88108 ``None``\s.
@@ -91,12 +111,7 @@ def __init__(
91111
92112 """
93113
94- if broadcastable is not None :
95- warnings .warn (
96- "The `broadcastable` keyword is deprecated; use `shape`." ,
97- DeprecationWarning ,
98- )
99- shape = broadcastable
114+ self .name = name
100115
101116 if str (dtype ) == "floatX" :
102117 self .dtype = config .floatX
@@ -106,30 +121,21 @@ def __init__(
106121
107122 self .dtype = np .dtype (dtype ).name
108123
109- def parse_bcast_and_shape (s ):
110- if isinstance (s , (bool , np .bool_ )):
111- return 1 if s else None
112- else :
113- return s
114-
115- self .shape = tuple (parse_bcast_and_shape (s ) for s in shape )
116124 self .dtype_specs () # error checking is done there
117- self .name = name
118125 self .numpy_dtype = np .dtype (self .dtype )
119126
120- def clone (
121- self , dtype = None , shape = None , broadcastable = None , ** kwargs
122- ) -> "TensorType" :
123- if broadcastable is not None :
124- warnings .warn (
125- "The `broadcastable` keyword is deprecated; use `shape`." ,
126- DeprecationWarning ,
127- )
128- shape = broadcastable
127+ self .shape_encoded = tuple (shape_key (s ) for s in shape )
128+
129+ assert isinstance (self .shape_encoded , tuple )
130+ assert all (
131+ isinstance (s , int ) and not isinstance (s , bool ) for s in self .shape_encoded
132+ )
133+
134+ def clone (self , dtype = None , shape = None , ** kwargs ) -> "TensorType" :
129135 if dtype is None :
130136 dtype = self .dtype
131137 if shape is None :
132- shape = self .shape
138+ shape = self .shape_encoded
133139 return type (self )(dtype , shape , name = self .name )
134140
135141 def filter (self , data , strict = False , allow_downcast = None ):
@@ -243,16 +249,24 @@ def filter(self, data, strict=False, allow_downcast=None):
243249 " Aesara C code does not support that." ,
244250 )
245251
246- if not all (
247- ds == ts if ts is not None else True
248- for ds , ts in zip (data .shape , self .shape )
249- ):
250- raise TypeError (
251- f"The type's shape ({ self .shape } ) is not compatible with the data's ({ data .shape } )"
252- )
252+ def check_shape_info (i , s_val , s_info ):
253+ if s_info == - 1 and s_val == 1 :
254+ raise ValueError (
255+ f"Value's shape in dimension { i } is not compatible "
256+ f"with the constraint: { s_val } != 1"
257+ )
258+ if s_info > - 1 and s_val != s_info :
259+ raise ValueError (
260+ f"Value's shape in dimension { i } is not compatible "
261+ f"with the constraint: { s_val } == { s_info } "
262+ )
263+
264+ for i , (s_val , s_info ) in enumerate (zip (np .shape (data ), self .shape_encoded )):
265+ check_shape_info (i , s_val , s_info )
253266
254267 if self .filter_checks_isfinite and not np .all (np .isfinite (data )):
255268 raise ValueError ("Non-finite elements not allowed" )
269+
256270 return data
257271
258272 def filter_variable (self , other , allow_convert = True ):
@@ -308,7 +322,10 @@ def in_same_class(self, otype):
308322 if (
309323 isinstance (otype , TensorType )
310324 and otype .dtype == self .dtype
311- and otype .broadcastable == self .broadcastable
325+ and all (
326+ s == o_s if s == 1 or o_s == 1 else True
327+ for s , o_s in zip (self .shape , otype .shape )
328+ )
312329 ):
313330 return True
314331 return False
@@ -320,7 +337,11 @@ def is_super(self, otype):
320337 and otype .ndim == self .ndim
321338 # `otype` is allowed to be as or more shape-specific than `self`,
322339 # but not less
323- and all (sb == ob or sb is None for sb , ob in zip (self .shape , otype .shape ))
340+ and all (
341+ s == o_s if s > - 1 and o_s > - 1 else s <= o_s
342+ # not (s is not None and s >= 0 and max(s, o_s, key=shape_key) != s)
343+ for s , o_s in zip (self .shape_encoded , otype .shape_encoded )
344+ )
324345 ):
325346 return True
326347
@@ -334,13 +355,22 @@ def convert_variable(self, var):
334355 if (self .ndim == var .type .ndim ) and (self .dtype == var .type .dtype ):
335356 # `var.type` only differs from `self` in that its shape is (at least partially)
336357 # less specific than `self`, so we convert `var` to `self`'s `Type`.
337- # `specify_shape` will combine the more precise shapes of the two types
338- return aesara .tensor .specify_shape (var , self .shape )
358+ # `specify_shape` will combine the more precise shapes of the two types.
359+
360+ new_shape_encoded = ()
361+ for s , o_s in zip (self .shape_encoded , var .type .shape_encoded ):
362+
363+ if s > - 1 and o_s > - 1 and s != o_s :
364+ raise ValueError (
365+ f"Incompatible shapes: { self .shape_encoded } , { var .type .shape_encoded } "
366+ )
367+
368+ new_shape_encoded += (max (s , o_s ),)
369+
370+ return aesara .tensor .specify_shape (var , new_shape_encoded )
339371
340372 @staticmethod
341373 def values_eq (a , b , force_same_dtype = True ):
342- # TODO: check to see if the shapes must match; for now, we err on safe
343- # side...
344374 if a .shape != b .shape :
345375 return False
346376 if force_same_dtype and a .dtype != b .dtype :
@@ -367,14 +397,23 @@ def __eq__(self, other):
367397 if type (self ) != type (other ):
368398 return NotImplemented
369399
370- return other .dtype == self .dtype and other .shape == self .shape
400+ return other .dtype == self .dtype and other .shape_encoded == self .shape_encoded
371401
372402 def __hash__ (self ):
373- return hash ((type (self ), self .dtype , self .shape ))
403+ return hash ((type (self ), self .dtype , self .shape_encoded ))
404+
405+ @property
406+ def shape (self ) -> Tuple [Optional [Union [int ]]]:
407+ """Return a static shape tuple with unknown values equal to ``None``."""
408+ return tuple (s if s > - 1 else None for s in self .shape_encoded )
374409
375410 @property
376411 def broadcastable (self ):
377412 """A boolean tuple indicating which dimensions have a shape equal to one."""
413+ warnings .warn (
414+ "TensorType.broadcastable is deprecated; use TensorType.shape" ,
415+ DeprecationWarning ,
416+ )
378417 return tuple (s == 1 for s in self .shape )
379418
380419 @property
@@ -386,7 +425,18 @@ def __str__(self):
386425 if self .name :
387426 return self .name
388427 else :
389- return f"TensorType({ self .dtype } , { self .shape } )"
428+
429+ def shape_str (s ):
430+ if s == - 1 :
431+ return ">1"
432+ elif s < - 1 :
433+ return "?"
434+ else :
435+ return str (s )
436+
437+ formatted_shape = ", " .join ([shape_str (s ) for s in self .shape_encoded ])
438+
439+ return f"TensorType({ self .dtype } , ({ formatted_shape } ))"
390440
391441 def __repr__ (self ):
392442 return str (self )
0 commit comments