@@ -72,10 +72,12 @@ async def validate_token(token: str) -> Principal | None:
7272import inspect
7373import json
7474import logging
75- from collections .abc import Awaitable , Callable
75+ from collections .abc import Awaitable , Mapping
7676from contextvars import ContextVar
7777from dataclasses import dataclass , field
78- from typing import TYPE_CHECKING , Any
78+ from typing import TYPE_CHECKING , Any , Protocol , TypeVar
79+
80+ _V = TypeVar ("_V" )
7981
8082from starlette .middleware .base import BaseHTTPMiddleware
8183from starlette .responses import JSONResponse
@@ -138,16 +140,32 @@ class Principal:
138140 metadata : dict [str , Any ] = field (default_factory = dict )
139141
140142
141- TokenValidator = Callable [[str ], "Principal | None | Awaitable[Principal | None]" ]
143+ class SyncTokenValidator (Protocol ):
144+ """Synchronous token validator — ``def validate_token(token) -> Principal | None``."""
145+
146+ def __call__ (self , token : str ) -> Principal | None : ...
147+
148+
149+ class AsyncTokenValidator (Protocol ):
150+ """Asynchronous token validator —
151+ ``async def validate_token(token) -> Principal | None``."""
152+
153+ def __call__ (self , token : str ) -> Awaitable [Principal | None ]: ...
154+
155+
156+ TokenValidator = SyncTokenValidator | AsyncTokenValidator
142157"""Seller-supplied callable that validates a bearer token.
143158
144159Called with the raw token string (``Authorization: Bearer <token>``
145160with the prefix already stripped). Return a :class:`Principal` on
146- success, ``None`` to reject.
161+ success, ``None`` to reject. Sync and async callables are both
162+ accepted — the middleware awaits the result when it's awaitable.
147163
148- Sync and async callables are both accepted — the middleware awaits the
149- result when it's awaitable, so plain ``def validate_token(...)`` and
150- ``async def validate_token(...)`` both work.
164+ Declared as a union of two Protocols (rather than a
165+ ``Callable[[str], Principal | None | Awaitable[...]]`` alias)
166+ because mypy narrows Protocol unions per-call-site: downstream code
167+ using ``async def validate_token`` gets the async branch without
168+ ``type: ignore`` noise. Either protocol is a valid ``TokenValidator``.
151169
152170**Do not raise on invalid tokens.** Exceptions become ``500 Internal
153171Server Error`` responses, which leak the presence of an auth path
@@ -367,7 +385,7 @@ def auth_context_factory(meta: RequestMetadata) -> ToolContext:
367385# ------------------------------------------------------------------
368386
369387
370- def constant_time_token_match (token : str , stored_hashes : dict [str , Any ]) -> Any :
388+ def constant_time_token_match (token : str , stored_hashes : Mapping [str , _V ]) -> _V | None :
371389 """Look up a token in a dict of SHA-256 hashes using
372390 :func:`hmac.compare_digest` rather than dict-containment.
373391
@@ -393,3 +411,42 @@ def constant_time_token_match(token: str, stored_hashes: dict[str, Any]) -> Any:
393411 if hmac .compare_digest (candidate , stored_hash ):
394412 return value
395413 return None
414+
415+
416+ def validator_from_token_map (
417+ token_map : Mapping [str , Principal ],
418+ ) -> SyncTokenValidator :
419+ """Build a :data:`TokenValidator` from a ``{raw_token: Principal}`` map.
420+
421+ The shape most demo/test agents actually need — a fixed set of
422+ tokens mapped to principals — without having to write the
423+ constant-time plumbing. The returned validator hashes each raw
424+ token at construction time and does constant-time lookups via
425+ :func:`hmac.compare_digest` on every call, matching the security
426+ properties of a hand-rolled validator::
427+
428+ validate_token = validator_from_token_map({
429+ "token-acme": Principal(caller_identity="p-acme", tenant_id="acme"),
430+ "token-globex": Principal(caller_identity="p-globex", tenant_id="globex"),
431+ })
432+ app.add_middleware(BearerTokenAuthMiddleware, validate_token=validate_token)
433+
434+ Production agents looking tokens up in Postgres / Redis / Vault
435+ should write their own async validator instead — this helper is
436+ for the small-fixed-set case (demo, test, CI fixtures).
437+
438+ :param token_map: Mapping of raw bearer tokens to their resolved
439+ :class:`Principal`. Tokens are hashed at construction; the
440+ plaintext is not retained.
441+ :returns: A :data:`SyncTokenValidator` (which satisfies
442+ :data:`TokenValidator`).
443+ """
444+ stored_hashes : dict [str , Principal ] = {
445+ hashlib .sha256 (token .encode ()).hexdigest (): principal
446+ for token , principal in token_map .items ()
447+ }
448+
449+ def _validate (token : str ) -> Principal | None :
450+ return constant_time_token_match (token , stored_hashes )
451+
452+ return _validate
0 commit comments