From ba634532d0c207206314f302039f283b0982ff5b Mon Sep 17 00:00:00 2001 From: aravind-segu Date: Mon, 17 Nov 2025 15:04:39 -0800 Subject: [PATCH 1/8] Add unit tests --- integrations/langchain/pyproject.toml | 2 + .../multi_server_mcp_client.py | 175 ++++++ .../test_multi_server_mcp_client.py | 527 ++++++++++++++++++ 3 files changed, 704 insertions(+) create mode 100644 integrations/langchain/src/databricks_langchain/multi_server_mcp_client.py create mode 100644 integrations/langchain/tests/unit_tests/test_multi_server_mcp_client.py diff --git a/integrations/langchain/pyproject.toml b/integrations/langchain/pyproject.toml index 5216f17a..f2c666a0 100644 --- a/integrations/langchain/pyproject.toml +++ b/integrations/langchain/pyproject.toml @@ -17,6 +17,8 @@ dependencies = [ "unitycatalog-langchain[databricks]>=0.3.0", "databricks-sdk>=0.65.0", "openai>=1.99.9", + "langchain-mcp-adapters>=0.1.13" + ] [project.optional-dependencies] diff --git a/integrations/langchain/src/databricks_langchain/multi_server_mcp_client.py b/integrations/langchain/src/databricks_langchain/multi_server_mcp_client.py new file mode 100644 index 00000000..e8551b0f --- /dev/null +++ b/integrations/langchain/src/databricks_langchain/multi_server_mcp_client.py @@ -0,0 +1,175 @@ +from langchain_mcp_adapters.client import MultiServerMCPClient +from typing import List, Literal, Callable, Union +from databricks.sdk import WorkspaceClient +from pydantic import BaseModel, Field, ConfigDict, model_validator +from typing import Any +from databricks_mcp.oauth_provider import DatabricksOAuthClientProvider + + +class Server(BaseModel): + """ + Base configuration for an MCP server connection using streamable HTTP transport. + + Accepts any additional keyword arguments which are automatically passed through + to LangChain's Connection type, making this forward-compatible with future updates. + + Common optional parameters: + - headers: dict[str, str] - Custom HTTP headers + - timeout: float - Request timeout in seconds + - sse_read_timeout: float - SSE read timeout in seconds + - auth: httpx.Auth - Authentication handler + - httpx_client_factory: Callable - Custom httpx client factory + - terminate_on_close: bool - Terminate connection on close + - session_kwargs: dict - Additional session kwargs + """ + + model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow") + + name: str = Field(..., exclude=True, description="Name to identify this server connection") + url: str + handle_tool_error: Union[bool, str, Callable[[Exception], str], None] = Field( + default=None, + exclude=True, + description=( + "How to handle errors raised by tools from this server. Options:\n" + "- None/False: Raise the error\n" + "- True: Return error message as string\n" + "- str: Return this string when errors occur\n" + "- Callable: Function that takes error and returns error message string" + ) + ) + + def to_connection_dict(self) -> dict[str, Any]: + """ + Convert to connection dictionary for LangChain MultiServerMCPClient. + + Automatically includes all extra fields passed to the constructor, + allowing forward compatibility with new LangChain connection fields. + """ + # Get all model fields including extra fields (name is auto-excluded) + data = self.model_dump() + + # Add transport type (hardcoded to streamable_http) + data["transport"] = "streamable_http" + + return data + + +class DatabricksServer(Server): + """ + MCP server configuration with Databricks authentication. + + Automatically sets up OAuth authentication using the provided WorkspaceClient. + Also accepts any additional connection parameters as keyword arguments. + """ + + workspace_client: WorkspaceClient | None = Field( + default=None, + description="Databricks WorkspaceClient for authentication. If None, will be auto-initialized.", + exclude=True + ) + + @model_validator(mode="after") + def setup_auth(self) -> "DatabricksServer": + """Set up Databricks OAuth authentication.""" + if self.workspace_client is None: + self.workspace_client = WorkspaceClient() + + # Set up Databricks OAuth authentication and store as a regular attribute + # This will be picked up by model_dump() since we have extra="allow" + object.__setattr__(self, "auth", DatabricksOAuthClientProvider(self.workspace_client)) + + return self + + +class DatabricksMultiServerMCPClient(MultiServerMCPClient): + """ + MultiServerMCPClient with simplified configuration for Databricks servers. + + This wrapper provides an ergonomic interface similar to LangChain's API while + remaining forward-compatible with future connection parameters. + + Example: + ```python + from databricks.sdk import WorkspaceClient + from databricks_langchain import DatabricksMultiServerMCPClient, DatabricksServer, Server + + client = DatabricksMultiServerMCPClient([ + # Databricks server with automatic OAuth - just pass params as kwargs! + DatabricksServer( + name="databricks-prod", + url="https://your-workspace.databricks.com/mcp", + workspace_client=WorkspaceClient(), + timeout=30.0, + sse_read_timeout=60.0, + handle_tool_error=True, # Return errors as strings instead of raising + ), + # Generic server with custom params - same flat API + Server( + name="other-server", + url="https://other-server.com/mcp", + headers={"X-API-Key": "secret"}, + timeout=15.0, + handle_tool_error="An error occurred. Please try again.", + ) + ]) + + tools = await client.get_tools() + ``` + """ + + def __init__( + self, + servers: List[Server], + **kwargs + ): + """ + Initialize the client with a list of server configurations. + + Args: + servers: List of Server or DatabricksServer configurations + **kwargs: Additional arguments to pass to MultiServerMCPClient + """ + # Store server configs for later use (e.g., handle_tool_errors) + self._server_configs = {server.name: server for server in servers} + + # Create connections dict (excluding tool-level params like handle_tool_errors) + connections = { + server.name: server.to_connection_dict() + for server in servers + } + super().__init__(connections=connections, **kwargs) + + async def get_tools(self, server_name: str | None = None): + """ + Get tools from MCP servers, applying server-level configurations. + + Args: + server_name: Optional server name to get tools from. If None, gets tools from all servers. + + Returns: + List of LangChain tools with server-level configurations applied. + """ + import asyncio + + # Determine which servers to load from + server_names = [server_name] if server_name is not None else list(self.connections.keys()) + + # Load tools from servers in parallel + load_tool_tasks = [ + asyncio.create_task(super().get_tools(server_name=name)) + for name in server_names + ] + tools_list = await asyncio.gather(*load_tool_tasks) + + # Apply server-level configurations and collect tools + all_tools = [] + for name, tools in zip(server_names, tools_list, strict=True): + if name in self._server_configs: + server_config = self._server_configs[name] + if server_config.handle_tool_error is not None: + for tool in tools: + tool.handle_tool_error = server_config.handle_tool_error + all_tools.extend(tools) + + return all_tools \ No newline at end of file diff --git a/integrations/langchain/tests/unit_tests/test_multi_server_mcp_client.py b/integrations/langchain/tests/unit_tests/test_multi_server_mcp_client.py new file mode 100644 index 00000000..7d4376f5 --- /dev/null +++ b/integrations/langchain/tests/unit_tests/test_multi_server_mcp_client.py @@ -0,0 +1,527 @@ +"""Unit tests for DatabricksMultiServerMCPClient and related classes.""" + +import asyncio +from typing import Any +from unittest.mock import AsyncMock, MagicMock, create_autospec, patch + +import pytest +from databricks.sdk import WorkspaceClient +from databricks_langchain.multi_server_mcp_client import ( + DatabricksMultiServerMCPClient, + DatabricksServer, + Server, +) + + +class TestServer: + """Tests for the Server class.""" + + def test_basic_server_creation(self): + """Test creating a basic server with minimal parameters.""" + server = Server(name="test-server", url="https://example.com/mcp") + + assert server.name == "test-server" + assert server.url == "https://example.com/mcp" + assert server.handle_tool_error is None + + @pytest.mark.parametrize( + "extra_params", + [ + {"timeout": 30.0}, + {"headers": {"X-API-Key": "secret"}}, + {"sse_read_timeout": 60.0}, + {"timeout": 15.0, "headers": {"Authorization": "Bearer token"}}, + {"session_kwargs": {"some_param": "value"}}, + ], + ) + def test_server_accepts_extra_params(self, extra_params: dict[str, Any]): + """Test that Server accepts and preserves extra parameters.""" + server = Server( + name="test-server", + url="https://example.com/mcp", + **extra_params + ) + + connection_dict = server.to_connection_dict() + + # Check that extra params are in connection dict + for key, value in extra_params.items(): + assert connection_dict[key] == value + + def test_server_to_connection_dict_excludes_name(self): + """Test that name is excluded from connection dict.""" + server = Server(name="test-server", url="https://example.com/mcp") + connection_dict = server.to_connection_dict() + + assert "name" not in connection_dict + assert "url" in connection_dict + + def test_server_to_connection_dict_excludes_handle_tool_error(self): + """Test that handle_tool_error is excluded from connection dict.""" + server = Server( + name="test-server", + url="https://example.com/mcp", + handle_tool_error=True + ) + connection_dict = server.to_connection_dict() + + assert "handle_tool_error" not in connection_dict + assert "url" in connection_dict + + def test_server_to_connection_dict_adds_transport(self): + """Test that transport is added to connection dict.""" + server = Server(name="test-server", url="https://example.com/mcp") + connection_dict = server.to_connection_dict() + + assert connection_dict["transport"] == "streamable_http" + + def test_server_connection_dict_has_required_fields(self): + """Test that connection dict has required fields for streamable_http.""" + server = Server( + name="test-server", + url="https://example.com/mcp", + timeout=30.0, + headers={"X-Custom": "value"} + ) + connection_dict = server.to_connection_dict() + + # Required fields for streamable_http connection + assert "url" in connection_dict + assert "transport" in connection_dict + assert connection_dict["transport"] == "streamable_http" + + # Extra fields should be present + assert connection_dict["timeout"] == 30.0 + assert connection_dict["headers"] == {"X-Custom": "value"} + + @pytest.mark.parametrize( + "handle_tool_error_value", + [ + True, + False, + "Custom error message", + lambda e: f"Error: {e}", + None, + ], + ) + def test_server_handle_tool_error_types(self, handle_tool_error_value: Any): + """Test that handle_tool_error accepts various types.""" + server = Server( + name="test-server", + url="https://example.com/mcp", + handle_tool_error=handle_tool_error_value + ) + + assert server.handle_tool_error == handle_tool_error_value + + +class TestDatabricksServer: + """Tests for the DatabricksServer class.""" + + def test_databricks_server_without_workspace_client(self): + """Test DatabricksServer creates WorkspaceClient automatically.""" + with patch("databricks_langchain.multi_server_mcp_client.WorkspaceClient") as mock_ws, \ + patch("databricks_langchain.multi_server_mcp_client.DatabricksOAuthClientProvider") as mock_auth: + + mock_ws_instance = MagicMock() + mock_ws.return_value = mock_ws_instance + mock_auth_instance = MagicMock() + mock_auth.return_value = mock_auth_instance + + server = DatabricksServer( + name="databricks", + url="https://databricks.com/mcp" + ) + + # Should have created WorkspaceClient + mock_ws.assert_called_once() + # Should have created auth provider + mock_auth.assert_called_once_with(mock_ws_instance) + + def test_databricks_server_with_workspace_client(self): + """Test DatabricksServer uses provided WorkspaceClient.""" + mock_workspace_client = create_autospec(WorkspaceClient, instance=True) + + with patch("databricks_langchain.multi_server_mcp_client.DatabricksOAuthClientProvider") as mock_auth: + mock_auth_instance = MagicMock() + mock_auth.return_value = mock_auth_instance + + server = DatabricksServer( + name="databricks", + url="https://databricks.com/mcp", + workspace_client=mock_workspace_client + ) + + # Should have used provided client + mock_auth.assert_called_once_with(mock_workspace_client) + assert server.workspace_client is mock_workspace_client + + def test_databricks_server_excludes_workspace_client_from_connection(self): + """Test that workspace_client is excluded from connection dict.""" + mock_workspace_client = create_autospec(WorkspaceClient, instance=True) + + with patch("databricks_langchain.multi_server_mcp_client.DatabricksOAuthClientProvider") as mock_auth: + mock_auth_instance = MagicMock() + mock_auth.return_value = mock_auth_instance + + server = DatabricksServer( + name="databricks", + url="https://databricks.com/mcp", + workspace_client=mock_workspace_client + ) + + connection_dict = server.to_connection_dict() + + assert "workspace_client" not in connection_dict + assert "auth" in connection_dict + + def test_databricks_server_includes_auth_in_connection(self): + """Test that auth is included in connection dict.""" + mock_workspace_client = create_autospec(WorkspaceClient, instance=True) + + with patch("databricks_langchain.multi_server_mcp_client.DatabricksOAuthClientProvider") as mock_auth: + mock_auth_instance = MagicMock() + mock_auth.return_value = mock_auth_instance + + server = DatabricksServer( + name="databricks", + url="https://databricks.com/mcp", + workspace_client=mock_workspace_client + ) + + connection_dict = server.to_connection_dict() + + assert connection_dict["auth"] is mock_auth_instance + + def test_databricks_server_accepts_extra_params(self): + """Test that DatabricksServer accepts extra connection params.""" + mock_workspace_client = create_autospec(WorkspaceClient, instance=True) + + with patch("databricks_langchain.multi_server_mcp_client.DatabricksOAuthClientProvider") as mock_auth: + mock_auth_instance = MagicMock() + mock_auth.return_value = mock_auth_instance + + server = DatabricksServer( + name="databricks", + url="https://databricks.com/mcp", + workspace_client=mock_workspace_client, + timeout=45.0, + headers={"X-Custom": "header"} + ) + + connection_dict = server.to_connection_dict() + + assert connection_dict["timeout"] == 45.0 + assert connection_dict["headers"] == {"X-Custom": "header"} + + +class TestDatabricksMultiServerMCPClient: + """Tests for the DatabricksMultiServerMCPClient class.""" + + def test_client_initialization_with_single_server(self): + """Test client initialization with a single server.""" + with patch("databricks_langchain.multi_server_mcp_client.MultiServerMCPClient.__init__") as mock_init: + mock_init.return_value = None + + server = Server(name="test", url="https://example.com/mcp") + client = DatabricksMultiServerMCPClient([server]) + + # Check that parent __init__ was called + mock_init.assert_called_once() + + # Check connections dict structure + call_kwargs = mock_init.call_args[1] + assert "connections" in call_kwargs + connections = call_kwargs["connections"] + + assert "test" in connections + assert connections["test"]["url"] == "https://example.com/mcp" + assert connections["test"]["transport"] == "streamable_http" + + def test_client_initialization_with_multiple_servers(self): + """Test client initialization with multiple servers.""" + with patch("databricks_langchain.multi_server_mcp_client.MultiServerMCPClient.__init__") as mock_init: + mock_init.return_value = None + + servers = [ + Server(name="server1", url="https://server1.com/mcp"), + Server(name="server2", url="https://server2.com/mcp"), + ] + client = DatabricksMultiServerMCPClient(servers) + + # Check that parent __init__ was called + mock_init.assert_called_once() + + # Check connections dict structure + call_kwargs = mock_init.call_args[1] + connections = call_kwargs["connections"] + + assert len(connections) == 2 + assert "server1" in connections + assert "server2" in connections + + def test_client_stores_server_configs(self): + """Test that client stores server configs for later use.""" + with patch("databricks_langchain.multi_server_mcp_client.MultiServerMCPClient.__init__") as mock_init: + mock_init.return_value = None + + server = Server( + name="test", + url="https://example.com/mcp", + handle_tool_error=True + ) + client = DatabricksMultiServerMCPClient([server]) + + # Check that server configs are stored + assert hasattr(client, "_server_configs") + assert "test" in client._server_configs + assert client._server_configs["test"].handle_tool_error is True + + @pytest.mark.asyncio + async def test_get_tools_single_server(self): + """Test get_tools with a specific server name.""" + server = Server( + name="test", + url="https://example.com/mcp", + handle_tool_error="Error occurred" + ) + + with patch("databricks_langchain.multi_server_mcp_client.MultiServerMCPClient.__init__") as mock_init, \ + patch("databricks_langchain.multi_server_mcp_client.MultiServerMCPClient.get_tools", new_callable=AsyncMock) as mock_parent_get_tools: + + mock_init.return_value = None + + # Create mock tools + mock_tool1 = MagicMock() + mock_tool2 = MagicMock() + mock_tools = [mock_tool1, mock_tool2] + mock_parent_get_tools.return_value = mock_tools + + client = DatabricksMultiServerMCPClient([server]) + client.connections = {"test": server.to_connection_dict()} + + tools = await client.get_tools(server_name="test") + + # Should call parent get_tools with server_name + mock_parent_get_tools.assert_called_once_with(server_name="test") + + # Should apply handle_tool_error to all tools + assert mock_tool1.handle_tool_error == "Error occurred" + assert mock_tool2.handle_tool_error == "Error occurred" + assert tools == mock_tools + + @pytest.mark.asyncio + async def test_get_tools_all_servers(self): + """Test get_tools without server_name (all servers).""" + servers = [ + Server(name="server1", url="https://server1.com/mcp", handle_tool_error=True), + Server(name="server2", url="https://server2.com/mcp", handle_tool_error="Custom error"), + ] + + # Create mock tools for each server + mock_tool1 = MagicMock() + mock_tool2 = MagicMock() + mock_tool3 = MagicMock() + + # Mock parent get_tools to return different tools for different servers + async def mock_get_tools_side_effect(server_name=None): + if server_name == "server1": + return [mock_tool1, mock_tool2] + elif server_name == "server2": + return [mock_tool3] + return [] + + with patch("databricks_langchain.multi_server_mcp_client.MultiServerMCPClient.__init__") as mock_init, \ + patch("databricks_langchain.multi_server_mcp_client.MultiServerMCPClient.get_tools", new_callable=AsyncMock, side_effect=mock_get_tools_side_effect) as mock_parent_get_tools: + + mock_init.return_value = None + + client = DatabricksMultiServerMCPClient(servers) + client.connections = { + "server1": servers[0].to_connection_dict(), + "server2": servers[1].to_connection_dict(), + } + + tools = await client.get_tools() + + # Should call parent get_tools for each server + assert mock_parent_get_tools.call_count == 2 + + # Should apply handle_tool_error from respective servers + assert mock_tool1.handle_tool_error is True + assert mock_tool2.handle_tool_error is True + assert mock_tool3.handle_tool_error == "Custom error" + + # Should return all tools + assert len(tools) == 3 + assert mock_tool1 in tools + assert mock_tool2 in tools + assert mock_tool3 in tools + + @pytest.mark.asyncio + async def test_get_tools_no_handle_tool_error(self): + """Test get_tools when handle_tool_error is None.""" + server = Server(name="test", url="https://example.com/mcp") + + # Create mock tool + mock_tool = MagicMock() + mock_tool.handle_tool_error = "original_value" + + with patch("databricks_langchain.multi_server_mcp_client.MultiServerMCPClient.__init__") as mock_init, \ + patch("databricks_langchain.multi_server_mcp_client.MultiServerMCPClient.get_tools", new_callable=AsyncMock) as mock_parent_get_tools: + + mock_init.return_value = None + mock_parent_get_tools.return_value = [mock_tool] + + client = DatabricksMultiServerMCPClient([server]) + client.connections = {"test": server.to_connection_dict()} + + tools = await client.get_tools(server_name="test") + + # Should NOT modify handle_tool_error when it's None in config + assert mock_tool.handle_tool_error == "original_value" + + @pytest.mark.asyncio + async def test_get_tools_parallel_execution(self): + """Test that get_tools executes server requests in parallel.""" + servers = [ + Server(name=f"server{i}", url=f"https://server{i}.com/mcp") + for i in range(5) + ] + + call_count = 0 + call_times = [] + + async def mock_get_tools_with_delay(server_name=None): + nonlocal call_count + call_count += 1 + call_times.append(asyncio.get_event_loop().time()) + await asyncio.sleep(0.1) # Simulate async work + return [MagicMock()] + + with patch("databricks_langchain.multi_server_mcp_client.MultiServerMCPClient.__init__") as mock_init, \ + patch("databricks_langchain.multi_server_mcp_client.MultiServerMCPClient.get_tools", new_callable=AsyncMock, side_effect=mock_get_tools_with_delay) as mock_parent_get_tools: + + mock_init.return_value = None + + client = DatabricksMultiServerMCPClient(servers) + client.connections = { + server.name: server.to_connection_dict() + for server in servers + } + + start_time = asyncio.get_event_loop().time() + tools = await client.get_tools() + end_time = asyncio.get_event_loop().time() + + # All 5 servers should be called + assert call_count == 5 + + # All calls should start around the same time (parallel) + # If sequential, would take 0.5s+. Parallel should be ~0.1s + elapsed = end_time - start_time + assert elapsed < 0.3 # Much less than 5 * 0.1s + + # Should return tools from all servers + assert len(tools) == 5 + + @pytest.mark.asyncio + async def test_get_tools_with_databricks_server(self): + """Test get_tools with DatabricksServer.""" + mock_workspace_client = create_autospec(WorkspaceClient, instance=True) + mock_tool = MagicMock() + + with patch("databricks_langchain.multi_server_mcp_client.MultiServerMCPClient.__init__") as mock_init, \ + patch("databricks_langchain.multi_server_mcp_client.DatabricksOAuthClientProvider") as mock_auth, \ + patch("databricks_langchain.multi_server_mcp_client.MultiServerMCPClient.get_tools", new_callable=AsyncMock) as mock_parent_get_tools: + + mock_init.return_value = None + mock_auth_instance = MagicMock() + mock_auth.return_value = mock_auth_instance + mock_parent_get_tools.return_value = [mock_tool] + + server = DatabricksServer( + name="databricks", + url="https://databricks.com/mcp", + workspace_client=mock_workspace_client, + handle_tool_error=True + ) + client = DatabricksMultiServerMCPClient([server]) + client.connections = {"databricks": server.to_connection_dict()} + + tools = await client.get_tools(server_name="databricks") + + # Should apply handle_tool_error + assert mock_tool.handle_tool_error is True + + # Connection should have auth + assert "auth" in client.connections["databricks"] + + +class TestConnectionDictCompatibility: + """Tests to ensure connection dict compatibility with LangChain.""" + + def test_connection_dict_structure_is_flexible(self): + """Test that connection dict allows extra fields (forward compatible).""" + # This test ensures we won't break if LangChain adds new fields + server = Server( + name="test", + url="https://example.com/mcp", + future_field_1="value1", + future_field_2=123, + nested_config={"key": "value"} + ) + + connection_dict = server.to_connection_dict() + + # Should include extra fields + assert connection_dict["future_field_1"] == "value1" + assert connection_dict["future_field_2"] == 123 + assert connection_dict["nested_config"] == {"key": "value"} + + def test_connection_dict_has_transport_field(self): + """Test that transport field is always present.""" + server = Server(name="test", url="https://example.com/mcp") + connection_dict = server.to_connection_dict() + + assert "transport" in connection_dict + assert isinstance(connection_dict["transport"], str) + + def test_connection_dict_has_url_field(self): + """Test that url field is always present.""" + server = Server(name="test", url="https://example.com/mcp") + connection_dict = server.to_connection_dict() + + assert "url" in connection_dict + assert isinstance(connection_dict["url"], str) + assert connection_dict["url"].startswith("http") + + @pytest.mark.parametrize( + "field_name", + ["name", "handle_tool_error", "workspace_client"], + ) + def test_connection_dict_excludes_internal_fields(self, field_name: str): + """Test that internal fields are excluded from connection dict.""" + # Create servers with fields that should be excluded + if field_name == "workspace_client": + with patch("databricks_langchain.multi_server_mcp_client.WorkspaceClient") as mock_ws, \ + patch("databricks_langchain.multi_server_mcp_client.DatabricksOAuthClientProvider") as mock_auth: + mock_ws.return_value = MagicMock() + mock_auth.return_value = MagicMock() + + server = DatabricksServer( + name="test", + url="https://example.com/mcp" + ) + else: + server = Server( + name="test", + url="https://example.com/mcp", + handle_tool_error=True + ) + + connection_dict = server.to_connection_dict() + + # Internal fields should not be in connection dict + assert field_name not in connection_dict + From b003bbe1831c49949afd100181c43c7bf2c2ef1b Mon Sep 17 00:00:00 2001 From: aravind-segu Date: Fri, 21 Nov 2025 11:32:25 -0500 Subject: [PATCH 2/8] Update tests --- .../multi_server_mcp_client.py | 196 +++++---- .../test_multi_server_mcp_client.py | 412 +++++------------- 2 files changed, 221 insertions(+), 387 deletions(-) diff --git a/integrations/langchain/src/databricks_langchain/multi_server_mcp_client.py b/integrations/langchain/src/databricks_langchain/multi_server_mcp_client.py index e8551b0f..efcd6544 100644 --- a/integrations/langchain/src/databricks_langchain/multi_server_mcp_client.py +++ b/integrations/langchain/src/databricks_langchain/multi_server_mcp_client.py @@ -1,18 +1,18 @@ -from langchain_mcp_adapters.client import MultiServerMCPClient -from typing import List, Literal, Callable, Union +from typing import Any, Callable, List, Union + from databricks.sdk import WorkspaceClient -from pydantic import BaseModel, Field, ConfigDict, model_validator -from typing import Any from databricks_mcp.oauth_provider import DatabricksOAuthClientProvider +from langchain_mcp_adapters.client import MultiServerMCPClient +from pydantic import BaseModel, ConfigDict, Field class Server(BaseModel): """ Base configuration for an MCP server connection using streamable HTTP transport. - + Accepts any additional keyword arguments which are automatically passed through to LangChain's Connection type, making this forward-compatible with future updates. - + Common optional parameters: - headers: dict[str, str] - Custom HTTP headers - timeout: float - Request timeout in seconds @@ -21,10 +21,27 @@ class Server(BaseModel): - httpx_client_factory: Callable - Custom httpx client factory - terminate_on_close: bool - Terminate connection on close - session_kwargs: dict - Additional session kwargs + + Example: + ```python + from databricks_langchain import DatabricksMultiServerMCPClient, Server + + # Generic server with custom params - flat API for easy configuration + server = Server( + name="other-server", + url="https://other-server.com/mcp", + headers={"X-API-Key": "secret"}, + timeout=15.0, + handle_tool_error="An error occurred. Please try again.", + ) + + client = DatabricksMultiServerMCPClient([server]) + tools = await client.get_tools() + ``` """ - + model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow") - + name: str = Field(..., exclude=True, description="Name to identify this server connection") url: str handle_tool_error: Union[bool, str, Callable[[Exception], str], None] = Field( @@ -36,133 +53,164 @@ class Server(BaseModel): "- True: Return error message as string\n" "- str: Return this string when errors occur\n" "- Callable: Function that takes error and returns error message string" - ) + ), ) - + def to_connection_dict(self) -> dict[str, Any]: """ Convert to connection dictionary for LangChain MultiServerMCPClient. - + Automatically includes all extra fields passed to the constructor, allowing forward compatibility with new LangChain connection fields. """ # Get all model fields including extra fields (name is auto-excluded) data = self.model_dump() - + # Add transport type (hardcoded to streamable_http) data["transport"] = "streamable_http" - + return data class DatabricksServer(Server): """ MCP server configuration with Databricks authentication. - + Automatically sets up OAuth authentication using the provided WorkspaceClient. Also accepts any additional connection parameters as keyword arguments. + + Example: + ```python + from databricks.sdk import WorkspaceClient + from databricks_langchain import DatabricksMultiServerMCPClient, DatabricksServer + + # Databricks server with automatic OAuth - just pass params as kwargs! + server = DatabricksServer( + name="databricks-prod", + url="https://your-workspace.databricks.com/mcp", + workspace_client=WorkspaceClient(), + timeout=30.0, + sse_read_timeout=60.0, + handle_tool_error=True, # Return errors as strings instead of raising + ) + + client = DatabricksMultiServerMCPClient([server]) + tools = await client.get_tools() + ``` """ - + workspace_client: WorkspaceClient | None = Field( default=None, description="Databricks WorkspaceClient for authentication. If None, will be auto-initialized.", - exclude=True + exclude=True, ) - - @model_validator(mode="after") - def setup_auth(self) -> "DatabricksServer": - """Set up Databricks OAuth authentication.""" + + def __init__(self, **data): + """Initialize DatabricksServer with auth setup.""" + super().__init__(**data) + + # Set up Databricks OAuth authentication after initialization if self.workspace_client is None: self.workspace_client = WorkspaceClient() - - # Set up Databricks OAuth authentication and store as a regular attribute - # This will be picked up by model_dump() since we have extra="allow" - object.__setattr__(self, "auth", DatabricksOAuthClientProvider(self.workspace_client)) - - return self + + # Store the auth provider internally + self._auth_provider = DatabricksOAuthClientProvider(self.workspace_client) + + def to_connection_dict(self) -> dict[str, Any]: + """ + Convert to connection dictionary, including Databricks auth. + """ + # Get base connection dict + data = super().to_connection_dict() + + # Add Databricks auth provider + data["auth"] = self._auth_provider + + return data class DatabricksMultiServerMCPClient(MultiServerMCPClient): """ MultiServerMCPClient with simplified configuration for Databricks servers. - + This wrapper provides an ergonomic interface similar to LangChain's API while remaining forward-compatible with future connection parameters. - + Example: ```python from databricks.sdk import WorkspaceClient - from databricks_langchain import DatabricksMultiServerMCPClient, DatabricksServer, Server - - client = DatabricksMultiServerMCPClient([ - # Databricks server with automatic OAuth - just pass params as kwargs! - DatabricksServer( - name="databricks-prod", - url="https://your-workspace.databricks.com/mcp", - workspace_client=WorkspaceClient(), - timeout=30.0, - sse_read_timeout=60.0, - handle_tool_error=True, # Return errors as strings instead of raising - ), - # Generic server with custom params - same flat API - Server( - name="other-server", - url="https://other-server.com/mcp", - headers={"X-API-Key": "secret"}, - timeout=15.0, - handle_tool_error="An error occurred. Please try again.", - ) - ]) - + from databricks_langchain import ( + DatabricksMultiServerMCPClient, + DatabricksServer, + Server, + ) + + client = DatabricksMultiServerMCPClient( + [ + # Databricks server with automatic OAuth - just pass params as kwargs! + DatabricksServer( + name="databricks-prod", + url="https://your-workspace.databricks.com/mcp", + workspace_client=WorkspaceClient(), + timeout=30.0, + sse_read_timeout=60.0, + handle_tool_error=True, # Return errors as strings instead of raising + ), + # Generic server with custom params - same flat API + Server( + name="other-server", + url="https://other-server.com/mcp", + headers={"X-API-Key": "secret"}, + timeout=15.0, + handle_tool_error="An error occurred. Please try again.", + ), + ] + ) + tools = await client.get_tools() ``` """ - - def __init__( - self, - servers: List[Server], - **kwargs - ): + + def __init__(self, servers: List[Server], **kwargs): """ Initialize the client with a list of server configurations. - + Args: servers: List of Server or DatabricksServer configurations **kwargs: Additional arguments to pass to MultiServerMCPClient """ # Store server configs for later use (e.g., handle_tool_errors) self._server_configs = {server.name: server for server in servers} - + # Create connections dict (excluding tool-level params like handle_tool_errors) - connections = { - server.name: server.to_connection_dict() - for server in servers - } + connections = {server.name: server.to_connection_dict() for server in servers} super().__init__(connections=connections, **kwargs) - + async def get_tools(self, server_name: str | None = None): """ - Get tools from MCP servers, applying server-level configurations. - + Get tools from MCP servers, applying handle_tool_error configuration. + Args: server_name: Optional server name to get tools from. If None, gets tools from all servers. - + Returns: - List of LangChain tools with server-level configurations applied. + List of LangChain tools with handle_tool_error configurations applied. """ import asyncio - + # Determine which servers to load from server_names = [server_name] if server_name is not None else list(self.connections.keys()) - + # Load tools from servers in parallel load_tool_tasks = [ - asyncio.create_task(super().get_tools(server_name=name)) + asyncio.create_task( + super(DatabricksMultiServerMCPClient, self).get_tools(server_name=name) + ) for name in server_names ] tools_list = await asyncio.gather(*load_tool_tasks) - - # Apply server-level configurations and collect tools + + # Apply handle_tool_error configurations and collect tools all_tools = [] for name, tools in zip(server_names, tools_list, strict=True): if name in self._server_configs: @@ -171,5 +219,5 @@ async def get_tools(self, server_name: str | None = None): for tool in tools: tool.handle_tool_error = server_config.handle_tool_error all_tools.extend(tools) - - return all_tools \ No newline at end of file + + return all_tools diff --git a/integrations/langchain/tests/unit_tests/test_multi_server_mcp_client.py b/integrations/langchain/tests/unit_tests/test_multi_server_mcp_client.py index 7d4376f5..f6a3fc31 100644 --- a/integrations/langchain/tests/unit_tests/test_multi_server_mcp_client.py +++ b/integrations/langchain/tests/unit_tests/test_multi_server_mcp_client.py @@ -6,6 +6,7 @@ import pytest from databricks.sdk import WorkspaceClient + from databricks_langchain.multi_server_mcp_client import ( DatabricksMultiServerMCPClient, DatabricksServer, @@ -19,7 +20,7 @@ class TestServer: def test_basic_server_creation(self): """Test creating a basic server with minimal parameters.""" server = Server(name="test-server", url="https://example.com/mcp") - + assert server.name == "test-server" assert server.url == "https://example.com/mcp" assert server.handle_tool_error is None @@ -39,60 +40,17 @@ def test_server_accepts_extra_params(self, extra_params: dict[str, Any]): server = Server( name="test-server", url="https://example.com/mcp", - **extra_params + handle_tool_error=True, + **extra_params, ) - + connection_dict = server.to_connection_dict() - + # Check that extra params are in connection dict for key, value in extra_params.items(): assert connection_dict[key] == value - - def test_server_to_connection_dict_excludes_name(self): - """Test that name is excluded from connection dict.""" - server = Server(name="test-server", url="https://example.com/mcp") - connection_dict = server.to_connection_dict() - - assert "name" not in connection_dict - assert "url" in connection_dict - - def test_server_to_connection_dict_excludes_handle_tool_error(self): - """Test that handle_tool_error is excluded from connection dict.""" - server = Server( - name="test-server", - url="https://example.com/mcp", - handle_tool_error=True - ) - connection_dict = server.to_connection_dict() - - assert "handle_tool_error" not in connection_dict - assert "url" in connection_dict - - def test_server_to_connection_dict_adds_transport(self): - """Test that transport is added to connection dict.""" - server = Server(name="test-server", url="https://example.com/mcp") - connection_dict = server.to_connection_dict() - - assert connection_dict["transport"] == "streamable_http" - - def test_server_connection_dict_has_required_fields(self): - """Test that connection dict has required fields for streamable_http.""" - server = Server( - name="test-server", - url="https://example.com/mcp", - timeout=30.0, - headers={"X-Custom": "value"} - ) - connection_dict = server.to_connection_dict() - - # Required fields for streamable_http connection - assert "url" in connection_dict - assert "transport" in connection_dict - assert connection_dict["transport"] == "streamable_http" - - # Extra fields should be present - assert connection_dict["timeout"] == 30.0 - assert connection_dict["headers"] == {"X-Custom": "value"} + assert "name" not in connection_dict + assert "handle_tool_error" not in connection_dict @pytest.mark.parametrize( "handle_tool_error_value", @@ -109,9 +67,9 @@ def test_server_handle_tool_error_types(self, handle_tool_error_value: Any): server = Server( name="test-server", url="https://example.com/mcp", - handle_tool_error=handle_tool_error_value + handle_tool_error=handle_tool_error_value, ) - + assert server.handle_tool_error == handle_tool_error_value @@ -120,19 +78,19 @@ class TestDatabricksServer: def test_databricks_server_without_workspace_client(self): """Test DatabricksServer creates WorkspaceClient automatically.""" - with patch("databricks_langchain.multi_server_mcp_client.WorkspaceClient") as mock_ws, \ - patch("databricks_langchain.multi_server_mcp_client.DatabricksOAuthClientProvider") as mock_auth: - + with ( + patch("databricks_langchain.multi_server_mcp_client.WorkspaceClient") as mock_ws, + patch( + "databricks_langchain.multi_server_mcp_client.DatabricksOAuthClientProvider" + ) as mock_auth, + ): mock_ws_instance = MagicMock() mock_ws.return_value = mock_ws_instance mock_auth_instance = MagicMock() mock_auth.return_value = mock_auth_instance - - server = DatabricksServer( - name="databricks", - url="https://databricks.com/mcp" - ) - + + server = DatabricksServer(name="databricks", url="https://databricks.com/mcp") + # Should have created WorkspaceClient mock_ws.assert_called_once() # Should have created auth provider @@ -141,76 +99,48 @@ def test_databricks_server_without_workspace_client(self): def test_databricks_server_with_workspace_client(self): """Test DatabricksServer uses provided WorkspaceClient.""" mock_workspace_client = create_autospec(WorkspaceClient, instance=True) - - with patch("databricks_langchain.multi_server_mcp_client.DatabricksOAuthClientProvider") as mock_auth: + + with patch( + "databricks_langchain.multi_server_mcp_client.DatabricksOAuthClientProvider" + ) as mock_auth: mock_auth_instance = MagicMock() mock_auth.return_value = mock_auth_instance - + server = DatabricksServer( name="databricks", url="https://databricks.com/mcp", - workspace_client=mock_workspace_client + workspace_client=mock_workspace_client, ) - + # Should have used provided client mock_auth.assert_called_once_with(mock_workspace_client) assert server.workspace_client is mock_workspace_client - def test_databricks_server_excludes_workspace_client_from_connection(self): - """Test that workspace_client is excluded from connection dict.""" - mock_workspace_client = create_autospec(WorkspaceClient, instance=True) - - with patch("databricks_langchain.multi_server_mcp_client.DatabricksOAuthClientProvider") as mock_auth: - mock_auth_instance = MagicMock() - mock_auth.return_value = mock_auth_instance - - server = DatabricksServer( - name="databricks", - url="https://databricks.com/mcp", - workspace_client=mock_workspace_client - ) - connection_dict = server.to_connection_dict() - assert "workspace_client" not in connection_dict assert "auth" in connection_dict - - def test_databricks_server_includes_auth_in_connection(self): - """Test that auth is included in connection dict.""" - mock_workspace_client = create_autospec(WorkspaceClient, instance=True) - - with patch("databricks_langchain.multi_server_mcp_client.DatabricksOAuthClientProvider") as mock_auth: - mock_auth_instance = MagicMock() - mock_auth.return_value = mock_auth_instance - - server = DatabricksServer( - name="databricks", - url="https://databricks.com/mcp", - workspace_client=mock_workspace_client - ) - - connection_dict = server.to_connection_dict() - assert connection_dict["auth"] is mock_auth_instance def test_databricks_server_accepts_extra_params(self): """Test that DatabricksServer accepts extra connection params.""" mock_workspace_client = create_autospec(WorkspaceClient, instance=True) - - with patch("databricks_langchain.multi_server_mcp_client.DatabricksOAuthClientProvider") as mock_auth: + + with patch( + "databricks_langchain.multi_server_mcp_client.DatabricksOAuthClientProvider" + ) as mock_auth: mock_auth_instance = MagicMock() mock_auth.return_value = mock_auth_instance - + server = DatabricksServer( name="databricks", url="https://databricks.com/mcp", workspace_client=mock_workspace_client, timeout=45.0, - headers={"X-Custom": "header"} + headers={"X-Custom": "header"}, ) - + connection_dict = server.to_connection_dict() - + assert connection_dict["timeout"] == 45.0 assert connection_dict["headers"] == {"X-Custom": "header"} @@ -218,97 +148,34 @@ def test_databricks_server_accepts_extra_params(self): class TestDatabricksMultiServerMCPClient: """Tests for the DatabricksMultiServerMCPClient class.""" - def test_client_initialization_with_single_server(self): - """Test client initialization with a single server.""" - with patch("databricks_langchain.multi_server_mcp_client.MultiServerMCPClient.__init__") as mock_init: - mock_init.return_value = None - - server = Server(name="test", url="https://example.com/mcp") - client = DatabricksMultiServerMCPClient([server]) - - # Check that parent __init__ was called - mock_init.assert_called_once() - - # Check connections dict structure - call_kwargs = mock_init.call_args[1] - assert "connections" in call_kwargs - connections = call_kwargs["connections"] - - assert "test" in connections - assert connections["test"]["url"] == "https://example.com/mcp" - assert connections["test"]["transport"] == "streamable_http" - def test_client_initialization_with_multiple_servers(self): """Test client initialization with multiple servers.""" - with patch("databricks_langchain.multi_server_mcp_client.MultiServerMCPClient.__init__") as mock_init: + with patch( + "databricks_langchain.multi_server_mcp_client.MultiServerMCPClient.__init__" + ) as mock_init: mock_init.return_value = None - + servers = [ Server(name="server1", url="https://server1.com/mcp"), Server(name="server2", url="https://server2.com/mcp"), ] client = DatabricksMultiServerMCPClient(servers) - + # Check that parent __init__ was called mock_init.assert_called_once() - + # Check connections dict structure call_kwargs = mock_init.call_args[1] connections = call_kwargs["connections"] - + assert len(connections) == 2 assert "server1" in connections assert "server2" in connections - def test_client_stores_server_configs(self): - """Test that client stores server configs for later use.""" - with patch("databricks_langchain.multi_server_mcp_client.MultiServerMCPClient.__init__") as mock_init: - mock_init.return_value = None - - server = Server( - name="test", - url="https://example.com/mcp", - handle_tool_error=True - ) - client = DatabricksMultiServerMCPClient([server]) - - # Check that server configs are stored assert hasattr(client, "_server_configs") - assert "test" in client._server_configs - assert client._server_configs["test"].handle_tool_error is True - - @pytest.mark.asyncio - async def test_get_tools_single_server(self): - """Test get_tools with a specific server name.""" - server = Server( - name="test", - url="https://example.com/mcp", - handle_tool_error="Error occurred" - ) - - with patch("databricks_langchain.multi_server_mcp_client.MultiServerMCPClient.__init__") as mock_init, \ - patch("databricks_langchain.multi_server_mcp_client.MultiServerMCPClient.get_tools", new_callable=AsyncMock) as mock_parent_get_tools: - - mock_init.return_value = None - - # Create mock tools - mock_tool1 = MagicMock() - mock_tool2 = MagicMock() - mock_tools = [mock_tool1, mock_tool2] - mock_parent_get_tools.return_value = mock_tools - - client = DatabricksMultiServerMCPClient([server]) - client.connections = {"test": server.to_connection_dict()} - - tools = await client.get_tools(server_name="test") - - # Should call parent get_tools with server_name - mock_parent_get_tools.assert_called_once_with(server_name="test") - - # Should apply handle_tool_error to all tools - assert mock_tool1.handle_tool_error == "Error occurred" - assert mock_tool2.handle_tool_error == "Error occurred" - assert tools == mock_tools + assert len(client._server_configs) == 2 + assert "server1" in client._server_configs + assert "server2" in client._server_configs @pytest.mark.asyncio async def test_get_tools_all_servers(self): @@ -317,12 +184,12 @@ async def test_get_tools_all_servers(self): Server(name="server1", url="https://server1.com/mcp", handle_tool_error=True), Server(name="server2", url="https://server2.com/mcp", handle_tool_error="Custom error"), ] - + # Create mock tools for each server mock_tool1 = MagicMock() mock_tool2 = MagicMock() mock_tool3 = MagicMock() - + # Mock parent get_tools to return different tools for different servers async def mock_get_tools_side_effect(server_name=None): if server_name == "server1": @@ -330,98 +197,78 @@ async def mock_get_tools_side_effect(server_name=None): elif server_name == "server2": return [mock_tool3] return [] - - with patch("databricks_langchain.multi_server_mcp_client.MultiServerMCPClient.__init__") as mock_init, \ - patch("databricks_langchain.multi_server_mcp_client.MultiServerMCPClient.get_tools", new_callable=AsyncMock, side_effect=mock_get_tools_side_effect) as mock_parent_get_tools: - + + with ( + patch( + "databricks_langchain.multi_server_mcp_client.MultiServerMCPClient.__init__" + ) as mock_init, + patch( + "databricks_langchain.multi_server_mcp_client.MultiServerMCPClient.get_tools", + new_callable=AsyncMock, + side_effect=mock_get_tools_side_effect, + ) as mock_parent_get_tools, + ): mock_init.return_value = None - + client = DatabricksMultiServerMCPClient(servers) client.connections = { "server1": servers[0].to_connection_dict(), "server2": servers[1].to_connection_dict(), } - + tools = await client.get_tools() - + # Should call parent get_tools for each server assert mock_parent_get_tools.call_count == 2 - + # Should apply handle_tool_error from respective servers assert mock_tool1.handle_tool_error is True assert mock_tool2.handle_tool_error is True assert mock_tool3.handle_tool_error == "Custom error" - + # Should return all tools assert len(tools) == 3 assert mock_tool1 in tools assert mock_tool2 in tools assert mock_tool3 in tools - @pytest.mark.asyncio - async def test_get_tools_no_handle_tool_error(self): - """Test get_tools when handle_tool_error is None.""" - server = Server(name="test", url="https://example.com/mcp") - - # Create mock tool - mock_tool = MagicMock() - mock_tool.handle_tool_error = "original_value" - - with patch("databricks_langchain.multi_server_mcp_client.MultiServerMCPClient.__init__") as mock_init, \ - patch("databricks_langchain.multi_server_mcp_client.MultiServerMCPClient.get_tools", new_callable=AsyncMock) as mock_parent_get_tools: - - mock_init.return_value = None - mock_parent_get_tools.return_value = [mock_tool] - - client = DatabricksMultiServerMCPClient([server]) - client.connections = {"test": server.to_connection_dict()} - - tools = await client.get_tools(server_name="test") - - # Should NOT modify handle_tool_error when it's None in config - assert mock_tool.handle_tool_error == "original_value" - @pytest.mark.asyncio async def test_get_tools_parallel_execution(self): """Test that get_tools executes server requests in parallel.""" - servers = [ - Server(name=f"server{i}", url=f"https://server{i}.com/mcp") - for i in range(5) - ] - + servers = [Server(name=f"server{i}", url=f"https://server{i}.com/mcp") for i in range(5)] + call_count = 0 call_times = [] - + async def mock_get_tools_with_delay(server_name=None): nonlocal call_count call_count += 1 call_times.append(asyncio.get_event_loop().time()) await asyncio.sleep(0.1) # Simulate async work return [MagicMock()] - - with patch("databricks_langchain.multi_server_mcp_client.MultiServerMCPClient.__init__") as mock_init, \ - patch("databricks_langchain.multi_server_mcp_client.MultiServerMCPClient.get_tools", new_callable=AsyncMock, side_effect=mock_get_tools_with_delay) as mock_parent_get_tools: - + + with ( + patch( + "databricks_langchain.multi_server_mcp_client.MultiServerMCPClient.__init__" + ) as mock_init, + patch( + "databricks_langchain.multi_server_mcp_client.MultiServerMCPClient.get_tools", + new_callable=AsyncMock, + side_effect=mock_get_tools_with_delay, + ) as mock_parent_get_tools, + ): mock_init.return_value = None - + client = DatabricksMultiServerMCPClient(servers) - client.connections = { - server.name: server.to_connection_dict() - for server in servers - } - + client.connections = {server.name: server.to_connection_dict() for server in servers} + start_time = asyncio.get_event_loop().time() tools = await client.get_tools() end_time = asyncio.get_event_loop().time() - + # All 5 servers should be called assert call_count == 5 - - # All calls should start around the same time (parallel) - # If sequential, would take 0.5s+. Parallel should be ~0.1s - elapsed = end_time - start_time - assert elapsed < 0.3 # Much less than 5 * 0.1s - + # Should return tools from all servers assert len(tools) == 5 @@ -430,98 +277,37 @@ async def test_get_tools_with_databricks_server(self): """Test get_tools with DatabricksServer.""" mock_workspace_client = create_autospec(WorkspaceClient, instance=True) mock_tool = MagicMock() - - with patch("databricks_langchain.multi_server_mcp_client.MultiServerMCPClient.__init__") as mock_init, \ - patch("databricks_langchain.multi_server_mcp_client.DatabricksOAuthClientProvider") as mock_auth, \ - patch("databricks_langchain.multi_server_mcp_client.MultiServerMCPClient.get_tools", new_callable=AsyncMock) as mock_parent_get_tools: - + + with ( + patch( + "databricks_langchain.multi_server_mcp_client.MultiServerMCPClient.__init__" + ) as mock_init, + patch( + "databricks_langchain.multi_server_mcp_client.DatabricksOAuthClientProvider" + ) as mock_auth, + patch( + "databricks_langchain.multi_server_mcp_client.MultiServerMCPClient.get_tools", + new_callable=AsyncMock, + ) as mock_parent_get_tools, + ): mock_init.return_value = None mock_auth_instance = MagicMock() mock_auth.return_value = mock_auth_instance mock_parent_get_tools.return_value = [mock_tool] - + server = DatabricksServer( name="databricks", url="https://databricks.com/mcp", workspace_client=mock_workspace_client, - handle_tool_error=True + handle_tool_error=True, ) client = DatabricksMultiServerMCPClient([server]) client.connections = {"databricks": server.to_connection_dict()} - + tools = await client.get_tools(server_name="databricks") - + # Should apply handle_tool_error assert mock_tool.handle_tool_error is True - + # Connection should have auth assert "auth" in client.connections["databricks"] - - -class TestConnectionDictCompatibility: - """Tests to ensure connection dict compatibility with LangChain.""" - - def test_connection_dict_structure_is_flexible(self): - """Test that connection dict allows extra fields (forward compatible).""" - # This test ensures we won't break if LangChain adds new fields - server = Server( - name="test", - url="https://example.com/mcp", - future_field_1="value1", - future_field_2=123, - nested_config={"key": "value"} - ) - - connection_dict = server.to_connection_dict() - - # Should include extra fields - assert connection_dict["future_field_1"] == "value1" - assert connection_dict["future_field_2"] == 123 - assert connection_dict["nested_config"] == {"key": "value"} - - def test_connection_dict_has_transport_field(self): - """Test that transport field is always present.""" - server = Server(name="test", url="https://example.com/mcp") - connection_dict = server.to_connection_dict() - - assert "transport" in connection_dict - assert isinstance(connection_dict["transport"], str) - - def test_connection_dict_has_url_field(self): - """Test that url field is always present.""" - server = Server(name="test", url="https://example.com/mcp") - connection_dict = server.to_connection_dict() - - assert "url" in connection_dict - assert isinstance(connection_dict["url"], str) - assert connection_dict["url"].startswith("http") - - @pytest.mark.parametrize( - "field_name", - ["name", "handle_tool_error", "workspace_client"], - ) - def test_connection_dict_excludes_internal_fields(self, field_name: str): - """Test that internal fields are excluded from connection dict.""" - # Create servers with fields that should be excluded - if field_name == "workspace_client": - with patch("databricks_langchain.multi_server_mcp_client.WorkspaceClient") as mock_ws, \ - patch("databricks_langchain.multi_server_mcp_client.DatabricksOAuthClientProvider") as mock_auth: - mock_ws.return_value = MagicMock() - mock_auth.return_value = MagicMock() - - server = DatabricksServer( - name="test", - url="https://example.com/mcp" - ) - else: - server = Server( - name="test", - url="https://example.com/mcp", - handle_tool_error=True - ) - - connection_dict = server.to_connection_dict() - - # Internal fields should not be in connection dict - assert field_name not in connection_dict - From 299a09fcd0139c02562c5fe67a3c42aad0eaa47a Mon Sep 17 00:00:00 2001 From: aravind-segu Date: Fri, 21 Nov 2025 11:34:44 -0500 Subject: [PATCH 3/8] Update init --- integrations/langchain/src/databricks_langchain/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/integrations/langchain/src/databricks_langchain/__init__.py b/integrations/langchain/src/databricks_langchain/__init__.py index 40135d97..923374e8 100644 --- a/integrations/langchain/src/databricks_langchain/__init__.py +++ b/integrations/langchain/src/databricks_langchain/__init__.py @@ -20,6 +20,7 @@ from databricks_langchain.chat_models import ChatDatabricks from databricks_langchain.embeddings import DatabricksEmbeddings from databricks_langchain.genie import GenieAgent +from databricks_langchain.multi_server_mcp_client import DatabricksMultiServerMCPClient from databricks_langchain.vector_search_retriever_tool import VectorSearchRetrieverTool from databricks_langchain.vectorstores import DatabricksVectorSearch @@ -34,4 +35,5 @@ "UnityCatalogTool", "DatabricksFunctionClient", "set_uc_function_client", + "DatabricksMultiServerMCPClient", ] From 3ea93af5975c815ccbc3d500b832a2ac965639a8 Mon Sep 17 00:00:00 2001 From: aravind-segu Date: Fri, 21 Nov 2025 11:44:02 -0500 Subject: [PATCH 4/8] Change Server to MCPServer --- .../src/databricks_langchain/__init__.py | 8 +++- .../multi_server_mcp_client.py | 24 +++++----- .../test_multi_server_mcp_client.py | 48 ++++++++++--------- 3 files changed, 44 insertions(+), 36 deletions(-) diff --git a/integrations/langchain/src/databricks_langchain/__init__.py b/integrations/langchain/src/databricks_langchain/__init__.py index 923374e8..3691ed7c 100644 --- a/integrations/langchain/src/databricks_langchain/__init__.py +++ b/integrations/langchain/src/databricks_langchain/__init__.py @@ -20,7 +20,11 @@ from databricks_langchain.chat_models import ChatDatabricks from databricks_langchain.embeddings import DatabricksEmbeddings from databricks_langchain.genie import GenieAgent -from databricks_langchain.multi_server_mcp_client import DatabricksMultiServerMCPClient +from databricks_langchain.multi_server_mcp_client import ( + DatabricksMCPServer, + DatabricksMultiServerMCPClient, + MCPServer, +) from databricks_langchain.vector_search_retriever_tool import VectorSearchRetrieverTool from databricks_langchain.vectorstores import DatabricksVectorSearch @@ -36,4 +40,6 @@ "DatabricksFunctionClient", "set_uc_function_client", "DatabricksMultiServerMCPClient", + "DatabricksMCPServer", + "MCPServer", ] diff --git a/integrations/langchain/src/databricks_langchain/multi_server_mcp_client.py b/integrations/langchain/src/databricks_langchain/multi_server_mcp_client.py index efcd6544..973fe329 100644 --- a/integrations/langchain/src/databricks_langchain/multi_server_mcp_client.py +++ b/integrations/langchain/src/databricks_langchain/multi_server_mcp_client.py @@ -6,7 +6,7 @@ from pydantic import BaseModel, ConfigDict, Field -class Server(BaseModel): +class MCPServer(BaseModel): """ Base configuration for an MCP server connection using streamable HTTP transport. @@ -24,10 +24,10 @@ class Server(BaseModel): Example: ```python - from databricks_langchain import DatabricksMultiServerMCPClient, Server + from databricks_langchain import DatabricksMultiServerMCPClient, MCPServer # Generic server with custom params - flat API for easy configuration - server = Server( + server = MCPServer( name="other-server", url="https://other-server.com/mcp", headers={"X-API-Key": "secret"}, @@ -72,7 +72,7 @@ def to_connection_dict(self) -> dict[str, Any]: return data -class DatabricksServer(Server): +class DatabricksMCPServer(MCPServer): """ MCP server configuration with Databricks authentication. @@ -82,10 +82,10 @@ class DatabricksServer(Server): Example: ```python from databricks.sdk import WorkspaceClient - from databricks_langchain import DatabricksMultiServerMCPClient, DatabricksServer + from databricks_langchain import DatabricksMultiServerMCPClient, DatabricksMCPServer # Databricks server with automatic OAuth - just pass params as kwargs! - server = DatabricksServer( + server = DatabricksMCPServer( name="databricks-prod", url="https://your-workspace.databricks.com/mcp", workspace_client=WorkspaceClient(), @@ -141,14 +141,14 @@ class DatabricksMultiServerMCPClient(MultiServerMCPClient): from databricks.sdk import WorkspaceClient from databricks_langchain import ( DatabricksMultiServerMCPClient, - DatabricksServer, - Server, + DatabricksMCPServer, + MCPServer, ) client = DatabricksMultiServerMCPClient( [ # Databricks server with automatic OAuth - just pass params as kwargs! - DatabricksServer( + DatabricksMCPServer( name="databricks-prod", url="https://your-workspace.databricks.com/mcp", workspace_client=WorkspaceClient(), @@ -157,7 +157,7 @@ class DatabricksMultiServerMCPClient(MultiServerMCPClient): handle_tool_error=True, # Return errors as strings instead of raising ), # Generic server with custom params - same flat API - Server( + MCPServer( name="other-server", url="https://other-server.com/mcp", headers={"X-API-Key": "secret"}, @@ -171,12 +171,12 @@ class DatabricksMultiServerMCPClient(MultiServerMCPClient): ``` """ - def __init__(self, servers: List[Server], **kwargs): + def __init__(self, servers: List[MCPServer], **kwargs): """ Initialize the client with a list of server configurations. Args: - servers: List of Server or DatabricksServer configurations + servers: List of MCPServer or DatabricksMCPServer configurations **kwargs: Additional arguments to pass to MultiServerMCPClient """ # Store server configs for later use (e.g., handle_tool_errors) diff --git a/integrations/langchain/tests/unit_tests/test_multi_server_mcp_client.py b/integrations/langchain/tests/unit_tests/test_multi_server_mcp_client.py index f6a3fc31..17b469f4 100644 --- a/integrations/langchain/tests/unit_tests/test_multi_server_mcp_client.py +++ b/integrations/langchain/tests/unit_tests/test_multi_server_mcp_client.py @@ -8,18 +8,18 @@ from databricks.sdk import WorkspaceClient from databricks_langchain.multi_server_mcp_client import ( + DatabricksMCPServer, DatabricksMultiServerMCPClient, - DatabricksServer, - Server, + MCPServer, ) -class TestServer: - """Tests for the Server class.""" +class TestMCPServer: + """Tests for the MCPServer class.""" def test_basic_server_creation(self): """Test creating a basic server with minimal parameters.""" - server = Server(name="test-server", url="https://example.com/mcp") + server = MCPServer(name="test-server", url="https://example.com/mcp") assert server.name == "test-server" assert server.url == "https://example.com/mcp" @@ -36,8 +36,8 @@ def test_basic_server_creation(self): ], ) def test_server_accepts_extra_params(self, extra_params: dict[str, Any]): - """Test that Server accepts and preserves extra parameters.""" - server = Server( + """Test that MCPServer accepts and preserves extra parameters.""" + server = MCPServer( name="test-server", url="https://example.com/mcp", handle_tool_error=True, @@ -64,7 +64,7 @@ def test_server_accepts_extra_params(self, extra_params: dict[str, Any]): ) def test_server_handle_tool_error_types(self, handle_tool_error_value: Any): """Test that handle_tool_error accepts various types.""" - server = Server( + server = MCPServer( name="test-server", url="https://example.com/mcp", handle_tool_error=handle_tool_error_value, @@ -73,11 +73,11 @@ def test_server_handle_tool_error_types(self, handle_tool_error_value: Any): assert server.handle_tool_error == handle_tool_error_value -class TestDatabricksServer: - """Tests for the DatabricksServer class.""" +class TestDatabricksMCPServer: + """Tests for the DatabricksMCPServer class.""" def test_databricks_server_without_workspace_client(self): - """Test DatabricksServer creates WorkspaceClient automatically.""" + """Test DatabricksMCPServer creates WorkspaceClient automatically.""" with ( patch("databricks_langchain.multi_server_mcp_client.WorkspaceClient") as mock_ws, patch( @@ -89,7 +89,7 @@ def test_databricks_server_without_workspace_client(self): mock_auth_instance = MagicMock() mock_auth.return_value = mock_auth_instance - server = DatabricksServer(name="databricks", url="https://databricks.com/mcp") + server = DatabricksMCPServer(name="databricks", url="https://databricks.com/mcp") # Should have created WorkspaceClient mock_ws.assert_called_once() @@ -97,7 +97,7 @@ def test_databricks_server_without_workspace_client(self): mock_auth.assert_called_once_with(mock_ws_instance) def test_databricks_server_with_workspace_client(self): - """Test DatabricksServer uses provided WorkspaceClient.""" + """Test DatabricksMCPServer uses provided WorkspaceClient.""" mock_workspace_client = create_autospec(WorkspaceClient, instance=True) with patch( @@ -106,7 +106,7 @@ def test_databricks_server_with_workspace_client(self): mock_auth_instance = MagicMock() mock_auth.return_value = mock_auth_instance - server = DatabricksServer( + server = DatabricksMCPServer( name="databricks", url="https://databricks.com/mcp", workspace_client=mock_workspace_client, @@ -122,7 +122,7 @@ def test_databricks_server_with_workspace_client(self): assert connection_dict["auth"] is mock_auth_instance def test_databricks_server_accepts_extra_params(self): - """Test that DatabricksServer accepts extra connection params.""" + """Test that DatabricksMCPServer accepts extra connection params.""" mock_workspace_client = create_autospec(WorkspaceClient, instance=True) with patch( @@ -131,7 +131,7 @@ def test_databricks_server_accepts_extra_params(self): mock_auth_instance = MagicMock() mock_auth.return_value = mock_auth_instance - server = DatabricksServer( + server = DatabricksMCPServer( name="databricks", url="https://databricks.com/mcp", workspace_client=mock_workspace_client, @@ -156,8 +156,8 @@ def test_client_initialization_with_multiple_servers(self): mock_init.return_value = None servers = [ - Server(name="server1", url="https://server1.com/mcp"), - Server(name="server2", url="https://server2.com/mcp"), + MCPServer(name="server1", url="https://server1.com/mcp"), + MCPServer(name="server2", url="https://server2.com/mcp"), ] client = DatabricksMultiServerMCPClient(servers) @@ -181,8 +181,10 @@ def test_client_initialization_with_multiple_servers(self): async def test_get_tools_all_servers(self): """Test get_tools without server_name (all servers).""" servers = [ - Server(name="server1", url="https://server1.com/mcp", handle_tool_error=True), - Server(name="server2", url="https://server2.com/mcp", handle_tool_error="Custom error"), + MCPServer(name="server1", url="https://server1.com/mcp", handle_tool_error=True), + MCPServer( + name="server2", url="https://server2.com/mcp", handle_tool_error="Custom error" + ), ] # Create mock tools for each server @@ -235,7 +237,7 @@ async def mock_get_tools_side_effect(server_name=None): @pytest.mark.asyncio async def test_get_tools_parallel_execution(self): """Test that get_tools executes server requests in parallel.""" - servers = [Server(name=f"server{i}", url=f"https://server{i}.com/mcp") for i in range(5)] + servers = [MCPServer(name=f"server{i}", url=f"https://server{i}.com/mcp") for i in range(5)] call_count = 0 call_times = [] @@ -274,7 +276,7 @@ async def mock_get_tools_with_delay(server_name=None): @pytest.mark.asyncio async def test_get_tools_with_databricks_server(self): - """Test get_tools with DatabricksServer.""" + """Test get_tools with DatabricksMCPServer.""" mock_workspace_client = create_autospec(WorkspaceClient, instance=True) mock_tool = MagicMock() @@ -295,7 +297,7 @@ async def test_get_tools_with_databricks_server(self): mock_auth.return_value = mock_auth_instance mock_parent_get_tools.return_value = [mock_tool] - server = DatabricksServer( + server = DatabricksMCPServer( name="databricks", url="https://databricks.com/mcp", workspace_client=mock_workspace_client, From fc4471e834fb34008580006dbea713b6e62f316d Mon Sep 17 00:00:00 2001 From: aravind-segu Date: Fri, 21 Nov 2025 12:31:13 -0500 Subject: [PATCH 5/8] Add databricks_mcp dependency --- integrations/langchain/pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/integrations/langchain/pyproject.toml b/integrations/langchain/pyproject.toml index f2c666a0..56ede148 100644 --- a/integrations/langchain/pyproject.toml +++ b/integrations/langchain/pyproject.toml @@ -17,7 +17,8 @@ dependencies = [ "unitycatalog-langchain[databricks]>=0.3.0", "databricks-sdk>=0.65.0", "openai>=1.99.9", - "langchain-mcp-adapters>=0.1.13" + "langchain-mcp-adapters>=0.1.13", + "databricks_mcp>=0.4.0" ] From d3c5cd3ab07a1d1761a700962e33bbfdfd8425e4 Mon Sep 17 00:00:00 2001 From: aravind-segu Date: Mon, 24 Nov 2025 11:49:38 -0500 Subject: [PATCH 6/8] Fix dependencies --- integrations/langchain/pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/integrations/langchain/pyproject.toml b/integrations/langchain/pyproject.toml index 56ede148..4c82939a 100644 --- a/integrations/langchain/pyproject.toml +++ b/integrations/langchain/pyproject.toml @@ -25,6 +25,7 @@ dependencies = [ [project.optional-dependencies] dev = [ "pytest", + "pytest-asyncio", "typing_extensions", "ruff==0.6.4", ] From 79c55cfcde4c55dc54fc99d6d303bfa3b1a2c9c1 Mon Sep 17 00:00:00 2001 From: aravind-segu Date: Wed, 26 Nov 2025 12:29:43 -0500 Subject: [PATCH 7/8] Use model post init --- .../src/databricks_langchain/multi_server_mcp_client.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/integrations/langchain/src/databricks_langchain/multi_server_mcp_client.py b/integrations/langchain/src/databricks_langchain/multi_server_mcp_client.py index 973fe329..3142e6ea 100644 --- a/integrations/langchain/src/databricks_langchain/multi_server_mcp_client.py +++ b/integrations/langchain/src/databricks_langchain/multi_server_mcp_client.py @@ -105,9 +105,9 @@ class DatabricksMCPServer(MCPServer): exclude=True, ) - def __init__(self, **data): + def model_post_init(self, __context: Any) -> None: """Initialize DatabricksServer with auth setup.""" - super().__init__(**data) + super().model_post_init(__context) # Set up Databricks OAuth authentication after initialization if self.workspace_client is None: From 28b3995aeddf18cf9007bc4ed9d0be3d9ba788ef Mon Sep 17 00:00:00 2001 From: aravind-segu Date: Tue, 2 Dec 2025 13:58:32 -0500 Subject: [PATCH 8/8] Address review comments --- .../databricks_langchain/multi_server_mcp_client.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/integrations/langchain/src/databricks_langchain/multi_server_mcp_client.py b/integrations/langchain/src/databricks_langchain/multi_server_mcp_client.py index 3142e6ea..0750891b 100644 --- a/integrations/langchain/src/databricks_langchain/multi_server_mcp_client.py +++ b/integrations/langchain/src/databricks_langchain/multi_server_mcp_client.py @@ -21,6 +21,7 @@ class MCPServer(BaseModel): - httpx_client_factory: Callable - Custom httpx client factory - terminate_on_close: bool - Terminate connection on close - session_kwargs: dict - Additional session kwargs + - handle_tool_error: bool | str | Callable - Error handling strategy Example: ```python @@ -48,11 +49,9 @@ class MCPServer(BaseModel): default=None, exclude=True, description=( - "How to handle errors raised by tools from this server. Options:\n" - "- None/False: Raise the error\n" - "- True: Return error message as string\n" - "- str: Return this string when errors occur\n" - "- Callable: Function that takes error and returns error message string" + "If True, return the error message as the output. If False, raise the error. " + "If a string, return the string as the error message. " + "If a callable, return the result of the callable as the error message." ), ) @@ -105,9 +104,9 @@ class DatabricksMCPServer(MCPServer): exclude=True, ) - def model_post_init(self, __context: Any) -> None: + def model_post_init(self, context: Any) -> None: """Initialize DatabricksServer with auth setup.""" - super().model_post_init(__context) + super().model_post_init(context) # Set up Databricks OAuth authentication after initialization if self.workspace_client is None: