Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,23 @@ jobs:
job:
- target: x86_64-unknown-linux-gnu
os: ubuntu-latest
flags: --features=native-tls
flags: --features=native-tls,plugin-tests
- target: x86_64-unknown-linux-gnu
os: ubuntu-latest
flags: --no-default-features --features=native-tls,online-tests # disables rustls
flags: --no-default-features --features=native-tls,online-tests,plugin-tests # disables rustls
- target: x86_64-unknown-linux-gnu
os: ubuntu-latest
flags: --features=http3,http-message-signatures
flags: --features=http3,http-message-signatures,plugin-tests
rustflags: --cfg reqwest_unstable
- target: x86_64-apple-darwin
os: macos-15-intel
flags: --features=native-tls
flags: --features=native-tls,plugin-tests
- target: aarch64-apple-darwin
os: macos-latest
flags: --features=native-tls
flags: --features=native-tls,plugin-tests
- target: x86_64-pc-windows-msvc
os: windows-latest
flags: --features=native-tls
flags: --features=native-tls,plugin-tests
- target: x86_64-unknown-linux-musl
os: ubuntu-latest
use-cross: true
Expand Down
10 changes: 5 additions & 5 deletions .github/workflows/release.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,19 @@ jobs:
job:
- target: x86_64-unknown-linux-gnu
os: ubuntu-latest
flags: --features=native-tls
flags: --features=native-tls,plugin-tests
- target: x86_64-unknown-linux-gnu
os: ubuntu-latest
flags: --no-default-features --features=native-tls,online-tests # disables rustls
flags: --no-default-features --features=native-tls,online-tests,plugin-tests # disables rustls
- target: x86_64-apple-darwin
os: macos-15-intel
flags: --features=native-tls
flags: --features=native-tls,plugin-tests
- target: aarch64-apple-darwin
os: macos-latest
flags: --features=native-tls
flags: --features=native-tls,plugin-tests
- target: x86_64-pc-windows-msvc
os: windows-latest
flags: --features=native-tls
flags: --features=native-tls,plugin-tests
- target: x86_64-unknown-linux-musl
os: ubuntu-latest
use-cross: true
Expand Down
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ network-interface = ["dep:network-interface"]

online-tests = []
ipv6-tests = []
plugin-tests = []

[package.metadata.cross.build.env]
passthrough = ["CARGO_PROFILE_RELEASE_LTO"]
Expand Down
183 changes: 179 additions & 4 deletions src/auth.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,26 @@
use std::io;
use std::ffi::OsString;
use std::io::{self, Write};
use std::process;

use anyhow::Result;
use anyhow::{Context as _, Result, anyhow};
use base64::{Engine as _, engine::general_purpose::STANDARD as base64_standard};
use regex_lite::Regex;
use reqwest::StatusCode;
use reqwest::blocking::{Request, Response};
use reqwest::header::{AUTHORIZATION, HeaderValue, WWW_AUTHENTICATE};
use reqwest::header::{AUTHORIZATION, HeaderMap, HeaderValue, WWW_AUTHENTICATE};
use serde::{Deserialize, Serialize};

use crate::cli::AuthType;
use crate::middleware::{Context, Middleware};
use crate::netrc;
use crate::utils::clone_request;
use crate::utils::{clone_request, is_path};

#[derive(Debug, PartialEq, Eq)]
pub enum Auth {
Bearer(String),
Basic(String, Option<String>),
Digest(String, String),
Plugin(AuthPlugin),
}

impl Auth {
Expand All @@ -33,6 +38,7 @@ impl Auth {
))
}
AuthType::Bearer => Ok(Auth::Bearer(auth.into())),
AuthType::Plugin(..) => unreachable!(),
}
}

Expand All @@ -41,6 +47,7 @@ impl Auth {
AuthType::Basic => Some(Auth::Basic(entry.login?, Some(entry.password))),
AuthType::Bearer => Some(Auth::Bearer(entry.password)),
AuthType::Digest => Some(Auth::Digest(entry.login?, entry.password)),
AuthType::Plugin(..) => None,
}
}
}
Expand Down Expand Up @@ -97,6 +104,174 @@ impl Middleware for DigestAuthMiddleware<'_> {
}
}

#[derive(Debug, PartialEq, Eq)]
pub struct AuthPlugin {
name_or_path: OsString,
auth: Vec<String>,
state: serde_json::Value,
config: PluginConfig,
}

impl AuthPlugin {
pub fn new(name_or_path: OsString, auth: Vec<String>) -> Self {
AuthPlugin {
name_or_path,
auth,
state: serde_json::Value::Null,
config: PluginConfig {
requires_body: Some(false),
},
}
}
}

#[derive(Debug, PartialEq, Eq, Deserialize)]
struct PluginConfig {
requires_body: Option<bool>,
}

#[derive(Debug, Serialize, Deserialize)]
struct Header {
name: String,
value: String,
}

#[derive(Debug, Deserialize)]
struct PluginResponse {
remove_headers: Option<Vec<String>>,
add_headers: Option<Vec<Header>>,
set_state: Option<serde_json::Value>,
}

#[derive(Debug, Serialize)]
struct NextRequest {
method: String,
url: String,
headers: Vec<Header>,
#[serde(skip_serializing_if = "Option::is_none")]
body_base64: Option<String>,
}

#[derive(Debug, Serialize)]
#[serde(tag = "type", rename_all = "snake_case")]
enum PluginInput<'a, 'b> {
Configure,
BeforeRequest {
request: NextRequest,
auth: &'a [String],
state: &'b serde_json::Value,
},
}

impl<'a, 'b> PluginInput<'a, 'b> {
fn new(
request: &mut Request,
auth: &'a [String],
state: &'b serde_json::Value,
config: &PluginConfig,
) -> Result<Self> {
let mut body_base64 = None;
if config.requires_body == Some(true) {
if let Some(body) = request.body_mut() {
let body = body.buffer()?;
body_base64 = Some(base64_standard.encode(body));
}
}

let plugin_input = PluginInput::BeforeRequest {
request: NextRequest {
method: request.method().to_string(),
url: request.url().to_string(),
headers: request
.headers()
.iter()
.map(|(name, value)| {
Ok(Header {
name: name.to_string(),
value: value.to_str()?.into(),
})
})
.collect::<Result<Vec<_>>>()?,
body_base64,
},
auth,
state,
};

Ok(plugin_input)
}
}

impl AuthPlugin {
pub fn configure(&mut self) -> Result<()> {
let plugin_input = PluginInput::Configure;
self.config = serde_json::from_slice::<PluginConfig>(
&self.exec(&serde_json::to_vec(&plugin_input)?)?,
)?;
Ok(())
}

pub fn authenticate(&mut self, next_request: &mut Request) -> Result<()> {
let plugin_input = PluginInput::new(next_request, &self.auth, &self.state, &self.config)?;

let plugin_output = serde_json::from_slice::<PluginResponse>(
&self.exec(&serde_json::to_vec(&plugin_input)?)?,
)?;

if let Some(headers_to_remove) = plugin_output.remove_headers {
for header in headers_to_remove {
next_request.headers_mut().remove(header);
}
}
if let Some(headers_to_add) = plugin_output.add_headers {
next_request.headers_mut().extend(
headers_to_add
.iter()
.map(|Header { name, value }| Ok((name.try_into()?, value.try_into()?)))
.collect::<Result<HeaderMap>>()?,
);
}
if let Some(state) = plugin_output.set_state {
self.state = state
}
Ok(())
}

fn exec(&self, plugin_input: &[u8]) -> Result<Vec<u8>> {
let plugin_path = if is_path(&self.name_or_path) {
std::path::PathBuf::from(&self.name_or_path)
} else {
std::path::PathBuf::from(format!("xh-{}", self.name_or_path.to_string_lossy()))
};

log::debug!("Spawning plugin {:?}", plugin_path);
let mut child = process::Command::new(&plugin_path)
.env("XH_PLUGIN", "auth")
.stdin(process::Stdio::piped())
.stdout(process::Stdio::piped())
.spawn()
.map_err(|e| anyhow!("Unable to spawn plugin {:?}: {}", plugin_path, e))?;

let child_stdin = child.stdin.as_mut().unwrap();
log::debug!("Writing to plugin's stdin");
child_stdin.write_all(plugin_input)?;

let output = child
.wait_with_output()
.context("Failed to wait for plugin output")?;

if !output.status.success() {
if let Some(code) = output.status.code() {
return Err(anyhow!("Plugin exited with exit code {}", code));
} else {
return Err(anyhow!("Plugin exited no exit code"));
}
}

Ok(output.stdout)
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
Loading
Loading