Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions go/ai/action_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ func TestGenerateAction(t *testing.T) {
cmpopts.EquateEmpty(),
cmpopts.IgnoreFields(ModelResponse{}, "LatencyMs"),
cmpopts.IgnoreFields(GenerationUsage{}, "InputCharacters", "OutputCharacters"),
cmpopts.IgnoreFields(ToolDefinition{}, "Metadata"),
}); diff != "" {
t.Errorf("response mismatch (-want +got):\n%s", diff)
}
Expand All @@ -157,6 +158,7 @@ func TestGenerateAction(t *testing.T) {
cmpopts.EquateEmpty(),
cmpopts.IgnoreFields(ModelResponse{}, "LatencyMs"),
cmpopts.IgnoreFields(GenerationUsage{}, "InputCharacters", "OutputCharacters"),
cmpopts.IgnoreFields(ToolDefinition{}, "Metadata"),
}); diff != "" {
t.Errorf("response mismatch (-want +got):\n%s", diff)
}
Expand Down
4 changes: 2 additions & 2 deletions go/ai/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -412,8 +412,8 @@ type toolRequestPart struct {
// the results of running a specific tool on the arguments passed to the client
// by the model in a [ToolRequest].
type ToolResponse struct {
Content []any `json:"content,omitempty"`
Name string `json:"name,omitempty"`
Content []*Part `json:"content,omitempty"`
Name string `json:"name,omitempty"`
// Output is a JSON object describing the results of running the tool.
// An example might be map[string]any{"name":"Thomas Jefferson", "born":1743}.
Output any `json:"output,omitempty"`
Expand Down
31 changes: 12 additions & 19 deletions go/ai/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -628,20 +628,12 @@ func clone[T any](obj *T) *T {
// either a new request to continue the conversation or nil if no tool requests
// need handling.
func handleToolRequests(ctx context.Context, r api.Registry, req *ModelRequest, resp *ModelResponse, cb ModelStreamCallback, messageIndex int) (*ModelRequest, *Message, error) {
toolCount := 0
if resp.Message != nil {
for _, part := range resp.Message.Content {
if part.IsToolRequest() {
toolCount++
}
}
}

toolCount := len(resp.ToolRequests())
if toolCount == 0 {
return nil, nil, nil
}

resultChan := make(chan result[any])
resultChan := make(chan result[*MultipartToolResponse])
toolMsg := &Message{Role: RoleTool}
revisedMsg := clone(resp.Message)

Expand All @@ -654,11 +646,11 @@ func handleToolRequests(ctx context.Context, r api.Registry, req *ModelRequest,
toolReq := p.ToolRequest
tool := LookupTool(r, p.ToolRequest.Name)
if tool == nil {
resultChan <- result[any]{idx, nil, core.NewError(core.NOT_FOUND, "tool %q not found", toolReq.Name)}
resultChan <- result[*MultipartToolResponse]{index: idx, err: core.NewError(core.NOT_FOUND, "tool %q not found", toolReq.Name)}
return
}

output, err := tool.RunRaw(ctx, toolReq.Input)
multipartResp, err := tool.RunRawMultipart(ctx, toolReq.Input)
if err != nil {
var tie *toolInterruptError
if errors.As(err, &tie) {
Expand All @@ -676,22 +668,22 @@ func handleToolRequests(ctx context.Context, r api.Registry, req *ModelRequest,

revisedMsg.Content[idx] = newPart

resultChan <- result[any]{idx, nil, tie}
resultChan <- result[*MultipartToolResponse]{index: idx, err: tie}
return
}

resultChan <- result[any]{idx, nil, core.NewError(core.INTERNAL, "tool %q failed: %v", toolReq.Name, err)}
resultChan <- result[*MultipartToolResponse]{index: idx, err: core.NewError(core.INTERNAL, "tool %q failed: %v", toolReq.Name, err)}
return
}

newPart := clone(p)
if newPart.Metadata == nil {
newPart.Metadata = make(map[string]any)
}
newPart.Metadata["pendingOutput"] = output
newPart.Metadata["pendingOutput"] = multipartResp.Output
revisedMsg.Content[idx] = newPart

resultChan <- result[any]{idx, output, nil}
resultChan <- result[*MultipartToolResponse]{index: idx, value: multipartResp}
}(i, part)
}

Expand All @@ -711,9 +703,10 @@ func handleToolRequests(ctx context.Context, r api.Registry, req *ModelRequest,

toolReq := revisedMsg.Content[res.index].ToolRequest
toolResps = append(toolResps, NewToolResponsePart(&ToolResponse{
Name: toolReq.Name,
Ref: toolReq.Ref,
Output: res.value,
Name: toolReq.Name,
Ref: toolReq.Ref,
Output: res.value.Output,
Content: res.value.Content,
}))
}

Expand Down
181 changes: 181 additions & 0 deletions go/ai/generate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,9 @@ func TestGenerate(t *testing.T) {
},
Name: "gablorken",
OutputSchema: map[string]any{"type": string("number")},
Metadata: map[string]any{
"multipart": false,
},
},
},
ToolChoice: ToolChoiceAuto,
Expand Down Expand Up @@ -1332,3 +1335,181 @@ func TestResourceProcessingError(t *testing.T) {
t.Fatalf("wrong error: %v", err)
}
}

func TestMultipartTools(t *testing.T) {
t.Run("define multipart tool registers as tool.v2 only", func(t *testing.T) {
r := registry.New()

DefineMultipartTool(r, "multipartTest", "a multipart tool",
func(ctx *ToolContext, input struct{ Query string }) (*MultipartToolResponse, error) {
return &MultipartToolResponse{
Output: "main output",
Content: []*Part{NewTextPart("content part 1")},
}, nil
},
)

// Should be found via LookupTool
tool := LookupTool(r, "multipartTest")
if tool == nil {
t.Fatal("expected multipart tool to be found via LookupTool")
}

// Should be able to produce response with content
resp, err := tool.RunRawMultipart(context.Background(), struct{ Query string }{Query: "Q"})
if err != nil {
t.Fatalf("failed running multipart tool: %v", err)
}
if len(resp.Content) == 0 {
t.Error("expected tool response to have content")
}
})

t.Run("regular tool registers as both tool and tool.v2", func(t *testing.T) {
r := registry.New()

DefineTool(r, "regularTestTool", "a regular tool",
func(ctx *ToolContext, input struct{ Value int }) (int, error) {
return input.Value * 2, nil
},
)

// Should be found via LookupTool
tool := LookupTool(r, "regularTestTool")
if tool == nil {
t.Fatal("expected regular tool to be found via LookupTool")
}

// Should produce response without content by default
resp, err := tool.RunRawMultipart(context.Background(), struct{ Value int }{Value: 21})
if err != nil {
t.Fatalf("failed running regular tool: %v", err)
}
if len(resp.Content) > 0 {
t.Error("expected regular tool response to have no content")
}
})

t.Run("multipart tool returns content in response", func(t *testing.T) {
r := registry.New()
ConfigureFormats(r)
DefineGenerateAction(context.Background(), r)

multipartTool := DefineMultipartTool(r, "imageGenerator", "generates images",
func(ctx *ToolContext, input struct{ Prompt string }) (*MultipartToolResponse, error) {
return &MultipartToolResponse{
Output: map[string]any{"description": "generated image"},
Content: []*Part{
NewMediaPart("image/png", "data:image/png;base64,iVBORw0..."),
},
}, nil
},
)

// Create a model that requests the tool
multipartToolModel := DefineModel(r, "test/multipartToolModel", &metadata, func(ctx context.Context, gr *ModelRequest, msc ModelStreamCallback) (*ModelResponse, error) {
// Check if we already have a tool response
for _, msg := range gr.Messages {
if msg.Role == RoleTool {
for _, part := range msg.Content {
if part.IsToolResponse() {
// Verify the content is present
if len(part.ToolResponse.Content) == 0 {
return nil, fmt.Errorf("expected tool response to have content")
}
return &ModelResponse{
Request: gr,
Message: NewModelTextMessage("Image generated successfully"),
}, nil
}
}
}
}

// First call: request the tool
return &ModelResponse{
Request: gr,
Message: &Message{
Role: RoleModel,
Content: []*Part{NewToolRequestPart(&ToolRequest{
Name: "imageGenerator",
Input: map[string]any{"Prompt": "a cat"},
Ref: "img1",
})},
},
}, nil
})

resp, err := Generate(context.Background(), r,
WithModel(multipartToolModel),
WithPrompt("Generate an image of a cat"),
WithTools(multipartTool),
)
if err != nil {
t.Fatalf("Generate failed: %v", err)
}

if resp.Text() != "Image generated successfully" {
t.Errorf("expected 'Image generated successfully', got %q", resp.Text())
}
})

t.Run("RunRawMultipart returns MultipartToolResponse for regular tool", func(t *testing.T) {
r := registry.New()

tool := DefineTool(r, "multipartWrapperTest", "test multipart wrapper",
func(ctx *ToolContext, input struct{ Value int }) (int, error) {
return input.Value * 3, nil
},
)

resp, err := tool.RunRawMultipart(context.Background(), map[string]any{"Value": 5})
if err != nil {
t.Fatalf("RunRawMultipart failed: %v", err)
}

// Output should be wrapped in MultipartToolResponse
output, ok := resp.Output.(float64) // JSON unmarshals numbers as float64
if !ok {
t.Fatalf("expected output to be float64, got %T", resp.Output)
}
if output != 15 {
t.Errorf("expected output 15, got %v", output)
}

// Content should be nil for regular tools
if resp.Content != nil {
t.Errorf("expected nil content for regular tool, got %v", resp.Content)
}
})

t.Run("RunRawMultipart returns full response for multipart tool", func(t *testing.T) {
r := registry.New()

tool := DefineMultipartTool(r, "multipartFullTest", "test multipart",
func(ctx *ToolContext, input struct{ Query string }) (*MultipartToolResponse, error) {
return &MultipartToolResponse{
Output: "result",
Content: []*Part{NewTextPart("additional content")},
}, nil
},
)

resp, err := tool.RunRawMultipart(context.Background(), map[string]any{"Query": "test"})
if err != nil {
t.Fatalf("RunRawMultipart failed: %v", err)
}

if resp.Output != "result" {
t.Errorf("expected output 'result', got %v", resp.Output)
}

if len(resp.Content) != 1 {
t.Fatalf("expected 1 content part, got %d", len(resp.Content))
}

if resp.Content[0].Text != "additional content" {
t.Errorf("expected content 'additional content', got %q", resp.Content[0].Text)
}
})
}
48 changes: 24 additions & 24 deletions go/ai/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,13 @@ func WithToolChoice(toolChoice ToolChoice) CommonGenOption {
return &commonGenOptions{ToolChoice: toolChoice}
}

// WithResources specifies resources to be temporarily available during generation.
// Resources are unregistered resources that get attached to a temporary registry
// during the generation request and cleaned up afterward.
func WithResources(resources ...Resource) CommonGenOption {
return &commonGenOptions{Resources: resources}
}

// inputOptions are options for the input of a prompt.
type inputOptions struct {
InputSchema map[string]any // JSON schema of the input.
Expand All @@ -265,6 +272,7 @@ type inputOptions struct {
type InputOption interface {
applyInput(*inputOptions) error
applyPrompt(*promptOptions) error
applyTool(*toolOptions) error
}

// applyInput applies the option to the input options.
Expand All @@ -291,9 +299,14 @@ func (o *inputOptions) applyPrompt(pOpts *promptOptions) error {
return o.applyInput(&pOpts.inputOptions)
}

// applyTool applies the option to the tool options.
func (o *inputOptions) applyTool(tOpts *toolOptions) error {
return o.applyInput(&tOpts.inputOptions)
}

// WithInputType uses the type provided to derive the input schema.
// The inputted value will serve as the default input if no input is given at generation time.
// Only supports structs and map[string]any api.
// The inputted value may serve as the default input if no input is given at generation time depending on the action.
// Only supports structs and map[string]any.
func WithInputType(input any) InputOption {
var defaultInput map[string]any

Expand Down Expand Up @@ -896,32 +909,19 @@ func WithToolRestarts(parts ...*Part) GenerateOption {
return &generateOptions{RestartParts: parts}
}

// WithResources specifies resources to be temporarily available during generation.
// Resources are unregistered resources that get attached to a temporary registry
// during the generation request and cleaned up afterward.
func WithResources(resources ...Resource) CommonGenOption {
return &withResources{resources: resources}
}

type withResources struct {
resources []Resource
}

func (w *withResources) applyCommonGen(o *commonGenOptions) error {
o.Resources = w.resources
return nil
}

func (w *withResources) applyPrompt(o *promptOptions) error {
return w.applyCommonGen(&o.commonGenOptions)
// toolOptions holds configuration options for defining tools.
type toolOptions struct {
inputOptions
}

func (w *withResources) applyGenerate(o *generateOptions) error {
return w.applyCommonGen(&o.commonGenOptions)
// ToolOption is an option for defining a tool.
type ToolOption interface {
applyTool(*toolOptions) error
}

func (w *withResources) applyPromptExecute(o *promptExecutionOptions) error {
return w.applyCommonGen(&o.commonGenOptions)
// applyTool applies the option to the tool options.
func (o *toolOptions) applyTool(opts *toolOptions) error {
return o.inputOptions.applyTool(opts)
}

// promptExecutionOptions are options for generating a model response by executing a prompt.
Expand Down
6 changes: 6 additions & 0 deletions go/ai/prompt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,9 @@ func TestValidPrompt(t *testing.T) {
"type": string("object"),
},
OutputSchema: map[string]any{"type": string("string")},
Metadata: map[string]any{
"multipart": false,
},
},
},
},
Expand Down Expand Up @@ -589,6 +592,9 @@ func TestValidPrompt(t *testing.T) {
"type": string("object"),
},
OutputSchema: map[string]any{"type": string("string")},
Metadata: map[string]any{
"multipart": false,
},
},
},
},
Expand Down
Loading
Loading