Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,6 @@
TextPart,
ToolDefinition,
)
from genkit._core._typing import GenerationCommonConfig
from genkit.model import Candidate, FinishReason, get_basic_usage_stats
from genkit.plugin_api import (
ActionRunContext,
Expand Down Expand Up @@ -1690,34 +1689,46 @@ def _normalize_config_to_dict(
self,
config: GeminiConfigSchema | ModelConfig | dict,
) -> dict[str, Any] | None:
"""Normalize any config type into a plain dict for uniform processing.
"""Return the config as a snake_case dict for the rest of the pipeline.

Handles three input shapes:
- GeminiConfigSchema (and subclasses like TTS/Image): model_dump
- ModelConfig: model_dump
- dict: route to the appropriate schema first, then model_dump
Callers can hand us three shapes: a typed ``GeminiConfigSchema``, the
generic ``GenerationCommonConfig`` (which keeps plugin-specific keys
as alias-form extras), or a raw dict in either casing. Only the
plugin schema knows the alias mapping (e.g. ``codeExecution`` <->
``code_execution``), so we re-validate through it whenever the input
isn't already one — that's what folds aliased keys onto their
canonical snake_case fields before tool extraction runs.

Returns:
A mutable dict ready for tool extraction and key cleanup,
or None if the config is empty after dumping.
Returns ``None`` if the config has no meaningful values.
"""
if isinstance(config, GeminiConfigSchema):
schema = config
elif isinstance(config, (ModelConfig, GenerationCommonConfig)):
schema = config
elif isinstance(config, ModelConfig):
# Re-route through the plugin schema so the alias machinery folds
# any plugin-specific extras onto their canonical fields.
schema = self._pick_plugin_schema(config.model_dump(exclude_none=True, by_alias=True))
elif isinstance(config, dict):
if 'image_config' in config:
schema = GeminiImageConfigSchema(**config)
elif 'speech_config' in config:
schema = GeminiTtsConfigSchema(**config)
else:
schema = GeminiConfigSchema(**config)
schema = self._pick_plugin_schema(config)
else:
return None

dumped = schema.model_dump(exclude_none=True, by_alias=False)
return dumped or None

def _pick_plugin_schema(self, data: dict[str, Any]) -> GeminiConfigSchema:
"""Validate ``data`` through whichever subclass matches the model.

Routing is purely by model name so each family gets its own
validation rules -- most importantly Gemma, which intentionally
relaxes the standard Gemini temperature bounds and would otherwise
reject valid configs. The per-request ``version`` override (when
present) takes precedence over the version this instance is bound
to, mirroring how the actual model name is resolved at call time.
"""
model_name = data.get('version') or self._version
schema_cls = get_model_config_schema(model_name)
return schema_cls.model_validate(data)

def _extract_tools_from_config(
self,
config: dict[str, Any],
Expand Down
123 changes: 123 additions & 0 deletions py/plugins/google-genai/test/models/googlegenai_gemini_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,15 @@
TextPart,
ToolDefinition,
)
from genkit._core._typing import GenerationCommonConfig
from genkit.plugin_api import to_json_schema
from genkit.plugins.google_genai.models.gemini import (
DEFAULT_SUPPORTS_MODEL,
GeminiConfigSchema,
GeminiImageConfigSchema,
GeminiModel,
GeminiTtsConfigSchema,
GemmaConfigSchema,
GoogleAIGeminiVersion,
VertexAIGeminiVersion,
google_model_info,
Expand Down Expand Up @@ -775,3 +780,121 @@ class MockPage(AsyncMock):
)

assert isinstance(cache, genai_types.CachedContent)


# ---------------------------------------------------------------------------
# Config normalization
#
# Plugin-specific keys like ``code_execution`` carry a camelCase alias
# (``codeExecution``) on the wire so that the Python and JS SDKs share the
# same JSON. Callers can hand the plugin three different shapes for the same
# logical config and we have to fold all of them onto the canonical
# snake_case field name before downstream translation runs. These tests pin
# that contract so a future refactor can't quietly let an alias-form key
# leak through to the strict ``GenerateContentConfig``.
# ---------------------------------------------------------------------------


@pytest.mark.parametrize(
('label', 'config'),
[
('snake_case dict', {'code_execution': True}),
('camelCase dict', {'codeExecution': True}),
(
'GenerationCommonConfig with alias-form extra',
GenerationCommonConfig.model_validate({'codeExecution': True}),
),
('GeminiConfigSchema instance', GeminiConfigSchema(code_execution=True)),
],
)
def test_gemini_model__normalize_config_canonicalizes_aliases(
gemini_model_instance: GeminiModel,
label: str,
config: object,
) -> None:
"""Every input shape collapses onto the canonical snake_case field."""
dumped = gemini_model_instance._normalize_config_to_dict(config)

assert dumped == {'code_execution': True}, label


@pytest.mark.asyncio
async def test_gemini_model__camelcase_code_execution_translates_to_tool(
gemini_model_instance: GeminiModel,
) -> None:
"""A camelCase convenience flag is translated into a tool, not leaked.

Reproduces the bug where ``ai.generate(config=GeminiConfigSchema(...).model_dump())``
produced an alias-form dict that fell through to the SDK's strict
``GenerateContentConfig`` and raised ``extra_forbidden``.
"""
request = ModelRequest(
messages=[Message(role=Role.USER, content=[Part(root=TextPart(text='hi'))])],
config=GeminiConfigSchema.model_validate({'code_execution': True}).model_dump(),
)

cfg = await gemini_model_instance._genkit_to_googleai_cfg(request)

assert cfg is not None
assert cfg.tools is not None
code_exec_tools = [t for t in cfg.tools if isinstance(t, genai_types.Tool) and t.code_execution is not None]
assert len(code_exec_tools) == 1
# The flag should not survive as an unknown SDK field in any casing.
assert 'codeExecution' not in cfg.model_dump(exclude_none=True)
assert 'code_execution' not in cfg.model_dump(exclude_none=True)


def test_gemini_model__normalize_config_picks_gemma_schema() -> None:
"""Gemma's relaxed temperature bounds survive normalization.

Gemma intentionally drops the [0.0, 2.0] cap that vanilla Gemini enforces,
so a config like ``temperature=3.0`` must be allowed when the bound model
is Gemma. If the routing falls back to the strict Gemini schema instead,
validation here would raise.
"""
gemma_model = GeminiModel(version='gemma-2-27b-it', client=MagicMock(spec=genai.Client))

dumped = gemma_model._normalize_config_to_dict({'temperature': 3.0})

assert dumped == {'temperature': 3.0}


def test_gemini_model__normalize_config_respects_version_override() -> None:
"""A per-request ``version`` override picks the matching schema.

Same model instance, but the caller overrides the version to a Gemma one,
so the schema selection has to follow the override -- otherwise the
instance's standard Gemini schema would reject the relaxed temperature.
"""
gemini_model = GeminiModel(version='gemini-2.0-flash-001', client=MagicMock(spec=genai.Client))

dumped = gemini_model._normalize_config_to_dict({'version': 'gemma-2-27b-it', 'temperature': 3.0})

assert dumped == {'version': 'gemma-2-27b-it', 'temperature': 3.0}


@pytest.mark.parametrize(
('version', 'expected_schema'),
[
('gemini-2.5-flash-preview-tts', GeminiTtsConfigSchema),
('gemini-2.0-flash-preview-image-generation', GeminiImageConfigSchema),
('gemma-2-27b-it', GemmaConfigSchema),
('gemini-2.0-flash-001', GeminiConfigSchema),
],
)
def test_gemini_model__pick_plugin_schema_routes_by_model_family(
version: str,
expected_schema: type[GeminiConfigSchema],
) -> None:
"""Each model family lands on its own schema based on the bound version.

Pins the routing contract so a future change can't quietly send TTS or
image models down the standard Gemini path (which would silently drop
their typed fields into ``extra='allow'`` and skip the family-specific
validation rules).
"""
model = GeminiModel(version=version, client=MagicMock(spec=genai.Client))

picked = model._pick_plugin_schema({})

assert type(picked) is expected_schema
Loading