Skip to content

Commit df84555

Browse files
authored
Implements outputSchema validation (#566)
* feat: implement output schema validation * fix: calculator example comply MCP spec * refactor: merge cached_schema_for_output into schema_for_output
1 parent 3c62ee8 commit df84555

File tree

7 files changed

+122
-10
lines changed

7 files changed

+122
-10
lines changed

crates/rmcp-macros/src/tool.rs

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,14 @@ fn extract_schema_from_return_type(ret_type: &syn::Type) -> Option<Expr> {
2727
// First, try direct Json<T>
2828
if let Some(inner_type) = extract_json_inner_type(ret_type) {
2929
return syn::parse2::<Expr>(quote! {
30-
rmcp::handler::server::tool::cached_schema_for_type::<#inner_type>()
30+
rmcp::handler::server::tool::schema_for_output::<#inner_type>()
31+
.unwrap_or_else(|e| {
32+
panic!(
33+
"Invalid output schema for Json<{}>: {}",
34+
std::any::type_name::<#inner_type>(),
35+
e
36+
)
37+
})
3138
})
3239
.ok();
3340
}
@@ -57,7 +64,14 @@ fn extract_schema_from_return_type(ret_type: &syn::Type) -> Option<Expr> {
5764
let inner_type = extract_json_inner_type(ok_type)?;
5865

5966
syn::parse2::<Expr>(quote! {
60-
rmcp::handler::server::tool::cached_schema_for_type::<#inner_type>()
67+
rmcp::handler::server::tool::schema_for_output::<#inner_type>()
68+
.unwrap_or_else(|e| {
69+
panic!(
70+
"Invalid output schema for Result<Json<{}>, E>: {}",
71+
std::any::type_name::<#inner_type>(),
72+
e
73+
)
74+
})
6175
})
6276
.ok()
6377
}

crates/rmcp/src/handler/server/common.rs

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,49 @@ pub fn cached_schema_for_type<T: JsonSchema + std::any::Any>() -> Arc<JsonObject
5050
})
5151
}
5252

