Skip to content

Commit fb07f3b

Browse files
committed
Make raw_union more flexible and add a fallible try_raw_rebuild
1 parent c18f6d4 commit fb07f3b

File tree

2 files changed

+75
-28
lines changed

2 files changed

+75
-28
lines changed

src/egraph.rs

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -804,20 +804,21 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
804804

805805
self.clean = false;
806806
let mut new_root = None;
807-
self.inner
808-
.raw_union(enode_id1, enode_id2, |class1, id1, p1, class2, _, p2| {
809-
new_root = Some(id1);
807+
self.inner.raw_union(enode_id1, enode_id2, |info| {
808+
new_root = Some(info.id1);
810809

811-
let did_merge = self.analysis.merge(&mut class1.data, class2.data);
812-
if did_merge.0 {
813-
self.analysis_pending.extend(p1);
814-
}
815-
if did_merge.1 {
816-
self.analysis_pending.extend(p2);
817-
}
810+
let did_merge = self.analysis.merge(&mut info.data1.data, info.data2.data);
811+
if did_merge.0 {
812+
self.analysis_pending
813+
.extend(info.parents1.into_iter().copied());
814+
}
815+
if did_merge.1 {
816+
self.analysis_pending
817+
.extend(info.parents2.into_iter().copied());
818+
}
818819

819-
concat_vecs(&mut class1.nodes, class2.nodes);
820-
});
820+
concat_vecs(&mut info.data1.nodes, info.data2.nodes);
821+
});
821822
if let Some(id) = new_root {
822823
if let Some(explain) = &mut self.explain {
823824
explain.union(enode_id1, enode_id2, rule.unwrap(), any_new_rhs);

src/raw/egraph.rs

Lines changed: 62 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use crate::{raw::RawEClass, Dot, HashMap, Id, Language, RecExpr, UnionFind};
2+
use std::convert::Infallible;
23
use std::ops::{Deref, DerefMut};
34
use std::{
45
borrow::BorrowMut,
@@ -426,6 +427,26 @@ impl<L: Language, D> RawEGraph<L, D> {
426427
}
427428
}
428429

430+
/// Information for [`RawEGraph::raw_union`] callback
431+
#[non_exhaustive]
432+
pub struct MergeInfo<'a, D: 'a> {
433+
/// id that will be the root for the newly merged eclass
434+
pub id1: Id,
435+
/// data associated with `id1` that can be modified to reflect `data2` being merged into it
436+
pub data1: &'a mut D,
437+
/// parents of `id1` before the merge
438+
pub parents1: &'a [Id],
439+
/// id that used to be a root but will now be in `id1` eclass
440+
pub id2: Id,
441+
/// data associated with `id2`
442+
pub data2: D,
443+
/// parents of `id2` before the merge
444+
pub parents2: &'a [Id],
445+
/// true if `id1` was the root of the second id passed to [`RawEGraph::raw_union`]
446+
/// false if `id1` was the root of the first id passed to [`RawEGraph::raw_union`]
447+
pub swapped_ids: bool,
448+
}
449+
429450
impl<L: Language, D> RawEGraph<L, D> {
430451
/// Adds `enode` to a [`RawEGraph`] contained within a wrapper type `T`
431452
///
@@ -524,7 +545,7 @@ impl<L: Language, D> RawEGraph<L, D> {
524545
&mut self,
525546
enode_id1: Id,
526547
enode_id2: Id,
527-
merge: impl FnOnce(&mut D, Id, Parents<'_>, D, Id, Parents<'_>),
548+
merge: impl FnOnce(MergeInfo<'_, D>),
528549
) {
529550
let mut id1 = self.find_mut(enode_id1);
530551
let mut id2 = self.find_mut(enode_id2);
@@ -534,7 +555,9 @@ impl<L: Language, D> RawEGraph<L, D> {
534555
// make sure class2 has fewer parents
535556
let class1_parents = self.classes[&id1].parents.len();
536557
let class2_parents = self.classes[&id2].parents.len();
558+
let mut swapped = false;
537559
if class1_parents < class2_parents {
560+
swapped = true;
538561
std::mem::swap(&mut id1, &mut id2);
539562
}
540563

@@ -545,22 +568,22 @@ impl<L: Language, D> RawEGraph<L, D> {
545568
let class2 = self.classes.remove(&id2).unwrap();
546569
let class1 = self.classes.get_mut(&id1).unwrap();
547570
assert_eq!(id1, class1.id);
548-
let (p1, p2) = (Parents(&class1.parents), Parents(&class2.parents));
549-
merge(
550-
&mut class1.raw_data,
551-
class1.id,
552-
p1,
553-
class2.raw_data,
554-
class2.id,
555-
p2,
556-
);
571+
let info = MergeInfo {
572+
id1: class1.id,
573+
data1: &mut class1.raw_data,
574+
parents1: &class1.parents,
575+
id2: class2.id,
576+
data2: class2.raw_data,
577+
parents2: &class2.parents,
578+
swapped_ids: swapped,
579+
};
580+
merge(info);
557581

558582
self.pending.extend(&class2.parents);
559583

560584
class1.parents.extend(class2.parents);
561585
}
562586

563-
#[inline]
564587
/// Rebuild to [`RawEGraph`] to restore congruence closure
565588
///
566589
/// ## Parameters
@@ -576,25 +599,48 @@ impl<L: Language, D> RawEGraph<L, D> {
576599
/// In order to be correct `perform_union` should call [`raw_union`](RawEGraph::raw_union)
577600
///
578601
/// ### `handle_pending`
579-
/// Called with the uncanonical id of each enode whose canonical children have changned, along with a canonical
602+
/// Called with the uncanonical id of each enode whose canonical children have changed, along with a canonical
580603
/// version of it
604+
#[inline]
581605
pub fn raw_rebuild<T>(
582606
outer: &mut T,
583607
get_self: impl Fn(&mut T) -> &mut Self,
584608
mut perform_union: impl FnMut(&mut T, Id, Id),
585-
mut handle_pending: impl FnMut(&mut T, Id, &L),
609+
handle_pending: impl FnMut(&mut T, Id, &L),
586610
) {
611+
let _: Result<(), Infallible> = RawEGraph::try_raw_rebuild(
612+
outer,
613+
get_self,
614+
|this, id1, id2| Ok(perform_union(this, id1, id2)),
615+
handle_pending,
616+
);
617+
}
618+
619+
/// Similar to [`raw_rebuild`] but allows for the union operation to fail and abort the rebuild
620+
#[inline]
621+
pub fn try_raw_rebuild<T, E>(
622+
outer: &mut T,
623+
get_self: impl Fn(&mut T) -> &mut Self,
624+
mut perform_union: impl FnMut(&mut T, Id, Id) -> Result<(), E>,
625+
mut handle_pending: impl FnMut(&mut T, Id, &L),
626+
) -> Result<(), E> {
587627
loop {
588628
let this = get_self(outer);
589629
if let Some(class) = this.pending.pop() {
590630
let mut node = this.id_to_node(class).clone();
591631
node.update_children(|id| this.find_mut(id));
592632
handle_pending(outer, class, &node);
593633
if let Some(memo_class) = get_self(outer).residual.memo.insert(node, class) {
594-
perform_union(outer, memo_class, class);
634+
match perform_union(outer, memo_class, class) {
635+
Ok(()) => {}
636+
Err(e) => {
637+
get_self(outer).pending.push(class);
638+
return Err(e);
639+
}
640+
}
595641
}
596642
} else {
597-
break;
643+
break Ok(());
598644
}
599645
}
600646
}
@@ -638,7 +684,7 @@ impl<L: Language> RawEGraph<L, ()> {
638684
/// Simplified version of [`raw_union`](RawEGraph::raw_union) for egraphs without eclass data
639685
pub fn union(&mut self, id1: Id, id2: Id) -> bool {
640686
let mut unioned = false;
641-
self.raw_union(id1, id2, |_, _, _, _, _, _| {
687+
self.raw_union(id1, id2, |_| {
642688
unioned = true;
643689
});
644690
unioned

0 commit comments

Comments
 (0)