Skip to content
This repository was archived by the owner on Nov 17, 2025. It is now read-only.

Commit 8ce3c63

Browse files
larryshamalamabrandonwillard
authored andcommitted
Introduce graph rewrite for mixture sub-graphs defined via Switch Op
1 parent 8b298d1 commit 8ce3c63

File tree

2 files changed

+113
-6
lines changed

2 files changed

+113
-6
lines changed

aeppl/mixture.py

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,17 @@
1111
pre_greedy_node_rewriter,
1212
)
1313
from aesara.ifelse import ifelse
14+
from aesara.scalar.basic import Switch
1415
from aesara.tensor.basic import Join, MakeVector
16+
from aesara.tensor.elemwise import Elemwise
1517
from aesara.tensor.random.rewriting import (
1618
local_dimshuffle_rv_lift,
1719
local_subtensor_rv_lift,
1820
)
1921
from aesara.tensor.shape import shape_tuple
2022
from aesara.tensor.subtensor import (
2123
as_index_literal,
24+
as_nontensor_scalar,
2225
get_canonical_form_slice,
2326
is_basic_idx,
2427
)
@@ -251,7 +254,6 @@ def mixture_replace(fgraph, node):
251254
From these terms, new terms ``Z_rv[i] = mixture_comps[i][i == I_rv]`` are
252255
created for each ``i`` in ``enumerate(mixture_comps)``.
253256
"""
254-
255257
rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)
256258

257259
if rv_map_feature is None:
@@ -303,6 +305,56 @@ def mixture_replace(fgraph, node):
303305
return [new_mixture_rv]
304306

305307

308+
@node_rewriter((Elemwise,))
309+
def switch_mixture_replace(fgraph, node):
310+
rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)
311+
312+
if rv_map_feature is None:
313+
return None # pragma: no cover
314+
315+
if not isinstance(node.op.scalar_op, Switch):
316+
return None # pragma: no cover
317+
318+
old_mixture_rv = node.default_output()
319+
# idx, component_1, component_2 = node.inputs
320+
321+
mixture_rvs = []
322+
323+
for component_rv in node.inputs[1:]:
324+
if not (
325+
component_rv.owner
326+
and isinstance(component_rv.owner.op, MeasurableVariable)
327+
and component_rv not in rv_map_feature.rv_values
328+
):
329+
return None
330+
new_node = assign_custom_measurable_outputs(component_rv.owner)
331+
out_idx = component_rv.owner.outputs.index(component_rv)
332+
new_comp_rv = new_node.outputs[out_idx]
333+
mixture_rvs.append(new_comp_rv)
334+
335+
mix_op = MixtureRV(
336+
2,
337+
old_mixture_rv.dtype,
338+
old_mixture_rv.broadcastable,
339+
)
340+
new_node = mix_op.make_node(
341+
*([NoneConst, as_nontensor_scalar(node.inputs[0])] + mixture_rvs)
342+
)
343+
344+
new_mixture_rv = new_node.default_output()
345+
346+
if aesara.config.compute_test_value != "off":
347+
if not hasattr(old_mixture_rv.tag, "test_value"):
348+
compute_test_value(node)
349+
350+
new_mixture_rv.tag.test_value = old_mixture_rv.tag.test_value
351+
352+
if old_mixture_rv.name:
353+
new_mixture_rv.name = f"{old_mixture_rv.name}-mixture"
354+
355+
return [new_mixture_rv]
356+
357+
306358
@_logprob.register(MixtureRV)
307359
def logprob_MixtureRV(
308360
op, values, *inputs: Optional[Union[TensorVariable, slice]], name=None, **kwargs
@@ -368,7 +420,8 @@ def logprob_MixtureRV(
368420
logprob_rewrites_db.register(
369421
"mixture_replace",
370422
EquilibriumGraphRewriter(
371-
[mixture_replace], max_use_ratio=aesara.config.optdb__max_use_ratio
423+
[mixture_replace, switch_mixture_replace],
424+
max_use_ratio=aesara.config.optdb__max_use_ratio,
372425
),
373426
0,
374427
"basic",

tests/test_mixture.py

Lines changed: 58 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@
33
import numpy as np
44
import pytest
55
import scipy.stats.distributions as sp
6-
from aesara.graph.basic import Variable
6+
from aesara.graph.basic import Variable, equal_computations
77
from aesara.tensor.random.basic import CategoricalRV
88
from aesara.tensor.shape import shape_tuple
99
from aesara.tensor.subtensor import as_index_constant
1010

1111
from aeppl.joint_logprob import factorized_joint_logprob, joint_logprob
12-
from aeppl.mixture import expand_indices
12+
from aeppl.mixture import MixtureRV, expand_indices
13+
from aeppl.rewriting import construct_ir_fgraph
1314
from tests.test_logprob import scipy_logprob
1415
from tests.utils import assert_no_rvs
1516

@@ -67,7 +68,14 @@ def create_mix_model(size, axis):
6768

6869

6970
@aesara.config.change_flags(compute_test_value="warn")
70-
def test_compute_test_value():
71+
@pytest.mark.parametrize(
72+
"op_constructor",
73+
[
74+
lambda _I, _X, _Y: at.stack([_X, _Y])[_I],
75+
lambda _I, _X, _Y: at.switch(_I, _X, _Y),
76+
],
77+
)
78+
def test_compute_test_value(op_constructor):
7179

7280
srng = at.random.RandomStream(29833)
7381

@@ -82,7 +90,7 @@ def test_compute_test_value():
8290
i_vv = I_rv.clone()
8391
i_vv.name = "i"
8492

85-
M_rv = at.stack([X_rv, Y_rv])[I_rv]
93+
M_rv = op_constructor(I_rv, X_rv, Y_rv)
8694
M_rv.name = "M"
8795

8896
m_vv = M_rv.clone()
@@ -705,3 +713,49 @@ def test_mixture_with_DiracDelta():
705713
logp_res = factorized_joint_logprob({M_rv: m_vv, I_rv: i_vv})
706714

707715
assert m_vv in logp_res
716+
717+
718+
def test_switch_mixture():
719+
srng = at.random.RandomStream(29833)
720+
721+
X_rv = srng.normal(-10.0, 0.1, name="X")
722+
Y_rv = srng.normal(10.0, 0.1, name="Y")
723+
724+
I_rv = srng.bernoulli(0.5, name="I")
725+
i_vv = I_rv.clone()
726+
i_vv.name = "i"
727+
728+
Z1_rv = at.switch(I_rv, X_rv, Y_rv)
729+
z_vv = Z1_rv.clone()
730+
z_vv.name = "z1"
731+
732+
fgraph, _, _ = construct_ir_fgraph({Z1_rv: z_vv, I_rv: i_vv})
733+
734+
assert isinstance(fgraph.outputs[0].owner.op, MixtureRV)
735+
assert not hasattr(
736+
fgraph.outputs[0].tag, "test_value"
737+
) # aesara.config.compute_test_value == "off"
738+
assert fgraph.outputs[0].name is None
739+
740+
Z1_rv.name = "Z1"
741+
742+
fgraph, _, _ = construct_ir_fgraph({Z1_rv: z_vv, I_rv: i_vv})
743+
744+
assert fgraph.outputs[0].name == "Z1-mixture"
745+
746+
# building the identical graph but with a stack to check that mixture computations are identical
747+
748+
Z2_rv = at.stack((X_rv, Y_rv))[I_rv]
749+
750+
fgraph2, _, _ = construct_ir_fgraph({Z2_rv: z_vv, I_rv: i_vv})
751+
752+
assert equal_computations(fgraph.outputs, fgraph2.outputs)
753+
754+
z1_logp = joint_logprob({Z1_rv: z_vv, I_rv: i_vv})
755+
z2_logp = joint_logprob({Z2_rv: z_vv, I_rv: i_vv})
756+
757+
# below should follow immediately from the equal_computations assertion above
758+
assert equal_computations([z1_logp], [z2_logp])
759+
760+
np.testing.assert_almost_equal(0.69049938, z1_logp.eval({z_vv: -10, i_vv: 0}))
761+
np.testing.assert_almost_equal(0.69049938, z2_logp.eval({z_vv: -10, i_vv: 0}))

0 commit comments

Comments
 (0)