53+
/// Generate and validate a JSON schema for outputSchema (must have root type "object").
54+
pub fn schema_for_output<T: JsonSchema + std::any::Any>() -> Result<Arc<JsonObject>, String> {
55+
thread_local! {
56+
static CACHE_FOR_OUTPUT: std::sync::RwLock<HashMap<TypeId, Result<Arc<JsonObject>, String>>> = Default::default();
57+
};
58+
59+
CACHE_FOR_OUTPUT.with(|cache| {
60+
// Try to get from cache first
61+
if let Some(result) = cache
62+
.read()
63+
.expect("output schema cache lock poisoned")
64+
.get(&TypeId::of::<T>())
65+
{
66+
return result.clone();
67+
}
68+
69+
// Generate and validate schema
70+
let schema = schema_for_type::<T>();
71+
let result = match schema.get("type") {
72+
Some(serde_json::Value::String(t)) if t == "object" => Ok(Arc::new(schema)),
73+
Some(serde_json::Value::String(t)) => Err(format!(
74+
"MCP specification requires tool outputSchema to have root type 'object', but found '{}'.",
75+
t
76+
)),
77+
None => Err(
78+
"Schema is missing 'type' field. MCP specification requires outputSchema to have root type 'object'.".to_string()
79+
),
80+
Some(other) => Err(format!(
81+
"Schema 'type' field has unexpected format: {:?}. Expected \"object\".",
82+
other
83+
)),
84+
};
85+
86+
// Cache the result (both success and error cases)
87+
cache
88+
.write()
89+
.expect("output schema cache lock poisoned")
90+
.insert(TypeId::of::<T>(), result.clone());
91+
92+
result
93+
})
94+
}
95+
5396
/// Trait for extracting parts from a context, unifying tool and prompt extraction
5497
pub trait FromContextPart<C>: Sized {
5598
fn from_context_part(context: &mut C) -> Result<Self, crate::ErrorData>;
@@ -143,3 +186,25 @@ pub trait AsRequestContext {
143186
fn as_request_context(&self) -> &RequestContext<RoleServer>;
144187
fn as_request_context_mut(&mut self) -> &mut RequestContext<RoleServer>;
145188
}
189+
190+
#[cfg(test)]
191+
mod tests {
192+
use super::*;
193+
194+
#[derive(serde::Serialize, serde::Deserialize, JsonSchema)]
195+
struct TestObject {
196+
value: i32,
197+
}
198+
199+
#[test]
200+
fn test_schema_for_output_rejects_primitive() {
201+
let result = schema_for_output::<i32>();
202+
assert!(result.is_err(),);
203+
}
204+
205+
#[test]
206+
fn test_schema_for_output_accepts_object() {
207+
let result = schema_for_output::<TestObject>();
208+
assert!(result.is_ok(),);
209+
}
210+
}

crates/rmcp/src/handler/server/tool.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use serde::de::DeserializeOwned;
99

1010
use super::common::{AsRequestContext, FromContextPart};
1111
pub use super::{
12-
common::{Extension, RequestId, cached_schema_for_type, schema_for_type},
12+
common::{Extension, RequestId, cached_schema_for_type, schema_for_output, schema_for_type},
1313
router::tool::{ToolRoute, ToolRouter},
1414
};
1515
use crate::{

crates/rmcp/src/model/tool.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,8 +165,14 @@ impl Tool {
165165
}
166166

167167
/// Set the output schema using a type that implements JsonSchema
168+
///
169+
/// # Panics
170+
///
171+
/// Panics if the generated schema does not have root type "object" as required by MCP specification.
168172
pub fn with_output_schema<T: JsonSchema + 'static>(mut self) -> Self {
169-
self.output_schema = Some(crate::handler::server::tool::cached_schema_for_type::<T>());
173+
let schema = crate::handler::server::tool::schema_for_output::<T>()
174+
.unwrap_or_else(|e| panic!("Invalid output schema for tool '{}': {}", self.name, e));
175+
self.output_schema = Some(schema);
170176
self
171177
}
172178

examples/servers/Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,3 +96,7 @@ path = "src/simple_auth_streamhttp.rs"
9696
[[example]]
9797
name = "servers_complex_auth_streamhttp"
9898
path = "src/complex_auth_streamhttp.rs"
99+
100+
[[example]]
101+
name = "servers_calculator_stdio"
102+
path = "src/calculator_stdio.rs"
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
use anyhow::Result;
2+
use common::calculator::Calculator;
3+
use rmcp::{ServiceExt, transport::stdio};
4+
use tracing_subscriber::{self, EnvFilter};
5+
mod common;
6+
7+
/// npx @modelcontextprotocol/inspector cargo run -p mcp-server-examples --example servers_calculator_stdio
8+
#[tokio::main]
9+
async fn main() -> Result<()> {
10+
// Initialize the tracing subscriber with file and stdout logging
11+
tracing_subscriber::fmt()
12+
.with_env_filter(EnvFilter::from_default_env().add_directive(tracing::Level::DEBUG.into()))
13+
.with_writer(std::io::stderr)
14+
.with_ansi(false)
15+
.init();
16+
17+
tracing::info!("Starting Calculator MCP server");
18+
19+
// Create an instance of our calculator router
20+
let service = Calculator::new().serve(stdio()).await.inspect_err(|e| {
21+
tracing::error!("serving error: {:?}", e);
22+
})?;
23+
24+
service.waiting().await?;
25+
Ok(())
26+
}

examples/servers/src/common/calculator.rs

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,7 @@
22

33
use rmcp::{
44
ServerHandler,
5-
handler::server::{
6-
router::tool::ToolRouter,
7-
wrapper::{Json, Parameters},
8-
},
5+
handler::server::{router::tool::ToolRouter, wrapper::Parameters},
96
model::{ServerCapabilities, ServerInfo},
107
schemars, tool, tool_handler, tool_router,
118
};
@@ -44,8 +41,8 @@ impl Calculator {
4441
}
4542

4643
#[tool(description = "Calculate the difference of two numbers")]
47-
fn sub(&self, Parameters(SubRequest { a, b }): Parameters<SubRequest>) -> Json<i32> {
48-
Json(a - b)
44+
fn sub(&self, Parameters(SubRequest { a, b }): Parameters<SubRequest>) -> String {
45+
(a - b).to_string()
4946
}
5047
}
5148

0 commit comments

Comments
 (0)