Skip to content

Commit b31f335

Browse files
claudebokelley
authored andcommitted
feat(server): pre-validation request hook for spec-default injection
Adds `pre_validation_hooks` to `serve()`, `create_mcp_server()`, `create_a2a_server()`, and `create_tool_caller()`. The per-tool hook runs on the raw wire dict before `validate_request` (schema) and `model_validate` (Pydantic), so it can supply spec-mandated defaults that absent pre-v3 buyers omit. Hook exceptions surface as INVALID_REQUEST. Includes 7 new pytest tests. Closes #614 https://claude.ai/code/session_015hHQqXRbn2jX9WTu564gjZ
1 parent 14d294c commit b31f335

5 files changed

Lines changed: 297 additions & 5 deletions

File tree

src/adcp/decisioning/serve.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,7 @@ def serve(
373373
advertise_all: bool = False,
374374
mock_ad_server: Any | None = None,
375375
enable_debug_endpoints: bool = False,
376+
pre_validation_hooks: dict[str, Any] | None = None,
376377
**serve_kwargs: Any,
377378
) -> None:
378379
"""One-call wrapper — build the handler and serve over MCP.
@@ -435,6 +436,24 @@ def serve(
435436
responses="strict")`` to enable schema-driven request/response
436437
validation against the bundled AdCP JSON schemas — sellers who
437438
want their server to enforce wire conformance turn it on here.
439+
:param pre_validation_hooks: Optional dict mapping AdCP tool name to
440+
a ``(tool_name, raw_args) -> raw_args`` callable. The hook runs
441+
on the raw wire dict **before** schema + Pydantic validation —
442+
use it to apply spec-mandated defaults for pre-v3 buyers that
443+
omit required fields. Example::
444+
445+
serve(
446+
router,
447+
pre_validation_hooks={
448+
"get_products": lambda n, a: {
449+
**a, "buying_mode": a.get("buying_mode", "brief")
450+
},
451+
},
452+
)
453+
454+
Hook exceptions surface as ``INVALID_REQUEST`` on the wire.
455+
The hook must return a new dict; mutating the input in-place is
456+
a bug — the original is captured separately for context echo.
438457
"""
439458
# Local import to avoid a circular at module-load time. Adopter
440459
# serves never run during foundation imports anyway.
@@ -504,6 +523,8 @@ def serve(
504523

505524
server_name = name or type(platform).__name__
506525
debug_traffic_source = mock_ad_server.get_traffic if mock_ad_server is not None else None
526+
if pre_validation_hooks is not None:
527+
serve_kwargs["pre_validation_hooks"] = pre_validation_hooks
507528
_adcp_serve(
508529
handler,
509530
name=server_name,

src/adcp/server/a2a_server.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ def __init__(
143143
message_parser: MessageParser | None = None,
144144
advertise_all: bool = False,
145145
validation: ValidationHookConfig | None = SERVER_DEFAULT_VALIDATION,
146+
pre_validation_hooks: dict[str, Any] | None = None,
146147
test_controller_account_resolver: Any | None = None,
147148
) -> None:
148149
self._handler = handler
@@ -169,7 +170,10 @@ def __init__(
169170
name = tool_def["name"]
170171
if name == "comply_test_controller" and test_controller is None:
171172
continue
172-
self._tool_callers[name] = create_tool_caller(handler, name, validation=validation)
173+
hook = (pre_validation_hooks or {}).get(name)
174+
self._tool_callers[name] = create_tool_caller(
175+
handler, name, validation=validation, pre_validation_hook=hook
176+
)
173177

174178
if test_controller is not None:
175179
self._register_test_controller(test_controller)
@@ -758,6 +762,7 @@ def create_a2a_server(
758762
message_parser: MessageParser | None = None,
759763
advertise_all: bool = False,
760764
validation: ValidationHookConfig | None = SERVER_DEFAULT_VALIDATION,
765+
pre_validation_hooks: dict[str, Any] | None = None,
761766
context_builder: Any | None = None,
762767
auth: BearerTokenAuth | None = None,
763768
public_url: str | None = None,
@@ -884,6 +889,7 @@ def create_a2a_server(
884889
message_parser=message_parser,
885890
advertise_all=advertise_all,
886891
validation=validation,
892+
pre_validation_hooks=pre_validation_hooks,
887893
test_controller_account_resolver=test_controller_account_resolver,
888894
)
889895

src/adcp/server/mcp_tools.py

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1870,6 +1870,7 @@ def create_tool_caller(
18701870
method_name: str,
18711871
*,
18721872
validation: ValidationHookConfig | None = None,
1873+
pre_validation_hook: Callable[[str, dict[str, Any]], dict[str, Any]] | None = None,
18731874
) -> Callable[..., Any]:
18741875
"""Create a tool caller function for an ADCP handler method.
18751876
@@ -1899,12 +1900,28 @@ def create_tool_caller(
18991900
server validation is a deliberate opt-in for authors who want
19001901
dispatcher-level enforcement.
19011902
1903+
**Pre-validation hook (issue #614).** When ``pre_validation_hook`` is
1904+
supplied, it is called with ``(tool_name, raw_args_dict)`` and must
1905+
return a (possibly modified) ``dict`` that replaces the wire args
1906+
before schema validation and Pydantic ``model_validate`` run. Use
1907+
this to apply spec-mandated defaults for pre-v3 buyers that omit
1908+
required fields (e.g. ``buying_mode``, ``format_id`` shape coercion,
1909+
``asset_type`` inference). The hook runs on every call; keep it fast.
1910+
Exceptions from the hook surface as ``INVALID_REQUEST`` — do not raise
1911+
for missing-but-defaultable fields, only for structurally unusable args.
1912+
The hook must return a new dict (or the original unchanged); mutating
1913+
the input dict in-place is a bug — the original is captured separately
1914+
for the context-echo path.
1915+
19021916
Args:
19031917
handler: The ADCP handler instance
19041918
method_name: Name of the method to call
19051919
validation: Optional :class:`ValidationHookConfig` with
19061920
per-side modes (``strict`` / ``warn`` / ``off``). Omitting
19071921
it disables server-side schema validation entirely.
1922+
pre_validation_hook: Optional callable ``(tool_name, args) -> args``
1923+
invoked on the raw wire dict before schema + Pydantic validation.
1924+
See the **Pre-validation hook** section above.
19081925
19091926
Returns:
19101927
Async callable ``call_tool(params, context=None)``. The ``context``
@@ -1938,7 +1955,22 @@ def create_tool_caller(
19381955

19391956
async def call_tool(params: dict[str, Any], context: ToolContext | None = None) -> Any:
19401957
ctx = context if context is not None else ToolContext()
1941-
raw_params = params # Preserve the original dict for context echo.
1958+
1959+
if pre_validation_hook is not None:
1960+
try:
1961+
params = pre_validation_hook(method_name, params)
1962+
except Exception as exc:
1963+
raise ADCPTaskError(
1964+
operation=method_name,
1965+
errors=[
1966+
Error(
1967+
code="INVALID_REQUEST",
1968+
message=f"pre_validation_hook raised {type(exc).__name__}: {exc}",
1969+
)
1970+
],
1971+
) from exc
1972+
1973+
raw_params = params # Preserve the (possibly hook-modified) dict for context echo.
19421974

19431975
if request_mode is not None and request_mode != "off":
19441976
outcome = validate_request(method_name, params)
@@ -2069,6 +2101,7 @@ def __init__(
20692101
*,
20702102
advertise_all: bool = False,
20712103
validation: ValidationHookConfig | None = None,
2104+
pre_validation_hooks: dict[str, Callable[[str, dict[str, Any]], dict[str, Any]]] | None = None,
20722105
):
20732106
"""Create tool set from handler.
20742107
@@ -2081,6 +2114,9 @@ def __init__(
20812114
(override-filtered advertisement).
20822115
validation: Opt-in schema validation config applied to every
20832116
tool caller. See :func:`create_tool_caller`.
2117+
pre_validation_hooks: Optional dict mapping tool name to a
2118+
``(tool_name, args) -> args`` callable. Applied before
2119+
schema + Pydantic validation. See :func:`create_tool_caller`.
20842120
"""
20852121
self.handler = handler
20862122
self._filtered_definitions = get_tools_for_handler(handler, advertise_all=advertise_all)
@@ -2089,7 +2125,10 @@ def __init__(
20892125
# Create tool callers only for filtered tools
20902126
for tool_def in self._filtered_definitions:
20912127
name = tool_def["name"]
2092-
self._tools[name] = create_tool_caller(handler, name, validation=validation)
2128+
hook = (pre_validation_hooks or {}).get(name)
2129+
self._tools[name] = create_tool_caller(
2130+
handler, name, validation=validation, pre_validation_hook=hook
2131+
)
20932132

20942133
@property
20952134
def tool_definitions(self) -> list[dict[str, Any]]:
@@ -2123,6 +2162,7 @@ def create_mcp_tools(
21232162
*,
21242163
advertise_all: bool = False,
21252164
validation: ValidationHookConfig | None = None,
2165+
pre_validation_hooks: dict[str, Callable[[str, dict[str, Any]], dict[str, Any]]] | None = None,
21262166
) -> MCPToolSet:
21272167
"""Create MCP tools from an ADCP handler.
21282168
@@ -2157,8 +2197,16 @@ async def call_tool(name: str, arguments: dict):
21572197
every tool caller validates requests and responses against
21582198
the bundled AdCP JSON schemas. See
21592199
:func:`create_tool_caller` for mode semantics.
2200+
pre_validation_hooks: Optional dict mapping tool name to a
2201+
``(tool_name, args) -> args`` callable. Applied before schema
2202+
+ Pydantic validation. See :func:`create_tool_caller`.
21602203
21612204
Returns:
21622205
MCPToolSet with tool definitions and handlers.
21632206
"""
2164-
return MCPToolSet(handler, advertise_all=advertise_all, validation=validation)
2207+
return MCPToolSet(
2208+
handler,
2209+
advertise_all=advertise_all,
2210+
validation=validation,
2211+
pre_validation_hooks=pre_validation_hooks,
2212+
)

src/adcp/server/serve.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ class ServeConfig:
134134
advertise_all: bool = False
135135
max_request_size: int | None = None
136136
validation: ValidationHookConfig | None = None
137+
pre_validation_hooks: dict[str, Callable[..., Any]] | None = None
137138

138139
# --- Discovery manifest ---
139140
base_url: str | None = None
@@ -525,6 +526,7 @@ def serve(
525526
max_request_size: int | None = None,
526527
streaming_responses: bool = False,
527528
validation: ValidationHookConfig | None = DEFAULT_VALIDATION,
529+
pre_validation_hooks: dict[str, Callable[..., Any]] | None = None,
528530
enable_debug_endpoints: bool = False,
529531
debug_traffic_source: Callable[[], dict[str, int]] | None = None,
530532
base_url: str | None = None,
@@ -772,6 +774,7 @@ async def force_account_status(self, account_id, status):
772774
max_request_size = config.max_request_size
773775
streaming_responses = config.streaming_responses
774776
validation = config.validation
777+
pre_validation_hooks = config.pre_validation_hooks
775778
enable_debug_endpoints = config.enable_debug_endpoints
776779
debug_traffic_source = config.debug_traffic_source
777780
base_url = config.base_url
@@ -815,6 +818,7 @@ async def force_account_status(self, account_id, status):
815818
advertise_all=advertise_all,
816819
max_request_size=max_request_size,
817820
validation=validation,
821+
pre_validation_hooks=pre_validation_hooks,
818822
base_url=base_url,
819823
specialisms=specialisms,
820824
description=description,
@@ -838,6 +842,7 @@ async def force_account_status(self, account_id, status):
838842
max_request_size=max_request_size,
839843
streaming_responses=streaming_responses,
840844
validation=validation,
845+
pre_validation_hooks=pre_validation_hooks,
841846
base_url=base_url,
842847
specialisms=specialisms,
843848
description=description,
@@ -865,6 +870,7 @@ async def force_account_status(self, account_id, status):
865870
max_request_size=max_request_size,
866871
streaming_responses=streaming_responses,
867872
validation=validation,
873+
pre_validation_hooks=pre_validation_hooks,
868874
base_url=base_url,
869875
specialisms=specialisms,
870876
description=description,
@@ -1239,6 +1245,7 @@ def _serve_mcp(
12391245
max_request_size: int | None = None,
12401246
streaming_responses: bool = False,
12411247
validation: ValidationHookConfig | None = DEFAULT_VALIDATION,
1248+
pre_validation_hooks: dict[str, Callable[..., Any]] | None = None,
12421249
base_url: str | None = None,
12431250
specialisms: list[str] | None = None,
12441251
description: str | None = None,
@@ -1260,6 +1267,7 @@ def _serve_mcp(
12601267
advertise_all=advertise_all,
12611268
streaming_responses=streaming_responses,
12621269
validation=validation,
1270+
pre_validation_hooks=pre_validation_hooks,
12631271
allowed_hosts=allowed_hosts,
12641272
allowed_origins=allowed_origins,
12651273
enable_dns_rebinding_protection=enable_dns_rebinding_protection,
@@ -1399,6 +1407,7 @@ def _serve_a2a(
13991407
advertise_all: bool = False,
14001408
max_request_size: int | None = None,
14011409
validation: ValidationHookConfig | None = DEFAULT_VALIDATION,
1410+
pre_validation_hooks: dict[str, Callable[..., Any]] | None = None,
14021411
base_url: str | None = None,
14031412
specialisms: list[str] | None = None,
14041413
description: str | None = None,
@@ -1427,6 +1436,7 @@ def _serve_a2a(
14271436
message_parser=message_parser,
14281437
advertise_all=advertise_all,
14291438
validation=validation,
1439+
pre_validation_hooks=pre_validation_hooks,
14301440
auth=auth,
14311441
public_url=public_url,
14321442
)
@@ -1481,6 +1491,7 @@ def _build_mcp_and_a2a_app(
14811491
max_request_size: int | None = None,
14821492
streaming_responses: bool = False,
14831493
validation: ValidationHookConfig | None = DEFAULT_VALIDATION,
1494+
pre_validation_hooks: dict[str, Callable[..., Any]] | None = None,
14841495
base_url: str | None = None,
14851496
specialisms: list[str] | None = None,
14861497
description: str | None = None,
@@ -1523,6 +1534,7 @@ def _build_mcp_and_a2a_app(
15231534
advertise_all=advertise_all,
15241535
streaming_responses=streaming_responses,
15251536
validation=validation,
1537+
pre_validation_hooks=pre_validation_hooks,
15261538
allowed_hosts=allowed_hosts,
15271539
allowed_origins=allowed_origins,
15281540
enable_dns_rebinding_protection=enable_dns_rebinding_protection,
@@ -1576,6 +1588,7 @@ def _build_mcp_and_a2a_app(
15761588
message_parser=message_parser,
15771589
advertise_all=advertise_all,
15781590
validation=validation,
1591+
pre_validation_hooks=pre_validation_hooks,
15791592
auth=auth,
15801593
public_url=public_url,
15811594
)
@@ -1659,6 +1672,7 @@ def _serve_mcp_and_a2a(
16591672
max_request_size: int | None = None,
16601673
streaming_responses: bool = False,
16611674
validation: ValidationHookConfig | None = DEFAULT_VALIDATION,
1675+
pre_validation_hooks: dict[str, Callable[..., Any]] | None = None,
16621676
base_url: str | None = None,
16631677
specialisms: list[str] | None = None,
16641678
description: str | None = None,
@@ -1706,6 +1720,7 @@ def _serve_mcp_and_a2a(
17061720
max_request_size=max_request_size,
17071721
streaming_responses=streaming_responses,
17081722
validation=validation,
1723+
pre_validation_hooks=pre_validation_hooks,
17091724
base_url=base_url,
17101725
specialisms=specialisms,
17111726
description=description,
@@ -1787,6 +1802,7 @@ def create_mcp_server(
17871802
advertise_all: bool = False,
17881803
streaming_responses: bool = False,
17891804
validation: ValidationHookConfig | None = DEFAULT_VALIDATION,
1805+
pre_validation_hooks: dict[str, Callable[..., Any]] | None = None,
17901806
allowed_hosts: Sequence[str] | None = None,
17911807
allowed_origins: Sequence[str] | None = None,
17921808
enable_dns_rebinding_protection: bool | None = None,
@@ -1948,6 +1964,7 @@ def create_mcp_server(
19481964
middleware=middleware,
19491965
advertise_all=advertise_all,
19501966
validation=validation,
1967+
pre_validation_hooks=pre_validation_hooks,
19511968
)
19521969
return mcp
19531970

@@ -1961,6 +1978,7 @@ def _register_handler_tools(
19611978
middleware: Sequence[SkillMiddleware] | None = None,
19621979
advertise_all: bool = False,
19631980
validation: ValidationHookConfig | None = DEFAULT_VALIDATION,
1981+
pre_validation_hooks: dict[str, Callable[..., Any]] | None = None,
19641982
) -> None:
19651983
"""Register all ADCP tools from a handler onto a FastMCP server."""
19661984
# Freeze middleware ordering at registration time. Tuple both guards
@@ -1980,7 +1998,10 @@ def _register_handler_tools(
19801998
description = tool_def.get("description", "")
19811999
input_schema = tool_def.get("inputSchema", {"type": "object", "properties": {}})
19822000
output_schema = tool_def.get("outputSchema")
1983-
caller = create_tool_caller(handler, tool_name, validation=validation)
2001+
hook = (pre_validation_hooks or {}).get(tool_name)
2002+
caller = create_tool_caller(
2003+
handler, tool_name, validation=validation, pre_validation_hook=hook
2004+
)
19842005
_register_tool(
19852006
mcp,
19862007
tool_name,

0 commit comments

Comments
 (0)