-
-
Notifications
You must be signed in to change notification settings - Fork 20
Introduce graph rewrite for mixture sub-graphs defined via IfElse Op
#169
base: main
Are you sure you want to change the base?
Conversation
d0d4c7b to
3458414
Compare
Codecov ReportBase: 95.15% // Head: 94.94% // Decreases project coverage by
Additional details and impacted files@@ Coverage Diff @@
## main #169 +/- ##
==========================================
- Coverage 95.15% 94.94% -0.22%
==========================================
Files 12 12
Lines 2023 1878 -145
Branches 253 280 +27
==========================================
- Hits 1925 1783 -142
+ Misses 56 53 -3
Partials 42 42
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. ☔ View full report at Codecov. |
28b20db to
47b2117
Compare
|
I added many tests; I used |
|
Here are some notes regarding why some tests are failing. The discrepancy between graphs occur here:
This comment would serve as a reminder on what I am stuck on before the upcoming meeting. |
47b2117 to
b284e35
Compare
b284e35 to
3f5ae48
Compare
|
I'm revisiting this PR slowly and after quite some time. I'm investigating one of my failing test cases and I probably have forgotten many details just due to time passing... Consider the following code. import aesara
import aesara.tensor as at
from aeppl.rewriting import construct_ir_fgraph
srng = at.random.RandomStream(29833)
X_rv = srng.normal(loc=[10, 20], scale=0.1, size=(2,), name="X")
Y_rv = srng.normal(loc=[-10, -20], scale=0.1, size=(2,), name="Y")
I_rv = srng.bernoulli([0.9, 0.1], size=(2,), name="I")
i_vv = I_rv.clone()
i_vv.name = "i"
Z1_rv = at.switch(I_rv, X_rv, Y_rv)
z_vv = Z1_rv.clone()
z_vv.name = "z1"
fgraph, _, _ = construct_ir_fgraph({Z1_rv: z_vv, I_rv: i_vv})
aesara.dprint(fgraph.outputs[0])yields SpecifyShape [id A]
|MixtureRV{indices_end_idx=2, out_dtype='float64', out_broadcastable=(False,)} [id B]
| |TensorConstant{0} [id C]
| |bernoulli_rv{0, (0,), int64, False}.1 [id D] 'I'
| | |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x16488AB20>) [id E]
| | |TensorConstant{(1,) of 2} [id F]
| | |TensorConstant{4} [id G]
| | |TensorConstant{[0.9 0.1]} [id H]
| |normal_rv{0, (0, 0), floatX, False}.1 [id I] 'X'
| | |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x1648899A0>) [id J]
| | |TensorConstant{(1,) of 2} [id F]
| | |TensorConstant{11} [id K]
| | |TensorConstant{[10 20]} [id L]
| | |TensorConstant{0.1} [id M]
| |normal_rv{0, (0, 0), floatX, False}.1 [id N] 'Y'
| |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x16488A340>) [id O]
| |TensorConstant{(1,) of 2} [id F]
| |TensorConstant{11} [id K]
| |TensorConstant{[-10 -20]} [id P]
| |TensorConstant{0.1} [id M]
|TensorConstant{2} [id Q]
bernoulli_rv{0, (0,), int64, False}.1 [id D] 'I'Where does the |
|
Hey, thanks for revisiting the PR! Do you mean running the code on this PR branch? If I run your code snippet on Elemwise{switch,no_inplace} [id A]
|bernoulli_rv{0, (0,), int64, False}.1 [id B] 'I'
| |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F0239D0E880>) [id C]
| |TensorConstant{(1,) of 2} [id D]
| |TensorConstant{4} [id E]
| |TensorConstant{[0.9 0.1]} [id F]
|normal_rv{0, (0, 0), floatX, False}.1 [id G] 'X'
| |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F023B681B60>) [id H]
| |TensorConstant{(1,) of 2} [id D]
| |TensorConstant{11} [id I]
| |TensorConstant{[10 20]} [id J]
| |TensorConstant{0.1} [id K]
|normal_rv{0, (0, 0), floatX, False}.1 [id L] 'Y'
|RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F0239D0DEE0>) [id M]
|TensorConstant{(1,) of 2} [id D]
|TensorConstant{11} [id I]
|TensorConstant{[-10 -20]} [id N]
|TensorConstant{0.1} [id K]Have you considered dispatching You will also need to resolve the (small) merge conflict due to the new |
Yes, running the code on this branch! The graph rewrite for
Yes, but I felt like that graph rewrite for both would be very similar. I can separate them as I work through them, for now...
Okay sounds good! |
They likely will, and we may want to merge them later. But I think it would be easier for you to move from one stable state to another, changing one thing at a time, and always keeping a reference implementation ( |
3f5ae48 to
5ff4105
Compare
|
I just rebased my code. I am working on having Switch/IfElse-induced mixture subgraphs yield the same canonical (or IR? the graph obtained after running |
|
Of course: https://aesara.readthedocs.io/en/latest/extending/graph_rewriting.html#detailed-profiling-of-aesara-rewrites. Alternatively, you can add a breakpoint here, |
There's also In general, as @rlouf said, don't be afraid to put |
|
Adding to this, you can set any reasonable IDE up so when you run tests it will open a debugger console whenever it hits a breakpoint or fails. If you don't have that in place already, spend some time setting it up; it was a huge boost in my productivity. |
5ff4105 to
f475bc9
Compare
Thanks for the tip. I also saw your recent related tweet 😅 As for this PR, I am thinking that it's best to close it to 1) split the tasks into smaller sub-PRs (I felt like too much was going on at once) and 2) address some other issues that came up. As for the latter, I divided them into subsections below. Any guidance would be helpful... Reworking
|
| new_node = mix_op.make_node( | |
| *([NoneConst, as_nontensor_scalar(node.inputs[0])] + mixture_rvs) | |
| ) |
SpecifyShape Op
The appearance of the SpecifyShape Op seems to be new... perhaps due to this recent addition to Aesara? Maybe a good first step would be to replace out_broadcastable in MixtureRV with the corresponding static shapes, if available. Would this be a good first step?
Lines 180 to 197 in 473c1e6
| class MixtureRV(Op): | |
| """A placeholder used to specify a log-likelihood for a mixture sub-graph.""" | |
| __props__ = ("indices_end_idx", "out_dtype", "out_broadcastable") | |
| def __init__(self, indices_end_idx, out_dtype, out_broadcastable): | |
| super().__init__() | |
| self.indices_end_idx = indices_end_idx | |
| self.out_dtype = out_dtype | |
| self.out_broadcastable = out_broadcastable | |
| def make_node(self, *inputs): | |
| return Apply( | |
| self, list(inputs), [TensorType(self.out_dtype, self.out_broadcastable)()] | |
| ) | |
| def perform(self, node, inputs, outputs): | |
| raise NotImplementedError("This is a stand-in Op.") # pragma: no cover |
Mismatch in MixtureRV shapes generated by Switch vs. at.stack
With the hot fix replacing broadcastable with shape, the MixtureRV shapes seem to be different if they are generated by a Switch vs. Join. Is this because subtensors don't have static shape inference yet? That would be my guess (Aesara issue #922?), but I'm not sure. Below is an example that I created using this branch's additions.
import aesara.tensor as at
from aeppl.rewriting import construct_ir_fgraph
from aeppl.mixture import MixtureRV
srng = at.random.RandomStream(29833)
X_rv = srng.normal([10, 20], 0.1, size=(2,), name="X")
Y_rv = srng.normal([-10, -20], 0.1, size=(2,), name="Y")
I_rv = srng.bernoulli([0.99, 0.01], size=(2,), name="I")
i_vv = I_rv.clone()
i_vv.name = "i"
Z1_rv = at.switch(I_rv, X_rv, Y_rv)
z_vv = Z1_rv.clone()
z_vv.name = "z1"
fgraph, _, _ = construct_ir_fgraph({Z1_rv: z_vv, I_rv: i_vv})
assert isinstance(fgraph.outputs[0].owner.op, MixtureRV)
assert not hasattr(
fgraph.outputs[0].tag, "test_value"
) # aesara.config.compute_test_value == "off"
assert fgraph.outputs[0].name is None
Z1_rv.name = "Z1"
fgraph, _, _ = construct_ir_fgraph({Z1_rv: z_vv, I_rv: i_vv})
assert fgraph.outputs[0].name == "Z1-mixture"
# building the identical graph but with a stack to check that mixture computations are identical
Z2_rv = at.stack((X_rv, Y_rv))[I_rv]
fgraph2, _, _ = construct_ir_fgraph({Z2_rv: z_vv, I_rv: i_vv})
fgraph.outputs[0].type.shape # (2,)
fgraph2.outputs[0].type.shape # (None, None)IfElse mixture subgraphs
Given that IfElse requires scalar conditions, maybe it would be good to start with them instead of refining switch-mixtures... Happy to hear any thoughts about these points above. I feel like there's a lot going on, and it can be challenging to address all at once (especially given that this is continuation from this summer's work...)
|
PRs that touch on core mechanisms in Aesara, or simply that implement big changes, can easily get frustrating. Breaking the problem down like you did is a great reaction to this situation. Do you mind if I keep it open and I come back to you later next week at least with some questions, maybe some insight? |
Of course, not a problem at all! |
f475bc9 to
8c9c0f3
Compare
@rlouf Just a quick update that @brandonwillard and I conversed recently, hence the recent force-push. The current focus is to ensure that the current mixture indexing operations via |
|
Glad to hear this is back on track! |
b6aa902 to
0dd44af
Compare
Closes #76.
Akin to #154, this PR introduces a
node_rewriterforIfElse. Effectively, this builds on the recently addedswitch_mixture_replaceto accommodate mixture sub-graphs as the same essence but defined with a differentOp:IfElse. Below is an example of the new functionality.