Skip to content

Commit 4d42569

Browse files
authored
Merge pull request #2449 from jlowin/sep-1686-final-spec
Complete SEP-1686 final spec implementation
2 parents 1bf7ebf + e8c35e4 commit 4d42569

25 files changed

+1916
-702
lines changed

src/fastmcp/cli/run.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ async def run_v1_server_async(
211211
port: Port to bind to
212212
transport: Transport protocol to use
213213
"""
214-
214+
215215
if host:
216216
server.settings.host = host
217217
if port:
Lines changed: 306 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,306 @@
1+
"""
2+
⚠️ TEMPORARY CODE - SEP-1686 WORKAROUNDS FOR MCP SDK LIMITATIONS ⚠️
3+
4+
This file contains workarounds for MCP SDK limitations related to SEP-1686 tasks:
5+
6+
1. Client capability declaration - SDK doesn't support customizing experimental capabilities
7+
2. Task protocol types - SDK doesn't have final task protocol types yet
8+
3. Task notification routing - Custom message handler for notifications/tasks/status
9+
10+
These shims will be removed when the MCP SDK is updated to match the final spec.
11+
12+
DO NOT WRITE TESTS FOR THIS FILE - these are temporary hacks.
13+
"""
14+
15+
from __future__ import annotations
16+
17+
import weakref
18+
from dataclasses import dataclass
19+
from typing import TYPE_CHECKING, Any, Literal
20+
21+
import mcp.types
22+
import pydantic
23+
from mcp.client.session import (
24+
SUPPORTED_PROTOCOL_VERSIONS,
25+
ClientSession,
26+
_default_elicitation_callback,
27+
_default_list_roots_callback,
28+
_default_sampling_callback,
29+
)
30+
from pydantic import BaseModel
31+
32+
from fastmcp.client.messages import Message, MessageHandler
33+
34+
if TYPE_CHECKING:
35+
from fastmcp.client.client import Client
36+
37+
38+
# ═══════════════════════════════════════════════════════════════════════════
39+
# 1. Client Capability Declaration
40+
# ═══════════════════════════════════════════════════════════════════════════
41+
42+
43+
class TaskCapableClientSession(ClientSession):
44+
"""Custom ClientSession that declares task capability.
45+
46+
Overrides initialize() to set experimental={"tasks": {}} in ClientCapabilities.
47+
"""
48+
49+
async def initialize(self) -> mcp.types.InitializeResult:
50+
"""Initialize with task capability declaration."""
51+
# Build capabilities
52+
sampling = (
53+
mcp.types.SamplingCapability()
54+
if self._sampling_callback != _default_sampling_callback
55+
else None
56+
)
57+
elicitation = (
58+
mcp.types.ElicitationCapability()
59+
if self._elicitation_callback != _default_elicitation_callback
60+
else None
61+
)
62+
roots = (
63+
mcp.types.RootsCapability(listChanged=True)
64+
if self._list_roots_callback != _default_list_roots_callback
65+
else None
66+
)
67+
68+
# Send initialize request with task capability
69+
result = await self.send_request(
70+
mcp.types.ClientRequest(
71+
mcp.types.InitializeRequest(
72+
params=mcp.types.InitializeRequestParams(
73+
protocolVersion=mcp.types.LATEST_PROTOCOL_VERSION,
74+
capabilities=mcp.types.ClientCapabilities(
75+
sampling=sampling,
76+
elicitation=elicitation,
77+
experimental={"tasks": {}},
78+
roots=roots,
79+
),
80+
clientInfo=self._client_info,
81+
),
82+
)
83+
),
84+
mcp.types.InitializeResult,
85+
)
86+
87+
# Validate protocol version
88+
if result.protocolVersion not in SUPPORTED_PROTOCOL_VERSIONS:
89+
raise RuntimeError(
90+
f"Unsupported protocol version from the server: {result.protocolVersion}"
91+
)
92+
93+
# Send initialized notification
94+
await self.send_notification(
95+
mcp.types.ClientNotification(mcp.types.InitializedNotification())
96+
)
97+
98+
return result
99+
100+
101+
async def task_capable_initialize(
102+
session: ClientSession,
103+
) -> mcp.types.InitializeResult:
104+
"""Initialize a session with task capabilities.
105+
106+
Args:
107+
session: The ClientSession to initialize
108+
109+
Returns:
110+
InitializeResult from the server
111+
"""
112+
# Build capabilities
113+
sampling = (
114+
mcp.types.SamplingCapability()
115+
if session._sampling_callback != _default_sampling_callback
116+
else None
117+
)
118+
elicitation = (
119+
mcp.types.ElicitationCapability()
120+
if session._elicitation_callback != _default_elicitation_callback
121+
else None
122+
)
123+
roots = (
124+
mcp.types.RootsCapability(listChanged=True)
125+
if session._list_roots_callback != _default_list_roots_callback
126+
else None
127+
)
128+
129+
# Send initialize request with task capability
130+
result = await session.send_request(
131+
mcp.types.ClientRequest(
132+
mcp.types.InitializeRequest(
133+
params=mcp.types.InitializeRequestParams(
134+
protocolVersion=mcp.types.LATEST_PROTOCOL_VERSION,
135+
capabilities=mcp.types.ClientCapabilities(
136+
sampling=sampling,
137+
elicitation=elicitation,
138+
experimental={"tasks": {}},
139+
roots=roots,
140+
),
141+
clientInfo=session._client_info,
142+
),
143+
)
144+
),
145+
mcp.types.InitializeResult,
146+
)
147+
148+
# Validate protocol version
149+
if result.protocolVersion not in SUPPORTED_PROTOCOL_VERSIONS:
150+
raise RuntimeError(
151+
f"Unsupported protocol version from the server: {result.protocolVersion}"
152+
)
153+
154+
# Send initialized notification
155+
await session.send_notification(
156+
mcp.types.ClientNotification(mcp.types.InitializedNotification())
157+
)
158+
159+
return result
160+
161+
162+
# ═══════════════════════════════════════════════════════════════════════════
163+
# 2. Task Protocol Types (SDK doesn't have these yet)
164+
# ═══════════════════════════════════════════════════════════════════════════
165+
166+
167+
class TasksGetRequest(BaseModel):
168+
"""Request for tasks/get MCP method."""
169+
170+
method: Literal["tasks/get"] = "tasks/get"
171+
params: TasksGetParams
172+
173+
174+
class TasksGetParams(BaseModel):
175+
"""Parameters for tasks/get request."""
176+
177+
taskId: str
178+
_meta: dict[str, Any] | None = None
179+
180+
181+
class TasksGetResult(BaseModel):
182+
"""Result from tasks/get MCP method."""
183+
184+
taskId: str
185+
status: Literal[
186+
"submitted", "working", "completed", "failed", "cancelled", "unknown"
187+
]
188+
createdAt: str
189+
ttl: int | None = None
190+
pollInterval: int | None = None
191+
error: str | None = None
192+
193+
194+
class TasksResultRequest(BaseModel):
195+
"""Request for tasks/result MCP method."""
196+
197+
method: Literal["tasks/result"] = "tasks/result"
198+
params: TasksResultParams
199+
200+
201+
class TasksResultParams(BaseModel):
202+
"""Parameters for tasks/result request."""
203+
204+
taskId: str
205+
_meta: dict[str, Any] | None = None
206+
207+
208+
class TasksListRequest(BaseModel):
209+
"""Request for tasks/list MCP method."""
210+
211+
method: Literal["tasks/list"] = "tasks/list"
212+
params: TasksListParams
213+
214+
215+
class TasksListParams(BaseModel):
216+
"""Parameters for tasks/list request."""
217+
218+
cursor: str | None = None
219+
limit: int = 50
220+
_meta: dict[str, Any] | None = None
221+
222+
223+
class TasksListResult(BaseModel):
224+
"""Result from tasks/list MCP method."""
225+
226+
tasks: list[dict[str, Any]]
227+
nextCursor: str | None = None
228+
229+
230+
class TasksDeleteRequest(BaseModel):
231+
"""Request for tasks/delete MCP method."""
232+
233+
method: Literal["tasks/delete"] = "tasks/delete"
234+
params: TasksDeleteParams
235+
236+
237+
class TasksDeleteParams(BaseModel):
238+
"""Parameters for tasks/delete request."""
239+
240+
taskId: str
241+
_meta: dict[str, Any] | None = None
242+
243+
244+
class TasksDeleteResult(BaseModel):
245+
"""Result from tasks/delete MCP method."""
246+
247+
_meta: dict[str, Any] | None = None
248+
249+
250+
# ═══════════════════════════════════════════════════════════════════════════
251+
# 3. Client-Side Type Helpers
252+
# ═══════════════════════════════════════════════════════════════════════════
253+
254+
255+
@dataclass
256+
class CallToolResult:
257+
"""Parsed result from a tool call."""
258+
259+
content: list[mcp.types.ContentBlock]
260+
structured_content: dict[str, Any] | None
261+
meta: dict[str, Any] | None
262+
data: Any = None
263+
is_error: bool = False
264+
265+
266+
class TaskStatusResponse(pydantic.BaseModel):
267+
"""Response from tasks/get endpoint."""
268+
269+
task_id: str = pydantic.Field(alias="taskId")
270+
status: Literal[
271+
"submitted", "working", "completed", "failed", "cancelled", "unknown"
272+
]
273+
created_at: str = pydantic.Field(alias="createdAt")
274+
ttl: int | None = pydantic.Field(default=None, alias="ttl")
275+
poll_interval: int | None = pydantic.Field(default=None, alias="pollInterval")
276+
error: str | None = None
277+
278+
model_config = pydantic.ConfigDict(populate_by_name=True)
279+
280+
281+
# ═══════════════════════════════════════════════════════════════════════════
282+
# 4. Task Notification Routing
283+
# ═══════════════════════════════════════════════════════════════════════════
284+
285+
286+
class ClientMessageHandler(MessageHandler):
287+
"""MessageHandler that routes task status notifications to Task objects."""
288+
289+
def __init__(self, client: Client):
290+
super().__init__()
291+
self._client_ref: weakref.ref[Client] = weakref.ref(client)
292+
293+
async def dispatch(self, message: Message) -> None:
294+
"""Dispatch messages, including task status notifications."""
295+
# Handle task status notifications
296+
if isinstance(message, mcp.types.ServerNotification):
297+
if (
298+
hasattr(message.root, "method")
299+
and message.root.method == "notifications/tasks/status"
300+
):
301+
client = self._client_ref()
302+
if client:
303+
client._handle_task_status_notification(message.root)
304+
305+
# Call parent dispatch for all other messages
306+
await super().dispatch(message)

0 commit comments

Comments
 (0)