Skip to content

Commit 90f8cdc

Browse files
authored
Merge pull request #567 from piodul/host-filter
Host filter
2 parents 51fa01d + c2944af commit 90f8cdc

File tree

10 files changed

+256
-21
lines changed

10 files changed

+256
-21
lines changed

scylla/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ pub use transport::query_result::QueryResult;
116116
pub use transport::session::{IntoTypedRows, Session, SessionConfig};
117117
pub use transport::session_builder::SessionBuilder;
118118

119+
pub use transport::host_filter;
119120
pub use transport::load_balancing;
120121
pub use transport::retry_policy;
121122
pub use transport::speculative_execution;

scylla/src/transport/cluster.rs

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use crate::frame::response::event::{Event, StatusChangeEvent};
33
use crate::frame::value::ValueList;
44
use crate::load_balancing::TokenAwarePolicy;
55
use crate::routing::Token;
6+
use crate::transport::host_filter::HostFilter;
67
use crate::transport::{
78
connection::{Connection, VerifiedKeyspaceName},
89
connection_pool::PoolConfig,
@@ -110,6 +111,10 @@ struct ClusterWorker {
110111

111112
// Keyspace send in "USE <keyspace name>" when opening each connection
112113
used_keyspace: Option<VerifiedKeyspaceName>,
114+
115+
// The host filter determines towards which nodes we should open
116+
// connections
117+
host_filter: Option<Arc<dyn HostFilter>>,
113118
}
114119

115120
#[derive(Debug)]
@@ -129,6 +134,7 @@ impl Cluster {
129134
pool_config: PoolConfig,
130135
fetch_schema_metadata: bool,
131136
address_translator: &Option<Arc<dyn AddressTranslator>>,
137+
host_filter: &Option<Arc<dyn HostFilter>>,
132138
) -> Result<Cluster, QueryError> {
133139
let (refresh_sender, refresh_receiver) = tokio::sync::mpsc::channel(32);
134140
let (use_keyspace_sender, use_keyspace_receiver) = tokio::sync::mpsc::channel(32);
@@ -141,10 +147,17 @@ impl Cluster {
141147
server_events_sender,
142148
fetch_schema_metadata,
143149
address_translator,
150+
host_filter,
144151
);
145152

146153
let metadata = metadata_reader.read_metadata(true).await?;
147-
let cluster_data = ClusterData::new(metadata, &pool_config, &HashMap::new(), &None);
154+
let cluster_data = ClusterData::new(
155+
metadata,
156+
&pool_config,
157+
&HashMap::new(),
158+
&None,
159+
host_filter.as_deref(),
160+
);
148161
cluster_data.wait_until_all_pools_are_initialized().await;
149162
let cluster_data: Arc<ArcSwap<ClusterData>> =
150163
Arc::new(ArcSwap::from(Arc::new(cluster_data)));
@@ -160,6 +173,8 @@ impl Cluster {
160173

161174
use_keyspace_channel: use_keyspace_receiver,
162175
used_keyspace: None,
176+
177+
host_filter: host_filter.clone(),
163178
};
164179

165180
let (fut, worker_handle) = worker.work().remote_handle();
@@ -273,6 +288,7 @@ impl ClusterData {
273288
pool_config: &PoolConfig,
274289
known_peers: &HashMap<SocketAddr, Arc<Node>>,
275290
used_keyspace: &Option<VerifiedKeyspaceName>,
291+
host_filter: Option<&dyn HostFilter>,
276292
) -> Self {
277293
// Create new updated known_peers and ring
278294
let mut new_known_peers: HashMap<SocketAddr, Arc<Node>> =
@@ -289,13 +305,17 @@ impl ClusterData {
289305
Some(node) if node.datacenter == peer.datacenter && node.rack == peer.rack => {
290306
node.clone()
291307
}
292-
_ => Arc::new(Node::new(
293-
peer.address,
294-
pool_config.clone(),
295-
peer.datacenter,
296-
peer.rack,
297-
used_keyspace.clone(),
298-
)),
308+
_ => {
309+
let is_enabled = host_filter.map_or(true, |f| f.accept(&peer));
310+
Arc::new(Node::new(
311+
peer.address,
312+
pool_config.clone(),
313+
peer.datacenter,
314+
peer.rack,
315+
used_keyspace.clone(),
316+
is_enabled,
317+
))
318+
}
299319
};
300320

301321
new_known_peers.insert(peer.address, node.clone());
@@ -567,6 +587,7 @@ impl ClusterWorker {
567587
&self.pool_config,
568588
&cluster_data.known_peers,
569589
&self.used_keyspace,
590+
self.host_filter.as_deref(),
570591
));
571592

572593
new_cluster_data
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
//! Host filters.
2+
//!
3+
//! Host filters are essentially just a predicate over
4+
//! [`Peer`](crate::transport::topology::Peer)s. Currently, they are used
5+
//! by the [`Session`](crate::transport::session::Session) to determine whether
6+
//! connections should be opened to a given node or not.
7+
8+
use std::collections::HashSet;
9+
use std::io::Error;
10+
use std::net::{SocketAddr, ToSocketAddrs};
11+
12+
use crate::transport::topology::Peer;
13+
14+
/// The `HostFilter` trait.
15+
pub trait HostFilter: Send + Sync {
16+
/// Returns whether a peer should be accepted or not.
17+
fn accept(&self, peer: &Peer) -> bool;
18+
}
19+
20+
/// Unconditionally accepts all nodes.
21+
pub struct AcceptAllHostFilter;
22+
23+
impl HostFilter for AcceptAllHostFilter {
24+
fn accept(&self, _peer: &Peer) -> bool {
25+
true
26+
}
27+
}
28+
29+
/// Accepts nodes whose addresses are present in the allow list provided
30+
/// during filter's construction.
31+
pub struct AllowListHostFilter {
32+
allowed: HashSet<SocketAddr>,
33+
}
34+
35+
impl AllowListHostFilter {
36+
/// Creates a new `AllowListHostFilter` which only accepts nodes from the
37+
/// list.
38+
pub fn new<I, A>(allowed_iter: I) -> Result<Self, Error>
39+
where
40+
I: IntoIterator<Item = A>,
41+
A: ToSocketAddrs,
42+
{
43+
// I couldn't get the iterator combinators to work
44+
let mut allowed = HashSet::new();
45+
for item in allowed_iter {
46+
for addr in item.to_socket_addrs()? {
47+
allowed.insert(addr);
48+
}
49+
}
50+
51+
Ok(Self { allowed })
52+
}
53+
}
54+
55+
impl HostFilter for AllowListHostFilter {
56+
fn accept(&self, peer: &Peer) -> bool {
57+
self.allowed.contains(&peer.address)
58+
}
59+
}
60+
61+
/// Accepts nodes from given DC.
62+
pub struct DcHostFilter {
63+
local_dc: String,
64+
}
65+
66+
impl DcHostFilter {
67+
/// Creates a new `DcHostFilter` that accepts nodes only from the
68+
/// `local_dc`.
69+
pub fn new(local_dc: String) -> Self {
70+
Self { local_dc }
71+
}
72+
}
73+
74+
impl HostFilter for DcHostFilter {
75+
fn accept(&self, peer: &Peer) -> bool {
76+
peer.datacenter.as_ref() == Some(&self.local_dc)
77+
}
78+
}

scylla/src/transport/load_balancing/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ mod tests {
150150
keyspaces: HashMap::new(),
151151
};
152152

153-
ClusterData::new(info, &Default::default(), &HashMap::new(), &None)
153+
ClusterData::new(info, &Default::default(), &HashMap::new(), &None, None)
154154
}
155155

156156
pub const EMPTY_STATEMENT: Statement = Statement {

scylla/src/transport/load_balancing/token_aware.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,7 @@ mod tests {
345345
keyspaces,
346346
};
347347

348-
ClusterData::new(info, &Default::default(), &HashMap::new(), &None)
348+
ClusterData::new(info, &Default::default(), &HashMap::new(), &None, None)
349349
}
350350

351351
// creates ClusterData with info about 8 nodes living in two different datacenters
@@ -444,7 +444,7 @@ mod tests {
444444
keyspaces,
445445
};
446446

447-
ClusterData::new(info, &Default::default(), &HashMap::new(), &None)
447+
ClusterData::new(info, &Default::default(), &HashMap::new(), &None, None)
448448
}
449449

