@@ -110,10 +110,12 @@ def _as_tensor_Variable(x, name, ndim, **kwargs):
110110
111111 if x .type .ndim > ndim :
112112 # Strip off leading broadcastable dimensions
113- non_broadcastables = [idx for idx in range (x .ndim ) if not x .broadcastable [idx ]]
113+ non_broadcastables = [
114+ idx for idx in range (x .type .ndim ) if x .type .shape_encoded [idx ] != 1
115+ ]
114116
115117 if non_broadcastables :
116- x = x .dimshuffle (list (range (x .ndim ))[non_broadcastables [0 ] :])
118+ x = x .dimshuffle (list (range (x .type . ndim ))[non_broadcastables [0 ] :])
117119 else :
118120 x = x .dimshuffle ()
119121
@@ -2210,18 +2212,18 @@ def make_node(self, axis, *tensors):
22102212 "Join cannot handle arguments of dimension 0."
22112213 " Use `stack` to join scalar values."
22122214 )
2213- # Handle single-tensor joins immediately.
2215+
22142216 if len (tensors ) == 1 :
2215- bcastable = list ( tensors [0 ].type .broadcastable )
2217+ out_shape = tensors [0 ].type .shape_encoded
22162218 else :
22172219 # When the axis is fixed, a dimension should be
22182220 # broadcastable if at least one of the inputs is
22192221 # broadcastable on that dimension (see justification below),
22202222 # except for the axis dimension.
22212223 # Initialize bcastable all false, and then fill in some trues with
22222224 # the loops.
2223- bcastable = [ False ] * len ( tensors [0 ].type .broadcastable )
2224- ndim = len ( bcastable )
2225+ ndim = tensors [0 ].type .ndim
2226+ out_shape = [ None ] * ndim
22252227
22262228 if not isinstance (axis , int ):
22272229 try :
@@ -2246,25 +2248,25 @@ def make_node(self, axis, *tensors):
22462248 axis += ndim
22472249
22482250 for x in tensors :
2249- for current_axis , bflag in enumerate (x .type .broadcastable ):
2251+ for current_axis , s in enumerate (x .type .shape_encoded ):
22502252 # Constant negative axis can no longer be negative at
22512253 # this point. It safe to compare this way.
22522254 if current_axis == axis :
22532255 continue
2254- if bflag :
2255- bcastable [current_axis ] = True
2256+ if s == 1 :
2257+ out_shape [current_axis ] = 1
22562258 try :
2257- bcastable [axis ] = False
2259+ out_shape [axis ] = None
22582260 except IndexError :
22592261 raise ValueError (
22602262 f"Axis value { axis } is out of range for the given input dimensions"
22612263 )
22622264 else :
22632265 # When the axis may vary, no dimension can be guaranteed to be
22642266 # broadcastable.
2265- bcastable = [False ] * len ( tensors [0 ].type .broadcastable )
2267+ out_shape = [None ] * tensors [0 ].type .ndim
22662268
2267- if not builtins .all (x .ndim == len (bcastable ) for x in tensors ):
2269+ if not builtins .all (x .ndim == len (out_shape ) for x in tensors ):
22682270 raise TypeError (
22692271 "Only tensors with the same number of dimensions can be joined"
22702272 )
@@ -2274,7 +2276,7 @@ def make_node(self, axis, *tensors):
22742276 if inputs [0 ].type .dtype not in int_dtypes :
22752277 raise TypeError (f"Axis value { inputs [0 ]} must be an integer type" )
22762278
2277- return Apply (self , inputs , [tensor (dtype = out_dtype , shape = bcastable )])
2279+ return Apply (self , inputs , [tensor (dtype = out_dtype , shape = out_shape )])
22782280
22792281 def perform (self , node , axis_and_tensors , out_ ):
22802282 (out ,) = out_
@@ -2385,6 +2387,7 @@ def grad(self, axis_and_tensors, grads):
23852387 # Split.make_node isn't always able to infer the right
23862388 # broadcast. As the grad need to keep the information,
23872389 # read it if needed.
2390+ # TODO FIXME: Remove all this broadcastable stuff.
23882391 split_gz = [
23892392 g
23902393 if g .type .broadcastable == t .type .broadcastable
@@ -2771,6 +2774,8 @@ def flatten(x, ndim=1):
27712774 else :
27722775 dims = (- 1 ,)
27732776 x_reshaped = _x .reshape (dims )
2777+
2778+ # TODO FIXME: Remove all this broadcastable stuff.
27742779 bcast_kept_dims = _x .broadcastable [: ndim - 1 ]
27752780 bcast_new_dim = builtins .all (_x .broadcastable [ndim - 1 :])
27762781 broadcastable = bcast_kept_dims + (bcast_new_dim ,)
@@ -2882,7 +2887,7 @@ def make_node(self, start, stop, step):
28822887 assert step .ndim == 0
28832888
28842889 inputs = [start , stop , step ]
2885- outputs = [tensor (self .dtype , ( False ,))]
2890+ outputs = [tensor (self .dtype , shape = ( None ,))]
28862891
28872892 return Apply (self , inputs , outputs )
28882893
@@ -3158,11 +3163,11 @@ def make_node(self, x, y, inverse):
31583163 elif x_dim < y_dim :
31593164 x = shape_padleft (x , n_ones = (y_dim - x_dim ))
31603165
3161- # Compute the broadcastable pattern of the output
3162- out_broadcastable = [
3163- xb and yb for xb , yb in zip (x .type .broadcastable , y .type .broadcastable )
3166+ out_shape = [
3167+ 1 if xb == 1 and yb == 1 else None
3168+ for xb , yb in zip (x .type .shape_encoded , y .type .shape_encoded )
31643169 ]
3165- out_type = tensor (dtype = x .type .dtype , shape = out_broadcastable )
3170+ out_type = tensor (dtype = x .type .dtype , shape = out_shape )
31663171
31673172 inputlist = [x , y , inverse ]
31683173 outputlist = [out_type ]
0 commit comments