Skip to content

Commit 1265be4

Browse files
authored
Improve metadata provider round robin logic (#3954)
1 parent 1d0f5e8 commit 1265be4

File tree

1 file changed

+71
-33
lines changed

1 file changed

+71
-33
lines changed

crates/metadata-providers/src/replicated.rs

Lines changed: 71 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ use bytes::BytesMut;
1616
use bytestring::ByteString;
1717
use indexmap::IndexMap;
1818
use parking_lot::Mutex;
19-
use rand::Rng;
2019
use restate_types::retries::RetryPolicy;
2120
use tonic::transport::Channel;
2221
use tonic::{Code, Status};
@@ -489,7 +488,7 @@ impl ChannelManager {
489488
fn choose_channel(&self) -> Option<ChannelWithAddress> {
490489
self.channels
491490
.lock()
492-
.choose_next_round_robin()
491+
.choose_next(&mut rand::rng())
493492
.map(|c| c.into_channel(self.connection_options.deref()))
494493
}
495494

@@ -606,48 +605,65 @@ struct Channels {
606605
// this is to allow for the initial addresses being dns entries that resolve to multiple node IPs
607606
initial_addresses: Vec<AdvertisedAddress<FabricPort>>,
608607
channels: IndexMap<PlainNodeId, ChannelWithAddress>,
609-
channel_index: usize,
608+
last: Option<AdvertisedAddress<FabricPort>>,
610609
}
611610

612611
impl Channels {
613612
fn new(initial_addresses: Vec<AdvertisedAddress<FabricPort>>) -> Self {
614613
assert!(!initial_addresses.is_empty());
615-
let initial_index = rand::rng().random_range(..initial_addresses.len());
616614
Channels {
617615
initial_addresses,
618616
channels: IndexMap::default(),
619-
channel_index: initial_index,
617+
last: None,
620618
}
621619
}
622620

623621
fn register(&mut self, plain_node_id: PlainNodeId, channel: ChannelWithAddress) {
624622
self.channels.insert(plain_node_id, channel);
625623
}
626624

627-
fn choose_next_round_robin(&mut self) -> Option<ChannelOrInitialAddress> {
625+
fn get(&self, i: usize) -> Option<ChannelOrInitialAddress> {
628626
let num_channels = self.channels.len();
629627

630-
self.channel_index += 1;
631-
if self.channel_index >= num_channels + self.initial_addresses.len() {
632-
self.channel_index = 0;
633-
}
634-
635-
if self.channel_index < num_channels {
636-
self.channels
637-
.get_index(self.channel_index)
638-
.map(|(_, channel)| channel)
639-
.cloned()
640-
.map(ChannelOrInitialAddress::from)
641-
} else if !self.initial_addresses.is_empty() {
642-
self.initial_addresses
643-
.get(self.channel_index - num_channels)
644-
.cloned()
645-
.map(ChannelOrInitialAddress::from)
628+
if i < num_channels {
629+
Some(ChannelOrInitialAddress::from(self.channels[i].clone()))
630+
} else if i < num_channels + self.initial_addresses.len() {
631+
Some(ChannelOrInitialAddress::from(
632+
self.initial_addresses[i - num_channels].clone(),
633+
))
646634
} else {
647635
None
648636
}
649637
}
650638

639+
fn choose_next(&mut self, rng: &mut impl rand::Rng) -> Option<ChannelOrInitialAddress> {
640+
// sample up to two distinct channels/initial addresses from the full list
641+
let mut random_channels = rand::seq::IteratorRandom::choose_multiple(
642+
0..(self.channels.len() + self.initial_addresses.len()),
643+
rng,
644+
2,
645+
);
646+
647+
// choose_multiple picks random items but the order is not random
648+
rand::seq::SliceRandom::shuffle(random_channels.as_mut_slice(), rng);
649+
650+
let mut random_channels = random_channels
651+
.into_iter()
652+
.map(|channel_index| self.get(channel_index).unwrap());
653+
654+
let next = match random_channels.next() {
655+
Some(channel) if Some(channel.address()) == self.last.as_ref() => {
656+
// same as last time, try another if available
657+
Some(random_channels.next().unwrap_or(channel))
658+
}
659+
Some(channel) => Some(channel), // different to last time
660+
None => None, // no channels
661+
};
662+
663+
self.last = next.as_ref().map(|n| n.address().clone());
664+
next
665+
}
666+
651667
/// Returns true if there are no channels. It ignores the initial channels.
652668
fn is_empty(&self) -> bool {
653669
self.channels.is_empty()
@@ -670,6 +686,9 @@ impl Channels {
670686

671687
#[cfg(test)]
672688
mod tests {
689+
use std::collections::HashSet;
690+
691+
use rand::SeedableRng;
673692
use restate_types::{PlainNodeId, net::address::AdvertisedAddress};
674693
use test_log::test;
675694
use tonic::transport::Channel;
@@ -684,10 +703,11 @@ mod tests {
684703

685704
#[test(restate_core::test)]
686705
async fn update_channels() {
706+
let mut rng = rand::rngs::SmallRng::seed_from_u64(4360796539057359171);
687707
let initial_addr: AdvertisedAddress<_> = "http://localhost".parse().unwrap();
688708
let mut channels = Channels::new(vec![initial_addr.clone()]);
689709

690-
assert!(channels.choose_next_round_robin().is_some());
710+
assert!(channels.choose_next(&mut rng).is_some());
691711

692712
// Define node addresses for easier comparison later
693713
let node1_addr: AdvertisedAddress<_> = "http://node1".parse().unwrap();
@@ -718,13 +738,22 @@ mod tests {
718738
),
719739
]);
720740

721-
let mut seen_addresses = Vec::new();
722-
for _ in 0..4 {
723-
if let Some(channel) = channels.choose_next_round_robin() {
724-
seen_addresses.push(channel.address().clone());
725-
}
741+
let mut seen_addresses = HashSet::new();
742+
743+
let mut last = None;
744+
745+
for _ in 0..50 {
746+
let next_address = channels
747+
.choose_next(&mut rng)
748+
.as_ref()
749+
.map(|c| c.address().clone());
750+
assert!(next_address.is_some());
751+
assert!(last != next_address); // check that we never get a repeat
752+
last = next_address.clone();
753+
seen_addresses.insert(next_address.unwrap());
726754
}
727755

756+
// check that all values turn up eventually
728757
assert!(seen_addresses.contains(&initial_addr));
729758
assert!(seen_addresses.contains(&node1_addr));
730759
assert!(seen_addresses.contains(&node2_addr));
@@ -733,18 +762,27 @@ mod tests {
733762
channels.drop_initial_addresses();
734763

735764
seen_addresses.clear();
736-
for _ in 0..3 {
737-
if let Some(channel) = channels.choose_next_round_robin() {
738-
seen_addresses.push(channel.address().clone());
739-
}
765+
last = None;
766+
767+
for _ in 0..50 {
768+
let next_address = channels
769+
.choose_next(&mut rng)
770+
.as_ref()
771+
.map(|c| c.address().clone());
772+
assert!(next_address.is_some());
773+
assert!(last != next_address); // check that we never get a repeat
774+
last = next_address.clone();
775+
seen_addresses.insert(next_address.unwrap());
740776
}
741777

778+
// check that all values turn up eventually
779+
assert!(!seen_addresses.contains(&initial_addr));
742780
assert!(seen_addresses.contains(&node1_addr));
743781
assert!(seen_addresses.contains(&node2_addr));
744782
assert!(seen_addresses.contains(&node3_addr));
745783

746784
channels.update_channels(vec![]);
747785

748-
assert!(channels.choose_next_round_robin().is_none());
786+
assert!(channels.choose_next(&mut rng).is_none());
749787
}
750788
}

0 commit comments

Comments
 (0)