Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions go/plugins/googlegenai/gemini.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,26 @@ func newModel(client *genai.Client, name string, opts ai.ModelOptions) ai.Model
return ai.NewModel(api.NewName(provider, name), meta, fn)
}

// resolveVertexModelName prepares a model name for the google.golang.org/genai
// SDK. The SDK transforms most names into `publishers/google/models/NAME`,
// which is wrong for tuned endpoints. For a short-form tuned endpoint name
// (`endpoints/ID`), this expands it to the full resource path
// `projects/PROJECT/locations/LOCATION/endpoints/ID` using the client's
// configured project and location. Other names are returned unchanged.
func resolveVertexModelName(client *genai.Client, name string) string {
Comment thread
cabljac marked this conversation as resolved.
if !isTunedGeminiName(name) {
return name
}
if strings.HasPrefix(name, "projects/") {
return name
}
cc := client.ClientConfig()
if cc.Backend != genai.BackendVertexAI || cc.Project == "" || cc.Location == "" {
Comment thread
cabljac marked this conversation as resolved.
return name
}
return fmt.Sprintf("projects/%s/locations/%s/%s", cc.Project, cc.Location, name)
}

// generate requests generate call to the specified model with the provided
// configuration.
func generate(
Expand All @@ -157,6 +177,7 @@ func generate(
if model == "" {
return nil, errors.New("model not provided")
}
model = resolveVertexModelName(client, model)

cache, err := handleCache(ctx, client, input, model)
if err != nil {
Expand Down
27 changes: 18 additions & 9 deletions go/plugins/googlegenai/googlegenai.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,24 +207,33 @@ func (ga *GoogleAI) DefineModel(g *genkit.Genkit, name string, opts *ai.ModelOpt
// The second argument describes the capability of the model.
// Use [IsDefinedModel] to determine if a model is already defined.
// After [Init] is called, only the known models are defined.
//
// Tuned Gemini endpoints are accepted in either the short form
// `endpoints/ID` or the full resource path
// `projects/PROJECT/locations/LOCATION/endpoints/ID`. When opts is nil the
// caller gets the default Gemini capability set.
func (v *VertexAI) DefineModel(g *genkit.Genkit, name string, opts *ai.ModelOptions) (ai.Model, error) {
Comment thread
cabljac marked this conversation as resolved.
v.mu.Lock()
defer v.mu.Unlock()
if !v.initted {
return nil, errors.New("VertexAI plugin not initialized")
}
models, err := listModels(vertexAIProvider)
if err != nil {
return nil, err
}

if opts == nil {
var ok bool
modelOpts, ok := models[name]
if !ok {
return nil, fmt.Errorf("VertexAI.DefineModel: called with unknown model %q and nil ModelOptions", name)
if isTunedGeminiName(name) {
defaults := GetModelOptions(name, vertexAIProvider)
opts = &defaults
Comment thread
cabljac marked this conversation as resolved.
} else {
models, err := listModels(vertexAIProvider)
if err != nil {
return nil, err
}
modelOpts, ok := models[name]
if !ok {
return nil, fmt.Errorf("VertexAI.DefineModel: called with unknown model %q and nil ModelOptions", name)
}
opts = &modelOpts
}
opts = &modelOpts
}

return newModel(v.gclient, name, *opts), nil
Expand Down
20 changes: 20 additions & 0 deletions go/plugins/googlegenai/model_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,31 @@ func ClassifyModel(name string) ModelType {
case strings.Contains(name, "embedding"):
// Covers: text-embedding-*, embedding-*, textembedding-*, multimodalembedding
return ModelTypeEmbedder
case isTunedGeminiName(name):
// Vertex tuned Gemini models, addressed either by `endpoints/ID` or a
// full `projects/.../endpoints/ID` path. They speak the Gemini
// generateContent protocol, so dispatch them as Gemini.
return ModelTypeGemini
default:
return ModelTypeUnknown
}
}

// isTunedGeminiName reports whether name refers to a Vertex AI tuned Gemini
// endpoint, either by its short form (`endpoints/ID`) or its fully qualified
// resource path (`projects/PROJECT/locations/LOCATION/endpoints/ID`).
func isTunedGeminiName(name string) bool {
if strings.HasPrefix(name, "endpoints/") {
return true
}
if strings.HasPrefix(name, "projects/") &&
strings.Contains(name, "/locations/") &&
strings.Contains(name, "/endpoints/") {
return true
}
return false
}

// ActionType returns the appropriate API action type for this model type.
func (mt ModelType) ActionType() api.ActionType {
switch mt {
Expand Down
128 changes: 128 additions & 0 deletions go/plugins/googlegenai/tuned_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0

package googlegenai

import (
"context"
"net/http"
"testing"

"google.golang.org/genai"
)

func TestClassifyModelTunedEndpoint(t *testing.T) {
cases := []struct {
name string
want ModelType
}{
{"endpoints/1234567890", ModelTypeGemini},
{"projects/my-proj/locations/us-central1/endpoints/1234567890", ModelTypeGemini},
{"gemini-2.5-flash", ModelTypeGemini},
{"imagen-3.0-generate-001", ModelTypeImagen},
{"veo-3.0-generate-001", ModelTypeVeo},
{"text-embedding-004", ModelTypeEmbedder},
{"random-name-with-no-prefix", ModelTypeUnknown},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
if got := ClassifyModel(tc.name); got != tc.want {
t.Fatalf("ClassifyModel(%q) = %v, want %v", tc.name, got, tc.want)
}
})
}
}

func TestResolveVertexModelName(t *testing.T) {
ctx := context.Background()

vertex, err := genai.NewClient(ctx, &genai.ClientConfig{
Backend: genai.BackendVertexAI,
Project: "test-project",
Location: "us-central1",
HTTPClient: &http.Client{},
})
if err != nil {
t.Fatalf("genai.NewClient (vertex): %v", err)
}

geminiAPI, err := genai.NewClient(ctx, &genai.ClientConfig{
Backend: genai.BackendGeminiAPI,
APIKey: "test-key",
})
if err != nil {
t.Fatalf("genai.NewClient (gemini): %v", err)
}

cases := []struct {
name string
client *genai.Client
in string
want string
}{
{
name: "short form on Vertex expands",
client: vertex,
in: "endpoints/1234567890",
want: "projects/test-project/locations/us-central1/endpoints/1234567890",
},
{
name: "fully qualified path is unchanged",
client: vertex,
in: "projects/my-proj/locations/us-central1/endpoints/999",
want: "projects/my-proj/locations/us-central1/endpoints/999",
},
{
name: "non-tuned name is unchanged",
client: vertex,
in: "gemini-2.5-flash",
want: "gemini-2.5-flash",
},
{
name: "short form on Gemini API backend is unchanged",
client: geminiAPI,
in: "endpoints/1234567890",
want: "endpoints/1234567890",
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
if got := resolveVertexModelName(tc.client, tc.in); got != tc.want {
t.Errorf("resolveVertexModelName(%q) = %q, want %q", tc.in, got, tc.want)
}
})
}
}

func TestIsTunedGeminiName(t *testing.T) {
cases := []struct {
name string
want bool
}{
{"endpoints/1234567890", true},
{"projects/p/locations/us-central1/endpoints/999", true},
{"projects/p/endpoints/999", false},
{"gemini-2.5-flash", false},
{"imagen-3.0-generate-001", false},
{"projects/p/locations/us-central1/publishers/google/models/gemini-2.5-flash", false},
{"", false},
}
for _, tc := range cases {
if got := isTunedGeminiName(tc.name); got != tc.want {
t.Errorf("isTunedGeminiName(%q) = %v, want %v", tc.name, got, tc.want)
}
}
}
30 changes: 30 additions & 0 deletions go/plugins/googlegenai/vertexai_live_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -391,4 +391,34 @@ func TestVertexAILive(t *testing.T) {
t.Errorf("expecting 0 thought tokens, got %d", resp.Usage.ThoughtsTokens)
}
})
t.Run("tuned gemini endpoint", func(t *testing.T) {
endpointID := os.Getenv("GENKIT_VERTEX_TUNED_ENDPOINT")
if endpointID == "" {
t.Skip("GENKIT_VERTEX_TUNED_ENDPOINT not set; skipping tuned endpoint live test")
}
modelName := endpointID
if !strings.HasPrefix(modelName, "endpoints/") && !strings.HasPrefix(modelName, "projects/") {
modelName = "endpoints/" + modelName
}

// Use a fresh Genkit instance so we can DefineModel on the Vertex
// plugin before Generate runs.
plugin := &googlegenai.VertexAI{ProjectID: projectID, Location: location}
gTuned := genkit.Init(ctx, genkit.WithPlugins(plugin))
m, err := plugin.DefineModel(gTuned, modelName, nil)
if err != nil {
t.Fatalf("failed to register tuned model %q: %v", modelName, err)
}

resp, err := genkit.Generate(ctx, gTuned,
ai.WithModel(m),
ai.WithPrompt("Say hello in one short sentence."),
)
if err != nil {
t.Fatal(err)
}
if strings.TrimSpace(resp.Text()) == "" {
t.Fatal("expected a non-empty response from the tuned endpoint")
}
})
}
Loading