450450
// Used as child policy for TokenAwarePolicy tests

scylla/src/transport/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ mod cluster;
33
pub(crate) mod connection;
44
mod connection_pool;
55
pub mod downgrading_consistency_retry_policy;
6+
pub mod host_filter;
67
pub mod iterator;
78
pub mod load_balancing;
89
pub(crate) mod metrics;

scylla/src/transport/node.rs

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ pub struct Node {
2121
pub datacenter: Option<String>,
2222
pub rack: Option<String>,
2323

24-
pool: NodeConnectionPool,
24+
// If the node is filtered out by the host filter, this will be None
25+
pool: Option<NodeConnectionPool>,
2526

2627
down_marker: AtomicBool,
2728
}
@@ -40,9 +41,11 @@ impl Node {
4041
datacenter: Option<String>,
4142
rack: Option<String>,
4243
keyspace_name: Option<VerifiedKeyspaceName>,
44+
enabled: bool,
4345
) -> Self {
44-
let pool =
45-
NodeConnectionPool::new(address.ip(), address.port(), pool_config, keyspace_name);
46+
let pool = enabled.then(|| {
47+
NodeConnectionPool::new(address.ip(), address.port(), pool_config, keyspace_name)
48+
});
4649

4750
Node {
4851
address,
@@ -54,7 +57,7 @@ impl Node {
5457
}
5558

5659
pub fn sharder(&self) -> Option<Sharder> {
57-
self.pool.sharder()
60+
self.pool.as_ref()?.sharder()
5861
}
5962

6063
/// Get connection which should be used to connect using given token
@@ -63,18 +66,25 @@ impl Node {
6366
&self,
6467
token: Token,
6568
) -> Result<Arc<Connection>, QueryError> {
66-
self.pool.connection_for_token(token)
69+
self.get_pool()?.connection_for_token(token)
6770
}
6871

6972
/// Get random connection
7073
pub(crate) async fn random_connection(&self) -> Result<Arc<Connection>, QueryError> {
71-
self.pool.random_connection()
74+
self.get_pool()?.random_connection()
7275
}
7376

7477
pub fn is_down(&self) -> bool {
7578
self.down_marker.load(Ordering::Relaxed)
7679
}
7780

81+
/// Returns a boolean which indicates whether this node was is enabled.
82+
/// Only enabled nodes will have connections open. For disabled nodes,
83+
/// no connections will be opened.
84+
pub fn is_enabled(&self) -> bool {
85+
self.pool.is_some()
86+
}
87+
7888
pub(crate) fn change_down_marker(&self, is_down: bool) {
7989
self.down_marker.store(is_down, Ordering::Relaxed);
8090
}
@@ -83,15 +93,30 @@ impl Node {
8393
&self,
8494
keyspace_name: VerifiedKeyspaceName,
8595
) -> Result<(), QueryError> {
86-
self.pool.use_keyspace(keyspace_name).await
96+
if let Some(pool) = &self.pool {
97+
pool.use_keyspace(keyspace_name).await?;
98+
}
99+
Ok(())
87100
}
88101

89102
pub(crate) fn get_working_connections(&self) -> Result<Vec<Arc<Connection>>, QueryError> {
90-
self.pool.get_working_connections()
103+
self.get_pool()?.get_working_connections()
91104
}
92105

93106
pub(crate) async fn wait_until_pool_initialized(&self) {
94-
self.pool.wait_until_initialized().await
107+
if let Some(pool) = &self.pool {
108+
pool.wait_until_initialized().await;
109+
}
110+
}
111+
112+
fn get_pool(&self) -> Result<&NodeConnectionPool, QueryError> {
113+
self.pool.as_ref().ok_or_else(|| {
114+
QueryError::IoError(Arc::new(std::io::Error::new(
115+
std::io::ErrorKind::Other,
116+
"No connections in the pool: the node has been disabled \
117+
by the host filter",
118+
)))
119+
})
95120
}
96121
}
97122

scylla/src/transport/session.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ use crate::tracing::{GetTracingConfig, TracingEvent, TracingInfo};
3636
use crate::transport::cluster::{Cluster, ClusterData, ClusterNeatDebug};
3737
use crate::transport::connection::{Connection, ConnectionConfig, VerifiedKeyspaceName};
3838
use crate::transport::connection_pool::PoolConfig;
39+
use crate::transport::host_filter::HostFilter;
3940
use crate::transport::iterator::{PreparedIteratorConfig, RowIterator};
4041
use crate::transport::load_balancing::{
4142
LoadBalancingPolicy, RoundRobinPolicy, Statement, TokenAwarePolicy,
@@ -203,6 +204,11 @@ pub struct SessionConfig {
203204

204205
pub address_translator: Option<Arc<dyn AddressTranslator>>,
205206

207+
/// The host filter decides whether any connections should be opened
208+
/// to the node or not. The driver will also avoid filtered out nodes when
209+
/// re-establishing the control connection.
210+
pub host_filter: Option<Arc<dyn HostFilter>>,
211+
206212
/// If true, full schema metadata is fetched after successfully reaching a schema agreement.
207213
/// It is true by default but can be disabled if successive schema-altering statements should be performed.
208214
pub refresh_metadata_on_auto_schema_agreement: bool,
@@ -250,6 +256,7 @@ impl SessionConfig {
250256
auto_await_schema_agreement_timeout: Some(std::time::Duration::from_secs(60)),
251257
request_timeout: Some(Duration::from_secs(30)),
252258
address_translator: None,
259+
host_filter: None,
253260
refresh_metadata_on_auto_schema_agreement: true,
254261
}
255262
}
@@ -436,6 +443,7 @@ impl Session {
436443
config.get_pool_config(),
437444
config.fetch_schema_metadata,
438445
&config.address_translator,
446+
&config.host_filter,
439447
)
440448
.await?;
441449

scylla/src/transport/session_builder.rs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use super::load_balancing::LoadBalancingPolicy;
55
use super::session::{AddressTranslator, Session, SessionConfig};
66
use super::speculative_execution::SpeculativeExecutionPolicy;
77
use super::Compression;
8+
use crate::transport::host_filter::HostFilter;
89
use crate::transport::{connection_pool::PoolSize, retry_policy::RetryPolicy};
910
use std::net::SocketAddr;
1011
use std::sync::Arc;
@@ -622,6 +623,38 @@ impl SessionBuilder {
622623
self
623624
}
624625

626+
/// Sets the host filter. The host filter decides whether any connections
627+
/// should be opened to the node or not. The driver will also avoid
628+
/// those nodes when re-establishing the control connection.
629+
///
630+
/// See the [host filter](crate::transport::host_filter) module for a list
631+
/// of pre-defined filters. It is also possible to provide a custom filter
632+
/// by implementing the HostFilter trait.
633+
///
634+
/// # Example
635+
/// ```
636+
/// # use async_trait::async_trait;
637+
/// # use std::net::SocketAddr;
638+
/// # use std::sync::Arc;
639+
/// # use scylla::{Session, SessionBuilder};
640+
/// # use scylla::transport::session::{AddressTranslator, TranslationError};
641+
/// # use scylla::transport::host_filter::DcHostFilter;
642+
///
643+
/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
644+
/// // The session will only connect to nodes from "my-local-dc"
645+
/// let session: Session = SessionBuilder::new()
646+
/// .known_node("127.0.0.1:9042")
647+
/// .host_filter(Arc::new(DcHostFilter::new("my-local-dc".to_string())))
648+
/// .build()
649+
/// .await?;
650+
/// # Ok(())
651+
/// # }
652+
/// ```
653+
pub fn host_filter(mut self, filter: Arc<dyn HostFilter>) -> Self {
654+
self.config.host_filter = Some(filter);
655+
self
656+
}
657+
625658
/// Set the refresh metadata on schema agreement flag.
626659
/// The default is true.
627660
///

0 commit comments

Comments
 (0)