Skip to content

Commit dbbc390

Browse files
authored
Merge pull request #248 from adcontextprotocol/bokelley/round-1-followups
feat(server+migrate): round-1 feedback followups
2 parents 0f50d39 + 0fe9cb5 commit dbbc390

6 files changed

Lines changed: 344 additions & 36 deletions

File tree

examples/mcp_with_auth_middleware.py

Lines changed: 16 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
:func:`~adcp.server.auth_context_factory` onto a multi-tenant sales
55
agent. The SDK owns the security-critical plumbing (constant-time
66
token compare, discovery bypass, ``ContextVar`` reset-in-finally);
7-
the seller supplies only ``validate_token`` and the handler logic.
7+
the seller supplies only the token → principal map and the handler
8+
logic.
89
910
Run::
1011
@@ -14,12 +15,14 @@
1415
1516
Production note: ``mcp.run()`` is used here for brevity. Real
1617
deployments should mount the Starlette app behind uvicorn + a reverse
17-
proxy that terminates TLS and handles rate limiting.
18+
proxy that terminates TLS and handles rate limiting. Production
19+
agents also typically load tokens from a database — swap
20+
``validator_from_token_map`` for an ``async def validate_token`` that
21+
hits your token store.
1822
"""
1923

2024
from __future__ import annotations
2125

22-
import hashlib
2326
from typing import Any
2427

2528
from adcp.server import (
@@ -28,37 +31,23 @@
2831
Principal,
2932
ToolContext,
3033
auth_context_factory,
31-
constant_time_token_match,
3234
create_mcp_server,
35+
validator_from_token_map,
3336
)
3437
from adcp.server.responses import capabilities_response, products_response
3538

3639
# Real agents look tokens up in Postgres / Vault / an identity provider.
37-
# Keyed by SHA-256 so the comparison uses ``hmac.compare_digest`` rather
38-
# than raw string equality — never compare raw bearer tokens with ``==``.
39-
_TOKEN_HASHES: dict[str, Principal] = {
40-
hashlib.sha256(raw.encode()).hexdigest(): principal
41-
for raw, principal in {
42-
"token-acme": Principal(
43-
caller_identity="principal-acme-ops",
44-
tenant_id="tenant-acme",
45-
),
40+
# ``validator_from_token_map`` hashes the raw tokens at construction and
41+
# does ``hmac.compare_digest`` lookups — same security properties as a
42+
# hand-rolled validator, one line instead of a dozen.
43+
validate_token = validator_from_token_map(
44+
{
45+
"token-acme": Principal(caller_identity="principal-acme-ops", tenant_id="tenant-acme"),
4646
"token-globex": Principal(
47-
caller_identity="principal-globex-ops",
48-
tenant_id="tenant-globex",
47+
caller_identity="principal-globex-ops", tenant_id="tenant-globex"
4948
),
50-
}.items()
51-
}
52-
53-
54-
def validate_token(token: str) -> Principal | None:
55-
"""Seller-supplied token validator.
56-
57-
``constant_time_token_match`` iterates every stored hash with
58-
:func:`hmac.compare_digest`, avoiding the prefix-match timing leak
59-
that a plain ``dict`` lookup would have.
60-
"""
61-
return constant_time_token_match(token, _TOKEN_HASHES)
49+
}
50+
)
6251

6352

6453
class MultiTenantSalesAgent(ADCPHandler):

src/adcp/migrate/v3_to_v4.py

Lines changed: 99 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -396,9 +396,44 @@ def _format_text_report(report: Report, *, apply_changes: bool) -> str:
396396
return "\n".join(lines)
397397

398398

399+
REPORT_SCHEMA_VERSION = 1
400+
"""Version of the JSON report shape. CI scripts / editors parsing the
401+
migrate output key on this so a future shape change (adding a summary
402+
block, renaming fields) doesn't silently break them.
403+
404+
Bump the minor SDK version AND this constant when changing the JSON
405+
shape in a non-additive way. Additive changes (new optional keys)
406+
stay at the same version.
407+
408+
**v1 shape:**
409+
410+
.. code-block:: json
411+
412+
{
413+
"schema_version": 1,
414+
"scanned_files": int,
415+
"rewritten_files": int,
416+
"applied": [
417+
{"kind": "rename", "path": str, "line": int, "column": int,
418+
"before": str, "after": str, "hint": null, "migration_anchor": null}
419+
],
420+
"flagged": [
421+
{"kind": "flag_removed" | "flag_numbered" | "flag_private" | "flag_attribute",
422+
"path": str, "line": int, "column": int, "before": str,
423+
"after": null, "hint": str | null, "migration_anchor": str | null}
424+
]
425+
}
426+
"""
427+
428+
399429
def _format_json_report(report: Report) -> str:
400-
"""JSON report for programmatic consumption (CI, editors)."""
430+
"""JSON report for programmatic consumption (CI, editors).
431+
432+
Versioned via :data:`REPORT_SCHEMA_VERSION` — parsers should check
433+
the top-level ``schema_version`` key before reading the rest.
434+
"""
401435
payload = {
436+
"schema_version": REPORT_SCHEMA_VERSION,
402437
"scanned_files": report.scanned_files,
403438
"rewritten_files": report.rewritten_files,
404439
"applied": [asdict(f) for f in report.applied],
@@ -407,6 +442,46 @@ def _format_json_report(report: Report) -> str:
407442
return json.dumps(payload, indent=2)
408443

409444

445+
def _is_dirty_tree(path: Path) -> bool:
446+
"""True when ``path`` is inside a git repo with uncommitted changes.
447+
448+
Uses ``git status --porcelain`` for speed and stability. Returns
449+
``False`` when git isn't installed, the path isn't in a repo, or
450+
the repo is clean — any non-clean state returns ``True`` so the
451+
``--apply`` guard fails safe.
452+
453+
The check is best-effort: absence of git isn't a reason to block
454+
the rewrite (sellers may run in sandboxed or read-only environments
455+
where git isn't available). A ``True`` result means we saw
456+
definite uncommitted state.
457+
"""
458+
import shutil
459+
import subprocess
460+
461+
if shutil.which("git") is None:
462+
return False
463+
464+
target = path.resolve()
465+
cwd = target if target.is_dir() else target.parent
466+
try:
467+
result = subprocess.run(
468+
["git", "status", "--porcelain"],
469+
cwd=cwd,
470+
check=False,
471+
capture_output=True,
472+
text=True,
473+
timeout=5,
474+
)
475+
except (OSError, subprocess.SubprocessError):
476+
return False
477+
# Exit 128 = not a git repo; anything non-zero → treat as clean
478+
# (not blocking — we don't want `--apply` in a sandboxed env to
479+
# break because git can't run).
480+
if result.returncode != 0:
481+
return False
482+
return bool(result.stdout.strip())
483+
484+
410485
def main(argv: list[str] | None = None) -> int:
411486
"""CLI entry point for ``python -m adcp.migrate v3-to-v4``."""
412487
parser = argparse.ArgumentParser(
@@ -429,6 +504,18 @@ def main(argv: list[str] | None = None) -> int:
429504
"Commit your tree first so `git diff` is your review view."
430505
),
431506
)
507+
parser.add_argument(
508+
"--allow-dirty",
509+
action="store_true",
510+
help=(
511+
"Allow --apply even when the git working tree has "
512+
"uncommitted changes. Default is to refuse so `git diff` "
513+
"after the migration shows only the codemod's rewrites, "
514+
"not a mix of the seller's in-progress work and the "
515+
"codemod. Pass --allow-dirty when you know what you're "
516+
"doing (e.g. applying to a staged change deliberately)."
517+
),
518+
)
432519
parser.add_argument(
433520
"--json",
434521
action="store_true",
@@ -440,6 +527,17 @@ def main(argv: list[str] | None = None) -> int:
440527
print(f"error: path does not exist: {args.path}", file=sys.stderr)
441528
return 2
442529

530+
if args.apply and not args.allow_dirty and _is_dirty_tree(args.path):
531+
print(
532+
"error: --apply refused on a dirty git working tree.\n"
533+
" Commit your changes first so `git diff` after the\n"
534+
" migration shows only the codemod's rewrites. Pass\n"
535+
" --allow-dirty to override (e.g. you're deliberately\n"
536+
" applying on top of staged changes).",
537+
file=sys.stderr,
538+
)
539+
return 2
540+
443541
report = run(args.path, apply_changes=args.apply)
444542

445543
if args.json:

src/adcp/server/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,14 @@ async def get_products(params, context=None):
5555
from adcp.capabilities import validate_capabilities
5656
from adcp.server.a2a_server import ADCPAgentExecutor, MessageParser, create_a2a_server
5757
from adcp.server.auth import (
58+
AsyncTokenValidator,
5859
BearerTokenAuthMiddleware,
5960
Principal,
61+
SyncTokenValidator,
6062
TokenValidator,
6163
auth_context_factory,
6264
constant_time_token_match,
65+
validator_from_token_map,
6366
)
6467
from adcp.server.base import (
6568
AccountAwareToolContext,
@@ -172,11 +175,14 @@ async def get_products(params, context=None):
172175
"SkillMiddleware",
173176
"create_a2a_server",
174177
# Bearer-token auth middleware (seller-facing recipe)
178+
"AsyncTokenValidator",
175179
"BearerTokenAuthMiddleware",
176180
"Principal",
181+
"SyncTokenValidator",
177182
"TokenValidator",
178183
"auth_context_factory",
179184
"constant_time_token_match",
185+
"validator_from_token_map",
180186
# Idempotency middleware (AdCP #2315 seller side)
181187
"IdempotencyStore",
182188
"MemoryBackend",

src/adcp/server/auth.py

Lines changed: 65 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,12 @@ async def validate_token(token: str) -> Principal | None:
7272
import inspect
7373
import json
7474
import logging
75-
from collections.abc import Awaitable, Callable
75+
from collections.abc import Awaitable, Mapping
7676
from contextvars import ContextVar
7777
from dataclasses import dataclass, field
78-
from typing import TYPE_CHECKING, Any
78+
from typing import TYPE_CHECKING, Any, Protocol, TypeVar
79+
80+
_V = TypeVar("_V")
7981

8082
from starlette.middleware.base import BaseHTTPMiddleware
8183
from starlette.responses import JSONResponse
@@ -138,16 +140,32 @@ class Principal:
138140
metadata: dict[str, Any] = field(default_factory=dict)
139141

140142

141-
TokenValidator = Callable[[str], "Principal | None | Awaitable[Principal | None]"]
143+
class SyncTokenValidator(Protocol):
144+
"""Synchronous token validator — ``def validate_token(token) -> Principal | None``."""
145+
146+
def __call__(self, token: str) -> Principal | None: ...
147+
148+
149+
class AsyncTokenValidator(Protocol):
150+
"""Asynchronous token validator —
151+
``async def validate_token(token) -> Principal | None``."""
152+
153+
def __call__(self, token: str) -> Awaitable[Principal | None]: ...
154+
155+
156+
TokenValidator = SyncTokenValidator | AsyncTokenValidator
142157
"""Seller-supplied callable that validates a bearer token.
143158
144159
Called with the raw token string (``Authorization: Bearer <token>``
145160
with the prefix already stripped). Return a :class:`Principal` on
146-
success, ``None`` to reject.
161+
success, ``None`` to reject. Sync and async callables are both
162+
accepted — the middleware awaits the result when it's awaitable.
147163
148-
Sync and async callables are both accepted — the middleware awaits the
149-
result when it's awaitable, so plain ``def validate_token(...)`` and
150-
``async def validate_token(...)`` both work.
164+
Declared as a union of two Protocols (rather than a
165+
``Callable[[str], Principal | None | Awaitable[...]]`` alias)
166+
because mypy narrows Protocol unions per-call-site: downstream code
167+
using ``async def validate_token`` gets the async branch without
168+
``type: ignore`` noise. Either protocol is a valid ``TokenValidator``.
151169
152170
**Do not raise on invalid tokens.** Exceptions become ``500 Internal
153171
Server Error`` responses, which leak the presence of an auth path
@@ -367,7 +385,7 @@ def auth_context_factory(meta: RequestMetadata) -> ToolContext:
367385
# ------------------------------------------------------------------
368386

369387

370-
def constant_time_token_match(token: str, stored_hashes: dict[str, Any]) -> Any:
388+
def constant_time_token_match(token: str, stored_hashes: Mapping[str, _V]) -> _V | None:
371389
"""Look up a token in a dict of SHA-256 hashes using
372390
:func:`hmac.compare_digest` rather than dict-containment.
373391
@@ -393,3 +411,42 @@ def constant_time_token_match(token: str, stored_hashes: dict[str, Any]) -> Any:
393411
if hmac.compare_digest(candidate, stored_hash):
394412
return value
395413
return None
414+
415+
416+
def validator_from_token_map(
417+
token_map: Mapping[str, Principal],
418+
) -> SyncTokenValidator:
419+
"""Build a :data:`TokenValidator` from a ``{raw_token: Principal}`` map.
420+
421+
The shape most demo/test agents actually need — a fixed set of
422+
tokens mapped to principals — without having to write the
423+
constant-time plumbing. The returned validator hashes each raw
424+
token at construction time and does constant-time lookups via
425+
:func:`hmac.compare_digest` on every call, matching the security
426+
properties of a hand-rolled validator::
427+
428+
validate_token = validator_from_token_map({
429+
"token-acme": Principal(caller_identity="p-acme", tenant_id="acme"),
430+
"token-globex": Principal(caller_identity="p-globex", tenant_id="globex"),
431+
})
432+
app.add_middleware(BearerTokenAuthMiddleware, validate_token=validate_token)
433+
434+
Production agents looking tokens up in Postgres / Redis / Vault
435+
should write their own async validator instead — this helper is
436+
for the small-fixed-set case (demo, test, CI fixtures).
437+
438+
:param token_map: Mapping of raw bearer tokens to their resolved
439+
:class:`Principal`. Tokens are hashed at construction; the
440+
plaintext is not retained.
441+
:returns: A :data:`SyncTokenValidator` (which satisfies
442+
:data:`TokenValidator`).
443+
"""
444+
stored_hashes: dict[str, Principal] = {
445+
hashlib.sha256(token.encode()).hexdigest(): principal
446+
for token, principal in token_map.items()
447+
}
448+
449+
def _validate(token: str) -> Principal | None:
450+
return constant_time_token_match(token, stored_hashes)
451+
452+
return _validate

0 commit comments

Comments
 (0)