Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 42 additions & 4 deletions google/genai/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,44 @@
logger = logging.getLogger('google_genai.models')


def _filter_thought_parts(
return_value: 'types.GenerateContentResponse',
parameter_model: 'types._GenerateContentParameters',
) -> 'types.GenerateContentResponse':
"""Filters thought parts from response when include_thoughts is False.

When the Vertex AI API returns thought parts despite include_thoughts=False
being set in ThinkingConfig, this function performs client-side filtering
to suppress them. The part.thought flag is reliably set by the API, so
filtering on it is safe.

Args:
return_value: The GenerateContentResponse to filter.
parameter_model: The request parameters, used to read ThinkingConfig.

Returns:
The response with thought parts removed if include_thoughts=False,
otherwise the response unchanged.
"""
config = parameter_model.config
if config is None:
return return_value
thinking_config = getattr(config, 'thinking_config', None)
if thinking_config is None:
return return_value
include_thoughts = getattr(thinking_config, 'include_thoughts', None)
if include_thoughts is not False:
return return_value
if not return_value.candidates:
return return_value
for candidate in return_value.candidates:
if candidate.content and candidate.content.parts:
candidate.content.parts = [
part for part in candidate.content.parts if not part.thought
]
return return_value


def _PersonGeneration_to_mldev_enum_validate(enum_value: Any) -> None:
if enum_value in set(['ALLOW_ALL']):
raise ValueError(f'{enum_value} enum value is not supported in Gemini API.')
Expand Down Expand Up @@ -4725,7 +4763,7 @@ def _generate_content(
headers=response.headers
)
self._api_client._verify_response(return_value)
return return_value
return _filter_thought_parts(return_value, parameter_model)

def _generate_content_stream(
self,
Expand Down Expand Up @@ -4826,7 +4864,7 @@ def _generate_content_stream(
headers=response.headers
)
self._api_client._verify_response(return_value)
yield return_value
yield _filter_thought_parts(return_value, parameter_model)

def _embed_content(
self,
Expand Down Expand Up @@ -6891,7 +6929,7 @@ async def _generate_content(
headers=response.headers
)
self._api_client._verify_response(return_value)
return return_value
return _filter_thought_parts(return_value, parameter_model)

async def _generate_content_stream(
self,
Expand Down Expand Up @@ -6995,7 +7033,7 @@ async def async_generator(): # type: ignore[no-untyped-def]
headers=response.headers
)
self._api_client._verify_response(return_value)
yield return_value
yield _filter_thought_parts(return_value, parameter_model)

return async_generator() # type: ignore[no-untyped-call, no-any-return]

Expand Down
139 changes: 139 additions & 0 deletions google/genai/tests/models/test_filter_thought_parts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""Unit tests for _filter_thought_parts.

Verifies client-side filtering of thought parts when include_thoughts=False
is set in ThinkingConfig. Regression test for:
https://github.com/googleapis/python-genai/issues/2239
"""

import pytest

from ... import _transformers as t
from ... import types
from ...models import _filter_thought_parts


def _make_response_with_thoughts() -> types.GenerateContentResponse:
"""Build a synthetic response resembling Vertex AI image gen with thoughts."""
parts = [
types.Part(text='thinking step 1', thought=True),
types.Part(text='thinking step 2', thought=True),
types.Part(inline_data=types.Blob(mime_type='image/png', data=b'draft'), thought=True),
types.Part(text='self-critique', thought=True),
types.Part(inline_data=types.Blob(mime_type='image/png', data=b'final')),
]
content = types.Content(role='model', parts=parts)
candidate = types.Candidate(content=content)
return types.GenerateContentResponse(candidates=[candidate])


def _make_parameter_model(include_thoughts) -> types._GenerateContentParameters:
return types._GenerateContentParameters(
model='gemini-3.1-flash-image-preview',
contents=t.t_contents('Draw a red car'),
config=types.GenerateContentConfig(
thinking_config=types.ThinkingConfig(include_thoughts=include_thoughts)
),
)


class TestFilterThoughtParts:

def test_include_thoughts_false_removes_thought_parts(self):
"""When include_thoughts=False, all parts with thought=True are removed."""
response = _make_response_with_thoughts()
parameter_model = _make_parameter_model(include_thoughts=False)

result = _filter_thought_parts(response, parameter_model)

parts = result.candidates[0].content.parts
assert all(not part.thought for part in parts), (
'Expected no thought parts but found some'
)
assert len(parts) == 1, f'Expected 1 non-thought part, got {len(parts)}'
assert parts[0].inline_data is not None
assert parts[0].inline_data.data == b'final'

def test_include_thoughts_true_preserves_all_parts(self):
"""When include_thoughts=True, no parts are filtered."""
response = _make_response_with_thoughts()
parameter_model = _make_parameter_model(include_thoughts=True)

result = _filter_thought_parts(response, parameter_model)

parts = result.candidates[0].content.parts
assert len(parts) == 5, f'Expected 5 parts, got {len(parts)}'

def test_include_thoughts_none_preserves_all_parts(self):
"""When include_thoughts is unset, no parts are filtered."""
response = _make_response_with_thoughts()
parameter_model = _make_parameter_model(include_thoughts=None)

result = _filter_thought_parts(response, parameter_model)

parts = result.candidates[0].content.parts
assert len(parts) == 5

def test_no_thinking_config_preserves_all_parts(self):
"""When ThinkingConfig is absent entirely, no parts are filtered."""
response = _make_response_with_thoughts()
parameter_model = types._GenerateContentParameters(
model='gemini-3.1-flash-image-preview',
contents=t.t_contents('Draw a red car'),
config=types.GenerateContentConfig(),
)

result = _filter_thought_parts(response, parameter_model)

parts = result.candidates[0].content.parts
assert len(parts) == 5

def test_no_config_preserves_all_parts(self):
"""When config is None entirely, no parts are filtered."""
response = _make_response_with_thoughts()
parameter_model = types._GenerateContentParameters(
model='gemini-3.1-flash-image-preview',
contents=t.t_contents('Draw a red car'),
)

result = _filter_thought_parts(response, parameter_model)

parts = result.candidates[0].content.parts
assert len(parts) == 5

def test_empty_candidates_is_safe(self):
"""Response with no candidates does not raise."""
response = types.GenerateContentResponse(candidates=[])
parameter_model = _make_parameter_model(include_thoughts=False)

result = _filter_thought_parts(response, parameter_model)

assert result.candidates == []

def test_no_thought_parts_in_response(self):
"""If API returns no thought parts, filtering is a no-op."""
parts = [
types.Part(inline_data=types.Blob(mime_type='image/png', data=b'final')),
]
content = types.Content(role='model', parts=parts)
candidate = types.Candidate(content=content)
response = types.GenerateContentResponse(candidates=[candidate])
parameter_model = _make_parameter_model(include_thoughts=False)

result = _filter_thought_parts(response, parameter_model)

assert len(result.candidates[0].content.parts) == 1
Loading