diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index 6ca084d188..ef5d6b272b 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -5097,10 +5097,10 @@ dependencies = [ [[package]] name = "rmcp" -version = "0.8.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e5947688160b56fb6c827e3c20a72c90392a1d7e9dec74749197aa1780ac42ca" +version = "0.9.0" +source = "git+https://github.com/bolinfest/rust-sdk?branch=pr556#4d9cc16f4c76c84486344f542ed9a3e9364019ba" dependencies = [ + "async-trait", "base64", "bytes", "chrono", @@ -5131,9 +5131,8 @@ dependencies = [ [[package]] name = "rmcp-macros" -version = "0.8.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01263441d3f8635c628e33856c468b96ebbce1af2d3699ea712ca71432d4ee7a" +version = "0.9.0" +source = "git+https://github.com/bolinfest/rust-sdk?branch=pr556#4d9cc16f4c76c84486344f542ed9a3e9364019ba" dependencies = [ "darling 0.21.3", "proc-macro2", diff --git a/codex-rs/Cargo.toml b/codex-rs/Cargo.toml index 6fb285c256..e1f64cb5d9 100644 --- a/codex-rs/Cargo.toml +++ b/codex-rs/Cargo.toml @@ -108,8 +108,8 @@ async-trait = "0.1.89" axum = { version = "0.8", default-features = false } base64 = "0.22.1" bytes = "1.10.1" -chrono = "0.4.42" chardetng = "0.1.17" +chrono = "0.4.42" clap = "4" clap_complete = "4" color-eyre = "0.6.3" @@ -120,9 +120,9 @@ diffy = "0.4.2" dirs = "6" dotenvy = "0.15.7" dunce = "1.0.4" +encoding_rs = "0.8.35" env-flags = "0.1.1" env_logger = "0.11.5" -encoding_rs = "0.8.35" escargot = "0.5" eventsource-stream = "0.2.3" futures = { version = "0.3", default-features = false } @@ -167,7 +167,7 @@ ratatui-macros = "0.6.0" regex-lite = "0.1.7" regex = "1.11.1" reqwest = "0.12" -rmcp = { version = "0.8.5", default-features = false } +rmcp = { version = "0.9.0", default-features = false } schemars = "0.8.22" seccompiler = "0.5.0" serde = "1" @@ -261,11 +261,7 @@ unwrap_used = "deny" # cargo-shear cannot see the platform-specific openssl-sys usage, so we # silence the false positive here instead of deleting a real dependency. [workspace.metadata.cargo-shear] -ignored = [ - "icu_provider", - "openssl-sys", - "codex-utils-readiness", -] +ignored = ["icu_provider", "openssl-sys", "codex-utils-readiness"] [profile.release] lto = "fat" @@ -286,6 +282,7 @@ opt-level = 0 # ratatui = { path = "../../ratatui" } crossterm = { git = "https://github.com/nornagon/crossterm", branch = "nornagon/color-query" } ratatui = { git = "https://github.com/nornagon/ratatui", branch = "nornagon-v0.29.0-patch" } +rmcp = { git = "https://github.com/bolinfest/rust-sdk", branch = "pr556" } # Uncomment to debug local changes. # rmcp = { path = "../../rust-sdk/crates/rmcp" } diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index 019ba8ce37..a91024348a 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -5,6 +5,7 @@ use std::sync::Arc; use std::sync::atomic::AtomicU64; use crate::AuthManager; +use crate::SandboxState; use crate::client_common::REVIEW_PROMPT; use crate::compact; use crate::compact::run_inline_auto_compact_task; @@ -614,6 +615,22 @@ impl Session { ) .await; + let sandbox_state = SandboxState { + sandbox_policy: session_configuration.sandbox_policy.clone(), + codex_linux_sandbox_exe: config.codex_linux_sandbox_exe.clone(), + sandbox_cwd: session_configuration.cwd.clone(), + }; + if let Err(e) = sess + .services + .mcp_connection_manager + .read() + .await + .notify_sandbox_state_change(&sandbox_state) + .await + { + tracing::error!("Failed to notify sandbox state change: {e}"); + } + // record_initial_history can emit events. We record only after the SessionConfiguredEvent is emitted. sess.record_initial_history(initial_history).await; diff --git a/codex-rs/core/src/lib.rs b/codex-rs/core/src/lib.rs index 6906489e7e..805943a2e7 100644 --- a/codex-rs/core/src/lib.rs +++ b/codex-rs/core/src/lib.rs @@ -32,6 +32,9 @@ pub mod git_info; pub mod landlock; pub mod mcp; mod mcp_connection_manager; +pub use mcp_connection_manager::MCP_SANDBOX_STATE_CAPABILITY; +pub use mcp_connection_manager::MCP_SANDBOX_STATE_NOTIFICATION; +pub use mcp_connection_manager::SandboxState; mod mcp_tool_call; mod message_history; mod model_provider_info; diff --git a/codex-rs/core/src/mcp_connection_manager.rs b/codex-rs/core/src/mcp_connection_manager.rs index e1b05cef48..22cb84e2c9 100644 --- a/codex-rs/core/src/mcp_connection_manager.rs +++ b/codex-rs/core/src/mcp_connection_manager.rs @@ -10,6 +10,7 @@ use std::collections::HashMap; use std::collections::HashSet; use std::env; use std::ffi::OsString; +use std::path::PathBuf; use std::sync::Arc; use std::sync::Mutex; use std::time::Duration; @@ -28,6 +29,7 @@ use codex_protocol::protocol::McpStartupCompleteEvent; use codex_protocol::protocol::McpStartupFailure; use codex_protocol::protocol::McpStartupStatus; use codex_protocol::protocol::McpStartupUpdateEvent; +use codex_protocol::protocol::SandboxPolicy; use codex_rmcp_client::ElicitationResponse; use codex_rmcp_client::OAuthCredentialsStoreMode; use codex_rmcp_client::RmcpClient; @@ -48,6 +50,8 @@ use mcp_types::Resource; use mcp_types::ResourceTemplate; use mcp_types::Tool; +use serde::Deserialize; +use serde::Serialize; use serde_json::json; use sha1::Digest; use sha1::Sha1; @@ -174,6 +178,7 @@ struct ManagedClient { tools: Vec, tool_filter: ToolFilter, tool_timeout: Option, + server_supports_sandbox_state_capability: bool, } #[derive(Clone)] @@ -222,6 +227,35 @@ impl AsyncManagedClient { async fn client(&self) -> Result { self.client.clone().await } + + async fn notify_sandbox_state_change(&self, sandbox_state: &SandboxState) -> Result<()> { + let managed = self.client().await?; + if !managed.server_supports_sandbox_state_capability { + return Ok(()); + } + + managed + .client + .send_custom_notification( + MCP_SANDBOX_STATE_NOTIFICATION, + Some(serde_json::to_value(sandbox_state)?), + ) + .await + } +} + +pub const MCP_SANDBOX_STATE_CAPABILITY: &str = "codex/sandbox-state"; + +/// Custom MCP notification for sandbox state updates. +/// When used, the `params` field of the notification is [`SandboxState`]. +pub const MCP_SANDBOX_STATE_NOTIFICATION: &str = "codex/sandbox-state/update"; + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SandboxState { + pub sandbox_policy: SandboxPolicy, + pub codex_linux_sandbox_exe: Option, + pub sandbox_cwd: PathBuf, } /// A thin wrapper around a set of running [`RmcpClient`] instances. @@ -567,6 +601,34 @@ impl McpConnectionManager { .get(tool_name) .map(|tool| (tool.server_name.clone(), tool.tool_name.clone())) } + + pub async fn notify_sandbox_state_change(&self, sandbox_state: &SandboxState) -> Result<()> { + let mut join_set = JoinSet::new(); + + for async_managed_client in self.clients.values() { + let sandbox_state = sandbox_state.clone(); + let async_managed_client = async_managed_client.clone(); + join_set.spawn(async move { + async_managed_client + .notify_sandbox_state_change(&sandbox_state) + .await + }); + } + + while let Some(join_res) = join_set.join_next().await { + match join_res { + Ok(Ok(())) => {} + Ok(Err(err)) => { + warn!("Failed to notify sandbox state change to MCP server: {err:#}"); + } + Err(err) => { + warn!("Task panic when notifying sandbox state change to MCP server: {err:#}"); + } + } + } + + Ok(()) + } } async fn emit_update( @@ -700,7 +762,7 @@ async fn start_server_task( let send_elicitation = elicitation_requests.make_sender(server_name.clone(), tx_event); - client + let initialize_result = client .initialize(params, startup_timeout, send_elicitation) .await .map_err(StartupOutcomeError::from)?; @@ -709,11 +771,19 @@ async fn start_server_task( .await .map_err(StartupOutcomeError::from)?; + let server_supports_sandbox_state_capability = initialize_result + .capabilities + .experimental + .as_ref() + .and_then(|exp| exp.get(MCP_SANDBOX_STATE_CAPABILITY)) + .is_some(); + let managed = ManagedClient { client: Arc::clone(&client), tools, tool_timeout: Some(tool_timeout), tool_filter, + server_supports_sandbox_state_capability, }; Ok(managed) diff --git a/codex-rs/exec-server/src/posix/escalate_server.rs b/codex-rs/exec-server/src/posix/escalate_server.rs index 3ad37f5ec3..b71142d5b1 100644 --- a/codex-rs/exec-server/src/posix/escalate_server.rs +++ b/codex-rs/exec-server/src/posix/escalate_server.rs @@ -8,8 +8,8 @@ use std::time::Duration; use anyhow::Context as _; use path_absolutize::Absolutize as _; +use codex_core::SandboxState; use codex_core::exec::process_exec_tool_call; -use codex_core::protocol::SandboxPolicy; use tokio::process::Command; use tokio_util::sync::CancellationToken; @@ -48,6 +48,7 @@ impl EscalateServer { &self, params: ExecParams, cancel_rx: CancellationToken, + sandbox_state: &SandboxState, ) -> anyhow::Result { let (escalate_server, escalate_client) = AsyncDatagramSocket::pair()?; let client_socket = escalate_client.into_inner(); @@ -64,12 +65,6 @@ impl EscalateServer { self.execve_wrapper.to_string_lossy().to_string(), ); - // TODO: use the sandbox policy and cwd from the calling client. - // Note that sandbox_cwd is ignored for ReadOnly, but needs to be legit - // for `SandboxPolicy::WorkspaceWrite`. - let sandbox_policy = SandboxPolicy::ReadOnly; - let sandbox_cwd = PathBuf::from("/__NONEXISTENT__"); - let ExecParams { command, workdir, @@ -94,9 +89,9 @@ impl EscalateServer { justification: None, arg0: None, }, - &sandbox_policy, - &sandbox_cwd, - &None, + &sandbox_state.sandbox_policy, + &sandbox_state.sandbox_cwd, + &sandbox_state.codex_linux_sandbox_exe, None, ) .await?; diff --git a/codex-rs/exec-server/src/posix/mcp.rs b/codex-rs/exec-server/src/posix/mcp.rs index e4e7b25d92..134fdc01c0 100644 --- a/codex-rs/exec-server/src/posix/mcp.rs +++ b/codex-rs/exec-server/src/posix/mcp.rs @@ -1,8 +1,13 @@ use std::path::PathBuf; +use std::sync::Arc; use std::time::Duration; use anyhow::Context as _; use anyhow::Result; +use codex_core::MCP_SANDBOX_STATE_CAPABILITY; +use codex_core::MCP_SANDBOX_STATE_NOTIFICATION; +use codex_core::SandboxState; +use codex_core::protocol::SandboxPolicy; use rmcp::ErrorData as McpError; use rmcp::RoleServer; use rmcp::ServerHandler; @@ -17,6 +22,8 @@ use rmcp::tool; use rmcp::tool_handler; use rmcp::tool_router; use rmcp::transport::stdio; +use tokio::sync::RwLock; +use tracing::debug; use crate::posix::escalate_server::EscalateServer; use crate::posix::escalate_server::{self}; @@ -27,6 +34,8 @@ use crate::posix::stopwatch::Stopwatch; /// Path to our patched bash. const CODEX_BASH_PATH_ENV_VAR: &str = "CODEX_BASH_PATH"; +const SANDBOX_STATE_CAPABILITY_VERSION: &str = "1.0.0"; + pub(crate) fn get_bash_path() -> Result { std::env::var(CODEX_BASH_PATH_ENV_VAR) .map(PathBuf::from) @@ -70,6 +79,7 @@ pub struct ExecTool { bash_path: PathBuf, execve_wrapper: PathBuf, policy: ExecPolicy, + sandbox_state: Arc>>, } #[tool_router] @@ -80,6 +90,7 @@ impl ExecTool { bash_path, execve_wrapper, policy, + sandbox_state: Arc::new(RwLock::new(None)), } } @@ -97,13 +108,24 @@ impl ExecTool { ); let stopwatch = Stopwatch::new(effective_timeout); let cancel_token = stopwatch.cancellation_token(); + let sandbox_state = + self.sandbox_state + .read() + .await + .clone() + .unwrap_or_else(|| SandboxState { + sandbox_policy: SandboxPolicy::ReadOnly, + codex_linux_sandbox_exe: None, + sandbox_cwd: PathBuf::from(¶ms.workdir), + }); let escalate_server = EscalateServer::new( self.bash_path.clone(), self.execve_wrapper.clone(), McpEscalationPolicy::new(self.policy, context, stopwatch.clone()), ); + let result = escalate_server - .exec(params, cancel_token) + .exec(params, cancel_token, &sandbox_state) .await .map_err(|e| McpError::internal_error(e.to_string(), None))?; Ok(CallToolResult::success(vec![Content::json( @@ -115,9 +137,22 @@ impl ExecTool { #[tool_handler] impl ServerHandler for ExecTool { fn get_info(&self) -> ServerInfo { + let mut experimental_capabilities = ExperimentalCapabilities::new(); + let mut sandbox_state_capability = JsonObject::new(); + sandbox_state_capability.insert( + "version".to_string(), + serde_json::Value::String(SANDBOX_STATE_CAPABILITY_VERSION.to_string()), + ); + experimental_capabilities.insert( + MCP_SANDBOX_STATE_CAPABILITY.to_string(), + sandbox_state_capability, + ); ServerInfo { protocol_version: ProtocolVersion::V_2025_06_18, - capabilities: ServerCapabilities::builder().enable_tools().build(), + capabilities: ServerCapabilities::builder() + .enable_tools() + .enable_experimental_with(experimental_capabilities) + .build(), server_info: Implementation::from_build_env(), instructions: Some( "This server provides a tool to execute shell commands and return their output." @@ -133,6 +168,31 @@ impl ServerHandler for ExecTool { ) -> Result { Ok(self.get_info()) } + + async fn on_custom_notification( + &self, + notification: rmcp::model::CustomClientNotification, + _context: rmcp::service::NotificationContext, + ) { + let rmcp::model::CustomClientNotification { method, params, .. } = notification; + if method == MCP_SANDBOX_STATE_NOTIFICATION + && let Some(params) = params + { + match serde_json::from_value::(params) { + Ok(sandbox_state) => { + debug!( + ?sandbox_state.sandbox_policy, + "received sandbox state notification" + ); + let mut state = self.sandbox_state.write().await; + *state = Some(sandbox_state); + } + Err(err) => { + tracing::warn!(?err, "failed to deserialize sandbox state notification"); + } + } + } + } } pub(crate) async fn serve( diff --git a/codex-rs/rmcp-client/src/rmcp_client.rs b/codex-rs/rmcp-client/src/rmcp_client.rs index fe9f48d04e..bcf7b49e93 100644 --- a/codex-rs/rmcp-client/src/rmcp_client.rs +++ b/codex-rs/rmcp-client/src/rmcp_client.rs @@ -25,8 +25,11 @@ use mcp_types::ReadResourceResult; use mcp_types::RequestId; use reqwest::header::HeaderMap; use rmcp::model::CallToolRequestParam; +use rmcp::model::ClientNotification; use rmcp::model::CreateElicitationRequestParam; use rmcp::model::CreateElicitationResult; +use rmcp::model::CustomClientNotification; +use rmcp::model::Extensions; use rmcp::model::InitializeRequestParam; use rmcp::model::PaginatedRequestParam; use rmcp::model::ReadResourceRequestParam; @@ -361,6 +364,25 @@ impl RmcpClient { Ok(converted) } + pub async fn send_custom_notification( + &self, + method: &str, + params: Option, + ) -> Result<()> { + let service: Arc> = self.service().await?; + service.service(); + service + .send_notification(ClientNotification::CustomClientNotification( + CustomClientNotification { + method: method.to_string(), + params, + extensions: Extensions::new(), + }, + )) + .await?; + Ok(()) + } + async fn service(&self) -> Result>> { let guard = self.state.lock().await; match &*guard {