11use crate :: { raw:: RawEClass , Dot , HashMap , Id , Language , RecExpr , UnionFind } ;
2+ use std:: convert:: Infallible ;
23use std:: ops:: { Deref , DerefMut } ;
34use std:: {
45 borrow:: BorrowMut ,
@@ -426,6 +427,26 @@ impl<L: Language, D> RawEGraph<L, D> {
426427 }
427428}
428429
430+ /// Information for [`RawEGraph::raw_union`] callback
431+ #[ non_exhaustive]
432+ pub struct MergeInfo < ' a , D : ' a > {
433+ /// id that will be the root for the newly merged eclass
434+ pub id1 : Id ,
435+ /// data associated with `id1` that can be modified to reflect `data2` being merged into it
436+ pub data1 : & ' a mut D ,
437+ /// parents of `id1` before the merge
438+ pub parents1 : & ' a [ Id ] ,
439+ /// id that used to be a root but will now be in `id1` eclass
440+ pub id2 : Id ,
441+ /// data associated with `id2`
442+ pub data2 : D ,
443+ /// parents of `id2` before the merge
444+ pub parents2 : & ' a [ Id ] ,
445+ /// true if `id1` was the root of the second id passed to [`RawEGraph::raw_union`]
446+ /// false if `id1` was the root of the first id passed to [`RawEGraph::raw_union`]
447+ pub swapped_ids : bool ,
448+ }
449+
429450impl < L : Language , D > RawEGraph < L , D > {
430451 /// Adds `enode` to a [`RawEGraph`] contained within a wrapper type `T`
431452 ///
@@ -524,7 +545,7 @@ impl<L: Language, D> RawEGraph<L, D> {
524545 & mut self ,
525546 enode_id1 : Id ,
526547 enode_id2 : Id ,
527- merge : impl FnOnce ( & mut D , Id , Parents < ' _ > , D , Id , Parents < ' _ > ) ,
548+ merge : impl FnOnce ( MergeInfo < ' _ , D > ) ,
528549 ) {
529550 let mut id1 = self . find_mut ( enode_id1) ;
530551 let mut id2 = self . find_mut ( enode_id2) ;
@@ -534,7 +555,9 @@ impl<L: Language, D> RawEGraph<L, D> {
534555 // make sure class2 has fewer parents
535556 let class1_parents = self . classes [ & id1] . parents . len ( ) ;
536557 let class2_parents = self . classes [ & id2] . parents . len ( ) ;
558+ let mut swapped = false ;
537559 if class1_parents < class2_parents {
560+ swapped = true ;
538561 std:: mem:: swap ( & mut id1, & mut id2) ;
539562 }
540563
@@ -545,22 +568,22 @@ impl<L: Language, D> RawEGraph<L, D> {
545568 let class2 = self . classes . remove ( & id2) . unwrap ( ) ;
546569 let class1 = self . classes . get_mut ( & id1) . unwrap ( ) ;
547570 assert_eq ! ( id1, class1. id) ;
548- let ( p1, p2) = ( Parents ( & class1. parents ) , Parents ( & class2. parents ) ) ;
549- merge (
550- & mut class1. raw_data ,
551- class1. id ,
552- p1,
553- class2. raw_data ,
554- class2. id ,
555- p2,
556- ) ;
571+ let info = MergeInfo {
572+ id1 : class1. id ,
573+ data1 : & mut class1. raw_data ,
574+ parents1 : & class1. parents ,
575+ id2 : class2. id ,
576+ data2 : class2. raw_data ,
577+ parents2 : & class2. parents ,
578+ swapped_ids : swapped,
579+ } ;
580+ merge ( info) ;
557581
558582 self . pending . extend ( & class2. parents ) ;
559583
560584 class1. parents . extend ( class2. parents ) ;
561585 }
562586
563- #[ inline]
564587 /// Rebuild to [`RawEGraph`] to restore congruence closure
565588 ///
566589 /// ## Parameters
@@ -576,25 +599,48 @@ impl<L: Language, D> RawEGraph<L, D> {
576599 /// In order to be correct `perform_union` should call [`raw_union`](RawEGraph::raw_union)
577600 ///
578601 /// ### `handle_pending`
579- /// Called with the uncanonical id of each enode whose canonical children have changned , along with a canonical
602+ /// Called with the uncanonical id of each enode whose canonical children have changed , along with a canonical
580603 /// version of it
604+ #[ inline]
581605 pub fn raw_rebuild < T > (
582606 outer : & mut T ,
583607 get_self : impl Fn ( & mut T ) -> & mut Self ,
584608 mut perform_union : impl FnMut ( & mut T , Id , Id ) ,
585- mut handle_pending : impl FnMut ( & mut T , Id , & L ) ,
609+ handle_pending : impl FnMut ( & mut T , Id , & L ) ,
586610 ) {
611+ let _: Result < ( ) , Infallible > = RawEGraph :: try_raw_rebuild (
612+ outer,
613+ get_self,
614+ |this, id1, id2| Ok ( perform_union ( this, id1, id2) ) ,
615+ handle_pending,
616+ ) ;
617+ }
618+
619+ /// Similar to [`raw_rebuild`] but allows for the union operation to fail and abort the rebuild
620+ #[ inline]
621+ pub fn try_raw_rebuild < T , E > (
622+ outer : & mut T ,
623+ get_self : impl Fn ( & mut T ) -> & mut Self ,
624+ mut perform_union : impl FnMut ( & mut T , Id , Id ) -> Result < ( ) , E > ,
625+ mut handle_pending : impl FnMut ( & mut T , Id , & L ) ,
626+ ) -> Result < ( ) , E > {
587627 loop {
588628 let this = get_self ( outer) ;
589629 if let Some ( class) = this. pending . pop ( ) {
590630 let mut node = this. id_to_node ( class) . clone ( ) ;
591631 node. update_children ( |id| this. find_mut ( id) ) ;
592632 handle_pending ( outer, class, & node) ;
593633 if let Some ( memo_class) = get_self ( outer) . residual . memo . insert ( node, class) {
594- perform_union ( outer, memo_class, class) ;
634+ match perform_union ( outer, memo_class, class) {
635+ Ok ( ( ) ) => { }
636+ Err ( e) => {
637+ get_self ( outer) . pending . push ( class) ;
638+ return Err ( e) ;
639+ }
640+ }
595641 }
596642 } else {
597- break ;
643+ break Ok ( ( ) ) ;
598644 }
599645 }
600646 }
@@ -638,7 +684,7 @@ impl<L: Language> RawEGraph<L, ()> {
638684 /// Simplified version of [`raw_union`](RawEGraph::raw_union) for egraphs without eclass data
639685 pub fn union ( & mut self , id1 : Id , id2 : Id ) -> bool {
640686 let mut unioned = false ;
641- self . raw_union ( id1, id2, |_, _ , _ , _ , _ , _ | {
687+ self . raw_union ( id1, id2, |_| {
642688 unioned = true ;
643689 } ) ;
644690 unioned
0 commit comments