diff --git a/src/deepagents/graph.py b/src/deepagents/graph.py index d49075b33..06d40f4a7 100644 --- a/src/deepagents/graph.py +++ b/src/deepagents/graph.py @@ -53,6 +53,7 @@ def create_deep_agent( debug: bool = False, name: str | None = None, cache: BaseCache | None = None, + preserve_message_tool_names: list[str] | None = None, ) -> CompiledStateGraph: """Create a deep agent. @@ -114,6 +115,7 @@ def create_deep_agent( ], default_interrupt_on=interrupt_on, general_purpose_agent=True, + preserve_message_tool_names=preserve_message_tool_names, ), SummarizationMiddleware( model=model, diff --git a/src/deepagents/middleware/subagents.py b/src/deepagents/middleware/subagents.py index 4dc5e584c..646196e50 100644 --- a/src/deepagents/middleware/subagents.py +++ b/src/deepagents/middleware/subagents.py @@ -9,7 +9,7 @@ from langchain.agents.middleware.types import AgentMiddleware, ModelRequest, ModelResponse from langchain.tools import BaseTool, ToolRuntime from langchain_core.language_models import BaseChatModel -from langchain_core.messages import HumanMessage, ToolMessage +from langchain_core.messages import HumanMessage, ToolMessage, AIMessage from langchain_core.runnables import Runnable from langchain_core.tools import StructuredTool from langgraph.types import Command @@ -286,6 +286,7 @@ def _create_task_tool( subagents: list[SubAgent | CompiledSubAgent], general_purpose_agent: bool, task_description: str | None = None, + preserve_message_tool_names: list[str] | None = None, ) -> BaseTool: """Create a task tool for invoking subagents. @@ -315,10 +316,71 @@ def _create_task_tool( def _return_command_with_state_update(result: dict, tool_call_id: str) -> Command: state_update = {k: v for k, v in result.items() if k not in _EXCLUDED_STATE_KEYS} + messages_update: list = [ + ToolMessage(result["messages"][-1].text, tool_call_id=tool_call_id) + ] + + if preserve_message_tool_names: + preserve_set = set(preserve_message_tool_names) + + tool_call_id_to_tool_msg: dict[str, ToolMessage] = {} + for m in result["messages"]: + if isinstance(m, ToolMessage) and getattr(m, "tool_call_id", None): + tool_call_id_to_tool_msg[m.tool_call_id] = m + + def _tc_name(tc: Any) -> str | None: + if isinstance(tc, dict): + name = tc.get("name") + if name: + return name + func = tc.get("function") or {} + return func.get("name") if isinstance(func, dict) else None + name = getattr(tc, "name", None) + if name: + return name + func = getattr(tc, "function", None) + if isinstance(func, dict): + return func.get("name") + return getattr(func, "name", None) + + def _tc_id(tc: Any) -> str | None: + if isinstance(tc, dict): + return tc.get("id") + return getattr(tc, "id", None) + + latest_ai_for_tool: dict[str, AIMessage] = {} + for m in reversed(result["messages"]): + if isinstance(m, AIMessage) and getattr(m, "tool_calls", None): + names_in_msg = {n for n in (_tc_name(tc) for tc in (m.tool_calls or [])) if n} + intersect = [n for n in names_in_msg if n in preserve_set and n not in latest_ai_for_tool] + if not intersect: + continue + for tool_name in intersect: + filtered_calls = [tc for tc in (m.tool_calls or []) if _tc_name(tc) == tool_name] + if not filtered_calls: + continue + cloned_ai = AIMessage( + content=m.content, + additional_kwargs=getattr(m, "additional_kwargs", {}), + response_metadata=getattr(m, "response_metadata", {}), + tool_calls=filtered_calls, + ) + latest_ai_for_tool[tool_name] = cloned_ai + + if latest_ai_for_tool: + ordered_tools = [t for t in preserve_message_tool_names if t in latest_ai_for_tool] + for tool_name in ordered_tools: + ai_msg = latest_ai_for_tool[tool_name] + messages_update.append(ai_msg) + for tc in ai_msg.tool_calls or []: + tc_id = _tc_id(tc) + if tc_id and tc_id in tool_call_id_to_tool_msg: + messages_update.append(tool_call_id_to_tool_msg[tc_id]) + return Command( update={ **state_update, - "messages": [ToolMessage(result["messages"][-1].text, tool_call_id=tool_call_id)], + "messages": messages_update, } ) @@ -444,6 +506,7 @@ def __init__( system_prompt: str | None = TASK_SYSTEM_PROMPT, general_purpose_agent: bool = True, task_description: str | None = None, + preserve_message_tool_names: list[str] | None = None, ) -> None: """Initialize the SubAgentMiddleware.""" super().__init__() @@ -456,6 +519,7 @@ def __init__( subagents=subagents or [], general_purpose_agent=general_purpose_agent, task_description=task_description, + preserve_message_tool_names=preserve_message_tool_names, ) self.tools = [task_tool]