|
| 1 | +import inspect |
1 | 2 | import json |
| 3 | +from dataclasses import asdict |
2 | 4 | from typing import Any |
3 | 5 |
|
4 | 6 | import pytest |
@@ -81,6 +83,44 @@ async def test_simple_function(): |
81 | 83 | ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments=""), "" |
82 | 84 | ) |
83 | 85 |
|
| 86 | + # Direct call |
| 87 | + result = tool(2, 2) |
| 88 | + assert result == 4 |
| 89 | + |
| 90 | + |
| 91 | +async def async_function(a: int, b: int = 5): |
| 92 | + return a + b |
| 93 | + |
| 94 | + |
| 95 | +@pytest.mark.asyncio |
| 96 | +async def test_async_function(): |
| 97 | + tool = function_tool(async_function, failure_error_function=None) |
| 98 | + assert tool.name == "async_function" |
| 99 | + |
| 100 | + result = await tool.on_invoke_tool( |
| 101 | + ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments='{"a": 1}'), |
| 102 | + '{"a": 1}', |
| 103 | + ) |
| 104 | + assert result == 6 |
| 105 | + |
| 106 | + result = await tool.on_invoke_tool( |
| 107 | + ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments='{"a": 1, "b": 2}'), |
| 108 | + '{"a": 1, "b": 2}', |
| 109 | + ) |
| 110 | + assert result == 3 |
| 111 | + |
| 112 | + # Missing required argument should raise an error |
| 113 | + with pytest.raises(ModelBehaviorError): |
| 114 | + await tool.on_invoke_tool( |
| 115 | + ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments=""), "" |
| 116 | + ) |
| 117 | + |
| 118 | + # Direct call |
| 119 | + result = await tool(2, 2) |
| 120 | + assert result == 4 |
| 121 | + |
| 122 | + assert not inspect.iscoroutinefunction(tool.__call__), "tool.__call__ should sync." |
| 123 | + |
84 | 124 |
|
85 | 125 | class Foo(BaseModel): |
86 | 126 | a: int |
@@ -148,6 +188,16 @@ async def test_complex_args_function(): |
148 | 188 | ) |
149 | 189 |
|
150 | 190 |
|
| 191 | +def test_absent_func_tool(): |
| 192 | + tool = function_tool(simple_function) |
| 193 | + kwargs = asdict(tool) |
| 194 | + kwargs.pop("_func") |
| 195 | + manually_defined_tool = FunctionTool(**kwargs) |
| 196 | + |
| 197 | + with pytest.raises(AttributeError, match="not callable"): |
| 198 | + manually_defined_tool(1, 1) |
| 199 | + |
| 200 | + |
151 | 201 | def test_function_config_overrides(): |
152 | 202 | tool = function_tool(simple_function, name_override="custom_name") |
153 | 203 | assert tool.name == "custom_name" |
|
0 commit comments