From 70b318c04079b6620055e7bbfaefdb2ed4cbdc05 Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Tue, 1 Jul 2025 18:18:29 +0000 Subject: [PATCH 01/11] Add OpenAIModelHandler for Beam ML Inference This commit introduces an OpenAIModelHandler to enable sending inference requests to OpenAI's API (e.g., GPT models) via the RunInference transform. Key changes: - Added `OpenAIModelHandler` in `sdks/python/apache_beam/ml/inference/openai_inference.py`. - Implemented request generation and response parsing for both completion and chat-based OpenAI models within the `generate_completion` helper function. - Included retry logic for common OpenAI API errors (rate limits, server errors). - Added comprehensive unit tests in `sdks/python/apache_beam/ml/inference/openai_inference_test.py`, mocking the OpenAI client and testing various scenarios. All unit tests pass. - Created integration tests in `sdks/python/apache_beam/ml/inference/openai_inference_it_test.py` (requires `OPENAI_API_KEY` to run). - Defined dependencies in `sdks/python/apache_beam/ml/inference/openai_tests_requirements.txt`, now with Apache license header. This handler follows the pattern of other remote model handlers like those for Gemini and Vertex AI. The handler is not exported by default from `apache_beam.ml.inference`. --- .../ml/inference/openai_inference.py | 231 +++++++++++++++ .../ml/inference/openai_inference_it_test.py | 159 ++++++++++ .../ml/inference/openai_inference_test.py | 280 ++++++++++++++++++ .../inference/openai_tests_requirements.txt | 17 ++ 4 files changed, 687 insertions(+) create mode 100644 sdks/python/apache_beam/ml/inference/openai_inference.py create mode 100644 sdks/python/apache_beam/ml/inference/openai_inference_it_test.py create mode 100644 sdks/python/apache_beam/ml/inference/openai_inference_test.py create mode 100644 sdks/python/apache_beam/ml/inference/openai_tests_requirements.txt diff --git a/sdks/python/apache_beam/ml/inference/openai_inference.py b/sdks/python/apache_beam/ml/inference/openai_inference.py new file mode 100644 index 000000000000..436c94cb1927 --- /dev/null +++ b/sdks/python/apache_beam/ml/inference/openai_inference.py @@ -0,0 +1,231 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import logging +from collections.abc import Callable +from collections.abc import Iterable +from collections.abc import Sequence +from typing import Any +from typing import Optional + +# pylint: disable=wrong-import-order, wrong-import-position +try: + import openai + from openai import APIError + from openai import OpenAIError + from openai import RateLimitError +except ImportError: + raise ImportError( + 'OpenAI dependencies are not installed. To use OpenAI model handler,' + 'run pip install apache-beam[gcp,openai]') + +from apache_beam.ml.inference import utils +from apache_beam.ml.inference.base import PredictionResult +from apache_beam.ml.inference.base import RemoteModelHandler + +LOGGER = logging.getLogger("OpenAIModelHandler") + + +def _retry_on_appropriate_openai_error(exception: Exception) -> bool: + """ + Retry filter that returns True if a returned HTTP error code is 5xx or 429 + (RateLimitError). + """ + LOGGER.debug(f"Checking exception for retry: {type(exception)} - {str(exception)}") + if isinstance(exception, RateLimitError): + LOGGER.debug("RateLimitError detected, retrying.") + return True # Always retry RateLimitError (HTTP 429) + + if isinstance(exception, APIError): # This covers APIStatusError as well + status_code = getattr(exception, 'status_code', None) + LOGGER.debug(f"APIError detected. Status code from getattr: {status_code}") + if status_code is not None: + LOGGER.debug(f"Condition check: {status_code} >= 500 is {status_code >= 500}") + return status_code >= 500 # Retry on 5xx errors + else: + LOGGER.debug("APIError but status_code is None.") + + LOGGER.debug("Exception not eligible for retry by this filter.") + return False # Do not retry for other errors or if status_code is not available + + +def generate_completion( + model_name: str, + batch: Sequence[str], + client: openai.OpenAI, + inference_args: dict[str, Any]): + """ + Generates completions for a batch of prompts using the OpenAI API. + """ + responses = [] + for prompt in batch: + try: + # Note: OpenAI's library expects a single prompt for completions.create, + # so we iterate and call. Batching is handled by RunInference. + # For chat models, multiple messages can be part of a single request. + if "chat.completions" in client.chat.completions.with_raw_response.create.binary_relative_path: # rough check + # Assuming chat model if path indicates chat completions + # User might need to format input as list of messages + # For simplicity, we'll assume a single user message per prompt string + # for now. + if not isinstance(prompt, list): # basic check for message format + messages = [{"role": "user", "content": prompt}] + else: # assume prompt is already in message format + messages = prompt + response = client.chat.completions.create( + model=model_name, messages=messages, **inference_args) + else: + response = client.completions.create( + model=model_name, prompt=prompt, **inference_args) + responses.append(response) + except OpenAIError as e: + # Capture individual errors to potentially return partial results + # or raise a combined error. For now, let it propagate to be caught + # by the RemoteModelHandler's retry logic. + LOGGER.error("OpenAI API error for prompt '%s': %s", prompt, e) + raise e + + # Parse responses within the generate_completion function + parsed_responses = [] + for response_obj in responses: + if hasattr(response_obj, 'choices'): + if response_obj.choices: + # For ChatCompletion, the message is nested + if hasattr(response_obj.choices[0], 'message') and \ + hasattr(response_obj.choices[0].message, 'content'): + parsed_responses.append(response_obj.choices[0].message.content) + # For Completion (older models) + elif hasattr(response_obj.choices[0], 'text'): + parsed_responses.append(response_obj.choices[0].text) + else: + LOGGER.warning("Unrecognized OpenAI response choice format: %s", response_obj.choices[0]) + parsed_responses.append(None) # Or raise error + else: + LOGGER.warning("OpenAI response had no choices: %s", response_obj) + parsed_responses.append(None) # Or raise error + else: + LOGGER.warning("Unrecognized OpenAI response format: %s", response_obj) + parsed_responses.append(None) # Or raise error + return parsed_responses + + +class OpenAIModelHandler(RemoteModelHandler[Any, PredictionResult, + openai.OpenAI]): + def __init__( + self, + api_key: str, + model: str, # Recommended to use 'model' like in openai library + request_fn: Callable[[str, Sequence[Any], openai.OpenAI, dict[str, Any]], + Any] = generate_completion, + *, + min_batch_size: Optional[int] = None, + max_batch_size: Optional[int] = None, + max_batch_duration_secs: Optional[int] = None, + **kwargs): + """Implementation of the ModelHandler interface for OpenAI. + **NOTE:** This API and its implementation are under development and + do not provide backward compatibility guarantees. + + Args: + api_key: the OpenAI API key to use for the requests. + model: The OpenAI model to use for inference (e.g., "gpt-3.5-turbo-instruct", "gpt-3.5-turbo"). + request_fn: the function to use to send the request. Should take the + model name and the parameters from request() and return the responses + from OpenAI. The class will handle bundling the inputs and responses + together. Defaults to `generate_completion`. + min_batch_size: optional. the minimum batch size to use when batching + inputs. + max_batch_size: optional. the maximum batch size to use when batching + inputs. + max_batch_duration_secs: optional. the maximum amount of time to buffer + a batch before emitting; used in streaming contexts. + kwargs: Other arguments to pass to the underlying RemoteModelHandler. + """ + self._batching_kwargs = {} + self._env_vars = kwargs.get('env_vars', {}) + if min_batch_size is not None: + self._batching_kwargs["min_batch_size"] = min_batch_size + if max_batch_size is not None: + self._batching_kwargs["max_batch_size"] = max_batch_size + if max_batch_duration_secs is not None: + self._batching_kwargs["max_batch_duration_secs"] = max_batch_duration_secs + + self.api_key = api_key + self.model_name = model # Renamed from model_name to model for consistency + self.request_fn = request_fn + + # OpenAI client will be initialized in create_client + self._client: Optional[openai.OpenAI] = None + + super().__init__( + namespace='OpenAIModelHandler', + retry_filter=_retry_on_appropriate_openai_error, + **kwargs) + + def create_client(self) -> openai.OpenAI: + """Creates the OpenAI client used to send requests.""" + if not self._client: + self._client = openai.OpenAI(api_key=self.api_key) + return self._client + + def request( + self, + batch: Sequence[Any], + model_client: openai.OpenAI, # Parameter name changed for clarity + inference_args: Optional[dict[str, Any]] = None + ) -> Iterable[PredictionResult]: + """ Sends a prediction request to an OpenAI model containing a batch + of inputs and matches that input with the prediction response from + the endpoint as an iterable of PredictionResults. + + Args: + batch: a sequence of any values to be passed to the OpenAI model. + Should be inputs accepted by the provided request_fn. + model_client: an openai.OpenAI client object. + inference_args: any additional arguments to send as part of the + prediction request to OpenAI (e.g., temperature, max_tokens). + + Returns: + An iterable of PredictionResults. + """ + if inference_args is None: + inference_args = {} + + # The `generate_completion` function now iterates through the batch + # and makes individual API calls if necessary (e.g. for non-chat models) + # or a single call if the underlying API supports batching (e.g. future chat models). + # The RunInference transform handles the primary batching of elements from PCollection. + try: + # request_fn (generate_completion) now returns a list of parsed strings/content + parsed_responses = self.request_fn( + self.model_name, batch, model_client, inference_args) + except Exception as e: + LOGGER.error( + "Error during OpenAI request for batch: %s. Error: %s", batch, e) + # Propagate the error to allow RemoteModelHandler's retry logic to kick in + raise + + return utils._convert_to_result(batch, parsed_responses, self.model_name) + + def batch_elements_kwargs(self) -> dict[str, Any]: + return self._batching_kwargs + + def validate_inference_args(self, inference_args: Optional[dict[str, Any]]): + # OpenAI's API takes various arguments, most are optional. + # No specific validation needed at this level for common args like + # temperature, max_tokens, etc. The API itself will validate. + pass diff --git a/sdks/python/apache_beam/ml/inference/openai_inference_it_test.py b/sdks/python/apache_beam/ml/inference/openai_inference_it_test.py new file mode 100644 index 000000000000..38ada5690a6b --- /dev/null +++ b/sdks/python/apache_beam/ml/inference/openai_inference_it_test.py @@ -0,0 +1,159 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""End-to-End test for OpenAI Remote Inference""" + +import logging +import os +import unittest +import uuid + +import apache_beam as beam +import pytest +from apache_beam.io.filesystems import FileSystems +from apache_beam.ml.inference.base import PredictionResult +from apache_beam.ml.inference.base import RunInference +from apache_beam.testing.test_pipeline import TestPipeline + +# pylint: disable=ungrouped-imports +try: + from apache_beam.ml.inference.openai_inference import OpenAIModelHandler +except ImportError: + raise unittest.SkipTest('OpenAI dependencies are not installed') + +# Default output directory for test results +_OUTPUT_DIR_DEFAULT = "gs://apache-beam-ml/testing/outputs/openai" +# Placeholder for API key. Users must set OPENAI_API_KEY environment variable. +_OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY") + +# Models for testing - one completion, one chat +_COMPLETION_MODEL = "gpt-3.5-turbo-instruct" # A smaller, faster completion model +_CHAT_MODEL = "gpt-3.5-turbo" + + +@unittest.skipIf(not _OPENAI_API_KEY, "OPENAI_API_KEY is not set.") +class OpenAIInferenceIT(unittest.TestCase): + def setUp(self): + self.output_dir = os.environ.get("BEAM_ML_OUTPUT_DIR", _OUTPUT_DIR_DEFAULT) + self.project = os.environ.get("BEAM_ML_PROJECT") # Not used by OpenAI but common in tests + + def run_pipeline(self, model_handler, test_data, output_path_suffix, inference_args=None): + output_file = '/'.join([self.output_dir, str(uuid.uuid4()), output_path_suffix]) + + pipeline_options = { + 'output': output_file, + } + # Add project if available, for consistency with other IT tests, + # though OpenAI handler doesn't directly use it. + if self.project: + pipeline_options['project'] = self.project + + test_pipeline = TestPipeline(is_integration_test=True, options=pipeline_options) + + with test_pipeline as p: + results = ( + p + | "CreateInputs" >> beam.Create(test_data) + | "RunInference" >> RunInference(model_handler, inference_args=inference_args) + | "SaveResults" >> beam.Map(lambda x: str(x)) # Convert PredictionResult to string for output + | beam.io.WriteToText(output_file) + ) + + self.assertTrue(FileSystems().exists(output_file)) + # Further checks could involve reading the output and verifying content, + # but for now, we just check if the pipeline runs and produces output. + + # Basic check for content in the output file to ensure it's not empty + # and contains expected PredictionResult structure. + # This part can be flaky if API responses change slightly. + # For a more robust check, one might mock the API in an IT setting or + # use a very deterministic, simple prompt. + match_results = [] + def process_output_file(readable_file): + for line in readable_file: + match_results.append(line) + + FileSystems.read_files(output_file, process_file_fn=process_output_file) + self.assertGreater(len(match_results), 0) + # Example: check if output contains 'PredictionResult(example=' or similar + self.assertTrue(any("PredictionResult(example=" in line for line in match_results)) + + + @pytest.mark.postcommit # Mark as postcommit as it makes external calls. + def test_openai_completion_model(self): + model_handler = OpenAIModelHandler( + api_key=_OPENAI_API_KEY, + model=_COMPLETION_MODEL + ) + test_data = [ + "What is the capital of France?", + "Translate 'hello' to Spanish." + ] + inference_args = {"max_tokens": 50, "temperature": 0.7} + self.run_pipeline( + model_handler, + test_data, + "output_completion.txt", + inference_args=inference_args + ) + + @pytest.mark.postcommit + def test_openai_chat_model(self): + model_handler = OpenAIModelHandler( + api_key=_OPENAI_API_KEY, + model=_CHAT_MODEL + ) + # Chat models expect a list of messages or a single string (handled as user message) + test_data = [ + "What is 2+2?", # Single string prompt + [{"role": "user", "content": "Describe a perfect day."}] # Message list prompt + ] + inference_args = {"max_tokens": 100, "temperature": 0.5} + self.run_pipeline( + model_handler, + test_data, + "output_chat.txt", + inference_args=inference_args + ) + + @pytest.mark.postcommit + def test_openai_chat_model_with_system_message(self): + model_handler = OpenAIModelHandler( + api_key=_OPENAI_API_KEY, + model=_CHAT_MODEL + ) + # Chat models expect a list of messages or a single string (handled as user message) + test_data = [ + # This requires the OpenAIModelHandler's generate_completion to correctly + # handle list of messages if the input element itself is a list of dicts. + [ + {"role": "system", "content": "You are a helpful assistant that speaks like a pirate."}, + {"role": "user", "content": "How are you?"} + ] + ] + inference_args = {"max_tokens": 50} + self.run_pipeline( + model_handler, + test_data, + "output_chat_system.txt", + inference_args=inference_args + ) + + +if __name__ == '__main__': + logging.getLogger().setLevel(logging.INFO) + unittest.main() diff --git a/sdks/python/apache_beam/ml/inference/openai_inference_test.py b/sdks/python/apache_beam/ml/inference/openai_inference_test.py new file mode 100644 index 000000000000..948e71286e58 --- /dev/null +++ b/sdks/python/apache_beam/ml/inference/openai_inference_test.py @@ -0,0 +1,280 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest +from unittest.mock import MagicMock +from unittest.mock import patch +import httpx # Added for mocking request object +import logging + +# pylint: disable=wrong-import-order, wrong-import-position +try: + import openai + from openai import APIError + from openai import RateLimitError + from openai.types.chat.chat_completion import ChatCompletion + from openai.types.chat.chat_completion import Choice as ChatChoice + from openai.types.chat.chat_completion_message import ChatCompletionMessage + from openai.types.completion import Completion + from openai.types.completion_choice import CompletionChoice # Corrected import + from apache_beam.ml.inference.openai_inference import ( + OpenAIModelHandler, _retry_on_appropriate_openai_error) +except ImportError: + raise unittest.SkipTest('OpenAI dependencies are not installed') + +from apache_beam.ml.inference.base import PredictionResult + +# Configure logger for debugging tests related to _retry_on_appropriate_openai_error +# This gets the logger instance used in openai_inference.py +logger_to_debug = logging.getLogger("OpenAIModelHandler") +logger_to_debug.setLevel(logging.DEBUG) +# Add a handler to see the output during tests, e.g., stream to stderr +# Check if a handler already exists to avoid duplicate messages if tests are run multiple times +if not any(isinstance(h, logging.StreamHandler) for h in logger_to_debug.handlers): + stream_handler = logging.StreamHandler() + stream_handler.setLevel(logging.DEBUG) + formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + stream_handler.setFormatter(formatter) + logger_to_debug.addHandler(stream_handler) + + +class RetryOnAPIErrorTest(unittest.TestCase): + def _create_mock_error_with_status(self, status_code, error_class=APIError): + """ + Helper to create a mock error object (APIError or RateLimitError) + with a given status code. + The key is to ensure that `getattr(err, 'status_code', None)` works as expected. + For real OpenAI errors: + - RateLimitError (and other APIStatusErrors) have `err.status_code` as a direct attribute. + - APIError (the base) has `err.status_code` as a property that inspects `err.request.response.status_code`. + """ + mock_response = MagicMock(spec=httpx.Response) + # mock_response.status_code will be set below. + # Ensure headers is a mock that can handle .get() for RateLimitError + mock_response.headers = MagicMock(spec=httpx.Headers) + mock_response.headers.get.return_value = "test-request-id" # For x-request-id + mock_response.content = b"{}" + mock_response.text = "{}" + + mock_request_obj = MagicMock(spec=httpx.Request) + mock_request_obj.method = "POST" + mock_request_obj.url = httpx.URL("https://api.openai.com/v1/completions") + + mock_response.request = mock_request_obj + + if error_class == RateLimitError: + mock_response.status_code = status_code + err = RateLimitError("rate limited", response=mock_response, body=None) + else: # Generic APIError + mock_request_that_failed = MagicMock(spec=httpx.Request) + mock_request_that_failed.method = "POST" + mock_request_that_failed.url = httpx.URL("https://api.openai.com/v1/completions") + + # This is the response that APIError.status_code property will look for + response_for_api_error_property = MagicMock(spec=httpx.Response) + response_for_api_error_property.status_code = status_code + mock_request_that_failed.response = response_for_api_error_property + + err = APIError("API error", request=mock_request_that_failed, body=None) + # Directly set status_code on the instance for getattr in the retry function to pick up. + # This is simpler than ensuring the nested property mock works perfectly. + # Note: This shadows the property for this instance. + err.status_code = status_code + return err + + def test_retry_on_rate_limit_error(self): + err = self._create_mock_error_with_status(429, error_class=RateLimitError) + self.assertTrue(_retry_on_appropriate_openai_error(err)) + + def test_retry_on_server_error_500(self): + err = self._create_mock_error_with_status(500, error_class=APIError) + self.assertTrue(_retry_on_appropriate_openai_error(err)) + + def test_retry_on_server_error_503(self): + err = self._create_mock_error_with_status(503, error_class=APIError) + self.assertTrue(_retry_on_appropriate_openai_error(err)) + + def test_no_retry_on_client_error_400(self): + err = self._create_mock_error_with_status(400, error_class=APIError) + self.assertFalse(_retry_on_appropriate_openai_error(err)) + + def test_no_retry_on_client_error_401(self): + err = self._create_mock_error_with_status(401, error_class=APIError) + self.assertFalse(_retry_on_appropriate_openai_error(err)) + + def test_no_retry_on_non_openai_error(self): + self.assertFalse(_retry_on_appropriate_openai_error(ValueError("some other error"))) + + +class OpenAIModelHandlerTest(unittest.TestCase): + def setUp(self): + self.api_key = "test_api_key" + self.model_name = "gpt-3.5-turbo-instruct" # A completion model + self.chat_model_name = "gpt-3.5-turbo" # A chat model + + @patch('openai.OpenAI') + def test_create_client(self, mock_openai_client_constructor): + mock_client_instance = MagicMock() + mock_openai_client_constructor.return_value = mock_client_instance + + handler = OpenAIModelHandler(api_key=self.api_key, model=self.model_name) + client = handler.create_client() + + mock_openai_client_constructor.assert_called_once_with(api_key=self.api_key) + self.assertEqual(client, mock_client_instance) + # Test if client is cached + client2 = handler.create_client() + mock_openai_client_constructor.assert_called_once() # Should still be called only once + self.assertEqual(client2, mock_client_instance) + + + @patch('openai.OpenAI') + def test_request_completion_model_success(self, mock_openai_client_constructor): + mock_openai_client = MagicMock() + mock_openai_client_constructor.return_value = mock_openai_client + + # Mock the response from client.completions.create + mock_completion_response = Completion( + id="cmpl-test", + object="text_completion", + created=12345, + model=self.model_name, + choices=[ + CompletionChoice(text=" World!", index=0, finish_reason="length", logprobs=None) + ] + ) + mock_openai_client.completions.create.return_value = mock_completion_response + + handler = OpenAIModelHandler(api_key=self.api_key, model=self.model_name) + # Initialize client by calling create_client or load_model + client = handler.load_model() + prompts = ["Hello", "Hi"] + results_generator = handler.request(prompts, client, {}) + results = list(results_generator) + + self.assertEqual(len(results), 2) + self.assertIsInstance(results[0], PredictionResult) + self.assertEqual(results[0].example, "Hello") + self.assertEqual(results[0].inference, " World!") + self.assertEqual(results[0].model_id, self.model_name) + self.assertEqual(results[1].example, "Hi") + self.assertEqual(results[1].inference, " World!") # Same mock response for both + + self.assertEqual(mock_openai_client.completions.create.call_count, 2) + mock_openai_client.completions.create.assert_any_call( + model=self.model_name, prompt="Hello" + ) + mock_openai_client.completions.create.assert_any_call( + model=self.model_name, prompt="Hi" + ) + + @patch('openai.OpenAI') + def test_request_chat_model_success(self, mock_openai_client_constructor): + mock_openai_client = MagicMock() + # Simulate chat model by checking a mock attribute on the client's chat completions path + # This is a bit of a hack for testing the path in generate_completion + mock_openai_client.chat.completions.with_raw_response.create.binary_relative_path = "chat.completions" + mock_openai_client_constructor.return_value = mock_openai_client + + # Mock the response from client.chat.completions.create + mock_chat_response = ChatCompletion( + id="chatcmpl-test", + object="chat.completion", + created=12345, + model=self.chat_model_name, + choices=[ + ChatChoice( + index=0, + message=ChatCompletionMessage(role="assistant", content="There!"), + finish_reason="stop" + ) + ] + ) + mock_openai_client.chat.completions.create.return_value = mock_chat_response + + handler = OpenAIModelHandler(api_key=self.api_key, model=self.chat_model_name) + client = handler.load_model() # This calls create_client + prompts = ["User prompt 1", [{"role": "user", "content": "User prompt 2"}]] # Test both string and message list + results_generator = handler.request(prompts, client, {"temperature": 0.5}) + results = list(results_generator) + + self.assertEqual(len(results), 2) + self.assertIsInstance(results[0], PredictionResult) + self.assertEqual(results[0].example, "User prompt 1") + self.assertEqual(results[0].inference, "There!") + self.assertEqual(results[0].model_id, self.chat_model_name) + + self.assertEqual(results[1].example, [{"role": "user", "content": "User prompt 2"}]) + self.assertEqual(results[1].inference, "There!") + + self.assertEqual(mock_openai_client.chat.completions.create.call_count, 2) + mock_openai_client.chat.completions.create.assert_any_call( + model=self.chat_model_name, + messages=[{"role": "user", "content": "User prompt 1"}], + temperature=0.5 + ) + mock_openai_client.chat.completions.create.assert_any_call( + model=self.chat_model_name, + messages=[{"role": "user", "content": "User prompt 2"}], + temperature=0.5 + ) + + @patch('openai.OpenAI') + def test_request_failure_propagates(self, mock_openai_client_constructor): + mock_openai_client = MagicMock() + mock_openai_client_constructor.return_value = mock_openai_client + + # Simulate an API error during the first call + mock_openai_client.completions.create.side_effect = RateLimitError( + "rate limited", response=MagicMock(), body=None + ) + + handler = OpenAIModelHandler(api_key=self.api_key, model=self.model_name) + client = handler.load_model() + prompts = ["Prompt that will fail"] + + with self.assertRaises(RateLimitError): + list(handler.request(prompts, client, {})) + + mock_openai_client.completions.create.assert_called_once_with( + model=self.model_name, prompt="Prompt that will fail" + ) + + def test_batch_elements_kwargs(self): + handler = OpenAIModelHandler( + api_key=self.api_key, + model=self.model_name, + min_batch_size=1, + max_batch_size=10, + max_batch_duration_secs=30) + kwargs = handler.batch_elements_kwargs() + self.assertEqual(kwargs['min_batch_size'], 1) + self.assertEqual(kwargs['max_batch_size'], 10) + self.assertEqual(kwargs['max_batch_duration_secs'], 30) + + def test_validate_inference_args(self): + handler = OpenAIModelHandler(api_key=self.api_key, model=self.model_name) + # This method is a pass-through for now, so just ensure it doesn't raise. + try: + handler.validate_inference_args({"temperature": 0.7}) + handler.validate_inference_args(None) + except Exception as e: + self.fail(f"validate_inference_args raised an exception: {e}") + + +if __name__ == '__main__': + unittest.main() diff --git a/sdks/python/apache_beam/ml/inference/openai_tests_requirements.txt b/sdks/python/apache_beam/ml/inference/openai_tests_requirements.txt new file mode 100644 index 000000000000..4ce62e619f31 --- /dev/null +++ b/sdks/python/apache_beam/ml/inference/openai_tests_requirements.txt @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +openai>=1.0.0 From 62be2423894296bf007e8f1f0502093b171fea5a Mon Sep 17 00:00:00 2001 From: Jack McCluskey Date: Tue, 1 Jul 2025 14:34:39 -0400 Subject: [PATCH 02/11] yapf formatting --- .../ml/inference/openai_inference.py | 51 +++++++------ .../ml/inference/openai_inference_it_test.py | 76 +++++++++---------- 2 files changed, 66 insertions(+), 61 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/openai_inference.py b/sdks/python/apache_beam/ml/inference/openai_inference.py index 436c94cb1927..5a196dba259b 100644 --- a/sdks/python/apache_beam/ml/inference/openai_inference.py +++ b/sdks/python/apache_beam/ml/inference/openai_inference.py @@ -45,22 +45,24 @@ def _retry_on_appropriate_openai_error(exception: Exception) -> bool: Retry filter that returns True if a returned HTTP error code is 5xx or 429 (RateLimitError). """ - LOGGER.debug(f"Checking exception for retry: {type(exception)} - {str(exception)}") + LOGGER.debug( + f"Checking exception for retry: {type(exception)} - {str(exception)}") if isinstance(exception, RateLimitError): LOGGER.debug("RateLimitError detected, retrying.") return True # Always retry RateLimitError (HTTP 429) - if isinstance(exception, APIError): # This covers APIStatusError as well - status_code = getattr(exception, 'status_code', None) - LOGGER.debug(f"APIError detected. Status code from getattr: {status_code}") - if status_code is not None: - LOGGER.debug(f"Condition check: {status_code} >= 500 is {status_code >= 500}") - return status_code >= 500 # Retry on 5xx errors - else: - LOGGER.debug("APIError but status_code is None.") + if isinstance(exception, APIError): # This covers APIStatusError as well + status_code = getattr(exception, 'status_code', None) + LOGGER.debug(f"APIError detected. Status code from getattr: {status_code}") + if status_code is not None: + LOGGER.debug( + f"Condition check: {status_code} >= 500 is {status_code >= 500}") + return status_code >= 500 # Retry on 5xx errors + else: + LOGGER.debug("APIError but status_code is None.") LOGGER.debug("Exception not eligible for retry by this filter.") - return False # Do not retry for other errors or if status_code is not available + return False # Do not retry for other errors or if status_code is not available def generate_completion( @@ -82,10 +84,10 @@ def generate_completion( # User might need to format input as list of messages # For simplicity, we'll assume a single user message per prompt string # for now. - if not isinstance(prompt, list): # basic check for message format - messages = [{"role": "user", "content": prompt}] - else: # assume prompt is already in message format - messages = prompt + if not isinstance(prompt, list): # basic check for message format + messages = [{"role": "user", "content": prompt}] + else: # assume prompt is already in message format + messages = prompt response = client.chat.completions.create( model=model_name, messages=messages, **inference_args) else: @@ -112,23 +114,26 @@ def generate_completion( elif hasattr(response_obj.choices[0], 'text'): parsed_responses.append(response_obj.choices[0].text) else: - LOGGER.warning("Unrecognized OpenAI response choice format: %s", response_obj.choices[0]) - parsed_responses.append(None) # Or raise error + LOGGER.warning( + "Unrecognized OpenAI response choice format: %s", + response_obj.choices[0]) + parsed_responses.append(None) # Or raise error else: LOGGER.warning("OpenAI response had no choices: %s", response_obj) - parsed_responses.append(None) # Or raise error + parsed_responses.append(None) # Or raise error else: LOGGER.warning("Unrecognized OpenAI response format: %s", response_obj) - parsed_responses.append(None) # Or raise error + parsed_responses.append(None) # Or raise error return parsed_responses -class OpenAIModelHandler(RemoteModelHandler[Any, PredictionResult, +class OpenAIModelHandler(RemoteModelHandler[Any, + PredictionResult, openai.OpenAI]): def __init__( self, api_key: str, - model: str, # Recommended to use 'model' like in openai library + model: str, # Recommended to use 'model' like in openai library request_fn: Callable[[str, Sequence[Any], openai.OpenAI, dict[str, Any]], Any] = generate_completion, *, @@ -165,7 +170,7 @@ def __init__( self._batching_kwargs["max_batch_duration_secs"] = max_batch_duration_secs self.api_key = api_key - self.model_name = model # Renamed from model_name to model for consistency + self.model_name = model # Renamed from model_name to model for consistency self.request_fn = request_fn # OpenAI client will be initialized in create_client @@ -179,13 +184,13 @@ def __init__( def create_client(self) -> openai.OpenAI: """Creates the OpenAI client used to send requests.""" if not self._client: - self._client = openai.OpenAI(api_key=self.api_key) + self._client = openai.OpenAI(api_key=self.api_key) return self._client def request( self, batch: Sequence[Any], - model_client: openai.OpenAI, # Parameter name changed for clarity + model_client: openai.OpenAI, # Parameter name changed for clarity inference_args: Optional[dict[str, Any]] = None ) -> Iterable[PredictionResult]: """ Sends a prediction request to an OpenAI model containing a batch diff --git a/sdks/python/apache_beam/ml/inference/openai_inference_it_test.py b/sdks/python/apache_beam/ml/inference/openai_inference_it_test.py index 38ada5690a6b..3578614a6e68 100644 --- a/sdks/python/apache_beam/ml/inference/openai_inference_it_test.py +++ b/sdks/python/apache_beam/ml/inference/openai_inference_it_test.py @@ -41,7 +41,7 @@ _OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY") # Models for testing - one completion, one chat -_COMPLETION_MODEL = "gpt-3.5-turbo-instruct" # A smaller, faster completion model +_COMPLETION_MODEL = "gpt-3.5-turbo-instruct" # A smaller, faster completion model _CHAT_MODEL = "gpt-3.5-turbo" @@ -49,10 +49,13 @@ class OpenAIInferenceIT(unittest.TestCase): def setUp(self): self.output_dir = os.environ.get("BEAM_ML_OUTPUT_DIR", _OUTPUT_DIR_DEFAULT) - self.project = os.environ.get("BEAM_ML_PROJECT") # Not used by OpenAI but common in tests + self.project = os.environ.get( + "BEAM_ML_PROJECT") # Not used by OpenAI but common in tests - def run_pipeline(self, model_handler, test_data, output_path_suffix, inference_args=None): - output_file = '/'.join([self.output_dir, str(uuid.uuid4()), output_path_suffix]) + def run_pipeline( + self, model_handler, test_data, output_path_suffix, inference_args=None): + output_file = '/'.join( + [self.output_dir, str(uuid.uuid4()), output_path_suffix]) pipeline_options = { 'output': output_file, @@ -60,18 +63,20 @@ def run_pipeline(self, model_handler, test_data, output_path_suffix, inference_a # Add project if available, for consistency with other IT tests, # though OpenAI handler doesn't directly use it. if self.project: - pipeline_options['project'] = self.project + pipeline_options['project'] = self.project - test_pipeline = TestPipeline(is_integration_test=True, options=pipeline_options) + test_pipeline = TestPipeline( + is_integration_test=True, options=pipeline_options) with test_pipeline as p: results = ( p | "CreateInputs" >> beam.Create(test_data) - | "RunInference" >> RunInference(model_handler, inference_args=inference_args) - | "SaveResults" >> beam.Map(lambda x: str(x)) # Convert PredictionResult to string for output - | beam.io.WriteToText(output_file) - ) + | "RunInference" >> RunInference( + model_handler, inference_args=inference_args) + | "SaveResults" >> beam.Map( + lambda x: str(x)) # Convert PredictionResult to string for output + | beam.io.WriteToText(output_file)) self.assertTrue(FileSystems().exists(output_file)) # Further checks could involve reading the output and verifying content, @@ -83,75 +88,70 @@ def run_pipeline(self, model_handler, test_data, output_path_suffix, inference_a # For a more robust check, one might mock the API in an IT setting or # use a very deterministic, simple prompt. match_results = [] + def process_output_file(readable_file): - for line in readable_file: - match_results.append(line) + for line in readable_file: + match_results.append(line) FileSystems.read_files(output_file, process_file_fn=process_output_file) self.assertGreater(len(match_results), 0) # Example: check if output contains 'PredictionResult(example=' or similar - self.assertTrue(any("PredictionResult(example=" in line for line in match_results)) - + self.assertTrue( + any("PredictionResult(example=" in line for line in match_results)) - @pytest.mark.postcommit # Mark as postcommit as it makes external calls. + @pytest.mark.postcommit # Mark as postcommit as it makes external calls. def test_openai_completion_model(self): model_handler = OpenAIModelHandler( - api_key=_OPENAI_API_KEY, - model=_COMPLETION_MODEL - ) + api_key=_OPENAI_API_KEY, model=_COMPLETION_MODEL) test_data = [ - "What is the capital of France?", - "Translate 'hello' to Spanish." + "What is the capital of France?", "Translate 'hello' to Spanish." ] inference_args = {"max_tokens": 50, "temperature": 0.7} self.run_pipeline( model_handler, test_data, "output_completion.txt", - inference_args=inference_args - ) + inference_args=inference_args) @pytest.mark.postcommit def test_openai_chat_model(self): model_handler = OpenAIModelHandler( - api_key=_OPENAI_API_KEY, - model=_CHAT_MODEL - ) + api_key=_OPENAI_API_KEY, model=_CHAT_MODEL) # Chat models expect a list of messages or a single string (handled as user message) test_data = [ - "What is 2+2?", # Single string prompt - [{"role": "user", "content": "Describe a perfect day."}] # Message list prompt + "What is 2+2?", # Single string prompt + [{ + "role": "user", "content": "Describe a perfect day." + }] # Message list prompt ] inference_args = {"max_tokens": 100, "temperature": 0.5} self.run_pipeline( model_handler, test_data, "output_chat.txt", - inference_args=inference_args - ) + inference_args=inference_args) @pytest.mark.postcommit def test_openai_chat_model_with_system_message(self): model_handler = OpenAIModelHandler( - api_key=_OPENAI_API_KEY, - model=_CHAT_MODEL - ) + api_key=_OPENAI_API_KEY, model=_CHAT_MODEL) # Chat models expect a list of messages or a single string (handled as user message) test_data = [ # This requires the OpenAIModelHandler's generate_completion to correctly # handle list of messages if the input element itself is a list of dicts. - [ - {"role": "system", "content": "You are a helpful assistant that speaks like a pirate."}, - {"role": "user", "content": "How are you?"} - ] + [{ + "role": "system", + "content": "You are a helpful assistant that speaks like a pirate." + }, { + "role": "user", "content": "How are you?" + }] ] inference_args = {"max_tokens": 50} self.run_pipeline( model_handler, test_data, "output_chat_system.txt", - inference_args=inference_args - ) + inference_args=inference_args) if __name__ == '__main__': From b54de92f6dc719c07299b9eca64f02c335715fb2 Mon Sep 17 00:00:00 2001 From: Jack McCluskey Date: Tue, 1 Jul 2025 14:34:55 -0400 Subject: [PATCH 03/11] more yapf formatting --- .../ml/inference/openai_inference_test.py | 99 ++++++++++--------- 1 file changed, 55 insertions(+), 44 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/openai_inference_test.py b/sdks/python/apache_beam/ml/inference/openai_inference_test.py index 948e71286e58..d49ef8651b68 100644 --- a/sdks/python/apache_beam/ml/inference/openai_inference_test.py +++ b/sdks/python/apache_beam/ml/inference/openai_inference_test.py @@ -18,7 +18,7 @@ import unittest from unittest.mock import MagicMock from unittest.mock import patch -import httpx # Added for mocking request object +import httpx # Added for mocking request object import logging # pylint: disable=wrong-import-order, wrong-import-position @@ -30,7 +30,7 @@ from openai.types.chat.chat_completion import Choice as ChatChoice from openai.types.chat.chat_completion_message import ChatCompletionMessage from openai.types.completion import Completion - from openai.types.completion_choice import CompletionChoice # Corrected import + from openai.types.completion_choice import CompletionChoice # Corrected import from apache_beam.ml.inference.openai_inference import ( OpenAIModelHandler, _retry_on_appropriate_openai_error) except ImportError: @@ -44,12 +44,14 @@ logger_to_debug.setLevel(logging.DEBUG) # Add a handler to see the output during tests, e.g., stream to stderr # Check if a handler already exists to avoid duplicate messages if tests are run multiple times -if not any(isinstance(h, logging.StreamHandler) for h in logger_to_debug.handlers): - stream_handler = logging.StreamHandler() - stream_handler.setLevel(logging.DEBUG) - formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') - stream_handler.setFormatter(formatter) - logger_to_debug.addHandler(stream_handler) +if not any(isinstance(h, logging.StreamHandler) + for h in logger_to_debug.handlers): + stream_handler = logging.StreamHandler() + stream_handler.setLevel(logging.DEBUG) + formatter = logging.Formatter( + '%(asctime)s - %(name)s - %(levelname)s - %(message)s') + stream_handler.setFormatter(formatter) + logger_to_debug.addHandler(stream_handler) class RetryOnAPIErrorTest(unittest.TestCase): @@ -66,7 +68,7 @@ def _create_mock_error_with_status(self, status_code, error_class=APIError): # mock_response.status_code will be set below. # Ensure headers is a mock that can handle .get() for RateLimitError mock_response.headers = MagicMock(spec=httpx.Headers) - mock_response.headers.get.return_value = "test-request-id" # For x-request-id + mock_response.headers.get.return_value = "test-request-id" # For x-request-id mock_response.content = b"{}" mock_response.text = "{}" @@ -79,10 +81,11 @@ def _create_mock_error_with_status(self, status_code, error_class=APIError): if error_class == RateLimitError: mock_response.status_code = status_code err = RateLimitError("rate limited", response=mock_response, body=None) - else: # Generic APIError + else: # Generic APIError mock_request_that_failed = MagicMock(spec=httpx.Request) mock_request_that_failed.method = "POST" - mock_request_that_failed.url = httpx.URL("https://api.openai.com/v1/completions") + mock_request_that_failed.url = httpx.URL( + "https://api.openai.com/v1/completions") # This is the response that APIError.status_code property will look for response_for_api_error_property = MagicMock(spec=httpx.Response) @@ -117,14 +120,15 @@ def test_no_retry_on_client_error_401(self): self.assertFalse(_retry_on_appropriate_openai_error(err)) def test_no_retry_on_non_openai_error(self): - self.assertFalse(_retry_on_appropriate_openai_error(ValueError("some other error"))) + self.assertFalse( + _retry_on_appropriate_openai_error(ValueError("some other error"))) class OpenAIModelHandlerTest(unittest.TestCase): def setUp(self): self.api_key = "test_api_key" - self.model_name = "gpt-3.5-turbo-instruct" # A completion model - self.chat_model_name = "gpt-3.5-turbo" # A chat model + self.model_name = "gpt-3.5-turbo-instruct" # A completion model + self.chat_model_name = "gpt-3.5-turbo" # A chat model @patch('openai.OpenAI') def test_create_client(self, mock_openai_client_constructor): @@ -138,12 +142,13 @@ def test_create_client(self, mock_openai_client_constructor): self.assertEqual(client, mock_client_instance) # Test if client is cached client2 = handler.create_client() - mock_openai_client_constructor.assert_called_once() # Should still be called only once + mock_openai_client_constructor.assert_called_once( + ) # Should still be called only once self.assertEqual(client2, mock_client_instance) - @patch('openai.OpenAI') - def test_request_completion_model_success(self, mock_openai_client_constructor): + def test_request_completion_model_success( + self, mock_openai_client_constructor): mock_openai_client = MagicMock() mock_openai_client_constructor.return_value = mock_openai_client @@ -154,9 +159,9 @@ def test_request_completion_model_success(self, mock_openai_client_constructor): created=12345, model=self.model_name, choices=[ - CompletionChoice(text=" World!", index=0, finish_reason="length", logprobs=None) - ] - ) + CompletionChoice( + text=" World!", index=0, finish_reason="length", logprobs=None) + ]) mock_openai_client.completions.create.return_value = mock_completion_response handler = OpenAIModelHandler(api_key=self.api_key, model=self.model_name) @@ -172,15 +177,14 @@ def test_request_completion_model_success(self, mock_openai_client_constructor): self.assertEqual(results[0].inference, " World!") self.assertEqual(results[0].model_id, self.model_name) self.assertEqual(results[1].example, "Hi") - self.assertEqual(results[1].inference, " World!") # Same mock response for both + self.assertEqual( + results[1].inference, " World!") # Same mock response for both self.assertEqual(mock_openai_client.completions.create.call_count, 2) mock_openai_client.completions.create.assert_any_call( - model=self.model_name, prompt="Hello" - ) + model=self.model_name, prompt="Hello") mock_openai_client.completions.create.assert_any_call( - model=self.model_name, prompt="Hi" - ) + model=self.model_name, prompt="Hi") @patch('openai.OpenAI') def test_request_chat_model_success(self, mock_openai_client_constructor): @@ -199,16 +203,20 @@ def test_request_chat_model_success(self, mock_openai_client_constructor): choices=[ ChatChoice( index=0, - message=ChatCompletionMessage(role="assistant", content="There!"), - finish_reason="stop" - ) - ] - ) + message=ChatCompletionMessage( + role="assistant", content="There!"), + finish_reason="stop") + ]) mock_openai_client.chat.completions.create.return_value = mock_chat_response - handler = OpenAIModelHandler(api_key=self.api_key, model=self.chat_model_name) - client = handler.load_model() # This calls create_client - prompts = ["User prompt 1", [{"role": "user", "content": "User prompt 2"}]] # Test both string and message list + handler = OpenAIModelHandler( + api_key=self.api_key, model=self.chat_model_name) + client = handler.load_model() # This calls create_client + prompts = [ + "User prompt 1", [{ + "role": "user", "content": "User prompt 2" + }] + ] # Test both string and message list results_generator = handler.request(prompts, client, {"temperature": 0.5}) results = list(results_generator) @@ -218,20 +226,25 @@ def test_request_chat_model_success(self, mock_openai_client_constructor): self.assertEqual(results[0].inference, "There!") self.assertEqual(results[0].model_id, self.chat_model_name) - self.assertEqual(results[1].example, [{"role": "user", "content": "User prompt 2"}]) + self.assertEqual( + results[1].example, [{ + "role": "user", "content": "User prompt 2" + }]) self.assertEqual(results[1].inference, "There!") self.assertEqual(mock_openai_client.chat.completions.create.call_count, 2) mock_openai_client.chat.completions.create.assert_any_call( model=self.chat_model_name, - messages=[{"role": "user", "content": "User prompt 1"}], - temperature=0.5 - ) + messages=[{ + "role": "user", "content": "User prompt 1" + }], + temperature=0.5) mock_openai_client.chat.completions.create.assert_any_call( model=self.chat_model_name, - messages=[{"role": "user", "content": "User prompt 2"}], - temperature=0.5 - ) + messages=[{ + "role": "user", "content": "User prompt 2" + }], + temperature=0.5) @patch('openai.OpenAI') def test_request_failure_propagates(self, mock_openai_client_constructor): @@ -240,8 +253,7 @@ def test_request_failure_propagates(self, mock_openai_client_constructor): # Simulate an API error during the first call mock_openai_client.completions.create.side_effect = RateLimitError( - "rate limited", response=MagicMock(), body=None - ) + "rate limited", response=MagicMock(), body=None) handler = OpenAIModelHandler(api_key=self.api_key, model=self.model_name) client = handler.load_model() @@ -251,8 +263,7 @@ def test_request_failure_propagates(self, mock_openai_client_constructor): list(handler.request(prompts, client, {})) mock_openai_client.completions.create.assert_called_once_with( - model=self.model_name, prompt="Prompt that will fail" - ) + model=self.model_name, prompt="Prompt that will fail") def test_batch_elements_kwargs(self): handler = OpenAIModelHandler( From a543a6cece9ecf74af3907a63b22b8f9bb705929 Mon Sep 17 00:00:00 2001 From: Jack McCluskey Date: Tue, 1 Jul 2025 14:36:31 -0400 Subject: [PATCH 04/11] remove unnecessary comments --- sdks/python/apache_beam/ml/inference/openai_inference.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/openai_inference.py b/sdks/python/apache_beam/ml/inference/openai_inference.py index 5a196dba259b..895b0fdc131d 100644 --- a/sdks/python/apache_beam/ml/inference/openai_inference.py +++ b/sdks/python/apache_beam/ml/inference/openai_inference.py @@ -210,12 +210,7 @@ def request( if inference_args is None: inference_args = {} - # The `generate_completion` function now iterates through the batch - # and makes individual API calls if necessary (e.g. for non-chat models) - # or a single call if the underlying API supports batching (e.g. future chat models). - # The RunInference transform handles the primary batching of elements from PCollection. try: - # request_fn (generate_completion) now returns a list of parsed strings/content parsed_responses = self.request_fn( self.model_name, batch, model_client, inference_args) except Exception as e: From 7ae6da1b94d36a0535f885b8c044724ddc8eb519 Mon Sep 17 00:00:00 2001 From: Jack McCluskey Date: Tue, 1 Jul 2025 15:04:27 -0400 Subject: [PATCH 05/11] more formatting cleanups --- .../ml/inference/openai_inference.py | 28 +++++++------------ .../ml/inference/openai_inference_it_test.py | 9 +++--- .../ml/inference/openai_inference_test.py | 11 ++++---- .../inference/openai_tests_requirements.txt | 3 +- 4 files changed, 21 insertions(+), 30 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/openai_inference.py b/sdks/python/apache_beam/ml/inference/openai_inference.py index 895b0fdc131d..0772d9547bc3 100644 --- a/sdks/python/apache_beam/ml/inference/openai_inference.py +++ b/sdks/python/apache_beam/ml/inference/openai_inference.py @@ -22,6 +22,10 @@ from typing import Any from typing import Optional +from apache_beam.ml.inference import utils +from apache_beam.ml.inference.base import PredictionResult +from apache_beam.ml.inference.base import RemoteModelHandler + # pylint: disable=wrong-import-order, wrong-import-position try: import openai @@ -33,10 +37,6 @@ 'OpenAI dependencies are not installed. To use OpenAI model handler,' 'run pip install apache-beam[gcp,openai]') -from apache_beam.ml.inference import utils -from apache_beam.ml.inference.base import PredictionResult -from apache_beam.ml.inference.base import RemoteModelHandler - LOGGER = logging.getLogger("OpenAIModelHandler") @@ -45,24 +45,15 @@ def _retry_on_appropriate_openai_error(exception: Exception) -> bool: Retry filter that returns True if a returned HTTP error code is 5xx or 429 (RateLimitError). """ - LOGGER.debug( - f"Checking exception for retry: {type(exception)} - {str(exception)}") if isinstance(exception, RateLimitError): - LOGGER.debug("RateLimitError detected, retrying.") - return True # Always retry RateLimitError (HTTP 429) + return True if isinstance(exception, APIError): # This covers APIStatusError as well status_code = getattr(exception, 'status_code', None) - LOGGER.debug(f"APIError detected. Status code from getattr: {status_code}") if status_code is not None: - LOGGER.debug( - f"Condition check: {status_code} >= 500 is {status_code >= 500}") - return status_code >= 500 # Retry on 5xx errors - else: - LOGGER.debug("APIError but status_code is None.") + return status_code >= 500 - LOGGER.debug("Exception not eligible for retry by this filter.") - return False # Do not retry for other errors or if status_code is not available + return False def generate_completion( @@ -133,7 +124,7 @@ class OpenAIModelHandler(RemoteModelHandler[Any, def __init__( self, api_key: str, - model: str, # Recommended to use 'model' like in openai library + model: str, request_fn: Callable[[str, Sequence[Any], openai.OpenAI, dict[str, Any]], Any] = generate_completion, *, @@ -147,7 +138,8 @@ def __init__( Args: api_key: the OpenAI API key to use for the requests. - model: The OpenAI model to use for inference (e.g., "gpt-3.5-turbo-instruct", "gpt-3.5-turbo"). + model: The OpenAI model to use for inference + (e.g., "gpt-3.5-turbo-instruct", "gpt-3.5-turbo"). request_fn: the function to use to send the request. Should take the model name and the parameters from request() and return the responses from OpenAI. The class will handle bundling the inputs and responses diff --git a/sdks/python/apache_beam/ml/inference/openai_inference_it_test.py b/sdks/python/apache_beam/ml/inference/openai_inference_it_test.py index 3578614a6e68..d64346edcf05 100644 --- a/sdks/python/apache_beam/ml/inference/openai_inference_it_test.py +++ b/sdks/python/apache_beam/ml/inference/openai_inference_it_test.py @@ -25,7 +25,6 @@ import apache_beam as beam import pytest from apache_beam.io.filesystems import FileSystems -from apache_beam.ml.inference.base import PredictionResult from apache_beam.ml.inference.base import RunInference from apache_beam.testing.test_pipeline import TestPipeline @@ -69,7 +68,7 @@ def run_pipeline( is_integration_test=True, options=pipeline_options) with test_pipeline as p: - results = ( + _ = ( p | "CreateInputs" >> beam.Create(test_data) | "RunInference" >> RunInference( @@ -99,7 +98,7 @@ def process_output_file(readable_file): self.assertTrue( any("PredictionResult(example=" in line for line in match_results)) - @pytest.mark.postcommit # Mark as postcommit as it makes external calls. + @pytest.mark.openai_postcommit # Mark as postcommit as it makes external calls. def test_openai_completion_model(self): model_handler = OpenAIModelHandler( api_key=_OPENAI_API_KEY, model=_COMPLETION_MODEL) @@ -113,7 +112,7 @@ def test_openai_completion_model(self): "output_completion.txt", inference_args=inference_args) - @pytest.mark.postcommit + @pytest.mark.openai_postcommit def test_openai_chat_model(self): model_handler = OpenAIModelHandler( api_key=_OPENAI_API_KEY, model=_CHAT_MODEL) @@ -131,7 +130,7 @@ def test_openai_chat_model(self): "output_chat.txt", inference_args=inference_args) - @pytest.mark.postcommit + @pytest.mark.openai_postcommit def test_openai_chat_model_with_system_message(self): model_handler = OpenAIModelHandler( api_key=_OPENAI_API_KEY, model=_CHAT_MODEL) diff --git a/sdks/python/apache_beam/ml/inference/openai_inference_test.py b/sdks/python/apache_beam/ml/inference/openai_inference_test.py index d49ef8651b68..b3f7cc39d638 100644 --- a/sdks/python/apache_beam/ml/inference/openai_inference_test.py +++ b/sdks/python/apache_beam/ml/inference/openai_inference_test.py @@ -15,24 +15,23 @@ # limitations under the License. # +import httpx +import logging import unittest from unittest.mock import MagicMock from unittest.mock import patch -import httpx # Added for mocking request object -import logging # pylint: disable=wrong-import-order, wrong-import-position try: - import openai + from apache_beam.ml.inference.openai_inference import OpenAIModelHandler + from apache_beam.ml.inference.openai_inference import _retry_on_appropriate_openai_error from openai import APIError from openai import RateLimitError from openai.types.chat.chat_completion import ChatCompletion from openai.types.chat.chat_completion import Choice as ChatChoice from openai.types.chat.chat_completion_message import ChatCompletionMessage from openai.types.completion import Completion - from openai.types.completion_choice import CompletionChoice # Corrected import - from apache_beam.ml.inference.openai_inference import ( - OpenAIModelHandler, _retry_on_appropriate_openai_error) + from openai.types.completion_choice import CompletionChoice except ImportError: raise unittest.SkipTest('OpenAI dependencies are not installed') diff --git a/sdks/python/apache_beam/ml/inference/openai_tests_requirements.txt b/sdks/python/apache_beam/ml/inference/openai_tests_requirements.txt index 4ce62e619f31..94416df570da 100644 --- a/sdks/python/apache_beam/ml/inference/openai_tests_requirements.txt +++ b/sdks/python/apache_beam/ml/inference/openai_tests_requirements.txt @@ -14,4 +14,5 @@ # See the License for the specific language governing permissions and # limitations under the License. # -openai>=1.0.0 + +openai>=1.0.0 \ No newline at end of file From b770711173ba4670d52269ccd338bc309605b4f8 Mon Sep 17 00:00:00 2001 From: Jack McCluskey Date: Tue, 1 Jul 2025 15:28:10 -0400 Subject: [PATCH 06/11] move httpx import into try block --- sdks/python/apache_beam/ml/inference/openai_inference_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/ml/inference/openai_inference_test.py b/sdks/python/apache_beam/ml/inference/openai_inference_test.py index b3f7cc39d638..f341c52f13a8 100644 --- a/sdks/python/apache_beam/ml/inference/openai_inference_test.py +++ b/sdks/python/apache_beam/ml/inference/openai_inference_test.py @@ -15,7 +15,6 @@ # limitations under the License. # -import httpx import logging import unittest from unittest.mock import MagicMock @@ -23,6 +22,7 @@ # pylint: disable=wrong-import-order, wrong-import-position try: + import httpx from apache_beam.ml.inference.openai_inference import OpenAIModelHandler from apache_beam.ml.inference.openai_inference import _retry_on_appropriate_openai_error from openai import APIError From 935dd5b9dc91e05263046a76d7a3df8b2e94b577 Mon Sep 17 00:00:00 2001 From: Jack McCluskey Date: Tue, 1 Jul 2025 15:30:21 -0400 Subject: [PATCH 07/11] explict httpx requirement for tests --- .../apache_beam/ml/inference/openai_tests_requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/sdks/python/apache_beam/ml/inference/openai_tests_requirements.txt b/sdks/python/apache_beam/ml/inference/openai_tests_requirements.txt index 94416df570da..e563f3196eff 100644 --- a/sdks/python/apache_beam/ml/inference/openai_tests_requirements.txt +++ b/sdks/python/apache_beam/ml/inference/openai_tests_requirements.txt @@ -15,4 +15,5 @@ # limitations under the License. # +httpx>=0.28.1 openai>=1.0.0 \ No newline at end of file From f11819515d0c35d598afb7ad5e299bc433587f36 Mon Sep 17 00:00:00 2001 From: Jack McCluskey Date: Tue, 1 Jul 2025 16:13:07 -0400 Subject: [PATCH 08/11] linting --- .../ml/inference/openai_inference.py | 3 +- .../ml/inference/openai_inference_it_test.py | 13 +++--- .../ml/inference/openai_inference_test.py | 40 +++++++------------ 3 files changed, 24 insertions(+), 32 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/openai_inference.py b/sdks/python/apache_beam/ml/inference/openai_inference.py index 0772d9547bc3..8424d776b782 100644 --- a/sdks/python/apache_beam/ml/inference/openai_inference.py +++ b/sdks/python/apache_beam/ml/inference/openai_inference.py @@ -70,7 +70,8 @@ def generate_completion( # Note: OpenAI's library expects a single prompt for completions.create, # so we iterate and call. Batching is handled by RunInference. # For chat models, multiple messages can be part of a single request. - if "chat.completions" in client.chat.completions.with_raw_response.create.binary_relative_path: # rough check + if ("chat.completions" in client.chat.completions.with_raw_response. + create.binary_relative_path): # Assuming chat model if path indicates chat completions # User might need to format input as list of messages # For simplicity, we'll assume a single user message per prompt string diff --git a/sdks/python/apache_beam/ml/inference/openai_inference_it_test.py b/sdks/python/apache_beam/ml/inference/openai_inference_it_test.py index d64346edcf05..d6dbd5ec127f 100644 --- a/sdks/python/apache_beam/ml/inference/openai_inference_it_test.py +++ b/sdks/python/apache_beam/ml/inference/openai_inference_it_test.py @@ -40,7 +40,7 @@ _OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY") # Models for testing - one completion, one chat -_COMPLETION_MODEL = "gpt-3.5-turbo-instruct" # A smaller, faster completion model +_COMPLETION_MODEL = "gpt-3.5-turbo-instruct" _CHAT_MODEL = "gpt-3.5-turbo" @@ -98,7 +98,7 @@ def process_output_file(readable_file): self.assertTrue( any("PredictionResult(example=" in line for line in match_results)) - @pytest.mark.openai_postcommit # Mark as postcommit as it makes external calls. + @pytest.mark.openai_postcommit def test_openai_completion_model(self): model_handler = OpenAIModelHandler( api_key=_OPENAI_API_KEY, model=_COMPLETION_MODEL) @@ -116,7 +116,7 @@ def test_openai_completion_model(self): def test_openai_chat_model(self): model_handler = OpenAIModelHandler( api_key=_OPENAI_API_KEY, model=_CHAT_MODEL) - # Chat models expect a list of messages or a single string (handled as user message) + # Chat models expect a list of messages or a single string test_data = [ "What is 2+2?", # Single string prompt [{ @@ -134,10 +134,11 @@ def test_openai_chat_model(self): def test_openai_chat_model_with_system_message(self): model_handler = OpenAIModelHandler( api_key=_OPENAI_API_KEY, model=_CHAT_MODEL) - # Chat models expect a list of messages or a single string (handled as user message) + # Chat models expect a list of messages or a single string test_data = [ - # This requires the OpenAIModelHandler's generate_completion to correctly - # handle list of messages if the input element itself is a list of dicts. + # This requires the OpenAIModelHandler's generate_completion to + # correctly handle list of messages if the input element itself + # is a list of dicts. [{ "role": "system", "content": "You are a helpful assistant that speaks like a pirate." diff --git a/sdks/python/apache_beam/ml/inference/openai_inference_test.py b/sdks/python/apache_beam/ml/inference/openai_inference_test.py index f341c52f13a8..b900243a784e 100644 --- a/sdks/python/apache_beam/ml/inference/openai_inference_test.py +++ b/sdks/python/apache_beam/ml/inference/openai_inference_test.py @@ -37,31 +37,19 @@ from apache_beam.ml.inference.base import PredictionResult -# Configure logger for debugging tests related to _retry_on_appropriate_openai_error -# This gets the logger instance used in openai_inference.py -logger_to_debug = logging.getLogger("OpenAIModelHandler") -logger_to_debug.setLevel(logging.DEBUG) -# Add a handler to see the output during tests, e.g., stream to stderr -# Check if a handler already exists to avoid duplicate messages if tests are run multiple times -if not any(isinstance(h, logging.StreamHandler) - for h in logger_to_debug.handlers): - stream_handler = logging.StreamHandler() - stream_handler.setLevel(logging.DEBUG) - formatter = logging.Formatter( - '%(asctime)s - %(name)s - %(levelname)s - %(message)s') - stream_handler.setFormatter(formatter) - logger_to_debug.addHandler(stream_handler) - class RetryOnAPIErrorTest(unittest.TestCase): def _create_mock_error_with_status(self, status_code, error_class=APIError): """ Helper to create a mock error object (APIError or RateLimitError) with a given status code. - The key is to ensure that `getattr(err, 'status_code', None)` works as expected. + The key is to ensure that `getattr(err, 'status_code', None)` works as + expected. For real OpenAI errors: - - RateLimitError (and other APIStatusErrors) have `err.status_code` as a direct attribute. - - APIError (the base) has `err.status_code` as a property that inspects `err.request.response.status_code`. + - RateLimitError (and other APIStatusErrors) have `err.status_code` as a + direct attribute. + - APIError (the base) has `err.status_code` as a property that inspects + `err.request.response.status_code`. """ mock_response = MagicMock(spec=httpx.Response) # mock_response.status_code will be set below. @@ -92,8 +80,9 @@ def _create_mock_error_with_status(self, status_code, error_class=APIError): mock_request_that_failed.response = response_for_api_error_property err = APIError("API error", request=mock_request_that_failed, body=None) - # Directly set status_code on the instance for getattr in the retry function to pick up. - # This is simpler than ensuring the nested property mock works perfectly. + # Directly set status_code on the instance for getattr in the retry + # function to pick up. This is simpler than ensuring the nested + # property mock works perfectly. # Note: This shadows the property for this instance. err.status_code = status_code return err @@ -161,7 +150,7 @@ def test_request_completion_model_success( CompletionChoice( text=" World!", index=0, finish_reason="length", logprobs=None) ]) - mock_openai_client.completions.create.return_value = mock_completion_response + mock_openai_client.completions.create.return_value = mock_completion_response # pylint: disable=line-too-long handler = OpenAIModelHandler(api_key=self.api_key, model=self.model_name) # Initialize client by calling create_client or load_model @@ -188,9 +177,10 @@ def test_request_completion_model_success( @patch('openai.OpenAI') def test_request_chat_model_success(self, mock_openai_client_constructor): mock_openai_client = MagicMock() - # Simulate chat model by checking a mock attribute on the client's chat completions path - # This is a bit of a hack for testing the path in generate_completion - mock_openai_client.chat.completions.with_raw_response.create.binary_relative_path = "chat.completions" + # Simulate chat model by checking a mock attribute on the client's chat + # completions path. This is a bit of a hack for testing the path in + # generate_completion. + mock_openai_client.chat.completions.with_raw_response.create.binary_relative_path = "chat.completions" # pylint: disable=line-too-long mock_openai_client_constructor.return_value = mock_openai_client # Mock the response from client.chat.completions.create @@ -206,7 +196,7 @@ def test_request_chat_model_success(self, mock_openai_client_constructor): role="assistant", content="There!"), finish_reason="stop") ]) - mock_openai_client.chat.completions.create.return_value = mock_chat_response + mock_openai_client.chat.completions.create.return_value = mock_chat_response # pylint: disable=line-too-long handler = OpenAIModelHandler( api_key=self.api_key, model=self.chat_model_name) From 95ba5e627360a0f524d4f16fd410241cb5600392 Mon Sep 17 00:00:00 2001 From: Jack McCluskey Date: Wed, 2 Jul 2025 09:31:10 -0400 Subject: [PATCH 09/11] test linting --- .../apache_beam/ml/inference/openai_inference_test.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/openai_inference_test.py b/sdks/python/apache_beam/ml/inference/openai_inference_test.py index b900243a784e..228b1ad707da 100644 --- a/sdks/python/apache_beam/ml/inference/openai_inference_test.py +++ b/sdks/python/apache_beam/ml/inference/openai_inference_test.py @@ -15,7 +15,6 @@ # limitations under the License. # -import logging import unittest from unittest.mock import MagicMock from unittest.mock import patch @@ -23,6 +22,7 @@ # pylint: disable=wrong-import-order, wrong-import-position try: import httpx + from apache_beam.ml.inference.base import PredictionResult from apache_beam.ml.inference.openai_inference import OpenAIModelHandler from apache_beam.ml.inference.openai_inference import _retry_on_appropriate_openai_error from openai import APIError @@ -35,8 +35,6 @@ except ImportError: raise unittest.SkipTest('OpenAI dependencies are not installed') -from apache_beam.ml.inference.base import PredictionResult - class RetryOnAPIErrorTest(unittest.TestCase): def _create_mock_error_with_status(self, status_code, error_class=APIError): @@ -55,7 +53,7 @@ def _create_mock_error_with_status(self, status_code, error_class=APIError): # mock_response.status_code will be set below. # Ensure headers is a mock that can handle .get() for RateLimitError mock_response.headers = MagicMock(spec=httpx.Headers) - mock_response.headers.get.return_value = "test-request-id" # For x-request-id + mock_response.headers.get.return_value = "test-request-id" mock_response.content = b"{}" mock_response.text = "{}" From ed32e1250a8e826ed0068d19b7a22263024e1ce1 Mon Sep 17 00:00:00 2001 From: Jack McCluskey Date: Wed, 2 Jul 2025 10:01:36 -0400 Subject: [PATCH 10/11] import order --- .../apache_beam/ml/inference/openai_inference_it_test.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/ml/inference/openai_inference_it_test.py b/sdks/python/apache_beam/ml/inference/openai_inference_it_test.py index d6dbd5ec127f..95e8f0ef6e6d 100644 --- a/sdks/python/apache_beam/ml/inference/openai_inference_it_test.py +++ b/sdks/python/apache_beam/ml/inference/openai_inference_it_test.py @@ -22,8 +22,9 @@ import unittest import uuid -import apache_beam as beam import pytest + +import apache_beam as beam from apache_beam.io.filesystems import FileSystems from apache_beam.ml.inference.base import RunInference from apache_beam.testing.test_pipeline import TestPipeline From 21652c662a00a7bf8d5e9f3e47e4971d9b847093 Mon Sep 17 00:00:00 2001 From: Jack McCluskey Date: Wed, 2 Jul 2025 15:26:53 -0400 Subject: [PATCH 11/11] remove unnecessary project references --- .../ml/inference/openai_inference_it_test.py | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/openai_inference_it_test.py b/sdks/python/apache_beam/ml/inference/openai_inference_it_test.py index 95e8f0ef6e6d..c11ab2f3c7c1 100644 --- a/sdks/python/apache_beam/ml/inference/openai_inference_it_test.py +++ b/sdks/python/apache_beam/ml/inference/openai_inference_it_test.py @@ -49,8 +49,6 @@ class OpenAIInferenceIT(unittest.TestCase): def setUp(self): self.output_dir = os.environ.get("BEAM_ML_OUTPUT_DIR", _OUTPUT_DIR_DEFAULT) - self.project = os.environ.get( - "BEAM_ML_PROJECT") # Not used by OpenAI but common in tests def run_pipeline( self, model_handler, test_data, output_path_suffix, inference_args=None): @@ -60,10 +58,6 @@ def run_pipeline( pipeline_options = { 'output': output_file, } - # Add project if available, for consistency with other IT tests, - # though OpenAI handler doesn't directly use it. - if self.project: - pipeline_options['project'] = self.project test_pipeline = TestPipeline( is_integration_test=True, options=pipeline_options) @@ -79,14 +73,7 @@ def run_pipeline( | beam.io.WriteToText(output_file)) self.assertTrue(FileSystems().exists(output_file)) - # Further checks could involve reading the output and verifying content, - # but for now, we just check if the pipeline runs and produces output. - - # Basic check for content in the output file to ensure it's not empty - # and contains expected PredictionResult structure. - # This part can be flaky if API responses change slightly. - # For a more robust check, one might mock the API in an IT setting or - # use a very deterministic, simple prompt. + match_results = [] def process_output_file(readable_file):