diff --git a/Cargo.toml b/Cargo.toml index e50ca71..1a8e556 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,6 +30,7 @@ futures = "0.3.31" hex = "0.4.3" http = "1.3.1" http-body = "1.0.1" +http-body-util = "0.1.3" http-cache-semantics = "2.1.0" http-serde = "2.1.1" httpdate = "1.0.3" @@ -41,7 +42,7 @@ serde = { version = "1.0.219", features = ["derive"] } sha2 = "0.10.9" smol = "2.0.2" tempfile = "3.20.0" -tokio = { version = "1.47.1", default-features = false, features = ["fs", "io-util", "rt"] } +tokio = { version = "1.47.1", default-features = false, features = ["fs", "io-util", "rt", "macros"] } tokio-util = { version = "0.7.16", features = ["io"] } tracing = "0.1.41" @@ -60,6 +61,7 @@ futures = { workspace = true } hex = { workspace = true } http = { workspace = true } http-body = { workspace = true } +http-body-util = { workspace = true } http-cache-semantics = { workspace = true } http-serde = { workspace = true } httpdate = { workspace = true } diff --git a/README.md b/README.md index 8d71969..89f2ef4 100644 --- a/README.md +++ b/README.md @@ -30,8 +30,8 @@ ## Overview -The `http-cache-stream` crate can be used to cache responses in accordance with -HTTP caching semantics. +The `http-cache-stream` crate can be used to cache streaming responses in +accordance with HTTP caching semantics. ### How this crate differs from [`http-cache`][http-cache] @@ -42,8 +42,9 @@ client APIs. The `http-cache-stream` crate is inspired by the implementation provided by `http-cache`, but differs in significant ways: -* `http-cache-stream` supports streaming of requests/responses and does not - read a response body into memory to store in the cache. +* ~~`http-cache-stream` supports streaming of requests/responses and does not + read a response body into memory to store in the cache.~~ (streaming is now + supported in `http-cache`) * The default storage implementation for `http-cache-stream` uses advisory file locking to coordinate access to storage across multiple processes and threads. * The default storage implementation is simple and provides no integrity of diff --git a/crates/reqwest/Cargo.toml b/crates/reqwest/Cargo.toml index 7fb726d..affa4aa 100644 --- a/crates/reqwest/Cargo.toml +++ b/crates/reqwest/Cargo.toml @@ -10,13 +10,14 @@ homepage = { workspace = true } repository = { workspace = true } [dependencies] -http-cache-stream = { path = "../.." , version = "0.2.0" } anyhow = { workspace = true } +bytes = { workspace = true } +futures = { workspace = true } +http-body-util = { workspace = true } +http-cache-stream = { path = "../..", version = "0.2.0" } reqwest = { workspace = true, features = ["stream"] } reqwest-middleware = { workspace = true } -futures = { workspace = true } -bytes = { workspace = true } -pin-project-lite = { workspace = true } [dev-dependencies] +tempfile = { workspace = true } tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread"] } diff --git a/crates/reqwest/src/lib.rs b/crates/reqwest/src/lib.rs index 5d34b0b..69785d9 100644 --- a/crates/reqwest/src/lib.rs +++ b/crates/reqwest/src/lib.rs @@ -27,22 +27,16 @@ #![warn(clippy::missing_docs_in_private_items)] #![warn(rustdoc::broken_intra_doc_links)] -use std::io; -use std::pin::Pin; -use std::task::Context; -use std::task::Poll; - use anyhow::Context as _; use anyhow::Result; -use bytes::Bytes; use futures::FutureExt; use futures::future::BoxFuture; +use http_body_util::BodyDataStream; pub use http_cache_stream::X_CACHE; pub use http_cache_stream::X_CACHE_DIGEST; pub use http_cache_stream::X_CACHE_LOOKUP; use http_cache_stream::http::Extensions; use http_cache_stream::http::Uri; -use http_cache_stream::http_body::Frame; pub use http_cache_stream::semantics; pub use http_cache_stream::semantics::CacheOptions; pub use http_cache_stream::storage; @@ -54,29 +48,6 @@ use reqwest::ResponseBuilderExt; use reqwest::header::HeaderMap; use reqwest_middleware::Next; -pin_project_lite::pin_project! { - /// Adapter for [`Body`] to implement `HttpBody`. - struct MiddlewareBody { - #[pin] - body: Body - } -} - -impl http_cache_stream::http_body::Body for MiddlewareBody { - type Data = Bytes; - type Error = io::Error; - - fn poll_frame( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll, Self::Error>>> { - // The two body implementations differ on error type, so map it here - self.project().body.poll_frame(cx).map_err(io::Error::other) - } -} - -impl http_cache_stream::HttpBody for MiddlewareBody {} - /// Represents a request flowing through the cache middleware. struct MiddlewareRequest<'a, 'b> { /// The request URI. @@ -89,7 +60,7 @@ struct MiddlewareRequest<'a, 'b> { extensions: &'b mut Extensions, } -impl http_cache_stream::Request for MiddlewareRequest<'_, '_> { +impl http_cache_stream::Request for MiddlewareRequest<'_, '_> { fn version(&self) -> http_cache_stream::http::Version { self.request.version() } @@ -109,7 +80,7 @@ impl http_cache_stream::Request for MiddlewareRequest<'_, '_> { async fn send( mut self, headers: Option, - ) -> anyhow::Result> { + ) -> anyhow::Result> { // Override the specified headers if let Some(headers) = headers { self.request.headers_mut().extend(headers); @@ -133,9 +104,7 @@ impl http_cache_stream::Request for MiddlewareRequest<'_, '_> { .expect("should have headers") .extend(headers); builder - .body(MiddlewareBody { - body: Body::wrap_stream(response.bytes_stream()), - }) + .body(response.into()) .context("failed to create response") } } @@ -205,9 +174,326 @@ impl reqwest_middleware::Middleware for Cache { .0 .send(request) .await - .map(|r| r.map(Body::wrap_stream).into())?; + .map(|r| r.map(|b| Body::wrap_stream(BodyDataStream::new(b))).into())?; Ok(response) } .boxed() } } + +#[cfg(test)] +mod test { + use std::sync::Arc; + use std::sync::Mutex; + + use http_cache_stream::http; + use http_cache_stream::storage::DefaultCacheStorage; + use reqwest::Response; + use reqwest::StatusCode; + use reqwest::header; + use reqwest_middleware::ClientWithMiddleware; + use reqwest_middleware::Middleware; + use tempfile::tempdir; + + use super::*; + + struct MockMiddlewareState { + responses: Vec>, + current: usize, + } + + struct MockMiddleware(Mutex); + + impl MockMiddleware { + fn new(responses: impl IntoIterator) -> Self + where + R: Into, + { + Self(Mutex::new(MockMiddlewareState { + responses: responses.into_iter().map(|r| Some(r.into())).collect(), + current: 0, + })) + } + } + + impl Middleware for MockMiddleware { + fn handle<'a, 'b, 'c, 'd>( + &'a self, + _: Request, + _: &'b mut Extensions, + _: Next<'c>, + ) -> BoxFuture<'d, reqwest_middleware::Result> + where + 'a: 'd, + 'b: 'd, + 'c: 'd, + Self: 'd, + { + async { + let mut state = self.0.lock().unwrap(); + + let current = state.current; + state.current += 1; + + Ok(state + .responses + .get_mut(current) + .expect("unexpected client request: not enough responses defined") + .take() + .unwrap()) + } + .boxed() + } + } + + #[tokio::test] + async fn no_store() { + const BODY: &str = "hello world!"; + // Blake3 digest of the body (from https://emn178.github.io/online-tools/blake3/) + const DIGEST: &str = "3aa61c409fd7717c9d9c639202af2fae470c0ef669be7ba2caea5779cb534e9d"; + + let dir = tempdir().unwrap(); + let cache = Arc::new(Cache::new(DefaultCacheStorage::new(dir.path()))); + let mock = Arc::new(MockMiddleware::new([ + http::Response::builder() + .header(header::CACHE_CONTROL, "no-store") + .body(BODY) + .unwrap(), + http::Response::builder() + .header(header::CACHE_CONTROL, "no-store") + .body(BODY) + .unwrap(), + ])); + let client = ClientWithMiddleware::new( + Default::default(), + vec![cache.clone() as Arc, mock.clone()], + ); + + // Response should not be served from the cache or stored + let response = client.get("http://test.local/").send().await.unwrap(); + assert_eq!( + response.headers().get(header::CACHE_CONTROL).unwrap(), + "no-store" + ); + assert_eq!(response.headers().get(X_CACHE_LOOKUP).unwrap(), "MISS"); + assert_eq!(response.headers().get(X_CACHE).unwrap(), "MISS"); + assert!(response.headers().get(X_CACHE_DIGEST).is_none()); + assert_eq!(response.text().await.unwrap(), BODY); + + // Ensure the body wasn't stored in the cache + assert!(!cache.storage().body_path(DIGEST).is_file()); + + // Response should *still* not be served from the cache or stored + let response = client.get("http://test.local/").send().await.unwrap(); + assert_eq!( + response.headers().get(header::CACHE_CONTROL).unwrap(), + "no-store" + ); + assert_eq!(response.headers().get(X_CACHE_LOOKUP).unwrap(), "MISS"); + assert_eq!(response.headers().get(X_CACHE).unwrap(), "MISS"); + assert!(response.headers().get(X_CACHE_DIGEST).is_none()); + assert_eq!(response.text().await.unwrap(), BODY); + + // Ensure the body wasn't stored in the cache + assert!(!cache.storage().body_path(DIGEST).is_file()); + } + + #[tokio::test] + async fn max_age() { + const BODY: &str = "hello world!"; + // Blake3 digest of the body (from https://emn178.github.io/online-tools/blake3/) + const DIGEST: &str = "3aa61c409fd7717c9d9c639202af2fae470c0ef669be7ba2caea5779cb534e9d"; + + let dir = tempdir().unwrap(); + let cache = Arc::new( + Cache::new(DefaultCacheStorage::new(dir.path())) + .with_revalidation_hook(|_, _| panic!("a revalidation should not take place")), + ); + let mock = Arc::new(MockMiddleware::new([http::Response::builder() + .header(header::CACHE_CONTROL, "max-age=1000") + .body(BODY) + .unwrap()])); + let client = ClientWithMiddleware::new( + Default::default(), + vec![cache.clone() as Arc, mock.clone()], + ); + + // First response should not be served from the cache + let response = client.get("http://test.local/").send().await.unwrap(); + assert_eq!( + response.headers().get(header::CACHE_CONTROL).unwrap(), + "max-age=1000" + ); + assert_eq!(response.headers().get(X_CACHE_LOOKUP).unwrap(), "MISS"); + assert_eq!(response.headers().get(X_CACHE).unwrap(), "MISS"); + assert!(response.headers().get(X_CACHE_DIGEST).is_none()); + assert_eq!(response.text().await.unwrap(), BODY); + + // Ensure the body was stored in the cache + assert!(cache.storage().body_path(DIGEST).is_file()); + + // Second response should be served from the cache without revalidation + // If a revalidation is made, the mock middleware will panic since there was + // only one response defined + let response = client.get("http://test.local/").send().await.unwrap(); + assert_eq!( + response.headers().get(header::CACHE_CONTROL).unwrap(), + "max-age=1000" + ); + assert_eq!(response.headers().get(X_CACHE_LOOKUP).unwrap(), "HIT"); + assert_eq!(response.headers().get(X_CACHE).unwrap(), "HIT"); + assert_eq!( + response + .headers() + .get(X_CACHE_DIGEST) + .map(|v| v.to_str().unwrap()) + .unwrap(), + DIGEST + ); + assert_eq!(response.text().await.unwrap(), BODY); + } + + #[tokio::test] + async fn cache_hit_unmodified() { + const BODY: &str = "hello world!"; + // Blake3 digest of the body (from https://emn178.github.io/online-tools/blake3/) + const DIGEST: &str = "3aa61c409fd7717c9d9c639202af2fae470c0ef669be7ba2caea5779cb534e9d"; + + #[derive(Default)] + struct State { + revalidated: bool, + } + + let dir = tempdir().unwrap(); + let state = Arc::new(Mutex::new(State::default())); + let state_clone = state.clone(); + let cache = Arc::new( + Cache::new(DefaultCacheStorage::new(dir.path())).with_revalidation_hook(move |_, _| { + state_clone.lock().unwrap().revalidated = true; + Ok(()) + }), + ); + let mock = Arc::new(MockMiddleware::new([ + http::Response::builder().body(BODY).unwrap(), + http::Response::builder() + .status(StatusCode::NOT_MODIFIED) + .body("") + .unwrap(), + ])); + let client = ClientWithMiddleware::new( + Default::default(), + vec![cache.clone() as Arc, mock.clone()], + ); + + // First response should be a miss + let response = client.get("http://test.local/").send().await.unwrap(); + assert_eq!(response.headers().get(X_CACHE_LOOKUP).unwrap(), "MISS"); + assert_eq!(response.headers().get(X_CACHE).unwrap(), "MISS"); + assert!(response.headers().get(X_CACHE_DIGEST).is_none()); + assert_eq!(response.text().await.unwrap(), BODY); + + // Ensure the body was stored in the cache + assert!(cache.storage().body_path(DIGEST).is_file()); + + // Assert no revalidation took place + assert!(!state.lock().unwrap().revalidated); + + // Second response should be served from the cache + let response = client.get("http://test.local/").send().await.unwrap(); + assert_eq!(response.headers().get(X_CACHE_LOOKUP).unwrap(), "HIT"); + assert_eq!(response.headers().get(X_CACHE).unwrap(), "HIT"); + assert_eq!( + response + .headers() + .get(X_CACHE_DIGEST) + .map(|v| v.to_str().unwrap()) + .unwrap(), + DIGEST + ); + assert_eq!(response.text().await.unwrap(), BODY); + + // Assert a revalidation took place + assert!(state.lock().unwrap().revalidated); + } + + #[tokio::test] + async fn cache_hit_modified() { + const BODY: &str = "hello world!"; + const MODIFIED_BODY: &str = "hello world!!!"; + // Blake3 digest of the body (from https://emn178.github.io/online-tools/blake3/) + const DIGEST: &str = "3aa61c409fd7717c9d9c639202af2fae470c0ef669be7ba2caea5779cb534e9d"; + // Blake3 digest of the modified body (from https://emn178.github.io/online-tools/blake3/) + const MODIFIED_DIGEST: &str = + "22b8d362b2e8064356915b1451f630d1d920b427d3b2f9b3432fbf4c03d94184"; + + #[derive(Default)] + struct State { + revalidated: bool, + } + + let dir = tempdir().unwrap(); + let state = Arc::new(Mutex::new(State::default())); + let state_clone = state.clone(); + let cache = Arc::new( + Cache::new(DefaultCacheStorage::new(dir.path())).with_revalidation_hook(move |_, _| { + state_clone.lock().unwrap().revalidated = true; + Ok(()) + }), + ); + let mock = Arc::new(MockMiddleware::new([ + http::Response::builder().body(BODY).unwrap(), + http::Response::builder().body(MODIFIED_BODY).unwrap(), + http::Response::builder() + .status(StatusCode::NOT_MODIFIED) + .body("") + .unwrap(), + ])); + let client = ClientWithMiddleware::new( + Default::default(), + vec![cache.clone() as Arc, mock.clone()], + ); + + // First response should be a miss + let response = client.get("http://test.local/").send().await.unwrap(); + assert_eq!(response.headers().get(X_CACHE_LOOKUP).unwrap(), "MISS"); + assert_eq!(response.headers().get(X_CACHE).unwrap(), "MISS"); + assert!(response.headers().get(X_CACHE_DIGEST).is_none()); + assert_eq!(response.text().await.unwrap(), BODY); + + // Ensure the body was stored in the cache + assert!(cache.storage().body_path(DIGEST).is_file()); + + // Assert no revalidation took place + assert!(!state.lock().unwrap().revalidated); + + // Second response should not be served from the cache (was modified) + let response = client.get("http://test.local/").send().await.unwrap(); + assert_eq!(response.headers().get(X_CACHE_LOOKUP).unwrap(), "HIT"); + assert_eq!(response.headers().get(X_CACHE).unwrap(), "MISS"); + assert!(response.headers().get(X_CACHE_DIGEST).is_none()); + assert_eq!(response.text().await.unwrap(), MODIFIED_BODY); + + // Ensure the body was stored in the cache + assert!(cache.storage().body_path(MODIFIED_DIGEST).is_file()); + + // Assert a revalidation took place and reset the flag back to false + assert!(std::mem::take(&mut state.lock().unwrap().revalidated)); + + // Second response should be served from the cache (not modified) + let response = client.get("http://test.local/").send().await.unwrap(); + assert_eq!(response.headers().get(X_CACHE_LOOKUP).unwrap(), "HIT"); + assert_eq!(response.headers().get(X_CACHE).unwrap(), "HIT"); + assert_eq!( + response + .headers() + .get(X_CACHE_DIGEST) + .map(|v| v.to_str().unwrap()) + .unwrap(), + MODIFIED_DIGEST + ); + assert_eq!(response.text().await.unwrap(), MODIFIED_BODY); + + // Assert a revalidation took place + assert!(state.lock().unwrap().revalidated); + } +} diff --git a/src/body.rs b/src/body.rs index aa68932..734fc4f 100644 --- a/src/body.rs +++ b/src/body.rs @@ -14,13 +14,14 @@ use bytes::Bytes; use bytes::BytesMut; use futures::Stream; use futures::future::BoxFuture; +use http_body::Body; use http_body::Frame; +use http_body_util::BodyStream; use pin_project_lite::pin_project; use runtime::AsyncWrite; use tempfile::NamedTempFile; use tempfile::TempPath; -use crate::HttpBody; use crate::runtime; /// The default capacity for reading from files. @@ -34,7 +35,7 @@ pin_project! { ReadingUpstream { // The upstream response body. #[pin] - upstream: B, + upstream: BodyStream, // The writer for the cache file. #[pin] writer: Option>, @@ -99,7 +100,7 @@ impl CachingUpstreamSource { Ok(Self { state: CachingUpstreamSourceState::ReadingUpstream { - upstream, + upstream: BodyStream::new(upstream), writer: Some(runtime::BufWriter::new(file)), path: Some(path), callback: Some(Box::new(callback)), @@ -110,13 +111,19 @@ impl CachingUpstreamSource { } } -impl Stream for CachingUpstreamSource +impl Body for CachingUpstreamSource where - B: HttpBody, + B: Body, + B::Data: Into, + B::Error: Into>, { - type Item = io::Result; + type Data = Bytes; + type Error = Box; - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + fn poll_frame( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>>> { loop { let this = self.as_mut().project(); match this.state.project() { @@ -130,19 +137,25 @@ where } => { // Check to see if a read is needed if current.is_empty() { - match ready!(upstream.poll_next_data(cx)) { - Some(Ok(data)) if data.is_empty() => continue, - Some(Ok(data)) => { - // Update the hasher with the data that was read - hasher.update(&data); - *current = data; + match ready!(upstream.poll_next(cx)) { + Some(Ok(frame)) => { + let frame = frame.map_data(Into::into); + match frame.into_data() { + Ok(data) if !data.is_empty() => { + // Update the hasher with the data that was read + hasher.update(&data); + *current = data; + } + Ok(_) => continue, + Err(frame) => return Poll::Ready(Some(Ok(frame))), + } } Some(Err(e)) => { // Set state to finished and return self.set(Self { state: CachingUpstreamSourceState::Completed, }); - return Poll::Ready(Some(Err(e))); + return Poll::Ready(Some(Err(e.into()))); } None => { let writer = writer.take(); @@ -170,13 +183,13 @@ where return match ready!(writer.as_pin_mut().unwrap().poll_write(cx, &data)) { Ok(n) => { *current = data.split_off(n); - Poll::Ready(Some(Ok(data))) + Poll::Ready(Some(Ok(Frame::data(data)))) } Err(e) => { self.set(Self { state: CachingUpstreamSourceState::Completed, }); - Poll::Ready(Some(Err(e))) + Poll::Ready(Some(Err(e.into()))) } }; } @@ -205,7 +218,7 @@ where self.set(Self { state: CachingUpstreamSourceState::Completed, }); - return Poll::Ready(Some(Err(e))); + return Poll::Ready(Some(Err(e.into()))); } } } @@ -221,7 +234,7 @@ where self.set(Self { state: CachingUpstreamSourceState::Completed, }); - Poll::Ready(Some(Err(io::Error::other(e)))) + Poll::Ready(Some(Err(e.into_boxed_dyn_error()))) } }; } @@ -246,10 +259,14 @@ pin_project! { } } -impl Stream for FileSource { - type Item = io::Result; +impl Body for FileSource { + type Data = Bytes; + type Error = io::Error; - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + fn poll_frame( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>>> { let this = self.project(); if *this.finished { @@ -269,7 +286,7 @@ impl Stream for FileSource { } Ok(_) => { let chunk = this.buf.split(); - Poll::Ready(Some(Ok(chunk.freeze()))) + Poll::Ready(Some(Ok(Frame::data(chunk.freeze())))) } Err(err) => { *this.finished = true; @@ -306,7 +323,7 @@ impl Stream for FileSource { unsafe { this.buf.advance_mut(n); } - Poll::Ready(Some(Ok(this.buf.split().freeze()))) + Poll::Ready(Some(Ok(Frame::data(this.buf.split().freeze())))) } Err(e) => { *this.finished = true; @@ -334,7 +351,7 @@ pin_project! { Upstream { // The underlying source for the body. #[pin] - source: B + source: BodyStream }, /// The body is coming from upstream with being cached. CachingUpstream { @@ -352,23 +369,27 @@ pin_project! { } pin_project! { - /// Represents a response body. - pub struct Body { + /// Represents a cache body. + /// + /// The cache body may be sourced from an upstream response or from a file from the cache. + pub struct CacheBody { // The body source. #[pin] source: BodySource } } -impl Body +impl CacheBody where - B: HttpBody, + B: Body, { /// Constructs a new body from an upstream response body that is not being /// cached. pub(crate) fn from_upstream(upstream: B) -> Self { Self { - source: BodySource::Upstream { source: upstream }, + source: BodySource::Upstream { + source: BodyStream::new(upstream), + }, } } @@ -406,23 +427,26 @@ where } } -impl http_body::Body for Body +impl Body for CacheBody where - B: HttpBody, + B: Body, + B::Data: Into, + B::Error: Into>, { type Data = Bytes; - type Error = io::Error; + type Error = Box; fn poll_frame( self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll, io::Error>>> { + ) -> Poll, Self::Error>>> { match self.project().source.project() { - ProjectedBodySource::Upstream { source } => source.poll_frame(cx), - ProjectedBodySource::CachingUpstream { source } => { - source.poll_next(cx).map_ok(Frame::data) - } - ProjectedBodySource::File { source } => source.poll_next(cx).map_ok(Frame::data), + ProjectedBodySource::Upstream { source } => source + .poll_frame(cx) + .map_ok(|f| f.map_data(Into::into)) + .map_err(Into::into), + ProjectedBodySource::CachingUpstream { source } => source.poll_frame(cx), + ProjectedBodySource::File { source } => source.poll_frame(cx).map_err(Into::into), } } @@ -438,10 +462,10 @@ where fn size_hint(&self) -> http_body::SizeHint { match &self.source { - BodySource::Upstream { source } => source.size_hint(), + BodySource::Upstream { source } => Body::size_hint(source), BodySource::CachingUpstream { source } => match &source.state { CachingUpstreamSourceState::ReadingUpstream { upstream, .. } => { - upstream.size_hint() + Body::size_hint(upstream) } _ => http_body::SizeHint::default(), }, @@ -449,25 +473,3 @@ where } } } - -impl HttpBody for Body where B: HttpBody + Send {} - -/// An implementation of `Stream` for body. -/// -/// This implementation only retrieves the data frames of the body. -/// -/// Trailer frames are not read. -impl Stream for Body -where - B: HttpBody, -{ - type Item = io::Result; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match self.project().source.project() { - ProjectedBodySource::Upstream { source } => source.poll_next_data(cx), - ProjectedBodySource::CachingUpstream { source } => source.poll_next(cx), - ProjectedBodySource::File { source } => source.poll_next(cx), - } - } -} diff --git a/src/cache.rs b/src/cache.rs index d05a359..b05676e 100644 --- a/src/cache.rs +++ b/src/cache.rs @@ -1,15 +1,9 @@ //! Implementation of the HTTP cache. use std::fmt; -use std::io; -use std::pin::Pin; -use std::task::Context; -use std::task::Poll; -use std::task::ready; use std::time::SystemTime; use anyhow::Result; -use bytes::Bytes; use http::HeaderMap; use http::HeaderValue; use http::Method; @@ -20,6 +14,7 @@ use http::Version; use http::header; use http::header::CACHE_CONTROL; use http::uri::Authority; +use http_body::Body; use http_cache_semantics::AfterResponse; use http_cache_semantics::BeforeRequest; use http_cache_semantics::CacheOptions; @@ -28,7 +23,7 @@ use sha2::Digest; use sha2::Sha256; use tracing::debug; -use crate::body::Body; +use crate::body::CacheBody; use crate::storage::CacheStorage; use crate::storage::StoredResponse; @@ -206,31 +201,11 @@ impl ResponseExt for Response { } } -/// Represents the supported HTTP body trait from middleware integrations. -pub trait HttpBody: http_body::Body + Send { - /// Polls the next data frame as bytes. - /// - /// Returns end of stream after all data frames, thereby ignoring trailers. - fn poll_next_data( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll>> { - match ready!(self.poll_frame(cx)) { - Some(Ok(frame)) => match frame.into_data().ok() { - Some(data) => Poll::Ready(Some(Ok(data))), - None => Poll::Ready(None), - }, - Some(Err(e)) => Poll::Ready(Some(Err(e))), - None => Poll::Ready(None), - } - } -} - /// An abstraction of an HTTP request. /// /// This trait is used in HTTP middleware integrations to abstract the request /// type and sending the request upstream. -pub trait Request: Send { +pub trait Request: Send { /// Gets the request's version. fn version(&self) -> Version; @@ -262,7 +237,7 @@ struct RequestLike { impl RequestLike { /// Constructs a new `RequestLike` for the given request. - fn new, B: HttpBody>(request: &R) -> Self { + fn new, B: Body>(request: &R) -> Self { // Unfortunate we have to clone the header map here Self { method: request.method().clone(), @@ -382,7 +357,10 @@ where /// /// If a previous response is not in the cache, the request is sent upstream /// and the response is cached, if it is cacheable. - pub async fn send(&self, request: impl Request) -> Result>> { + pub async fn send( + &self, + request: impl Request, + ) -> Result>> { let method = request.method(); let uri = request.uri(); @@ -433,12 +411,12 @@ where /// Sends the original request upstream. /// /// Caches the response if the response is cacheable. - async fn send_upstream( + async fn send_upstream( &self, key: String, request: impl Request, lookup_status: CacheLookupStatus, - ) -> Result>> { + ) -> Result>> { let request_like: RequestLike = RequestLike::new(&request); let mut response = request.send(None).await?; @@ -498,7 +476,7 @@ where } } - Ok(response.map(Body::from_upstream)) + Ok(response.map(CacheBody::from_upstream)) } /// Performs a conditional send to upstream. @@ -506,12 +484,12 @@ where /// If a cached request is still fresh, it is returned. /// /// If a cached request is stale, an attempt is made to revalidate it. - async fn conditional_send_upstream( + async fn conditional_send_upstream( &self, key: String, request: impl Request, mut stored: StoredResponse, - ) -> Result>> { + ) -> Result>> { let request_like = RequestLike::new(&request); let mut headers = match stored @@ -723,7 +701,7 @@ where // Otherwise, don't serve the cached response at all response.set_cache_status(CacheLookupStatus::Hit, CacheStatus::Miss, None); - Ok(response.map(Body::from_upstream)) + Ok(response.map(CacheBody::from_upstream)) } Err(e) => { if stored.response.must_revalidate() { @@ -752,7 +730,7 @@ where } /// Prepares a stale response for sending back to the client. - fn prepare_stale_response(uri: &Uri, response: &mut Response>, digest: &str) { + fn prepare_stale_response(uri: &Uri, response: &mut Response>, digest: &str) { // If the server failed to give us a response, add the required warning to the // cached response: // 111 Revalidation failed diff --git a/src/storage.rs b/src/storage.rs index 7c6d82c..aefa132 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -5,19 +5,19 @@ use std::path::PathBuf; use anyhow::Result; use http::Response; use http::response::Parts; +use http_body::Body; use http_cache_semantics::CachePolicy; -use crate::HttpBody; -use crate::body::Body; +use crate::body::CacheBody; mod default; pub use default::*; /// Represents a response from storage. -pub struct StoredResponse { +pub struct StoredResponse { /// The cached response. - pub response: Response>, + pub response: Response>, /// The current cache policy. pub policy: CachePolicy, /// The response content digest. @@ -32,7 +32,7 @@ pub trait CacheStorage: Send + Sync + 'static { /// /// Returns `Ok(None)` if a response does not exist in the storage for the /// given response key. - fn get( + fn get( &self, key: &str, ) -> impl Future>>> + Send; @@ -53,13 +53,13 @@ pub trait CacheStorage: Send + Sync + 'static { /// Stores a new response body in the cache. /// /// Returns a response with a body streaming to the cache. - fn store( + fn store( &self, key: String, parts: Parts, body: B, policy: CachePolicy, - ) -> impl Future>>> + Send; + ) -> impl Future>>> + Send; /// Deletes a previously cached response for the given response key. /// diff --git a/src/storage/default.rs b/src/storage/default.rs index 5e41449..add89ab 100644 --- a/src/storage/default.rs +++ b/src/storage/default.rs @@ -15,14 +15,14 @@ use http::Response; use http::StatusCode; use http::Version; use http::response::Parts; +use http_body::Body; use http_cache_semantics::CachePolicy; use serde::Deserialize; use serde::Serialize; use tracing::debug; use super::StoredResponse; -use crate::HttpBody; -use crate::body::Body; +use crate::body::CacheBody; use crate::runtime; use crate::storage::CacheStorage; @@ -172,7 +172,7 @@ impl DefaultCacheStorage { } impl CacheStorage for DefaultCacheStorage { - async fn get(&self, key: &str) -> Result>> { + async fn get(&self, key: &str) -> Result>> { let cached = match self.0.read_response(key).await? { Some(response) => response, None => return Ok(None), @@ -209,7 +209,7 @@ impl CacheStorage for DefaultCacheStorage { Ok(Some(StoredResponse { response: builder - .body(Body::from_file(body).await.with_context(|| { + .body(CacheBody::from_file(body).await.with_context(|| { format!( "failed to create response body for `{path}`", path = path.display() @@ -242,13 +242,13 @@ impl CacheStorage for DefaultCacheStorage { .await } - async fn store( + async fn store( &self, key: String, parts: Parts, body: B, policy: CachePolicy, - ) -> Result>> { + ) -> Result>> { // Create a temporary file for the download of the body let inner = self.0.clone(); let temp_dir = inner.temp_dir_path(); @@ -265,7 +265,7 @@ impl CacheStorage for DefaultCacheStorage { let version = parts.version; let headers = parts.headers.clone(); - let body = Body::from_caching_upstream(body, &temp_dir, move |digest, path| { + let body = CacheBody::from_caching_upstream(body, &temp_dir, move |digest, path| { async move { let content_path = inner.content_path(&digest); fs::create_dir_all(content_path.parent().expect("should have parent")) @@ -479,3 +479,107 @@ impl DefaultCacheStorageInner { Ok(file) } } + +#[cfg(all(test, feature = "tokio"))] +mod test { + use futures::StreamExt; + use http::Request; + use http_body_util::BodyDataStream; + use http_cache_semantics::CachePolicy; + use tempfile::tempdir; + + use super::*; + + #[tokio::test] + async fn cache_miss() { + let dir = tempdir().unwrap(); + let storage = DefaultCacheStorage::new(dir.path()); + assert!( + storage + .get::("does-not-exist") + .await + .expect("should not fail") + .is_none() + ); + } + + #[tokio::test] + async fn cache_hit() { + const KEY: &str = "key"; + const BODY: &str = "hello world"; + const DIGEST: &str = "d74981efa70a0c880b8d8c1985d075dbcbf679b99a5f9914e5aaf96b831a9e24"; + const HEADER_NAME: &str = "foo"; + const HEADER_VALUE: &str = "bar"; + + let dir = tempdir().unwrap(); + let storage = DefaultCacheStorage::new(dir.path()); + + // Assert the key doesn't currently exist in the cache + assert!(storage.get::(KEY).await.unwrap().is_none()); + + // Store a response in the cache + let request = Request::builder().body("").unwrap(); + let response = Response::builder().body(BODY.to_string()).unwrap(); + let policy: CachePolicy = CachePolicy::new(&request, &response); + + let (parts, body) = response.into_parts(); + let response = storage + .store(KEY.to_string(), parts, body, policy) + .await + .unwrap(); + + // Read the response to the end to fully cache the body + let mut stream = BodyDataStream::new(response.into_body()); + let data = stream.next().await.unwrap().unwrap(); + assert!(stream.next().await.is_none()); + assert_eq!(data, BODY); + drop(stream); + + // Lookup the cache entry (should exist now, without the header) + let cached = storage.get::(KEY).await.unwrap().unwrap(); + assert!(cached.response.headers().get(HEADER_NAME).is_none()); + + // Read the cached response + let data = BodyDataStream::new(cached.response.into_body()) + .next() + .await + .unwrap() + .unwrap(); + assert_eq!(data, BODY); + assert_eq!(cached.digest, DIGEST); + + // Create an "updated" response and put it into the cache with the same body + let response = Response::builder() + .header(HEADER_NAME, HEADER_VALUE) + .body(BODY.to_string()) + .unwrap(); + let policy = CachePolicy::new(&request, &response); + + let (parts, _) = response.into_parts(); + storage.put(KEY, &parts, &policy, DIGEST).await.unwrap(); + + // Lookup the cache entry (should exist with the header) + let cached = storage.get::(KEY).await.unwrap().unwrap(); + assert_eq!( + cached + .response + .headers() + .get(HEADER_NAME) + .map(|v| v.to_str().unwrap()), + Some(HEADER_VALUE) + ); + + // Read the cached response (should be unchanged) + let data = BodyDataStream::new(cached.response.into_body()) + .next() + .await + .unwrap() + .unwrap(); + assert_eq!(data, BODY); + assert_eq!(cached.digest, DIGEST); + + // Delete the key and ensure it no longer exists + storage.delete(KEY).await.unwrap(); + assert!(storage.get::(KEY).await.unwrap().is_none()); + } +}