@@ -16,7 +16,6 @@ use bytes::BytesMut;
1616use bytestring:: ByteString ;
1717use indexmap:: IndexMap ;
1818use parking_lot:: Mutex ;
19- use rand:: Rng ;
2019use restate_types:: retries:: RetryPolicy ;
2120use tonic:: transport:: Channel ;
2221use tonic:: { Code , Status } ;
@@ -489,7 +488,7 @@ impl ChannelManager {
489488 fn choose_channel ( & self ) -> Option < ChannelWithAddress > {
490489 self . channels
491490 . lock ( )
492- . choose_next_round_robin ( )
491+ . choose_next ( & mut rand :: rng ( ) )
493492 . map ( |c| c. into_channel ( self . connection_options . deref ( ) ) )
494493 }
495494
@@ -606,48 +605,65 @@ struct Channels {
606605 // this is to allow for the initial addresses being dns entries that resolve to multiple node IPs
607606 initial_addresses : Vec < AdvertisedAddress < FabricPort > > ,
608607 channels : IndexMap < PlainNodeId , ChannelWithAddress > ,
609- channel_index : usize ,
608+ last : Option < AdvertisedAddress < FabricPort > > ,
610609}
611610
612611impl Channels {
613612 fn new ( initial_addresses : Vec < AdvertisedAddress < FabricPort > > ) -> Self {
614613 assert ! ( !initial_addresses. is_empty( ) ) ;
615- let initial_index = rand:: rng ( ) . random_range ( ..initial_addresses. len ( ) ) ;
616614 Channels {
617615 initial_addresses,
618616 channels : IndexMap :: default ( ) ,
619- channel_index : initial_index ,
617+ last : None ,
620618 }
621619 }
622620
623621 fn register ( & mut self , plain_node_id : PlainNodeId , channel : ChannelWithAddress ) {
624622 self . channels . insert ( plain_node_id, channel) ;
625623 }
626624
627- fn choose_next_round_robin ( & mut self ) -> Option < ChannelOrInitialAddress > {
625+ fn get ( & self , i : usize ) -> Option < ChannelOrInitialAddress > {
628626 let num_channels = self . channels . len ( ) ;
629627
630- self . channel_index += 1 ;
631- if self . channel_index >= num_channels + self . initial_addresses . len ( ) {
632- self . channel_index = 0 ;
633- }
634-
635- if self . channel_index < num_channels {
636- self . channels
637- . get_index ( self . channel_index )
638- . map ( |( _, channel) | channel)
639- . cloned ( )
640- . map ( ChannelOrInitialAddress :: from)
641- } else if !self . initial_addresses . is_empty ( ) {
642- self . initial_addresses
643- . get ( self . channel_index - num_channels)
644- . cloned ( )
645- . map ( ChannelOrInitialAddress :: from)
628+ if i < num_channels {
629+ Some ( ChannelOrInitialAddress :: from ( self . channels [ i] . clone ( ) ) )
630+ } else if i < num_channels + self . initial_addresses . len ( ) {
631+ Some ( ChannelOrInitialAddress :: from (
632+ self . initial_addresses [ i - num_channels] . clone ( ) ,
633+ ) )
646634 } else {
647635 None
648636 }
649637 }
650638
639+ fn choose_next ( & mut self , rng : & mut impl rand:: Rng ) -> Option < ChannelOrInitialAddress > {
640+ // sample up to two distinct channels/initial addresses from the full list
641+ let mut random_channels = rand:: seq:: IteratorRandom :: choose_multiple (
642+ 0 ..( self . channels . len ( ) + self . initial_addresses . len ( ) ) ,
643+ rng,
644+ 2 ,
645+ ) ;
646+
647+ // choose_multiple picks random items but the order is not random
648+ rand:: seq:: SliceRandom :: shuffle ( random_channels. as_mut_slice ( ) , rng) ;
649+
650+ let mut random_channels = random_channels
651+ . into_iter ( )
652+ . map ( |channel_index| self . get ( channel_index) . unwrap ( ) ) ;
653+
654+ let next = match random_channels. next ( ) {
655+ Some ( channel) if Some ( channel. address ( ) ) == self . last . as_ref ( ) => {
656+ // same as last time, try another if available
657+ Some ( random_channels. next ( ) . unwrap_or ( channel) )
658+ }
659+ Some ( channel) => Some ( channel) , // different to last time
660+ None => None , // no channels
661+ } ;
662+
663+ self . last = next. as_ref ( ) . map ( |n| n. address ( ) . clone ( ) ) ;
664+ next
665+ }
666+
651667 /// Returns true if there are no channels. It ignores the initial channels.
652668 fn is_empty ( & self ) -> bool {
653669 self . channels . is_empty ( )
@@ -670,6 +686,9 @@ impl Channels {
670686
671687#[ cfg( test) ]
672688mod tests {
689+ use std:: collections:: HashSet ;
690+
691+ use rand:: SeedableRng ;
673692 use restate_types:: { PlainNodeId , net:: address:: AdvertisedAddress } ;
674693 use test_log:: test;
675694 use tonic:: transport:: Channel ;
@@ -684,10 +703,11 @@ mod tests {
684703
685704 #[ test( restate_core:: test) ]
686705 async fn update_channels ( ) {
706+ let mut rng = rand:: rngs:: SmallRng :: seed_from_u64 ( 4360796539057359171 ) ;
687707 let initial_addr: AdvertisedAddress < _ > = "http://localhost" . parse ( ) . unwrap ( ) ;
688708 let mut channels = Channels :: new ( vec ! [ initial_addr. clone( ) ] ) ;
689709
690- assert ! ( channels. choose_next_round_robin ( ) . is_some( ) ) ;
710+ assert ! ( channels. choose_next ( & mut rng ) . is_some( ) ) ;
691711
692712 // Define node addresses for easier comparison later
693713 let node1_addr: AdvertisedAddress < _ > = "http://node1" . parse ( ) . unwrap ( ) ;
@@ -718,13 +738,22 @@ mod tests {
718738 ) ,
719739 ] ) ;
720740
721- let mut seen_addresses = Vec :: new ( ) ;
722- for _ in 0 ..4 {
723- if let Some ( channel) = channels. choose_next_round_robin ( ) {
724- seen_addresses. push ( channel. address ( ) . clone ( ) ) ;
725- }
741+ let mut seen_addresses = HashSet :: new ( ) ;
742+
743+ let mut last = None ;
744+
745+ for _ in 0 ..50 {
746+ let next_address = channels
747+ . choose_next ( & mut rng)
748+ . as_ref ( )
749+ . map ( |c| c. address ( ) . clone ( ) ) ;
750+ assert ! ( next_address. is_some( ) ) ;
751+ assert ! ( last != next_address) ; // check that we never get a repeat
752+ last = next_address. clone ( ) ;
753+ seen_addresses. insert ( next_address. unwrap ( ) ) ;
726754 }
727755
756+ // check that all values turn up eventually
728757 assert ! ( seen_addresses. contains( & initial_addr) ) ;
729758 assert ! ( seen_addresses. contains( & node1_addr) ) ;
730759 assert ! ( seen_addresses. contains( & node2_addr) ) ;
@@ -733,18 +762,27 @@ mod tests {
733762 channels. drop_initial_addresses ( ) ;
734763
735764 seen_addresses. clear ( ) ;
736- for _ in 0 ..3 {
737- if let Some ( channel) = channels. choose_next_round_robin ( ) {
738- seen_addresses. push ( channel. address ( ) . clone ( ) ) ;
739- }
765+ last = None ;
766+
767+ for _ in 0 ..50 {
768+ let next_address = channels
769+ . choose_next ( & mut rng)
770+ . as_ref ( )
771+ . map ( |c| c. address ( ) . clone ( ) ) ;
772+ assert ! ( next_address. is_some( ) ) ;
773+ assert ! ( last != next_address) ; // check that we never get a repeat
774+ last = next_address. clone ( ) ;
775+ seen_addresses. insert ( next_address. unwrap ( ) ) ;
740776 }
741777
778+ // check that all values turn up eventually
779+ assert ! ( !seen_addresses. contains( & initial_addr) ) ;
742780 assert ! ( seen_addresses. contains( & node1_addr) ) ;
743781 assert ! ( seen_addresses. contains( & node2_addr) ) ;
744782 assert ! ( seen_addresses. contains( & node3_addr) ) ;
745783
746784 channels. update_channels ( vec ! [ ] ) ;
747785
748- assert ! ( channels. choose_next_round_robin ( ) . is_none( ) ) ;
786+ assert ! ( channels. choose_next ( & mut rng ) . is_none( ) ) ;
749787 }
750788}
0 commit comments