Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 79 additions & 21 deletions crates/deadpool-postgres/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ use std::{
use deadpool::managed;
#[cfg(not(target_arch = "wasm32"))]
use tokio::spawn;
use tokio::sync::Semaphore;
use tokio::task::JoinHandle;
use tokio_postgres::{
types::Type, Client as PgClient, Config as PgConfig, Error, IsolationLevel, Statement,
Expand Down Expand Up @@ -288,6 +289,15 @@ struct StatementCacheKey<'a> {
types: Cow<'a, [Type]>,
}

// The contents of a [`StatementCache`].
enum StatementCacheValue {
// The previously prepared statement.
Statement(Statement),
// A semaphore limiting how many threads will try to prepare the statement
// at once.
Semaphore(Arc<Semaphore>),
}

/// Representation of a cache of [`Statement`]s.
///
/// [`StatementCache`] is bound to one [`Client`], and [`Statement`]s generated
Expand All @@ -308,7 +318,7 @@ struct StatementCacheKey<'a> {
/// and [`ClientWrapper::prepare_typed_cached()`] methods instead (or the
/// similar ones on [`Transaction`]).
pub struct StatementCache {
map: RwLock<HashMap<StatementCacheKey<'static>, Statement>>,
map: RwLock<HashMap<StatementCacheKey<'static>, StatementCacheValue>>,
size: AtomicUsize,
}

Expand Down Expand Up @@ -348,20 +358,15 @@ impl StatementCache {
types: Cow::Owned(types.to_owned()),
};
let mut map = self.map.write().unwrap();
let removed = map.remove(&key);
if removed.is_some() {
let _ = self.size.fetch_sub(1, Ordering::Relaxed);
match map.remove(&key) {
Some(StatementCacheValue::Statement(statement)) => {
// Decrease cache size only when removing a statement
let _ = self.size.fetch_sub(1, Ordering::Relaxed);
Some(statement)
}
Some(StatementCacheValue::Semaphore(_)) => None,
None => None,
}
removed
}

/// Returns a [`Statement`] from this [`StatementCache`].
fn get(&self, query: &str, types: &[Type]) -> Option<Statement> {
let key = StatementCacheKey {
query: Cow::Borrowed(query),
types: Cow::Borrowed(types),
};
self.map.read().unwrap().get(&key).map(ToOwned::to_owned)
}

/// Inserts a [`Statement`] into this [`StatementCache`].
Expand All @@ -371,7 +376,11 @@ impl StatementCache {
types: Cow::Owned(types.to_owned()),
};
let mut map = self.map.write().unwrap();
if map.insert(key, stmt).is_none() {
// Increase cache size if key was absent or when replacing a semaphore
// with a statement
if let None | Some(StatementCacheValue::Semaphore(_)) =
map.insert(key, StatementCacheValue::Statement(stmt))
{
let _ = self.size.fetch_add(1, Ordering::Relaxed);
}
}
Expand All @@ -394,14 +403,63 @@ impl StatementCache {
query: &str,
types: &[Type],
) -> Result<Statement, Error> {
match self.get(query, types) {
Some(statement) => Ok(statement),
None => {
let stmt = client.prepare_typed(query, types).await?;
self.insert(query, types, stmt.clone());
Ok(stmt)
let borrowed_key = StatementCacheKey {
query: Cow::Borrowed(query),
types: Cow::Borrowed(types),
};
// Each map entry is empty, contains a Semaphore, or contains a prepared
// Statement. The first call to `prepare_typed()` will insert a
// `Semaphore`. Then, tasks will wait for a semaphore permit before
// preparing the query, to ensure that only one task sends a `PREPARE`
// statement at once. At both steps in the process, the map will be read
// via a read lock first, then again via a write lock, before updating
// the entry.
let semaphore = {
let read_lock = self.map.read().unwrap();
match read_lock.get(&borrowed_key) {
// Fast path: statement already prepared.
Some(StatementCacheValue::Statement(stmt)) => {
return Ok(stmt.clone());
}
// Slow path: statement not yet prepared but semaphore exists.
Some(StatementCacheValue::Semaphore(semaphore)) => semaphore.clone(),
// Slow path: statement not yet prepared and no semaphore
// exists, so create one.
None => {
// Drop the read lock and upgrade to a write lock.
drop(read_lock);
let mut write_lock = self.map.write().unwrap();

// A statement may have been inserted while we waited for
// the write lock
match write_lock
.entry(StatementCacheKey {
query: Cow::Owned(query.to_owned()),
types: Cow::Owned(types.to_owned()),
})
.or_insert(StatementCacheValue::Semaphore(Arc::new(Semaphore::new(1))))
{
StatementCacheValue::Statement(stmt) => return Ok(stmt.clone()),
StatementCacheValue::Semaphore(semaphore) => semaphore.clone(),
}
}
}
};

// unwrap safety: we never close the semaphore.
let _permit = semaphore.acquire().await.unwrap();
// A statement may have been inserted while we waited to acquire the
// semaphore.
if let Some(StatementCacheValue::Statement(stmt)) =
self.map.read().unwrap().get(&borrowed_key)
{
return Ok(stmt.clone());
}
// Still no statement in the cache, so do the expensive statement
// preparation.
let stmt = client.prepare_typed(query, types).await?;
self.insert(query, types, stmt.clone());
Ok(stmt)
}
}

Expand Down
Loading