Skip to content

Commit ad8b520

Browse files
bokelleyclaude
andauthored
feat(types): narrow discriminated-union errors (Stability AI Emma P2) (#340)
* feat(types): narrow discriminated-union ValidationErrors to user's intended variant Stability AI Emma backend test (verdict 5/10) flagged that constructing a CreativeManifest whose assets value matched one variant (e.g. ImageAsset) but was missing fields required by THAT variant produced a 60-line pydantic ValidationError listing every variant of the asset content union (13+ variants, 26 errors). The user's actual mistake (one missing field on the variant they picked) was buried. adcp.types.error_narrowing.narrow_union_errors post-processes ValidationError.errors() to keep only errors from the closest-fit variant. Strategy: 1. Discriminator match — variants with no literal_error / union_tag_not_found had their discriminator value match the user's input. If exactly one such variant exists, surface only its errors. Stability case: ImageAsset. 2. Fewest-errors fallback — when no clear winner, pick the variant with fewest non-literal errors as a closest-fit guess. 3. Tie → pass through all errors so the adopter can disambiguate. Wired into create_tool_caller's INVALID_REQUEST projection so wire-side validation errors get the narrowing automatically. Adopters can also call narrow_union_errors manually on a caught ValidationError. Before/after: BEFORE: 26 errors (every variant of the asset union) AFTER: 2 errors (assets.hero.ImageAsset.width / .height) Tests: 6 new. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * fix(types): expert-review hardening on PR #340 narrowing heuristic Two reviewers (code-reviewer, python-expert) flagged five P0/P1s: P0 (python-expert): pydantic upper bound. The narrowing heuristic depends on pydantic-2 ValidationError.errors() internals (err["type"] literals, CamelCase variant names interleaved in err["loc"]). Pydantic 3 makes no API guarantee on these. Pin <3. P1 (code-reviewer): missing union_tag_invalid in mismatch types. Sibling of union_tag_not_found that pydantic-2 emits when a tag is found but invalid. Without it, a variant with that error stays in the candidate pool and may falsely win. P1 (python-expert): _split_at_variant used FIRST CamelCase segment. Nested unions (Union[Outer[Union[A, B]], C]) emit loc like ("field", "Outer", "inner", "A", "subfield"); splitting at "Outer" collapsed A and B into one bucket. Switch to LAST variant segment (innermost wins). P1 (code-reviewer): narrowing call was unguarded. A bug in the heuristic would 500 the wire path. Wrap in try/except in the INVALID_REQUEST projection; fall back to unfiltered errors with WARNING log. P1 (python-expert): defensive copy of returned dicts. Caller mutation could leak back into pydantic's internal error list. ``dict(err)`` per-error at the boundary; cheap insurance. P2 cleanup: - Lift import out of hot path to module top. - DoS guard: cap input at 500 errors (narrowing is UX, not correctness; an attacker shouldn't get to amplify CPU through bucketing). - Document residual literal_error overloading edge case. Tests: 4 new (union_tag_invalid mismatch handling, nested-union innermost-variant resolution, defensive-copy guard, DoS cap). Deferred to follow-up (per python-expert): schema-level fix — teach codegen to emit Annotated[Union[...], Field(discriminator=...)] so pydantic narrows correctly without our post-processor. Test count: 2878 passed (was 2874 — +4 hardening tests). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 985dcd9 commit ad8b520

4 files changed

Lines changed: 600 additions & 1 deletion

File tree

pyproject.toml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,13 @@ dependencies = [
3535
# ``tests/conformance/signing/test_ip_pinned_transport_contract.py``
3636
# guards the specific API shapes we rely on.
3737
"httpcore>=1.0,<2.0",
38-
"pydantic>=2.0.0",
38+
# Upper bound is load-bearing for ``adcp.types.error_narrowing``,
39+
# which depends on pydantic-2 ValidationError.errors() internals
40+
# (``err["type"]`` literals like ``"literal_error"`` /
41+
# ``"union_tag_not_found"`` and CamelCase variant names interleaved
42+
# in ``err["loc"]``). Pydantic 3 has no API guarantee on these
43+
# internals; bump only after porting the narrowing heuristics.
44+
"pydantic>=2.0.0,<3",
3945
"typing-extensions>=4.5.0",
4046
# A2A protocol v1.0 (protobuf types, ProtoJSON on the wire). We run
4147
# on the v1.0 Python SDK with ``enable_v0_3_compat=True`` on the

src/adcp/server/mcp_tools.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
from adcp.server.base import ADCPHandler, ToolContext
2828
from adcp.server.test_controller import SCENARIOS as _CONTROLLER_SCENARIOS
29+
from adcp.types.error_narrowing import narrow_union_errors
2930
from adcp.validation.client_hooks import ValidationHookConfig
3031

3132
logger = logging.getLogger(__name__)
@@ -1746,6 +1747,26 @@ async def call_tool(params: dict[str, Any], context: ToolContext | None = None)
17461747
errors_list = exc.errors(
17471748
include_input=False, include_context=False, include_url=False
17481749
)
1750+
# Narrow discriminated-union failures to the variant
1751+
# the user actually intended (Stability AI Emma P2:
1752+
# 60-line dump → focused error). For non-union
1753+
# failures the function is a no-op.
1754+
#
1755+
# Defensive: if the narrowing helper itself raises
1756+
# (heuristic edge case, future pydantic format
1757+
# change), keep the original error list rather than
1758+
# 500'ing the wire path. The narrowed-error UX is a
1759+
# nice-to-have; correctness is surfacing SOME error.
1760+
try:
1761+
errors_list = list(narrow_union_errors(errors_list))
1762+
except Exception:
1763+
logger.warning(
1764+
"narrow_union_errors raised on %s — passing through "
1765+
"unfiltered errors. This is a bug in the narrowing "
1766+
"heuristic, NOT in the validation itself.",
1767+
method_name,
1768+
exc_info=True,
1769+
)
17491770
first: dict[str, Any] = dict(errors_list[0]) if errors_list else {}
17501771
field_path = ".".join(str(loc) for loc in first.get("loc", ()))
17511772
message = first.get("msg", "validation failed")

src/adcp/types/error_narrowing.py

Lines changed: 277 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,277 @@
1+
"""Narrow pydantic discriminated-union ValidationErrors to the
2+
variant the user actually intended.
3+
4+
Background (Stability AI Emma backend test, verdict 5/10): when an
5+
adopter constructs a ``CreativeManifest`` whose ``assets`` value
6+
matches one variant of a discriminated union but is missing fields
7+
required by THAT variant, pydantic 2 reports validation errors for
8+
EVERY variant in the union (13+ for asset content types). The error
9+
dump runs 60+ lines and obscures the actual problem (a single
10+
missing field on the variant the user picked).
11+
12+
This module exposes :func:`narrow_union_errors` which post-processes
13+
``ValidationError.errors()`` to keep only the errors from the
14+
"closest fit" variant — the one whose discriminator matched OR the
15+
one with the fewest non-discriminator errors. The result is a
16+
focused error pointing at the user's actual mistake.
17+
18+
Used by:
19+
20+
* :func:`adcp.server.mcp_tools.create_tool_caller` — narrows
21+
wire-side ``INVALID_REQUEST`` errors automatically.
22+
* Adopter code via :func:`narrow_validation_error` (manual) — for
23+
adopters who construct typed models in their platform method
24+
bodies and want the same friendlier error UX.
25+
"""
26+
27+
from __future__ import annotations
28+
29+
from typing import Any
30+
31+
32+
# Heuristic: a "variant" location segment is a class name.
33+
# Pydantic emits ``("assets", "hero", "ImageAsset", "width")`` for
34+
# union-validation errors. The variant name is the
35+
# second-to-last position when the error is on a field, OR the
36+
# last position when the error is at the variant itself
37+
# (e.g. ``union_tag_not_found``).
38+
def _looks_like_variant_name(segment: Any) -> bool:
39+
"""Heuristic: a Python class name (CamelCase, starts with capital).
40+
41+
Used to detect variant segments in ``ValidationError.errors()[i].loc``.
42+
Pydantic interleaves variant class names into the loc tuple for
43+
discriminated-union failures; we strip those segments to identify
44+
"this error belongs to variant X."
45+
"""
46+
if not isinstance(segment, str):
47+
return False
48+
if not segment:
49+
return False
50+
# Class names start with an uppercase letter and contain only
51+
# alphanumerics. Reject ``snake_case`` field names.
52+
return segment[0].isupper() and segment.replace("_", "").isalnum() and "_" not in segment
53+
54+
55+
def _split_at_variant(
56+
loc: tuple[Any, ...],
57+
) -> tuple[tuple[Any, ...], str, tuple[Any, ...]] | None:
58+
"""Split a loc tuple at the LAST variant-name segment.
59+
60+
Returns ``(prefix_before_variant, variant_name, suffix_after_variant)``
61+
or ``None`` if no variant segment is found. Used to group
62+
union-validation errors by their containing field path + variant.
63+
64+
The LAST variant segment (innermost) is the one whose error
65+
mattered. For nested unions
66+
(``Union[Outer[Union[A, B]], C]``) pydantic emits
67+
``("field", "Outer", "inner", "A", "subfield")`` — splitting at
68+
the FIRST variant segment ("Outer") would collapse "A" and "B"
69+
into one bucket; splitting at the LAST ("A") correctly groups by
70+
innermost variant.
71+
72+
Example::
73+
74+
loc = ("assets", "hero", "ImageAsset", "width")
75+
→ (("assets", "hero"), "ImageAsset", ("width",))
76+
77+
loc = ("field", "Outer", "inner", "A", "subfield")
78+
→ (("field", "Outer", "inner"), "A", ("subfield",))
79+
"""
80+
last_variant_idx: int | None = None
81+
for i, segment in enumerate(loc):
82+
if _looks_like_variant_name(segment):
83+
last_variant_idx = i
84+
if last_variant_idx is None:
85+
return None
86+
variant = loc[last_variant_idx]
87+
assert isinstance(variant, str) # _looks_like_variant_name guarantees this
88+
return loc[:last_variant_idx], variant, loc[last_variant_idx + 1 :]
89+
90+
91+
#: Cap on input list size. Beyond this we pass through unchanged —
92+
#: narrowing is a UX feature, not correctness, and an attacker
93+
#: submitting a request that triggers thousands of validation errors
94+
#: shouldn't get to amplify CPU through O(N) bucketing logic. The cap
95+
#: is generous enough that genuine union dumps (~30 errors for a 13-
96+
#: variant asset union) never hit it.
97+
_MAX_NARROW_INPUT_SIZE = 500
98+
99+
100+
def narrow_union_errors(
101+
errors: Any,
102+
) -> list[Any]:
103+
"""Return a focused subset of ``errors`` for discriminated-union
104+
failures.
105+
106+
For each (parent_loc) where multiple variant errors exist, pick
107+
the "best fit" variant by:
108+
109+
1. **Discriminator match**: variants with no ``literal_error``,
110+
``union_tag_not_found``, or ``union_tag_invalid`` had their
111+
discriminator value match the user's input. Keep ONLY their
112+
errors.
113+
2. **Fewest non-discriminator errors**: if no clear discriminator
114+
winner, the variant with the smallest count of non-literal
115+
errors is the closest fit. Keep ONLY its errors.
116+
117+
Errors that aren't part of a union failure (no variant in their
118+
``loc``) pass through unchanged. The function never returns an
119+
empty list when the input is non-empty — the worst case falls
120+
back to the input.
121+
122+
**Edge case** (residual): pydantic's ``literal_error`` type fires
123+
on ANY ``Literal[...]`` field mismatch, not just the discriminator.
124+
A user input that hits a non-discriminator literal mismatch on the
125+
matched variant (e.g., correct ``asset_type`` but wrong
126+
``codec``) will eliminate the matched variant from step 1 and the
127+
fallback may pick a wrong variant. The narrowing reduces noise
128+
even in this case but may surface the wrong variant's errors.
129+
Resolving this requires knowing the discriminator field name,
130+
which the heuristic doesn't have access to. Schema-level fix
131+
(``Annotated[Union[...], Field(discriminator=...)]``) avoids the
132+
issue entirely; tracked as a follow-up.
133+
134+
Mirrors the JS-side ``narrowUnionValidationErrors`` (when ported).
135+
"""
136+
if not errors:
137+
return []
138+
139+
errors_list = list(errors)
140+
# DoS guard: don't process pathologically-large inputs. Below the
141+
# cap, narrowing helps. Above it, we're either in a hostile
142+
# request or a legitimately massive schema; either way, the
143+
# narrowing UX win doesn't justify the CPU.
144+
if len(errors_list) > _MAX_NARROW_INPUT_SIZE:
145+
return errors_list
146+
147+
# Bucket errors by (prefix_before_variant) — every error sharing
148+
# the same prefix is contending for the same logical slot, and
149+
# different errors in the same bucket are different variants of
150+
# the same union.
151+
buckets: dict[tuple[Any, ...], list[tuple[str, Any]]] = {}
152+
passthrough: list[Any] = []
153+
154+
for err in errors_list:
155+
loc = tuple(err.get("loc", ()))
156+
split = _split_at_variant(loc)
157+
if split is None:
158+
passthrough.append(err)
159+
continue
160+
prefix, variant, _suffix = split
161+
buckets.setdefault(prefix, []).append((variant, err))
162+
163+
if not buckets:
164+
return errors_list
165+
166+
# Defensive copy of the dicts we're about to surface — the caller
167+
# might mutate the returned list and we don't want that to leak
168+
# back into the input. ``dict(err)`` is shallow which is fine:
169+
# ``loc`` is a tuple (immutable), and other values are scalars or
170+
# nested dicts pydantic doesn't share across errors.
171+
narrowed: list[Any] = [dict(err) for err in passthrough]
172+
for _prefix, variant_errors in buckets.items():
173+
# Group by variant name within this bucket.
174+
per_variant: dict[str, list[Any]] = {}
175+
for variant, err in variant_errors:
176+
per_variant.setdefault(variant, []).append(err)
177+
178+
if len(per_variant) <= 1:
179+
# Only one variant in this bucket — no narrowing needed.
180+
for errs in per_variant.values():
181+
narrowed.extend(dict(e) for e in errs)
182+
continue
183+
184+
winner = _pick_winning_variant(per_variant)
185+
if winner is None:
186+
# Couldn't disambiguate; fall back to all variants for
187+
# this bucket so the adopter doesn't lose information.
188+
for errs in per_variant.values():
189+
narrowed.extend(dict(e) for e in errs)
190+
continue
191+
narrowed.extend(dict(e) for e in per_variant[winner])
192+
193+
return narrowed
194+
195+
196+
#: Pydantic-2 error types that signal "this variant's discriminator
197+
#: didn't match the user's input". A variant whose error list contains
198+
#: ANY of these is eliminated from step 1's candidate pool.
199+
#:
200+
#: * ``literal_error`` — a ``Literal[...]`` field rejected the value.
201+
#: Discriminators are typically Literal-typed (``asset_type:
202+
#: Literal["image"]``); a mismatch here means this variant isn't
203+
#: the user's intent.
204+
#: * ``union_tag_not_found`` — a NESTED tagged union inside this
205+
#: variant couldn't be narrowed to any of ITS sub-variants. Means
206+
#: the user's input doesn't fit this variant's shape at all.
207+
#: * ``union_tag_invalid`` — pydantic-2's "tag found but invalid for
208+
#: this union" code. Same semantic as ``union_tag_not_found`` for
209+
#: our purposes.
210+
_DISCRIMINATOR_MISMATCH_TYPES = frozenset(
211+
{"literal_error", "union_tag_not_found", "union_tag_invalid"}
212+
)
213+
214+
215+
def _pick_winning_variant(
216+
per_variant: dict[str, list[Any]],
217+
) -> str | None:
218+
"""Return the variant name whose errors are the closest fit.
219+
220+
Strategy (in order):
221+
222+
1. **Discriminator match**: variants with ZERO discriminator-mismatch
223+
errors (see :data:`_DISCRIMINATOR_MISMATCH_TYPES`) had their
224+
discriminator value match the user's input. If exactly one
225+
such variant exists, it's the winner. This is the Stability AI
226+
/ AudioStack 60-line-dump fix — when the user provides
227+
``asset_type='image'`` and ImageAsset's other fields fail, we
228+
surface ImageAsset errors only.
229+
2. **Fewest errors among matched**: if multiple variants matched
230+
the discriminator, pick the one with fewest errors (closest
231+
fit to user's input shape).
232+
3. **Fallback to fewest errors overall**: if NO variant matched
233+
the discriminator (the user provided an invalid discriminator
234+
value, e.g. ``asset_type='image_asset'`` instead of ``'image'``),
235+
pick the variant with fewest non-literal errors as a closest-
236+
fit guess.
237+
4. **Tie**: return ``None`` so the caller passes through all
238+
errors — adopter can disambiguate manually.
239+
"""
240+
if not per_variant:
241+
return None
242+
243+
matched = {
244+
variant: errs
245+
for variant, errs in per_variant.items()
246+
if not any(e.get("type") in _DISCRIMINATOR_MISMATCH_TYPES for e in errs)
247+
}
248+
249+
if matched:
250+
if len(matched) == 1:
251+
return next(iter(matched.keys()))
252+
# Multiple matched (rare — would mean two variants share the
253+
# same discriminator literal). Pick the one with fewest
254+
# errors.
255+
sorted_matched = sorted(matched.items(), key=lambda kv: len(kv[1]))
256+
if len(sorted_matched) >= 2 and len(sorted_matched[0][1]) == len(sorted_matched[1][1]):
257+
return None # tie
258+
return sorted_matched[0][0]
259+
260+
# Step 3: no discriminator match — pick the variant with fewest
261+
# non-literal errors. That's the closest-fit guess for an adopter
262+
# who used an invalid discriminator.
263+
def _non_literal_score(item: tuple[str, list[Any]]) -> int:
264+
_variant, errs = item
265+
return sum(1 for e in errs if e.get("type") != "literal_error")
266+
267+
sorted_all = sorted(per_variant.items(), key=_non_literal_score)
268+
if len(sorted_all) >= 2 and _non_literal_score(sorted_all[0]) == _non_literal_score(
269+
sorted_all[1]
270+
):
271+
return None # tie — don't guess
272+
return sorted_all[0][0]
273+
274+
275+
__all__ = [
276+
"narrow_union_errors",
277+
]

0 commit comments

Comments
 (0)