From ff631cac0affef4b179536f57b751c42f9e385ea Mon Sep 17 00:00:00 2001 From: Matt Hicks Date: Sat, 23 Aug 2025 13:04:35 -0700 Subject: [PATCH] fix: properly wrap Vertex AI beta streaming responses Fixes tool input parameters being lost when using Claude models via Vertex AI with streaming enabled. Both create(stream=True) and stream() methods now properly handle event accumulation. The Vertex beta messages implementation was returning plain Stream objects instead of BetaMessageStream/BetaMessageStreamManager, bypassing the event accumulation logic entirely. This fix overrides create() and stream() methods in /src/anthropic/lib/vertex/_beta_messages.py to: - Wrap create(stream=True) responses in BetaMessageStream - Return BetaMessageStreamManager from stream() - Apply same fixes to async implementations Fixes #1020 --- src/anthropic/lib/vertex/_beta_messages.py | 336 ++++++++++++++++++++- 1 file changed, 332 insertions(+), 4 deletions(-) diff --git a/src/anthropic/lib/vertex/_beta_messages.py b/src/anthropic/lib/vertex/_beta_messages.py index 72b97b049..569c9db09 100644 --- a/src/anthropic/lib/vertex/_beta_messages.py +++ b/src/anthropic/lib/vertex/_beta_messages.py @@ -2,20 +2,197 @@ from __future__ import annotations +from typing import TYPE_CHECKING, List, Union, Iterable, Optional +from functools import partial +from typing_extensions import Literal + +import httpx + from ... import _legacy_response +from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven from ..._compat import cached_property from ..._resource import SyncAPIResource, AsyncAPIResource from ..._response import to_streamed_response_wrapper, async_to_streamed_response_wrapper +from ..._streaming import Stream, AsyncStream +from ...types.beta import BetaMessage, BetaRawMessageStreamEvent +from ...lib.streaming import ( + BetaMessageStream, + BetaAsyncMessageStream, + BetaMessageStreamManager, + BetaAsyncMessageStreamManager, +) from ...resources.beta import Messages as FirstPartyMessagesAPI, AsyncMessages as FirstPartyAsyncMessagesAPI +from ...types.model_param import ModelParam +from ...types.anthropic_beta_param import AnthropicBetaParam +from ...types.beta.beta_message_param import BetaMessageParam +from ...types.beta.beta_metadata_param import BetaMetadataParam +from ...types.beta.beta_text_block_param import BetaTextBlockParam +from ...types.beta.beta_tool_union_param import BetaToolUnionParam +from ...types.beta.beta_tool_choice_param import BetaToolChoiceParam +from ...types.beta.beta_thinking_config_param import BetaThinkingConfigParam +from ...types.beta.beta_request_mcp_server_url_definition_param import BetaRequestMCPServerURLDefinitionParam + +if TYPE_CHECKING: + pass __all__ = ["Messages", "AsyncMessages"] class Messages(SyncAPIResource): - create = FirstPartyMessagesAPI.create - stream = FirstPartyMessagesAPI.stream + # Delegate count_tokens to the first-party implementation count_tokens = FirstPartyMessagesAPI.count_tokens + def create( + self, + *, + max_tokens: int, + messages: Iterable[BetaMessageParam], + model: ModelParam, + container: Optional[str] | NotGiven = NOT_GIVEN, + mcp_servers: Iterable[BetaRequestMCPServerURLDefinitionParam] | NotGiven = NOT_GIVEN, + metadata: BetaMetadataParam | NotGiven = NOT_GIVEN, + service_tier: Literal["auto", "standard_only"] | NotGiven = NOT_GIVEN, + stop_sequences: List[str] | NotGiven = NOT_GIVEN, + stream: Literal[False] | Literal[True] | NotGiven = NOT_GIVEN, + system: Union[str, Iterable[BetaTextBlockParam]] | NotGiven = NOT_GIVEN, + temperature: float | NotGiven = NOT_GIVEN, + thinking: BetaThinkingConfigParam | NotGiven = NOT_GIVEN, + tool_choice: BetaToolChoiceParam | NotGiven = NOT_GIVEN, + tools: Iterable[BetaToolUnionParam] | NotGiven = NOT_GIVEN, + top_k: int | NotGiven = NOT_GIVEN, + top_p: float | NotGiven = NOT_GIVEN, + betas: List[AnthropicBetaParam] | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> BetaMessage | Stream[BetaRawMessageStreamEvent]: + """ + Create a message using the Vertex AI endpoint. + + When streaming is enabled, this wraps the response in BetaMessageStream + for proper event accumulation, particularly for tool_use inputs. + """ + # If streaming is enabled, wrap the response in BetaMessageStream for accumulation + if stream is True: + # Get the raw stream from the first-party API + raw_stream = FirstPartyMessagesAPI.create( + self, + max_tokens=max_tokens, + messages=messages, + model=model, + container=container, + mcp_servers=mcp_servers, + metadata=metadata, + service_tier=service_tier, + stop_sequences=stop_sequences, + stream=True, + system=system, + temperature=temperature, + thinking=thinking, + tool_choice=tool_choice, + tools=tools, + top_k=top_k, + top_p=top_p, + betas=betas, + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + ) + # Wrap in BetaMessageStream which has the accumulation logic + # This ensures tool inputs are properly accumulated from delta events + return BetaMessageStream(raw_stream) + + # For non-streaming, delegate normally + return FirstPartyMessagesAPI.create( + self, + max_tokens=max_tokens, + messages=messages, + model=model, + container=container, + mcp_servers=mcp_servers, + metadata=metadata, + service_tier=service_tier, + stop_sequences=stop_sequences, + stream=stream, + system=system, + temperature=temperature, + thinking=thinking, + tool_choice=tool_choice, + tools=tools, + top_k=top_k, + top_p=top_p, + betas=betas, + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + ) + + def stream( + self, + *, + max_tokens: int, + messages: Iterable[BetaMessageParam], + model: ModelParam, + container: Optional[str] | NotGiven = NOT_GIVEN, + mcp_servers: Iterable[BetaRequestMCPServerURLDefinitionParam] | NotGiven = NOT_GIVEN, + metadata: BetaMetadataParam | NotGiven = NOT_GIVEN, + service_tier: Literal["auto", "standard_only"] | NotGiven = NOT_GIVEN, + stop_sequences: List[str] | NotGiven = NOT_GIVEN, + system: Union[str, Iterable[BetaTextBlockParam]] | NotGiven = NOT_GIVEN, + temperature: float | NotGiven = NOT_GIVEN, + thinking: BetaThinkingConfigParam | NotGiven = NOT_GIVEN, + tool_choice: BetaToolChoiceParam | NotGiven = NOT_GIVEN, + tools: Iterable[BetaToolUnionParam] | NotGiven = NOT_GIVEN, + top_k: int | NotGiven = NOT_GIVEN, + top_p: float | NotGiven = NOT_GIVEN, + betas: List[AnthropicBetaParam] | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> BetaMessageStreamManager: + """ + Create a streaming message using the Vertex AI endpoint. + + This method ensures that the response is properly wrapped in a BetaMessageStreamManager + for correct event accumulation, particularly for tool_use inputs. + """ + # Create a function that makes the streaming request + make_request = partial( + self.create, + max_tokens=max_tokens, + messages=messages, + model=model, + container=container, + mcp_servers=mcp_servers, + metadata=metadata, + service_tier=service_tier, + stop_sequences=stop_sequences, + stream=True, # Force streaming + system=system, + temperature=temperature, + thinking=thinking, + tool_choice=tool_choice, + tools=tools, + top_k=top_k, + top_p=top_p, + betas=betas, + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + ) + + # Return the proper stream manager wrapper + return BetaMessageStreamManager(make_request) + @cached_property def with_raw_response(self) -> MessagesWithRawResponse: """ @@ -37,10 +214,161 @@ def with_streaming_response(self) -> MessagesWithStreamingResponse: class AsyncMessages(AsyncAPIResource): - create = FirstPartyAsyncMessagesAPI.create - stream = FirstPartyAsyncMessagesAPI.stream + # Delegate count_tokens to the first-party implementation count_tokens = FirstPartyAsyncMessagesAPI.count_tokens + async def create( + self, + *, + max_tokens: int, + messages: Iterable[BetaMessageParam], + model: ModelParam, + container: Optional[str] | NotGiven = NOT_GIVEN, + mcp_servers: Iterable[BetaRequestMCPServerURLDefinitionParam] | NotGiven = NOT_GIVEN, + metadata: BetaMetadataParam | NotGiven = NOT_GIVEN, + service_tier: Literal["auto", "standard_only"] | NotGiven = NOT_GIVEN, + stop_sequences: List[str] | NotGiven = NOT_GIVEN, + stream: Literal[False] | Literal[True] | NotGiven = NOT_GIVEN, + system: Union[str, Iterable[BetaTextBlockParam]] | NotGiven = NOT_GIVEN, + temperature: float | NotGiven = NOT_GIVEN, + thinking: BetaThinkingConfigParam | NotGiven = NOT_GIVEN, + tool_choice: BetaToolChoiceParam | NotGiven = NOT_GIVEN, + tools: Iterable[BetaToolUnionParam] | NotGiven = NOT_GIVEN, + top_k: int | NotGiven = NOT_GIVEN, + top_p: float | NotGiven = NOT_GIVEN, + betas: List[AnthropicBetaParam] | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> BetaMessage | AsyncStream[BetaRawMessageStreamEvent]: + """ + Create a message using the Vertex AI endpoint. + + When streaming is enabled, this properly wraps the response in BetaAsyncMessageStream + for proper event accumulation, particularly for tool_use inputs. + """ + # If streaming is enabled, wrap the response in BetaAsyncMessageStream for accumulation + if stream is True: + # Get the raw stream from the first-party API + raw_stream = await FirstPartyAsyncMessagesAPI.create( + self, + max_tokens=max_tokens, + messages=messages, + model=model, + container=container, + mcp_servers=mcp_servers, + metadata=metadata, + service_tier=service_tier, + stop_sequences=stop_sequences, + stream=True, + system=system, + temperature=temperature, + thinking=thinking, + tool_choice=tool_choice, + tools=tools, + top_k=top_k, + top_p=top_p, + betas=betas, + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + ) + # Wrap in BetaAsyncMessageStream which has the accumulation logic + # This ensures tool inputs are properly accumulated from delta events + return BetaAsyncMessageStream(raw_stream) + + # For non-streaming, delegate normally + return await FirstPartyAsyncMessagesAPI.create( + self, + max_tokens=max_tokens, + messages=messages, + model=model, + container=container, + mcp_servers=mcp_servers, + metadata=metadata, + service_tier=service_tier, + stop_sequences=stop_sequences, + stream=stream, + system=system, + temperature=temperature, + thinking=thinking, + tool_choice=tool_choice, + tools=tools, + top_k=top_k, + top_p=top_p, + betas=betas, + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + ) + + def stream( + self, + *, + max_tokens: int, + messages: Iterable[BetaMessageParam], + model: ModelParam, + container: Optional[str] | NotGiven = NOT_GIVEN, + mcp_servers: Iterable[BetaRequestMCPServerURLDefinitionParam] | NotGiven = NOT_GIVEN, + metadata: BetaMetadataParam | NotGiven = NOT_GIVEN, + service_tier: Literal["auto", "standard_only"] | NotGiven = NOT_GIVEN, + stop_sequences: List[str] | NotGiven = NOT_GIVEN, + system: Union[str, Iterable[BetaTextBlockParam]] | NotGiven = NOT_GIVEN, + temperature: float | NotGiven = NOT_GIVEN, + thinking: BetaThinkingConfigParam | NotGiven = NOT_GIVEN, + tool_choice: BetaToolChoiceParam | NotGiven = NOT_GIVEN, + tools: Iterable[BetaToolUnionParam] | NotGiven = NOT_GIVEN, + top_k: int | NotGiven = NOT_GIVEN, + top_p: float | NotGiven = NOT_GIVEN, + betas: List[AnthropicBetaParam] | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> BetaAsyncMessageStreamManager: + """ + Create a streaming message using the Vertex AI endpoint. + + This method ensures that the response is properly wrapped in a BetaAsyncMessageStreamManager + for correct event accumulation, particularly for tool_use inputs. + """ + + # Create an async function that makes the streaming request + async def make_request(): + return await self.create( + max_tokens=max_tokens, + messages=messages, + model=model, + container=container, + mcp_servers=mcp_servers, + metadata=metadata, + service_tier=service_tier, + stop_sequences=stop_sequences, + stream=True, # Force streaming + system=system, + temperature=temperature, + thinking=thinking, + tool_choice=tool_choice, + tools=tools, + top_k=top_k, + top_p=top_p, + betas=betas, + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + ) + + # Return the proper async stream manager wrapper + return BetaAsyncMessageStreamManager(make_request()) + @cached_property def with_raw_response(self) -> AsyncMessagesWithRawResponse: """