@@ -180,17 +180,17 @@ def rv_pull_down(x: TensorVariable, dont_touch_vars=None) -> TensorVariable:
180180class MixtureRV (Op ):
181181 """A placeholder used to specify a log-likelihood for a mixture sub-graph."""
182182
183- __props__ = ("indices_end_idx" , "out_dtype" , "out_broadcastable " )
183+ __props__ = ("indices_end_idx" , "out_dtype" , "out_shape " )
184184
185- def __init__ (self , indices_end_idx , out_dtype , out_broadcastable ):
185+ def __init__ (self , indices_end_idx , out_dtype , out_shape ):
186186 super ().__init__ ()
187187 self .indices_end_idx = indices_end_idx
188188 self .out_dtype = out_dtype
189- self .out_broadcastable = out_broadcastable
189+ self .out_shape = out_shape
190190
191191 def make_node (self , * inputs ):
192192 return Apply (
193- self , list (inputs ), [TensorType (self .out_dtype , self .out_broadcastable )()]
193+ self , list (inputs ), [TensorType (self .out_dtype , self .out_shape )()]
194194 )
195195
196196 def perform (self , node , inputs , outputs ):
@@ -285,7 +285,7 @@ def mixture_replace(fgraph, node):
285285 mix_op = MixtureRV (
286286 1 + len (mixing_indices ),
287287 old_mixture_rv .dtype ,
288- old_mixture_rv .broadcastable ,
288+ old_mixture_rv .type . shape ,
289289 )
290290 new_node = mix_op .make_node (* ([join_axis ] + mixing_indices + mixture_rvs ))
291291
@@ -337,7 +337,7 @@ def switch_mixture_replace(fgraph, node):
337337 mix_op = MixtureRV (
338338 2 ,
339339 old_mixture_rv .dtype ,
340- old_mixture_rv .broadcastable ,
340+ old_mixture_rv .type . shape ,
341341 )
342342 if node .inputs [0 ].ndim == 0 :
343343 # as_nontensor_scalar to allow graphs to be identical to mixture sub-graphs
@@ -397,7 +397,7 @@ def ifelse_mixture_replace(fgraph, node):
397397 mix_op = MixtureRV (
398398 2 ,
399399 old_mixture_rv .dtype ,
400- old_mixture_rv .broadcastable ,
400+ old_mixture_rv .type . shape ,
401401 )
402402
403403 if node .inputs [0 ].ndim == 0 :
0 commit comments