-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Add parallel tool execution to ReAct #8999
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
Copilot
wants to merge
5
commits into
main
Choose a base branch
from
copilot/add-parallel-tool-execution
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from 3 commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
1607be3
Initial plan
Copilot e4c1c6f
Implement parallel tool execution for ReAct module
Copilot 84f5263
Address code review feedback
Copilot 741ac11
Merge test_react_parallel.py into test_react.py
Copilot e0833e3
Use ToolCalls class for next_tool_calls schema
Copilot File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,7 @@ | ||
| import asyncio | ||
| import logging | ||
| from typing import TYPE_CHECKING, Any, Callable, Literal | ||
| from concurrent.futures import ThreadPoolExecutor | ||
| from typing import TYPE_CHECKING, Any, Callable | ||
|
|
||
| from litellm import ContextWindowExceededError | ||
|
|
||
|
|
@@ -53,10 +55,11 @@ def get_weather(city: str) -> str: | |
| [ | ||
| f"You are an Agent. In each episode, you will be given the fields {inputs} as input. And you can see your past trajectory so far.", | ||
| f"Your goal is to use one or more of the supplied tools to collect any necessary information for producing {outputs}.\n", | ||
| "To do this, you will interleave next_thought, next_tool_name, and next_tool_args in each turn, and also when finishing the task.", | ||
| "After each tool call, you receive a resulting observation, which gets appended to your trajectory.\n", | ||
| "To do this, you will interleave next_thought and next_tool_calls in each turn, and also when finishing the task.", | ||
| "You can call multiple tools in parallel by providing multiple tool calls in next_tool_calls.", | ||
| "After each set of tool calls, you receive resulting observations, which get appended to your trajectory.\n", | ||
| "When writing next_thought, you may reason about the current situation and plan for future steps.", | ||
| "When selecting the next_tool_name and its next_tool_args, the tool must be one of:\n", | ||
| "When selecting next_tool_calls, each tool must be one of:\n", | ||
| ] | ||
| ) | ||
|
|
||
|
|
@@ -69,14 +72,16 @@ def get_weather(city: str) -> str: | |
|
|
||
| for idx, tool in enumerate(tools.values()): | ||
| instr.append(f"({idx + 1}) {tool}") | ||
| instr.append("When providing `next_tool_args`, the value inside the field must be in JSON format") | ||
| instr.append( | ||
| "When providing `next_tool_calls`, provide a list of tool calls. Each tool call should be a dictionary with 'name' and 'args' keys. " | ||
| "The 'name' must be one of the tool names listed above, and 'args' must be a dictionary in JSON format containing the arguments for that tool." | ||
| ) | ||
|
|
||
| react_signature = ( | ||
| dspy.Signature({**signature.input_fields}, "\n".join(instr)) | ||
| .append("trajectory", dspy.InputField(), type_=str) | ||
| .append("next_thought", dspy.OutputField(), type_=str) | ||
| .append("next_tool_name", dspy.OutputField(), type_=Literal[tuple(tools.keys())]) | ||
| .append("next_tool_args", dspy.OutputField(), type_=dict[str, Any]) | ||
| .append("next_tool_calls", dspy.OutputField(), type_=list[dict[str, Any]]) | ||
| ) | ||
|
|
||
| fallback_signature = dspy.Signature( | ||
|
|
@@ -104,15 +109,26 @@ def forward(self, **input_args): | |
| break | ||
|
|
||
| trajectory[f"thought_{idx}"] = pred.next_thought | ||
| trajectory[f"tool_name_{idx}"] = pred.next_tool_name | ||
| trajectory[f"tool_args_{idx}"] = pred.next_tool_args | ||
|
|
||
| try: | ||
| trajectory[f"observation_{idx}"] = self.tools[pred.next_tool_name](**pred.next_tool_args) | ||
| except Exception as err: | ||
| trajectory[f"observation_{idx}"] = f"Execution error in {pred.next_tool_name}: {_fmt_exc(err)}" | ||
|
|
||
| if pred.next_tool_name == "finish": | ||
| # Parse tool calls - handle both list format and backward compatibility | ||
| tool_calls = self._parse_tool_calls(pred.next_tool_calls) | ||
| trajectory[f"tool_calls_{idx}"] = tool_calls | ||
|
|
||
| # Execute tools in parallel | ||
| observations = self._execute_tools_parallel(tool_calls) | ||
|
|
||
| # Store observations as a structured format that includes tool names | ||
| # This makes it easier for the LLM to understand which observation corresponds to which tool | ||
| formatted_observations = [] | ||
| for tool_call, observation in zip(tool_calls, observations, strict=True): | ||
| formatted_observations.append({ | ||
| "tool": tool_call["name"], | ||
| "result": observation | ||
| }) | ||
| trajectory[f"observations_{idx}"] = formatted_observations | ||
|
|
||
| # Check if any tool call is "finish" | ||
| if any(tc["name"] == "finish" for tc in tool_calls): | ||
| break | ||
|
|
||
| extract = self._call_with_potential_trajectory_truncation(self.extract, trajectory, **input_args) | ||
|
|
@@ -129,20 +145,92 @@ async def aforward(self, **input_args): | |
| break | ||
|
|
||
| trajectory[f"thought_{idx}"] = pred.next_thought | ||
| trajectory[f"tool_name_{idx}"] = pred.next_tool_name | ||
| trajectory[f"tool_args_{idx}"] = pred.next_tool_args | ||
|
|
||
| try: | ||
| trajectory[f"observation_{idx}"] = await self.tools[pred.next_tool_name].acall(**pred.next_tool_args) | ||
| except Exception as err: | ||
| trajectory[f"observation_{idx}"] = f"Execution error in {pred.next_tool_name}: {_fmt_exc(err)}" | ||
|
|
||
| if pred.next_tool_name == "finish": | ||
| # Parse tool calls - handle both list format and backward compatibility | ||
| tool_calls = self._parse_tool_calls(pred.next_tool_calls) | ||
| trajectory[f"tool_calls_{idx}"] = tool_calls | ||
|
|
||
| # Execute tools in parallel | ||
| observations = await self._execute_tools_parallel_async(tool_calls) | ||
|
|
||
| # Store observations as a structured format that includes tool names | ||
| # This makes it easier for the LLM to understand which observation corresponds to which tool | ||
| formatted_observations = [] | ||
| for tool_call, observation in zip(tool_calls, observations, strict=True): | ||
| formatted_observations.append({ | ||
| "tool": tool_call["name"], | ||
| "result": observation | ||
| }) | ||
| trajectory[f"observations_{idx}"] = formatted_observations | ||
|
|
||
| # Check if any tool call is "finish" | ||
| if any(tc["name"] == "finish" for tc in tool_calls): | ||
| break | ||
|
|
||
| extract = await self._async_call_with_potential_trajectory_truncation(self.extract, trajectory, **input_args) | ||
| return dspy.Prediction(trajectory=trajectory, **extract) | ||
|
|
||
| def _parse_tool_calls(self, tool_calls_data): | ||
| """Parse tool calls from the prediction output. | ||
|
|
||
| Handles both the new list format and provides backward compatibility. | ||
| """ | ||
| # If it's already a list of dicts with 'name' and 'args', use it directly | ||
| if isinstance(tool_calls_data, list): | ||
| return tool_calls_data | ||
|
|
||
| # Handle single dict case (shouldn't normally happen but for robustness) | ||
| if isinstance(tool_calls_data, dict) and "name" in tool_calls_data and "args" in tool_calls_data: | ||
| return [tool_calls_data] | ||
|
|
||
| # If we got something unexpected, raise an error | ||
| raise ValueError(f"Invalid tool_calls format: {tool_calls_data}") | ||
|
|
||
| def _execute_tools_parallel(self, tool_calls: list[dict[str, Any]]) -> list[Any]: | ||
| """Execute multiple tools in parallel using ThreadPoolExecutor. | ||
|
|
||
| Args: | ||
| tool_calls: List of tool call dicts, each with 'name' and 'args' keys | ||
|
|
||
| Returns: | ||
| List of observations in the same order as tool_calls | ||
| """ | ||
| def execute_single_tool(tool_call: dict[str, Any]) -> Any: | ||
| tool_name = tool_call["name"] | ||
| tool_args = tool_call.get("args", {}) | ||
| try: | ||
| return self.tools[tool_name](**tool_args) | ||
| except Exception as err: | ||
| return f"Execution error in {tool_name}: {_fmt_exc(err)}" | ||
|
|
||
| # Execute tools in parallel using ThreadPoolExecutor | ||
| with ThreadPoolExecutor() as executor: | ||
| observations = list(executor.map(execute_single_tool, tool_calls)) | ||
|
|
||
| return observations | ||
|
|
||
| async def _execute_tools_parallel_async(self, tool_calls: list[dict[str, Any]]) -> list[Any]: | ||
| """Execute multiple tools in parallel using asyncio.gather. | ||
|
|
||
| Args: | ||
| tool_calls: List of tool call dicts, each with 'name' and 'args' keys | ||
|
|
||
| Returns: | ||
| List of observations in the same order as tool_calls | ||
| """ | ||
| async def execute_single_tool(tool_call: dict[str, Any]) -> Any: | ||
| tool_name = tool_call["name"] | ||
| tool_args = tool_call.get("args", {}) | ||
| try: | ||
| return await self.tools[tool_name].acall(**tool_args) | ||
| except Exception as err: | ||
| return f"Execution error in {tool_name}: {_fmt_exc(err)}" | ||
|
|
||
| # Execute tools in parallel using asyncio.gather | ||
| observations = await asyncio.gather(*[execute_single_tool(tc) for tc in tool_calls]) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @copilot Use the newer |
||
|
|
||
| return observations | ||
|
|
||
| def _call_with_potential_trajectory_truncation(self, module, trajectory, **input_args): | ||
| for _ in range(3): | ||
| try: | ||
|
|
@@ -171,14 +259,14 @@ def truncate_trajectory(self, trajectory): | |
| Users can override this method to implement their own truncation logic. | ||
| """ | ||
| keys = list(trajectory.keys()) | ||
| if len(keys) < 4: | ||
| # Every tool call has 4 keys: thought, tool_name, tool_args, and observation. | ||
| if len(keys) < 3: | ||
| # Every iteration has 3 keys: thought, tool_calls, and observations. | ||
| raise ValueError( | ||
| "The trajectory is too long so your prompt exceeded the context window, but the trajectory cannot be " | ||
| "truncated because it only has one tool call." | ||
| "truncated because it only has one iteration." | ||
| ) | ||
|
|
||
| for key in keys[:4]: | ||
| for key in keys[:3]: | ||
| trajectory.pop(key) | ||
|
|
||
| return trajectory | ||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@copilot can we use
ToolCallsclass to define the schema explicitly?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done! Changed the type annotation from
list[dict[str, Any]]toToolCallsand updated the_parse_tool_callsmethod to handle ToolCalls objects. All tests passing. (commit e0833e3)