Skip to content

Commit 05d4cd8

Browse files
bokelleyclaude
andauthored
feat(server): pre-validation request hook for spec-default injection (#614) (#629)
* feat(server): pre-validation request hook for spec-default injection Adds `pre_validation_hooks` to `serve()`, `create_mcp_server()`, `create_a2a_server()`, and `create_tool_caller()`. The per-tool hook runs on the raw wire dict before `validate_request` (schema) and `model_validate` (Pydantic), so it can supply spec-mandated defaults that absent pre-v3 buyers omit. Hook exceptions surface as INVALID_REQUEST. Includes 7 new pytest tests. Closes #614 https://claude.ai/code/session_015hHQqXRbn2jX9WTu564gjZ * fix(server): snapshot raw_params before hook to preserve context-echo contract Both code-reviewer and dx-expert expert review flagged that assigning raw_params after the hook would echo server-injected defaults back to the buyer as if they were sent on the wire, violating the AdCP context-echo contract. Move the snapshot to before the hook call (raw_params = params), so inject_context always uses the original wire dict regardless of what the hook returns or mutates. Also strengthens test_hook_does_not_pollute_context_echo: the previous test passed a copy of the outer dict so the assertion was trivially true. The new test sends a wire payload with a context field, uses a stripping hook that returns a completely new dict (no context key), and asserts the response context echo still carries the original wire context — only possible if raw_params was captured before the hook ran. https://claude.ai/code/session_015hHQqXRbn2jX9WTu564gjZ * fix(server): defensive copy before hook call eliminates in-place mutation footgun Pass dict(params) to pre_validation_hook so hooks that mutate their argument in-place still leave raw_params (the context-echo snapshot) untouched. The "must return a new dict" restriction is removed from docstrings; either mutation style is now safe. Adds test_in_place_mutation_is_safe_for_context_echo to prove the invariant. Adds a docstring cross-reference to #623 for the account-omission case per adopter feedback on #629. Co-authored-by: bokelley <bokelley@users.noreply.github.com> https://claude.ai/code/session_0115Pruuy4MbdaxhPvgTBFoo --------- Co-authored-by: Claude <noreply@anthropic.com>
1 parent ae687b6 commit 05d4cd8

5 files changed

Lines changed: 342 additions & 5 deletions

File tree

src/adcp/decisioning/serve.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,7 @@ def serve(
373373
advertise_all: bool = False,
374374
mock_ad_server: Any | None = None,
375375
enable_debug_endpoints: bool = False,
376+
pre_validation_hooks: dict[str, Any] | None = None,
376377
**serve_kwargs: Any,
377378
) -> None:
378379
"""One-call wrapper — build the handler and serve over MCP.
@@ -435,6 +436,25 @@ def serve(
435436
responses="strict")`` to enable schema-driven request/response
436437
validation against the bundled AdCP JSON schemas — sellers who
437438
want their server to enforce wire conformance turn it on here.
439+
:param pre_validation_hooks: Optional dict mapping AdCP tool name to
440+
a ``(tool_name, raw_args) -> raw_args`` callable. The hook runs
441+
on the raw wire dict **before** schema + Pydantic validation —
442+
use it to apply spec-mandated defaults for pre-v3 buyers that
443+
omit required fields. Example::
444+
445+
serve(
446+
router,
447+
pre_validation_hooks={
448+
"get_products": lambda n, a: {
449+
**a, "buying_mode": a.get("buying_mode", "brief")
450+
},
451+
},
452+
)
453+
454+
Hook exceptions surface as ``INVALID_REQUEST`` on the wire.
455+
The hook receives a shallow copy of the wire args, so it may
456+
mutate its argument freely or return a new dict — either style
457+
is safe. Context echo always reflects the original wire input.
438458
"""
439459
# Local import to avoid a circular at module-load time. Adopter
440460
# serves never run during foundation imports anyway.
@@ -504,6 +524,8 @@ def serve(
504524

505525
server_name = name or type(platform).__name__
506526
debug_traffic_source = mock_ad_server.get_traffic if mock_ad_server is not None else None
527+
if pre_validation_hooks is not None:
528+
serve_kwargs["pre_validation_hooks"] = pre_validation_hooks
507529
_adcp_serve(
508530
handler,
509531
name=server_name,

src/adcp/server/a2a_server.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ def __init__(
143143
message_parser: MessageParser | None = None,
144144
advertise_all: bool = False,
145145
validation: ValidationHookConfig | None = SERVER_DEFAULT_VALIDATION,
146+
pre_validation_hooks: dict[str, Any] | None = None,
146147
test_controller_account_resolver: Any | None = None,
147148
) -> None:
148149
self._handler = handler
@@ -169,7 +170,10 @@ def __init__(
169170
name = tool_def["name"]
170171
if name == "comply_test_controller" and test_controller is None:
171172
continue
172-
self._tool_callers[name] = create_tool_caller(handler, name, validation=validation)
173+
hook = (pre_validation_hooks or {}).get(name)
174+
self._tool_callers[name] = create_tool_caller(
175+
handler, name, validation=validation, pre_validation_hook=hook
176+
)
173177

174178
if test_controller is not None:
175179
self._register_test_controller(test_controller)
@@ -758,6 +762,7 @@ def create_a2a_server(
758762
message_parser: MessageParser | None = None,
759763
advertise_all: bool = False,
760764
validation: ValidationHookConfig | None = SERVER_DEFAULT_VALIDATION,
765+
pre_validation_hooks: dict[str, Any] | None = None,
761766
context_builder: Any | None = None,
762767
auth: BearerTokenAuth | None = None,
763768
public_url: str | None = None,
@@ -884,6 +889,7 @@ def create_a2a_server(
884889
message_parser=message_parser,
885890
advertise_all=advertise_all,
886891
validation=validation,
892+
pre_validation_hooks=pre_validation_hooks,
887893
test_controller_account_resolver=test_controller_account_resolver,
888894
)
889895

src/adcp/server/mcp_tools.py

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1870,6 +1870,7 @@ def create_tool_caller(
18701870
method_name: str,
18711871
*,
18721872
validation: ValidationHookConfig | None = None,
1873+
pre_validation_hook: Callable[[str, dict[str, Any]], dict[str, Any]] | None = None,
18731874
) -> Callable[..., Any]:
18741875
"""Create a tool caller function for an ADCP handler method.
18751876
@@ -1899,12 +1900,36 @@ def create_tool_caller(
18991900
server validation is a deliberate opt-in for authors who want
19001901
dispatcher-level enforcement.
19011902
1903+
**Pre-validation hook (issue #614).** When ``pre_validation_hook`` is
1904+
supplied, it is called with ``(tool_name, shallow_copy_of_args)`` and
1905+
must return a ``dict`` that replaces the wire args before schema
1906+
validation and Pydantic ``model_validate`` run. The framework passes
1907+
a shallow copy of the incoming params dict, so the hook may mutate
1908+
its argument freely or return a brand-new dict — either style is safe.
1909+
The original wire params are captured before the copy is made, so
1910+
context echo always reflects what the buyer sent. Use this to apply
1911+
spec-mandated defaults for pre-v3 buyers that omit required fields
1912+
(e.g. ``buying_mode``, ``format_id`` shape coercion, ``asset_type``
1913+
inference). The hook runs on every call; keep it fast.
1914+
Exceptions from the hook surface as ``INVALID_REQUEST`` — do not raise
1915+
for missing-but-defaultable fields, only for structurally unusable args.
1916+
1917+
.. note::
1918+
For the specific case of buyers omitting ``account``, see issue
1919+
#623 ("Typed dispatcher rejects valid request when ``account`` is
1920+
omitted") — that will be the canonical spec-level fix for that
1921+
field. Once #623 lands you can drop any ``account`` placeholder
1922+
hook entry.
1923+
19021924
Args:
19031925
handler: The ADCP handler instance
19041926
method_name: Name of the method to call
19051927
validation: Optional :class:`ValidationHookConfig` with
19061928
per-side modes (``strict`` / ``warn`` / ``off``). Omitting
19071929
it disables server-side schema validation entirely.
1930+
pre_validation_hook: Optional callable ``(tool_name, args) -> args``
1931+
invoked on the raw wire dict before schema + Pydantic validation.
1932+
See the **Pre-validation hook** section above.
19081933
19091934
Returns:
19101935
Async callable ``call_tool(params, context=None)``. The ``context``
@@ -1938,7 +1963,22 @@ def create_tool_caller(
19381963

19391964
async def call_tool(params: dict[str, Any], context: ToolContext | None = None) -> Any:
19401965
ctx = context if context is not None else ToolContext()
1941-
raw_params = params # Preserve the original dict for context echo.
1966+
1967+
raw_params = params # Preserve original wire params for context echo.
1968+
1969+
if pre_validation_hook is not None:
1970+
try:
1971+
params = pre_validation_hook(method_name, dict(params))
1972+
except Exception as exc:
1973+
raise ADCPTaskError(
1974+
operation=method_name,
1975+
errors=[
1976+
Error(
1977+
code="INVALID_REQUEST",
1978+
message=f"pre_validation_hook raised {type(exc).__name__}: {exc}",
1979+
)
1980+
],
1981+
) from exc
19421982

19431983
if request_mode is not None and request_mode != "off":
19441984
outcome = validate_request(method_name, params)
@@ -2069,6 +2109,7 @@ def __init__(
20692109
*,
20702110
advertise_all: bool = False,
20712111
validation: ValidationHookConfig | None = None,
2112+
pre_validation_hooks: dict[str, Callable[[str, dict[str, Any]], dict[str, Any]]] | None = None,
20722113
):
20732114
"""Create tool set from handler.
20742115
@@ -2081,6 +2122,9 @@ def __init__(
20812122
(override-filtered advertisement).
20822123
validation: Opt-in schema validation config applied to every
20832124
tool caller. See :func:`create_tool_caller`.
2125+
pre_validation_hooks: Optional dict mapping tool name to a
2126+
``(tool_name, args) -> args`` callable. Applied before
2127+
schema + Pydantic validation. See :func:`create_tool_caller`.
20842128
"""
20852129
self.handler = handler
20862130
self._filtered_definitions = get_tools_for_handler(handler, advertise_all=advertise_all)
@@ -2089,7 +2133,10 @@ def __init__(
20892133
# Create tool callers only for filtered tools
20902134
for tool_def in self._filtered_definitions:
20912135
name = tool_def["name"]
2092-
self._tools[name] = create_tool_caller(handler, name, validation=validation)
2136+
hook = (pre_validation_hooks or {}).get(name)
2137+
self._tools[name] = create_tool_caller(
2138+
handler, name, validation=validation, pre_validation_hook=hook
2139+
)
20932140

20942141
@property
20952142
def tool_definitions(self) -> list[dict[str, Any]]:
@@ -2123,6 +2170,7 @@ def create_mcp_tools(
21232170
*,
21242171
advertise_all: bool = False,
21252172
validation: ValidationHookConfig | None = None,
2173+
pre_validation_hooks: dict[str, Callable[[str, dict[str, Any]], dict[str, Any]]] | None = None,
21262174
) -> MCPToolSet:
21272175
"""Create MCP tools from an ADCP handler.
21282176
@@ -2157,8 +2205,16 @@ async def call_tool(name: str, arguments: dict):
21572205
every tool caller validates requests and responses against
21582206
the bundled AdCP JSON schemas. See
21592207
:func:`create_tool_caller` for mode semantics.
2208+
pre_validation_hooks: Optional dict mapping tool name to a
2209+
``(tool_name, args) -> args`` callable. Applied before schema
2210+
+ Pydantic validation. See :func:`create_tool_caller`.
21602211
21612212
Returns:
21622213
MCPToolSet with tool definitions and handlers.
21632214
"""
2164-
return MCPToolSet(handler, advertise_all=advertise_all, validation=validation)
2215+
return MCPToolSet(
2216+
handler,
2217+
advertise_all=advertise_all,
2218+
validation=validation,
2219+
pre_validation_hooks=pre_validation_hooks,
2220+
)

src/adcp/server/serve.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ class ServeConfig:
134134
advertise_all: bool = False
135135
max_request_size: int | None = None
136136
validation: ValidationHookConfig | None = None
137+
pre_validation_hooks: dict[str, Callable[..., Any]] | None = None
137138

138139
# --- Discovery manifest ---
139140
base_url: str | None = None
@@ -525,6 +526,7 @@ def serve(
525526
max_request_size: int | None = None,
526527
streaming_responses: bool = False,
527528
validation: ValidationHookConfig | None = DEFAULT_VALIDATION,
529+
pre_validation_hooks: dict[str, Callable[..., Any]] | None = None,
528530
enable_debug_endpoints: bool = False,
529531
debug_traffic_source: Callable[[], dict[str, int]] | None = None,
530532
base_url: str | None = None,
@@ -772,6 +774,7 @@ async def force_account_status(self, account_id, status):
772774
max_request_size = config.max_request_size
773775
streaming_responses = config.streaming_responses
774776
validation = config.validation
777+
pre_validation_hooks = config.pre_validation_hooks
775778
enable_debug_endpoints = config.enable_debug_endpoints
776779
debug_traffic_source = config.debug_traffic_source
777780
base_url = config.base_url
@@ -815,6 +818,7 @@ async def force_account_status(self, account_id, status):
815818
advertise_all=advertise_all,
816819
max_request_size=max_request_size,
817820
validation=validation,
821+
pre_validation_hooks=pre_validation_hooks,
818822
base_url=base_url,
819823
specialisms=specialisms,
820824
description=description,
@@ -838,6 +842,7 @@ async def force_account_status(self, account_id, status):
838842
max_request_size=max_request_size,
839843
streaming_responses=streaming_responses,
840844
validation=validation,
845+
pre_validation_hooks=pre_validation_hooks,
841846
base_url=base_url,
842847
specialisms=specialisms,
843848
description=description,
@@ -865,6 +870,7 @@ async def force_account_status(self, account_id, status):
865870
max_request_size=max_request_size,
866871
streaming_responses=streaming_responses,
867872
validation=validation,
873+
pre_validation_hooks=pre_validation_hooks,
868874
base_url=base_url,
869875
specialisms=specialisms,
870876
description=description,
@@ -1239,6 +1245,7 @@ def _serve_mcp(
12391245
max_request_size: int | None = None,
12401246
streaming_responses: bool = False,
12411247
validation: ValidationHookConfig | None = DEFAULT_VALIDATION,
1248+
pre_validation_hooks: dict[str, Callable[..., Any]] | None = None,
12421249
base_url: str | None = None,
12431250
specialisms: list[str] | None = None,
12441251
description: str | None = None,
@@ -1260,6 +1267,7 @@ def _serve_mcp(
12601267
advertise_all=advertise_all,
12611268
streaming_responses=streaming_responses,
12621269
validation=validation,
1270+
pre_validation_hooks=pre_validation_hooks,
12631271
allowed_hosts=allowed_hosts,
12641272
allowed_origins=allowed_origins,
12651273
enable_dns_rebinding_protection=enable_dns_rebinding_protection,
@@ -1399,6 +1407,7 @@ def _serve_a2a(
13991407
advertise_all: bool = False,
14001408
max_request_size: int | None = None,
14011409
validation: ValidationHookConfig | None = DEFAULT_VALIDATION,
1410+
pre_validation_hooks: dict[str, Callable[..., Any]] | None = None,
14021411
base_url: str | None = None,
14031412
specialisms: list[str] | None = None,
14041413
description: str | None = None,
@@ -1427,6 +1436,7 @@ def _serve_a2a(
14271436
message_parser=message_parser,
14281437
advertise_all=advertise_all,
14291438
validation=validation,
1439+
pre_validation_hooks=pre_validation_hooks,
14301440
auth=auth,
14311441
public_url=public_url,
14321442
)
@@ -1481,6 +1491,7 @@ def _build_mcp_and_a2a_app(
14811491
max_request_size: int | None = None,
14821492
streaming_responses: bool = False,
14831493
validation: ValidationHookConfig | None = DEFAULT_VALIDATION,
1494+
pre_validation_hooks: dict[str, Callable[..., Any]] | None = None,
14841495
base_url: str | None = None,
14851496
specialisms: list[str] | None = None,
14861497
description: str | None = None,
@@ -1523,6 +1534,7 @@ def _build_mcp_and_a2a_app(
15231534
advertise_all=advertise_all,
15241535
streaming_responses=streaming_responses,
15251536
validation=validation,
1537+
pre_validation_hooks=pre_validation_hooks,
15261538
allowed_hosts=allowed_hosts,
15271539
allowed_origins=allowed_origins,
15281540
enable_dns_rebinding_protection=enable_dns_rebinding_protection,
@@ -1576,6 +1588,7 @@ def _build_mcp_and_a2a_app(
15761588
message_parser=message_parser,
15771589
advertise_all=advertise_all,
15781590
validation=validation,
1591+
pre_validation_hooks=pre_validation_hooks,
15791592
auth=auth,
15801593
public_url=public_url,
15811594
)
@@ -1659,6 +1672,7 @@ def _serve_mcp_and_a2a(
16591672
max_request_size: int | None = None,
16601673
streaming_responses: bool = False,
16611674
validation: ValidationHookConfig | None = DEFAULT_VALIDATION,
1675+
pre_validation_hooks: dict[str, Callable[..., Any]] | None = None,
16621676
base_url: str | None = None,
16631677
specialisms: list[str] | None = None,
16641678
description: str | None = None,
@@ -1706,6 +1720,7 @@ def _serve_mcp_and_a2a(
17061720
max_request_size=max_request_size,
17071721
streaming_responses=streaming_responses,
17081722
validation=validation,
1723+
pre_validation_hooks=pre_validation_hooks,
17091724
base_url=base_url,
17101725
specialisms=specialisms,
17111726
description=description,
@@ -1787,6 +1802,7 @@ def create_mcp_server(
17871802
advertise_all: bool = False,
17881803
streaming_responses: bool = False,
17891804
validation: ValidationHookConfig | None = DEFAULT_VALIDATION,
1805+
pre_validation_hooks: dict[str, Callable[..., Any]] | None = None,
17901806
allowed_hosts: Sequence[str] | None = None,
17911807
allowed_origins: Sequence[str] | None = None,
17921808
enable_dns_rebinding_protection: bool | None = None,
@@ -1948,6 +1964,7 @@ def create_mcp_server(
19481964
middleware=middleware,
19491965
advertise_all=advertise_all,
19501966
validation=validation,
1967+
pre_validation_hooks=pre_validation_hooks,
19511968
)
19521969
return mcp
19531970

@@ -1961,6 +1978,7 @@ def _register_handler_tools(
19611978
middleware: Sequence[SkillMiddleware] | None = None,
19621979
advertise_all: bool = False,
19631980
validation: ValidationHookConfig | None = DEFAULT_VALIDATION,
1981+
pre_validation_hooks: dict[str, Callable[..., Any]] | None = None,
19641982
) -> None:
19651983
"""Register all ADCP tools from a handler onto a FastMCP server."""
19661984
# Freeze middleware ordering at registration time. Tuple both guards
@@ -1980,7 +1998,10 @@ def _register_handler_tools(
19801998
description = tool_def.get("description", "")
19811999
input_schema = tool_def.get("inputSchema", {"type": "object", "properties": {}})
19822000
output_schema = tool_def.get("outputSchema")
1983-
caller = create_tool_caller(handler, tool_name, validation=validation)
2001+
hook = (pre_validation_hooks or {}).get(tool_name)
2002+
caller = create_tool_caller(
2003+
handler, tool_name, validation=validation, pre_validation_hook=hook
2004+
)
19842005
_register_tool(
19852006
mcp,
19862007
tool_name,

0 commit comments

Comments
 (0)