diff --git a/src/fastmcp/contrib/bulk_tool_caller/bulk_tool_caller.py b/src/fastmcp/contrib/bulk_tool_caller/bulk_tool_caller.py index 07e82e8aa..cd7483e51 100644 --- a/src/fastmcp/contrib/bulk_tool_caller/bulk_tool_caller.py +++ b/src/fastmcp/contrib/bulk_tool_caller/bulk_tool_caller.py @@ -1,8 +1,6 @@ +import warnings from typing import Any -from mcp.types import CallToolResult -from pydantic import BaseModel, Field - from fastmcp import FastMCP from fastmcp.client import Client from fastmcp.client.transports import FastMCPTransport @@ -12,45 +10,30 @@ mcp_tool, ) +# Re-export types from the new location for backward compatibility +from fastmcp.server.middleware.bulk_tool_caller_types import ( + CallToolRequest, + CallToolRequestResult, +) -class CallToolRequest(BaseModel): - """A class to represent a request to call a tool with specific arguments.""" - tool: str = Field(description="The name of the tool to call.") - arguments: dict[str, Any] = Field( - description="A dictionary containing the arguments for the tool call." - ) +class BulkToolCaller(MCPMixin): + """A class to provide a "bulk tool call" tool for a FastMCP server. + .. deprecated:: 2.1.0 + Use :class:`~fastmcp.server.middleware.BulkToolCallerMiddleware` instead. + This class is maintained for backward compatibility but will be removed + in a future version. -class CallToolRequestResult(CallToolResult): - """ - A class to represent the result of a bulk tool call. - It extends CallToolResult to include information about the requested tool call. - """ + Old usage:: - tool: str = Field(description="The name of the tool that was called.") - arguments: dict[str, Any] = Field( - description="The arguments used for the tool call." - ) + bulk_tool_caller = BulkToolCaller() + bulk_tool_caller.register_tools(mcp) - @classmethod - def from_call_tool_result( - cls, result: CallToolResult, tool: str, arguments: dict[str, Any] - ) -> "CallToolRequestResult": - """ - Create a CallToolRequestResult from a CallToolResult. - """ - return cls( - tool=tool, - arguments=arguments, - isError=result.isError, - content=result.content, - ) + New usage:: - -class BulkToolCaller(MCPMixin): - """ - A class to provide a "bulk tool call" tool for a FastMCP server + from fastmcp.server.middleware import BulkToolCallerMiddleware + mcp = FastMCP(middleware=[BulkToolCallerMiddleware()]) """ def register_tools( @@ -59,9 +42,19 @@ def register_tools( prefix: str | None = None, separator: str = _DEFAULT_SEPARATOR_TOOL, ) -> None: + """Register the tools provided by this class with the given MCP server. + + .. deprecated:: 2.1.0 + Use :class:`~fastmcp.server.middleware.BulkToolCallerMiddleware` instead. """ - Register the tools provided by this class with the given MCP server. - """ + warnings.warn( + "BulkToolCaller is deprecated and will be removed in a future version. " + "Use BulkToolCallerMiddleware instead: " + "FastMCP(middleware=[BulkToolCallerMiddleware()])", + DeprecationWarning, + stacklevel=2, + ) + self.connection = FastMCPTransport(mcp_server) super().register_tools(mcp_server=mcp_server) @@ -125,9 +118,6 @@ async def _call_tool( async with Client(self.connection) as client: result = await client.call_tool_mcp(name=tool, arguments=arguments) - return CallToolRequestResult( - tool=tool, - arguments=arguments, - isError=result.isError, - content=result.content, + return CallToolRequestResult.from_call_tool_result( + result, tool=tool, arguments=arguments ) diff --git a/src/fastmcp/server/middleware/__init__.py b/src/fastmcp/server/middleware/__init__.py index 1e2035b21..e68c663cf 100644 --- a/src/fastmcp/server/middleware/__init__.py +++ b/src/fastmcp/server/middleware/__init__.py @@ -1,10 +1,30 @@ +from typing import TYPE_CHECKING + from .middleware import ( Middleware, MiddlewareContext, CallNext, ) +if TYPE_CHECKING: + from .bulk_tool_caller import BulkToolCallerMiddleware + + +def __getattr__(name: str): + if name == "BulkToolCallerMiddleware": + from .bulk_tool_caller import BulkToolCallerMiddleware + + return BulkToolCallerMiddleware + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +def __dir__() -> list[str]: + """Ensure BulkToolCallerMiddleware shows up in dir() output.""" + return sorted([*globals().keys(), "BulkToolCallerMiddleware"]) + + __all__ = [ + "BulkToolCallerMiddleware", "CallNext", "Middleware", "MiddlewareContext", diff --git a/src/fastmcp/server/middleware/bulk_tool_caller.py b/src/fastmcp/server/middleware/bulk_tool_caller.py new file mode 100644 index 000000000..c6c1ed5c6 --- /dev/null +++ b/src/fastmcp/server/middleware/bulk_tool_caller.py @@ -0,0 +1,192 @@ +"""Middleware for bulk tool calling functionality.""" + +from typing import Annotated + +from mcp.types import TextContent + +from fastmcp.server.context import Context +from fastmcp.server.middleware.bulk_tool_caller_types import ( + CallToolRequest, + CallToolRequestResult, +) +from fastmcp.server.middleware.tool_injection import ToolInjectionMiddleware +from fastmcp.tools.tool import Tool + + +async def call_tools_bulk( + context: Context, + tool_calls: Annotated[ + list[CallToolRequest], + "List of tool calls to execute. Each call can be for a different tool with different arguments.", + ], + continue_on_error: Annotated[ + bool, + "If True, continue executing remaining tools even if one fails. If False, stop on first error.", + ] = True, +) -> list[CallToolRequestResult]: + """Call multiple tools registered on this MCP server in a single request. + + Each call can be for a different tool and can include different arguments. + Useful for speeding up what would otherwise take several individual tool calls. + + Args: + context: The request context providing access to the server + tool_calls: List of tool calls to execute + continue_on_error: Whether to continue on errors (default: True) + + Returns: + List of results, one per tool call + """ + results = [] + + for tool_call in tool_calls: + try: + # Call the tool directly through the tool manager + tool_result = await context.fastmcp._tool_manager.call_tool( + key=tool_call.tool, arguments=tool_call.arguments + ) + + # Convert ToolResult to CallToolRequestResult, preserving all fields + # Note: ToolResult doesn't have isError - it's only on CallToolResult + # For successful calls, we don't set isError (defaults to None) + results.append( + CallToolRequestResult( + tool=tool_call.tool, + arguments=tool_call.arguments, + content=tool_result.content, + structuredContent=tool_result.structured_content, + ) + ) + except Exception as e: + # Create error result + error_message = f"Error calling tool '{tool_call.tool}': {e}" + results.append( + CallToolRequestResult( + tool=tool_call.tool, + arguments=tool_call.arguments, + isError=True, + content=[TextContent(text=error_message, type="text")], + ) + ) + + if not continue_on_error: + break + + return results + + +async def call_tool_bulk( + context: Context, + tool: Annotated[str, "The name of the tool to call multiple times."], + tool_arguments: Annotated[ + list[dict[str, str | int | float | bool | None]], + "List of argument dictionaries. Each dictionary contains the arguments for one tool invocation.", + ], + continue_on_error: Annotated[ + bool, + "If True, continue executing remaining calls even if one fails. If False, stop on first error.", + ] = True, +) -> list[CallToolRequestResult]: + """Call a single tool registered on this MCP server multiple times with a single request. + + Each call can include different arguments. Useful for speeding up what would + otherwise take several individual tool calls. + + Args: + context: The request context providing access to the server + tool: The name of the tool to call + tool_arguments: List of argument dictionaries for each invocation + continue_on_error: Whether to continue on errors (default: True) + + Returns: + List of results, one per invocation + """ + results = [] + + for args in tool_arguments: + try: + # Call the tool directly through the tool manager + tool_result = await context.fastmcp._tool_manager.call_tool( + key=tool, arguments=args + ) + + # Convert ToolResult to CallToolRequestResult, preserving all fields + # Note: ToolResult doesn't have isError - it's only on CallToolResult + # For successful calls, we don't set isError (defaults to None) + results.append( + CallToolRequestResult( + tool=tool, + arguments=args, + content=tool_result.content, + structuredContent=tool_result.structured_content, + ) + ) + except Exception as e: + # Create error result + error_message = f"Error calling tool '{tool}': {e}" + results.append( + CallToolRequestResult( + tool=tool, + arguments=args, + isError=True, + content=[TextContent(text=error_message, type="text")], + ) + ) + + if not continue_on_error: + break + + return results + + +class BulkToolCallerMiddleware(ToolInjectionMiddleware): + """Middleware for injecting bulk tool calling capabilities into the server. + + This middleware adds two tools to the server: + - call_tools_bulk: Call multiple different tools in a single request + - call_tool_bulk: Call a single tool multiple times with different arguments + + Example: + ```python + from fastmcp import FastMCP + from fastmcp.server.middleware import BulkToolCallerMiddleware + + mcp = FastMCP("MyServer", middleware=[BulkToolCallerMiddleware()]) + + @mcp.tool + def greet(name: str) -> str: + return f"Hello, {name}!" + + @mcp.tool + def add(a: int, b: int) -> int: + return a + b + ``` + + Now clients can use bulk calling: + ```python + # Call multiple different tools + result = await client.call_tool("call_tools_bulk", { + "tool_calls": [ + {"tool": "greet", "arguments": {"name": "Alice"}}, + {"tool": "add", "arguments": {"a": 1, "b": 2}} + ] + }) + + # Call same tool multiple times + result = await client.call_tool("call_tool_bulk", { + "tool": "greet", + "tool_arguments": [ + {"name": "Alice"}, + {"name": "Bob"} + ] + }) + ``` + """ + + def __init__(self) -> None: + """Initialize the bulk tool caller middleware.""" + tools: list[Tool] = [ + Tool.from_function(call_tools_bulk), + Tool.from_function(call_tool_bulk), + ] + super().__init__(tools=tools) diff --git a/src/fastmcp/server/middleware/bulk_tool_caller_types.py b/src/fastmcp/server/middleware/bulk_tool_caller_types.py new file mode 100644 index 000000000..b9a805cc9 --- /dev/null +++ b/src/fastmcp/server/middleware/bulk_tool_caller_types.py @@ -0,0 +1,41 @@ +"""Types for bulk tool caller.""" + +from typing import Any + +from mcp.types import CallToolResult +from pydantic import BaseModel, Field + + +class CallToolRequest(BaseModel): + """A class to represent a request to call a tool with specific arguments.""" + + tool: str = Field(description="The name of the tool to call.") + arguments: dict[str, Any] = Field( + description="A dictionary containing the arguments for the tool call." + ) + + +class CallToolRequestResult(CallToolResult): + """A class to represent the result of a bulk tool call. + + It extends CallToolResult to include information about the requested tool call. + """ + + tool: str = Field(description="The name of the tool that was called.") + arguments: dict[str, Any] = Field( + description="The arguments used for the tool call." + ) + + @classmethod + def from_call_tool_result( + cls, result: CallToolResult, tool: str, arguments: dict[str, Any] + ) -> "CallToolRequestResult": + """Create a CallToolRequestResult from a CallToolResult.""" + return cls( + tool=tool, + arguments=arguments, + isError=result.isError, + content=result.content, + _meta=getattr(result, "_meta", None), + structuredContent=getattr(result, "structuredContent", None), + ) diff --git a/tests/server/middleware/test_bulk_tool_caller.py b/tests/server/middleware/test_bulk_tool_caller.py new file mode 100644 index 000000000..05b8c7561 --- /dev/null +++ b/tests/server/middleware/test_bulk_tool_caller.py @@ -0,0 +1,374 @@ +"""Tests for bulk tool caller middleware.""" + +import pytest +from inline_snapshot import snapshot + +from fastmcp import FastMCP +from fastmcp.client import Client +from fastmcp.server.middleware import BulkToolCallerMiddleware + + +class ToolException(Exception): + """Custom exception for tool errors.""" + + +@pytest.fixture +def server_with_tools(): + """Create a FastMCP server with bulk tool caller middleware and test tools.""" + mcp = FastMCP("BulkToolServer", middleware=[BulkToolCallerMiddleware()]) + + @mcp.tool + async def echo_tool(arg1: str) -> str: + """A simple tool that echoes arguments.""" + return arg1 + + @mcp.tool + async def error_tool(arg1: str) -> str: + """A tool that raises an error for testing purposes.""" + raise ToolException(f"Error in tool with arg1: {arg1}") + + @mcp.tool + async def no_return_tool(arg1: str) -> None: + """A simple tool that returns nothing. + + Returns: + None: This tool does not return any value. + """ + + @mcp.tool + def add(a: int, b: int) -> int: + """Add two numbers.""" + return a + b + + return mcp + + +class TestBulkToolCallerMiddleware: + """Tests for BulkToolCallerMiddleware.""" + + async def test_middleware_adds_bulk_tools(self, server_with_tools: FastMCP): + """Test that the middleware adds the bulk tool caller tools.""" + async with Client(server_with_tools) as client: + tools = await client.list_tools() + + tool_names = [tool.name for tool in tools] + # Should have: echo_tool, error_tool, no_return_tool, add, call_tools_bulk, call_tool_bulk + assert len(tools) == 6 + assert "call_tools_bulk" in tool_names + assert "call_tool_bulk" in tool_names + assert "echo_tool" in tool_names + assert "error_tool" in tool_names + assert "no_return_tool" in tool_names + assert "add" in tool_names + + async def test_call_tool_bulk_single_success(self, server_with_tools: FastMCP): + """Test single successful call via call_tool_bulk.""" + async with Client(server_with_tools) as client: + result = await client.call_tool( + "call_tool_bulk", + {"tool": "echo_tool", "tool_arguments": [{"arg1": "value1"}]}, + ) + + assert result.structured_content is not None + assert result.structured_content["result"] == snapshot( + [ + { + "_meta": None, + "content": [ + { + "type": "text", + "text": "value1", + "annotations": None, + "_meta": None, + } + ], + "structuredContent": {"result": "value1"}, + "isError": False, + "tool": "echo_tool", + "arguments": {"arg1": "value1"}, + } + ] + ) + + async def test_call_tool_bulk_multiple_success(self, server_with_tools: FastMCP): + """Test multiple successful calls via call_tool_bulk.""" + async with Client(server_with_tools) as client: + result = await client.call_tool( + "call_tool_bulk", + { + "tool": "echo_tool", + "tool_arguments": [{"arg1": "value1"}, {"arg1": "value2"}], + }, + ) + + assert result.structured_content is not None + assert result.structured_content["result"] == snapshot( + [ + { + "_meta": None, + "content": [ + { + "type": "text", + "text": "value1", + "annotations": None, + "_meta": None, + } + ], + "structuredContent": {"result": "value1"}, + "isError": False, + "tool": "echo_tool", + "arguments": {"arg1": "value1"}, + }, + { + "_meta": None, + "content": [ + { + "type": "text", + "text": "value2", + "annotations": None, + "_meta": None, + } + ], + "structuredContent": {"result": "value2"}, + "isError": False, + "tool": "echo_tool", + "arguments": {"arg1": "value2"}, + }, + ] + ) + + async def test_call_tool_bulk_error_stops(self, server_with_tools: FastMCP): + """Test call_tool_bulk stops on first error.""" + async with Client(server_with_tools) as client: + result = await client.call_tool( + "call_tool_bulk", + { + "tool": "error_tool", + "tool_arguments": [{"arg1": "error_value"}, {"arg1": "value2"}], + "continue_on_error": False, + }, + ) + + assert result.structured_content is not None + results = result.structured_content["result"] # type: ignore[attr-defined] + assert len(results) == 1 + assert results[0]["isError"] is True + assert ( + "Error in tool with arg1: error_value" in results[0]["content"][0]["text"] + ) + + async def test_call_tool_bulk_error_continues(self, server_with_tools: FastMCP): + """Test call_tool_bulk continues on error.""" + async with Client(server_with_tools) as client: + result = await client.call_tool( + "call_tool_bulk", + { + "tool": "error_tool", + "tool_arguments": [{"arg1": "error_value"}, {"arg1": "value2"}], + "continue_on_error": True, + }, + ) + + assert result.structured_content is not None + results = result.structured_content["result"] # type: ignore[attr-defined] + # Both should be errors since the tool always raises + assert len(results) == 2 + assert results[0]["isError"] is True + assert results[1]["isError"] is True + + async def test_call_tools_bulk_single_success(self, server_with_tools: FastMCP): + """Test single successful call via call_tools_bulk.""" + async with Client(server_with_tools) as client: + result = await client.call_tool( + "call_tools_bulk", + { + "tool_calls": [ + {"tool": "echo_tool", "arguments": {"arg1": "value1"}} + ] + }, + ) + + assert result.structured_content is not None + assert result.structured_content["result"] == snapshot( + [ + { + "_meta": None, + "content": [ + { + "type": "text", + "text": "value1", + "annotations": None, + "_meta": None, + } + ], + "structuredContent": {"result": "value1"}, + "isError": False, + "tool": "echo_tool", + "arguments": {"arg1": "value1"}, + } + ] + ) + + async def test_call_tools_bulk_multiple_different_tools( + self, server_with_tools: FastMCP + ): + """Test multiple successful calls via call_tools_bulk with different tools.""" + async with Client(server_with_tools) as client: + result = await client.call_tool( + "call_tools_bulk", + { + "tool_calls": [ + {"tool": "echo_tool", "arguments": {"arg1": "echo_value"}}, + {"tool": "add", "arguments": {"a": 5, "b": 3}}, + ] + }, + ) + + assert result.structured_content is not None + results = result.structured_content["result"] # type: ignore[attr-defined] + assert len(results) == 2 + assert results[0]["tool"] == "echo_tool" + assert results[0]["content"][0]["text"] == "echo_value" + assert results[1]["tool"] == "add" + assert results[1]["content"][0]["text"] == "8" + + async def test_call_tools_bulk_error_stops(self, server_with_tools: FastMCP): + """Test call_tools_bulk stops on first error.""" + async with Client(server_with_tools) as client: + result = await client.call_tool( + "call_tools_bulk", + { + "tool_calls": [ + {"tool": "error_tool", "arguments": {"arg1": "error_value"}}, + {"tool": "echo_tool", "arguments": {"arg1": "skipped_value"}}, + ], + "continue_on_error": False, + }, + ) + + assert result.structured_content is not None + results = result.structured_content["result"] # type: ignore[attr-defined] + # Should only have one result (stopped on error) + assert len(results) == 1 + assert results[0]["isError"] is True + assert ( + "Error in tool with arg1: error_value" in results[0]["content"][0]["text"] + ) + + async def test_call_tools_bulk_error_continues(self, server_with_tools: FastMCP): + """Test call_tools_bulk continues on error.""" + async with Client(server_with_tools) as client: + result = await client.call_tool( + "call_tools_bulk", + { + "tool_calls": [ + {"tool": "error_tool", "arguments": {"arg1": "error_value"}}, + {"tool": "echo_tool", "arguments": {"arg1": "success_value"}}, + ], + "continue_on_error": True, + }, + ) + + assert result.structured_content is not None + results = result.structured_content["result"] # type: ignore[attr-defined] + # Should have both results + assert len(results) == 2 + assert results[0]["isError"] is True + assert results[0]["tool"] == "error_tool" + # isError can be None or False for successful calls + assert results[1]["isError"] in (None, False) + assert results[1]["tool"] == "echo_tool" + assert results[1]["content"][0]["text"] == "success_value" + + async def test_call_tools_bulk_with_no_return_tool( + self, server_with_tools: FastMCP + ): + """Test calling tools that return None.""" + async with Client(server_with_tools) as client: + result = await client.call_tool( + "call_tools_bulk", + { + "tool_calls": [ + { + "tool": "no_return_tool", + "arguments": {"arg1": "no_return_value"}, + } + ] + }, + ) + + assert result.structured_content is not None + results = result.structured_content["result"] # type: ignore[attr-defined] + assert len(results) == 1 + assert results[0]["tool"] == "no_return_tool" + assert results[0]["content"] == [] + + async def test_bulk_tools_bypass_filtering(self): + """Test that bulk caller tools bypass tag filtering.""" + mcp = FastMCP( + "FilteredServer", + middleware=[BulkToolCallerMiddleware()], + exclude_tags={"math"}, + ) + + @mcp.tool(tags={"math"}) + def multiply(a: int, b: int) -> int: + """Multiply two numbers.""" + return a * b + + async with Client(mcp) as client: + tools = await client.list_tools() + + tool_names = [tool.name for tool in tools] + # The multiply tool should be filtered out, but bulk tools should still be available + assert "call_tools_bulk" in tool_names + assert "call_tool_bulk" in tool_names + assert "multiply" not in tool_names + + +class TestBulkToolCallerDeprecation: + """Tests for BulkToolCaller deprecation warnings.""" + + async def test_old_bulk_tool_caller_shows_deprecation(self): + """Test that using BulkToolCaller shows deprecation warning.""" + from fastmcp.contrib.bulk_tool_caller.bulk_tool_caller import BulkToolCaller + + mcp = FastMCP("OldStyleServer") + + @mcp.tool + def echo(text: str) -> str: + return text + + with pytest.warns(DeprecationWarning, match="BulkToolCaller is deprecated"): + bulk_tool_caller = BulkToolCaller() + bulk_tool_caller.register_tools(mcp) + + async def test_old_bulk_tool_caller_still_works(self): + """Test that old BulkToolCaller still functions correctly.""" + from fastmcp.contrib.bulk_tool_caller.bulk_tool_caller import BulkToolCaller + + mcp = FastMCP("OldStyleServer") + + @mcp.tool + def echo(text: str) -> str: + return text + + with pytest.warns(DeprecationWarning): + bulk_tool_caller = BulkToolCaller() + bulk_tool_caller.register_tools(mcp) + + async with Client(mcp) as client: + tools = await client.list_tools() + tool_names = [tool.name for tool in tools] + assert "call_tools_bulk" in tool_names + assert "call_tool_bulk" in tool_names + + # Test that it actually works + result = await client.call_tool( + "call_tool_bulk", + {"tool": "echo", "tool_arguments": [{"text": "hello"}]}, + ) + + # The old BulkToolCaller returns text content (not structured), so just verify it doesn't error + assert result.content is not None + assert len(result.content) > 0