Skip to content

Commit b48ed6f

Browse files
committed
comments
1 parent 583d001 commit b48ed6f

File tree

2 files changed

+64
-2
lines changed

2 files changed

+64
-2
lines changed

src/agents/tool.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import functools
34
import inspect
45
import json
56
from collections.abc import Awaitable
@@ -179,7 +180,7 @@ class FunctionTool:
179180
and returns whether the tool is enabled. You can use this to dynamically enable/disable a tool
180181
based on your context/state."""
181182

182-
func: ToolFunction[...] | None = None
183+
_func: ToolFunction[...] | None = field(default=None, repr=False)
183184
"""The function that implements the tool. Ensures that a reference to the
184185
original function exists when @function_tool is used."""
185186

@@ -194,6 +195,17 @@ def __post_init__(self):
194195
if self.strict_json_schema:
195196
self.params_json_schema = ensure_strict_json_schema(self.params_json_schema)
196197

198+
if self._func:
199+
functools.update_wrapper(self, self._func)
200+
201+
def __call__(self, *args, **kwargs):
202+
if not self._func:
203+
raise AttributeError("""FunctionTool has no attribute `_func` and is not callable.
204+
Likely because it was created directly without the
205+
@function_tool decorator.""")
206+
207+
return self._func(*args, **kwargs)
208+
197209

198210
@dataclass
199211
class FileSearchTool:
@@ -665,7 +677,7 @@ async def _on_invoke_tool(ctx: ToolContext[Any], input: str) -> Any:
665677
on_invoke_tool=_on_invoke_tool,
666678
strict_json_schema=strict_mode,
667679
is_enabled=is_enabled,
668-
func=func,
680+
_func=func,
669681
)
670682

671683
# If func is actually a callable, we were used as @function_tool with no parentheses

tests/test_function_tool.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import inspect
12
import json
3+
from dataclasses import asdict
24
from typing import Any
35

46
import pytest
@@ -81,6 +83,44 @@ async def test_simple_function():
8183
ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments=""), ""
8284
)
8385

86+
# Direct call
87+
result = tool(2, 2)
88+
assert result == 4
89+
90+
91+
async def async_function(a: int, b: int = 5):
92+
return a + b
93+
94+
95+
@pytest.mark.asyncio
96+
async def test_async_function():
97+
tool = function_tool(async_function, failure_error_function=None)
98+
assert tool.name == "async_function"
99+
100+
result = await tool.on_invoke_tool(
101+
ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments='{"a": 1}'),
102+
'{"a": 1}',
103+
)
104+
assert result == 6
105+
106+
result = await tool.on_invoke_tool(
107+
ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments='{"a": 1, "b": 2}'),
108+
'{"a": 1, "b": 2}',
109+
)
110+
assert result == 3
111+
112+
# Missing required argument should raise an error
113+
with pytest.raises(ModelBehaviorError):
114+
await tool.on_invoke_tool(
115+
ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments=""), ""
116+
)
117+
118+
# Direct call
119+
result = await tool(2, 2)
120+
assert result == 4
121+
122+
assert not inspect.iscoroutinefunction(tool.__call__), "tool.__call__ should sync."
123+
84124

85125
class Foo(BaseModel):
86126
a: int
@@ -148,6 +188,16 @@ async def test_complex_args_function():
148188
)
149189

150190

191+
def test_absent_func_tool():
192+
tool = function_tool(simple_function)
193+
kwargs = asdict(tool)
194+
kwargs.pop("_func")
195+
manually_defined_tool = FunctionTool(**kwargs)
196+
197+
with pytest.raises(AttributeError, match="not callable"):
198+
manually_defined_tool(1, 1)
199+
200+
151201
def test_function_config_overrides():
152202
tool = function_tool(simple_function, name_override="custom_name")
153203
assert tool.name == "custom_name"

0 commit comments

Comments
 (0)