33import numpy as np
44import pytest
55import scipy .stats .distributions as sp
6- from aesara .graph .basic import Variable
6+ from aesara .graph .basic import Variable , equal_computations
77from aesara .tensor .random .basic import CategoricalRV
88from aesara .tensor .shape import shape_tuple
99from aesara .tensor .subtensor import as_index_constant
1010
1111from aeppl .joint_logprob import factorized_joint_logprob , joint_logprob
12- from aeppl .mixture import expand_indices
12+ from aeppl .mixture import MixtureRV , expand_indices
13+ from aeppl .rewriting import construct_ir_fgraph
1314from tests .test_logprob import scipy_logprob
1415from tests .utils import assert_no_rvs
1516
@@ -67,7 +68,14 @@ def create_mix_model(size, axis):
6768
6869
6970@aesara .config .change_flags (compute_test_value = "warn" )
70- def test_compute_test_value ():
71+ @pytest .mark .parametrize (
72+ "op_constructor" ,
73+ [
74+ lambda _I , _X , _Y : at .stack ([_X , _Y ])[_I ],
75+ lambda _I , _X , _Y : at .switch (_I , _X , _Y ),
76+ ],
77+ )
78+ def test_compute_test_value (op_constructor ):
7179
7280 srng = at .random .RandomStream (29833 )
7381
@@ -82,7 +90,7 @@ def test_compute_test_value():
8290 i_vv = I_rv .clone ()
8391 i_vv .name = "i"
8492
85- M_rv = at . stack ([ X_rv , Y_rv ])[ I_rv ]
93+ M_rv = op_constructor ( I_rv , X_rv , Y_rv )
8694 M_rv .name = "M"
8795
8896 m_vv = M_rv .clone ()
@@ -705,3 +713,49 @@ def test_mixture_with_DiracDelta():
705713 logp_res = factorized_joint_logprob ({M_rv : m_vv , I_rv : i_vv })
706714
707715 assert m_vv in logp_res
716+
717+
718+ def test_switch_mixture ():
719+ srng = at .random .RandomStream (29833 )
720+
721+ X_rv = srng .normal (- 10.0 , 0.1 , name = "X" )
722+ Y_rv = srng .normal (10.0 , 0.1 , name = "Y" )
723+
724+ I_rv = srng .bernoulli (0.5 , name = "I" )
725+ i_vv = I_rv .clone ()
726+ i_vv .name = "i"
727+
728+ Z1_rv = at .switch (I_rv , X_rv , Y_rv )
729+ z_vv = Z1_rv .clone ()
730+ z_vv .name = "z1"
731+
732+ fgraph , _ , _ = construct_ir_fgraph ({Z1_rv : z_vv , I_rv : i_vv })
733+
734+ assert isinstance (fgraph .outputs [0 ].owner .op , MixtureRV )
735+ assert not hasattr (
736+ fgraph .outputs [0 ].tag , "test_value"
737+ ) # aesara.config.compute_test_value == "off"
738+ assert fgraph .outputs [0 ].name is None
739+
740+ Z1_rv .name = "Z1"
741+
742+ fgraph , _ , _ = construct_ir_fgraph ({Z1_rv : z_vv , I_rv : i_vv })
743+
744+ assert fgraph .outputs [0 ].name == "Z1-mixture"
745+
746+ # building the identical graph but with a stack to check that mixture computations are identical
747+
748+ Z2_rv = at .stack ((X_rv , Y_rv ))[I_rv ]
749+
750+ fgraph2 , _ , _ = construct_ir_fgraph ({Z2_rv : z_vv , I_rv : i_vv })
751+
752+ assert equal_computations (fgraph .outputs , fgraph2 .outputs )
753+
754+ z1_logp = joint_logprob ({Z1_rv : z_vv , I_rv : i_vv })
755+ z2_logp = joint_logprob ({Z2_rv : z_vv , I_rv : i_vv })
756+
757+ # below should follow immediately from the equal_computations assertion above
758+ assert equal_computations ([z1_logp ], [z2_logp ])
759+
760+ np .testing .assert_almost_equal (0.69049938 , z1_logp .eval ({z_vv : - 10 , i_vv : 0 }))
761+ np .testing .assert_almost_equal (0.69049938 , z2_logp .eval ({z_vv : - 10 , i_vv : 0 }))
0 commit comments