From bf9ae80c34419e2a05eca1f62a6c68e59bcb88bf Mon Sep 17 00:00:00 2001 From: Prajna1999 Date: Wed, 27 May 2026 23:17:20 +0530 Subject: [PATCH 01/15] feat: claude integration --- backend/app/core/providers.py | 4 + backend/app/models/llm/request.py | 13 +- backend/app/services/llm/mappers.py | 77 ++++++ .../app/services/llm/providers/__init__.py | 1 + backend/app/services/llm/providers/claude.py | 179 +++++++++++++ .../app/services/llm/providers/registry.py | 9 +- .../services/llm/providers/test_claude.py | 242 ++++++++++++++++++ backend/pyproject.toml | 1 + backend/uv.lock | 30 +++ 9 files changed, 549 insertions(+), 7 deletions(-) create mode 100644 backend/app/services/llm/providers/claude.py create mode 100644 backend/app/tests/services/llm/providers/test_claude.py diff --git a/backend/app/core/providers.py b/backend/app/core/providers.py index 597c9708d..202cecf12 100644 --- a/backend/app/core/providers.py +++ b/backend/app/core/providers.py @@ -14,6 +14,7 @@ class Provider(str, Enum): GOOGLE = "google" SARVAMAI = "sarvamai" ELEVENLABS = "elevenlabs" + ANTHROPIC = "anthropic" WEBHOOK_SECRET = "webhook_secret" @@ -43,6 +44,9 @@ class ProviderConfig: Provider.ELEVENLABS: ProviderConfig( required_fields=["api_key"], sensitive_fields=["api_key"] ), + Provider.ANTHROPIC: ProviderConfig( + required_fields=["api_key"], sensitive_fields=["api_key"] + ), Provider.WEBHOOK_SECRET: ProviderConfig( required_fields=["webhook_secret"], sensitive_fields=["webhook_secret"] ), diff --git a/backend/app/models/llm/request.py b/backend/app/models/llm/request.py index da0c18120..866b85a46 100644 --- a/backend/app/models/llm/request.py +++ b/backend/app/models/llm/request.py @@ -227,7 +227,11 @@ class NativeCompletionConfig(SQLModel): """ provider: Literal[ - "openai-native", "google-native", "sarvamai-native", "elevenlabs-native" + "openai-native", + "google-native", + "sarvamai-native", + "elevenlabs-native", + "anthropic-native", ] = Field( ..., description="Native provider type (e.g., openai-native)", @@ -248,8 +252,11 @@ class KaapiCompletionConfig(SQLModel): Supports multiple providers: OpenAI, Claude, Gemini, etc. """ - provider: Literal["openai", "google", "sarvamai", "elevenlabs"] | None = Field( - None, description="LLM provider (openai, google, sarvamai, elevenlabs)" + provider: ( + Literal["openai", "google", "sarvamai", "elevenlabs", "anthropic"] | None + ) = Field( + None, + description="LLM provider (openai, google, sarvamai, elevenlabs, anthropic)", ) type: Literal["text", "stt", "tts"] = Field( diff --git a/backend/app/services/llm/mappers.py b/backend/app/services/llm/mappers.py index 3bd049f05..1b42c3191 100644 --- a/backend/app/services/llm/mappers.py +++ b/backend/app/services/llm/mappers.py @@ -428,6 +428,68 @@ def map_kaapi_to_elevenlabs_params( return elevenlabs_params, warnings +def map_kaapi_to_anthropic_params( + kaapi_params: dict, +) -> tuple[dict, list[str]]: + """Map Kaapi-abstracted parameters to Anthropic Messages API parameters. + + Supported Mapping: + - model → model + - instructions → system + - temperature → temperature + - top_p → top_p + - max_output_tokens → max_tokens (Anthropic requires this; + provider defaults if absent) + + Unsupported Kaapi params: + - knowledge_base_ids / max_num_results: Anthropic has no native + vector-store / file_search tool, dropped with warning. + - reasoning / effort / summary: Messages API does not expose a + reasoning-effort knob, dropped with warning. + """ + anthropic_params: dict = {} + warnings: list[str] = [] + + model = kaapi_params.get("model") + instructions = kaapi_params.get("instructions") + temperature = kaapi_params.get("temperature") + top_p = kaapi_params.get("top_p") + max_output_tokens = kaapi_params.get("max_output_tokens") + knowledge_base_ids = kaapi_params.get("knowledge_base_ids") + reasoning = kaapi_params.get("reasoning") + effort = kaapi_params.get("effort") + summary = kaapi_params.get("summary") + + if model: + anthropic_params["model"] = model + + if instructions: + anthropic_params["system"] = instructions + + if temperature is not None: + anthropic_params["temperature"] = temperature + + if top_p is not None: + anthropic_params["top_p"] = top_p + + if max_output_tokens is not None: + anthropic_params["max_tokens"] = max_output_tokens + + if knowledge_base_ids: + warnings.append( + "Parameter 'knowledge_base_ids' was ignored because Anthropic has no " + "native vector-store/file_search tool. Inline document content blocks instead." + ) + + if reasoning is not None or effort is not None or summary is not None: + warnings.append( + "Parameters 'reasoning'/'effort'/'summary' were ignored because the " + "Anthropic Messages API does not expose a reasoning-effort knob." + ) + + return anthropic_params, warnings + + def transform_kaapi_config_to_native( session: Session, kaapi_config: KaapiCompletionConfig, @@ -492,4 +554,19 @@ def transform_kaapi_config_to_native( warnings, ) + if kaapi_config.provider == "anthropic": + if kaapi_config.type != "text": + raise ValueError( + f"Anthropic provider does not support completion type '{kaapi_config.type}'" + ) + mapped_params, warnings = map_kaapi_to_anthropic_params(kaapi_config.params) + return ( + NativeCompletionConfig( + provider="anthropic-native", + params=mapped_params, + type=kaapi_config.type, + ), + warnings, + ) + raise ValueError(f"Unsupported provider: {kaapi_config.provider}") diff --git a/backend/app/services/llm/providers/__init__.py b/backend/app/services/llm/providers/__init__.py index d0df8dce6..fabfdf156 100644 --- a/backend/app/services/llm/providers/__init__.py +++ b/backend/app/services/llm/providers/__init__.py @@ -3,6 +3,7 @@ from app.services.llm.providers.gai import GoogleAIProvider from app.services.llm.providers.eai import ElevenlabsAIProvider from app.services.llm.providers.sai import SarvamAIProvider +from app.services.llm.providers.claude import ClaudeProvider from app.services.llm.providers.registry import ( LLMProvider, get_llm_provider, diff --git a/backend/app/services/llm/providers/claude.py b/backend/app/services/llm/providers/claude.py new file mode 100644 index 000000000..382000f27 --- /dev/null +++ b/backend/app/services/llm/providers/claude.py @@ -0,0 +1,179 @@ +import logging +from typing import Any + +import anthropic +from anthropic import Anthropic +from anthropic.types import Message + +from app.models.llm import ( + NativeCompletionConfig, + LLMCallResponse, + QueryParams, + LLMResponse, + Usage, + TextOutput, + TextContent, + ImageContent, + PDFContent, +) +from app.services.llm.providers.base import BaseProvider, ContentPart, MultiModalInput + +logger = logging.getLogger(__name__) + +DEFAULT_MAX_TOKENS = 4096 + + +class ClaudeProvider(BaseProvider): + def __init__(self, client: Anthropic): + """Initialize Anthropic Claude provider with client. + + Args: + client: Anthropic client instance + """ + super().__init__(client) + self.client = client + + @staticmethod + def create_client(credentials: dict[str, Any]) -> Any: + if "api_key" not in credentials: + raise ValueError("Anthropic credentials not configured for this project.") + return Anthropic(api_key=credentials["api_key"]) + + @staticmethod + def format_parts( + parts: list[ContentPart], + ) -> list[dict]: + items = [] + for part in parts: + if isinstance(part, TextContent): + items.append({"type": "text", "text": part.value}) + + elif isinstance(part, ImageContent): + if part.format == "base64": + items.append( + { + "type": "image", + "source": { + "type": "base64", + "media_type": part.mime_type, + "data": part.value, + }, + } + ) + else: + items.append( + { + "type": "image", + "source": {"type": "url", "url": part.value}, + } + ) + + elif isinstance(part, PDFContent): + if part.format == "base64": + items.append( + { + "type": "document", + "source": { + "type": "base64", + "media_type": part.mime_type, + "data": part.value, + }, + } + ) + else: + items.append( + { + "type": "document", + "source": {"type": "url", "url": part.value}, + } + ) + + return items + + def execute( + self, + completion_config: NativeCompletionConfig, + query: QueryParams, + resolved_input: str | list[ContentPart] | MultiModalInput, + include_provider_raw_response: bool = False, + ) -> tuple[LLMCallResponse | None, str | None]: + response: Message | None = None + error_message: str | None = None + + try: + params = {**completion_config.params} + + # Anthropic requires max_tokens; default if caller did not supply + params.setdefault("max_tokens", DEFAULT_MAX_TOKENS) + + # Kaapi exposes "instructions"; Anthropic uses "system". Always + # strip "instructions" — Anthropic rejects unknown kwargs. + if "instructions" in params: + if "system" not in params: + params["system"] = params["instructions"] + params.pop("instructions") + + if isinstance(resolved_input, MultiModalInput): + content = self.format_parts(resolved_input.parts) + elif isinstance(resolved_input, list): + content = self.format_parts(resolved_input) + else: + content = resolved_input + + params["messages"] = [{"role": "user", "content": content}] + + # Anthropic Messages API has no first-class conversation primitive, + # callers must replay prior messages themselves. Strip conversation + # config so it never leaks into the API call. + params.pop("conversation", None) + + response = self.client.messages.create(**params) + + output_text = "".join( + block.text for block in response.content if block.type == "text" + ) + + llm_response = LLMCallResponse( + response=LLMResponse( + provider_response_id=response.id, + conversation_id=None, + model=response.model, + provider=completion_config.provider, + output=TextOutput(content=TextContent(value=output_text)), + ), + usage=Usage( + input_tokens=response.usage.input_tokens, + output_tokens=response.usage.output_tokens, + total_tokens=response.usage.input_tokens + + response.usage.output_tokens, + ), + ) + + if include_provider_raw_response: + llm_response.provider_raw_response = response.model_dump() + + logger.info( + f"[ClaudeProvider.execute] Successfully generated response | " + f"request_id={response.id}, provider={completion_config.provider}, model={response.model}" + ) + return llm_response, None + + except TypeError as e: + error_message = f"Invalid or unexpected parameter in Config: {str(e)}" + return None, error_message + + except anthropic.AnthropicError as e: + error_message = f"Anthropic API error: {str(e)}" + logger.warning( + f"[ClaudeProvider.execute] {error_message} | provider={completion_config.provider}", + exc_info=True, + ) + return None, error_message + + except Exception as e: + error_message = "Unexpected error occurred" + logger.error( + f"[ClaudeProvider.execute] {error_message}: {str(e)} | provider={completion_config.provider}", + exc_info=True, + ) + return None, error_message diff --git a/backend/app/services/llm/providers/registry.py b/backend/app/services/llm/providers/registry.py index 9f4538ae1..60655028d 100644 --- a/backend/app/services/llm/providers/registry.py +++ b/backend/app/services/llm/providers/registry.py @@ -6,6 +6,7 @@ from app.services.llm.providers.gai import GoogleAIProvider from app.services.llm.providers.sai import SarvamAIProvider from app.services.llm.providers.eai import ElevenlabsAIProvider +from app.services.llm.providers.claude import ClaudeProvider logger = logging.getLogger(__name__) @@ -15,24 +16,24 @@ class LLMProvider: SARVAMAI = "sarvamai" ELEVENLABS = "elevenlabs" GOOGLE = "google" - # Future constants for native providers: - # CLAUDE_NATIVE = "claude-native" + ANTHROPIC = "anthropic" OPENAI_NATIVE = "openai-native" GOOGLE_NATIVE = "google-native" SARVAMAI_NATIVE = "sarvamai-native" ELEVENLABS_NATIVE = "elevenlabs-native" + ANTHROPIC_NATIVE = "anthropic-native" _registry: dict[str, type[BaseProvider]] = { OPENAI: OpenAIProvider, GOOGLE: GoogleAIProvider, SARVAMAI: SarvamAIProvider, ELEVENLABS: ElevenlabsAIProvider, - # Future native providers: - # CLAUDE_NATIVE: ClaudeProvider, + ANTHROPIC: ClaudeProvider, OPENAI_NATIVE: OpenAIProvider, GOOGLE_NATIVE: GoogleAIProvider, SARVAMAI_NATIVE: SarvamAIProvider, ELEVENLABS_NATIVE: ElevenlabsAIProvider, + ANTHROPIC_NATIVE: ClaudeProvider, } @classmethod diff --git a/backend/app/tests/services/llm/providers/test_claude.py b/backend/app/tests/services/llm/providers/test_claude.py new file mode 100644 index 000000000..311171954 --- /dev/null +++ b/backend/app/tests/services/llm/providers/test_claude.py @@ -0,0 +1,242 @@ +""" +Tests for the Anthropic Claude provider. +""" + +import pytest +from unittest.mock import MagicMock +from types import SimpleNamespace + +import anthropic + +from app.models.llm import ( + NativeCompletionConfig, + QueryParams, + TextContent, + ImageContent, + PDFContent, +) +from app.services.llm.providers.base import MultiModalInput +from app.services.llm.providers.claude import ClaudeProvider, DEFAULT_MAX_TOKENS + + +def mock_claude_message( + text: str = "hello", + model: str = "claude-opus-4-7", + message_id: str = "msg_123", + input_tokens: int = 10, + output_tokens: int = 5, + extra_blocks: list | None = None, +) -> SimpleNamespace: + """Build a SimpleNamespace mimicking an anthropic Message.""" + content = [SimpleNamespace(type="text", text=text)] + if extra_blocks: + content.extend(extra_blocks) + return SimpleNamespace( + id=message_id, + model=model, + content=content, + usage=SimpleNamespace(input_tokens=input_tokens, output_tokens=output_tokens), + model_dump=lambda: {"id": message_id, "model": model}, + ) + + +class TestClaudeProvider: + @pytest.fixture + def mock_client(self): + client = MagicMock() + client.messages.create = MagicMock() + return client + + @pytest.fixture + def provider(self, mock_client): + return ClaudeProvider(client=mock_client) + + @pytest.fixture + def text_config(self): + return NativeCompletionConfig( + provider="anthropic-native", + type="text", + params={"model": "claude-opus-4-7"}, + ) + + @pytest.fixture + def query_params(self): + return QueryParams(input="hi") + + def test_create_client_requires_api_key(self): + with pytest.raises(ValueError, match="not configured"): + ClaudeProvider.create_client(credentials={}) + + def test_create_client_with_api_key(self): + client = ClaudeProvider.create_client(credentials={"api_key": "sk-test"}) + assert isinstance(client, anthropic.Anthropic) + + def test_execute_success_text_input( + self, provider, mock_client, text_config, query_params + ): + mock_client.messages.create.return_value = mock_claude_message( + text="ok", model="claude-opus-4-7" + ) + + result, error = provider.execute(text_config, query_params, "hi") + + assert error is None + assert result.response.output.content.value == "ok" + assert result.response.provider == "anthropic-native" + assert result.response.model == "claude-opus-4-7" + assert result.response.provider_response_id == "msg_123" + assert result.usage.input_tokens == 10 + assert result.usage.output_tokens == 5 + assert result.usage.total_tokens == 15 + + call_kwargs = mock_client.messages.create.call_args.kwargs + assert call_kwargs["model"] == "claude-opus-4-7" + assert call_kwargs["max_tokens"] == DEFAULT_MAX_TOKENS + assert call_kwargs["messages"] == [{"role": "user", "content": "hi"}] + + def test_execute_does_not_override_user_max_tokens( + self, provider, mock_client, query_params + ): + config = NativeCompletionConfig( + provider="anthropic-native", + type="text", + params={"model": "claude-opus-4-7", "max_tokens": 64}, + ) + mock_client.messages.create.return_value = mock_claude_message() + + provider.execute(config, query_params, "hi") + + assert mock_client.messages.create.call_args.kwargs["max_tokens"] == 64 + + def test_execute_instructions_renamed_to_system( + self, provider, mock_client, query_params + ): + config = NativeCompletionConfig( + provider="anthropic-native", + type="text", + params={"model": "claude-opus-4-7", "instructions": "be brief"}, + ) + mock_client.messages.create.return_value = mock_claude_message() + + provider.execute(config, query_params, "hi") + + kwargs = mock_client.messages.create.call_args.kwargs + assert kwargs.get("system") == "be brief" + assert "instructions" not in kwargs + + def test_execute_strips_instructions_when_system_also_set( + self, provider, mock_client, query_params + ): + config = NativeCompletionConfig( + provider="anthropic-native", + type="text", + params={ + "model": "claude-opus-4-7", + "instructions": "ignored", + "system": "winner", + }, + ) + mock_client.messages.create.return_value = mock_claude_message() + + provider.execute(config, query_params, "hi") + + kwargs = mock_client.messages.create.call_args.kwargs + assert kwargs["system"] == "winner" + assert "instructions" not in kwargs + + def test_execute_multimodal_text_image_pdf( + self, provider, mock_client, text_config, query_params + ): + mock_client.messages.create.return_value = mock_claude_message() + multimodal = MultiModalInput( + parts=[ + TextContent(value="describe"), + ImageContent(format="base64", mime_type="image/png", value="ZmFrZQ=="), + PDFContent( + format="url", mime_type="application/pdf", value="https://x/y.pdf" + ), + ] + ) + + provider.execute(text_config, query_params, multimodal) + + content = mock_client.messages.create.call_args.kwargs["messages"][0]["content"] + assert content[0] == {"type": "text", "text": "describe"} + assert content[1] == { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/png", + "data": "ZmFrZQ==", + }, + } + assert content[2] == { + "type": "document", + "source": {"type": "url", "url": "https://x/y.pdf"}, + } + + def test_execute_strips_conversation_param( + self, provider, mock_client, query_params + ): + config = NativeCompletionConfig( + provider="anthropic-native", + type="text", + params={"model": "claude-opus-4-7", "conversation": {"id": "conv_x"}}, + ) + mock_client.messages.create.return_value = mock_claude_message() + + provider.execute(config, query_params, "hi") + + assert "conversation" not in mock_client.messages.create.call_args.kwargs + + def test_execute_joins_only_text_blocks( + self, provider, mock_client, text_config, query_params + ): + # Response with a tool_use block mixed in; we only join text blocks + tool_block = SimpleNamespace(type="tool_use", id="t1", name="x", input={}) + mock_client.messages.create.return_value = mock_claude_message( + text="part1", + extra_blocks=[tool_block, SimpleNamespace(type="text", text="part2")], + ) + + result, error = provider.execute(text_config, query_params, "hi") + + assert error is None + assert result.response.output.content.value == "part1part2" + + def test_execute_includes_raw_response_when_requested( + self, provider, mock_client, text_config, query_params + ): + mock_client.messages.create.return_value = mock_claude_message() + + result, _ = provider.execute( + text_config, query_params, "hi", include_provider_raw_response=True + ) + + assert result.provider_raw_response == { + "id": "msg_123", + "model": "claude-opus-4-7", + } + + def test_execute_returns_error_on_anthropic_api_error( + self, provider, mock_client, text_config, query_params + ): + mock_client.messages.create.side_effect = anthropic.AnthropicError("boom") + + result, error = provider.execute(text_config, query_params, "hi") + + assert result is None + assert error is not None + assert "boom" in error + + def test_execute_returns_error_on_unexpected_kwarg( + self, provider, mock_client, text_config, query_params + ): + mock_client.messages.create.side_effect = TypeError( + "unexpected keyword argument 'foo'" + ) + + result, error = provider.execute(text_config, query_params, "hi") + + assert result is None + assert "Invalid or unexpected parameter" in error diff --git a/backend/pyproject.toml b/backend/pyproject.toml index ad480923f..96cfa8055 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -54,6 +54,7 @@ dependencies = [ "gevent>=25.9.1", "openpyxl>=3.1.5", "litellm>=1.83.10", + "anthropic>=0.104.1", ] [tool.uv] diff --git a/backend/uv.lock b/backend/uv.lock index ee6bb6242..67912f5e9 100644 --- a/backend/uv.lock +++ b/backend/uv.lock @@ -199,6 +199,25 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53", size = 13643, upload-time = "2024-05-20T21:33:24.1Z" }, ] +[[package]] +name = "anthropic" +version = "0.104.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "distro" }, + { name = "docstring-parser" }, + { name = "httpx" }, + { name = "jiter" }, + { name = "pydantic" }, + { name = "sniffio" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/22/c7/7a655b948916f777354648ce979f68b94d5b8dbdb5f61fed1f37fad9378c/anthropic-0.104.1.tar.gz", hash = "sha256:17362b6c45f527afcc9b0fdf62011ffd359726ab2ebcb1978ea0cc41bd8d8d40", size = 850081, upload-time = "2026-05-22T15:36:57.432Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b8/12/d9ab42790494d7c428391a46cd28492395566a6a8ccb138d681978594455/anthropic-0.104.1-py3-none-any.whl", hash = "sha256:35c8cb456f5a4405aafe1f10f03f6fcc54fa51fa8ec01d655cc4b437d120e9b7", size = 832996, upload-time = "2026-05-22T15:36:59.519Z" }, +] + [[package]] name = "anyio" version = "4.12.1" @@ -218,6 +237,7 @@ version = "0.1.0" source = { editable = "." } dependencies = [ { name = "alembic" }, + { name = "anthropic" }, { name = "asgi-correlation-id" }, { name = "bcrypt" }, { name = "boto3" }, @@ -282,6 +302,7 @@ dev = [ [package.metadata] requires-dist = [ { name = "alembic", specifier = ">=1.12.1,<2.0.0" }, + { name = "anthropic", specifier = ">=0.104.1" }, { name = "asgi-correlation-id", specifier = ">=4.3.4" }, { name = "bcrypt", specifier = "==4.0.1" }, { name = "boto3", specifier = ">=1.37.20" }, @@ -876,6 +897,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ba/5a/18ad964b0086c6e62e2e7500f7edc89e3faa45033c71c1893d34eed2b2de/dnspython-2.8.0-py3-none-any.whl", hash = "sha256:01d9bbc4a2d76bf0db7c1f729812ded6d912bd318d3b1cf81d30c0f845dbf3af", size = 331094, upload-time = "2025-09-07T18:57:58.071Z" }, ] +[[package]] +name = "docstring-parser" +version = "0.18.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e0/4d/f332313098c1de1b2d2ff91cf2674415cc7cddab2ca1b01ae29774bd5fdf/docstring_parser-0.18.0.tar.gz", hash = "sha256:292510982205c12b1248696f44959db3cdd1740237a968ea1e2e7a900eeb2015", size = 29341, upload-time = "2026-04-14T04:09:19.867Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a7/5f/ed01f9a3cdffbd5a008556fc7b2a08ddb1cc6ace7effa7340604b1d16699/docstring_parser-0.18.0-py3-none-any.whl", hash = "sha256:b3fcbed555c47d8479be0796ef7e19c2670d428d72e96da63f3a40122860374b", size = 22484, upload-time = "2026-04-14T04:09:18.638Z" }, +] + [[package]] name = "docutils" version = "0.22.4" From b279ff73d77f693cbf79bb02a1579c7a5ec575b8 Mon Sep 17 00:00:00 2001 From: Prajna1999 Date: Thu, 28 May 2026 15:33:53 +0530 Subject: [PATCH 02/15] feat: adapter for google-vertex --- ...nthropic_google_vertex_to_provider_enum.py | 115 ++++++ backend/app/core/providers.py | 5 + backend/app/crud/model_config.py | 4 +- backend/app/models/llm/request.py | 15 +- backend/app/models/model_config.py | 8 +- backend/app/services/llm/mappers.py | 19 + .../app/services/llm/providers/__init__.py | 1 + .../app/services/llm/providers/gai_vertex.py | 363 ++++++++++++++++++ .../app/services/llm/providers/registry.py | 5 + .../services/llm/providers/test_gai_vertex.py | 312 +++++++++++++++ 10 files changed, 842 insertions(+), 5 deletions(-) create mode 100644 backend/app/alembic/versions/064_add_anthropic_google_vertex_to_provider_enum.py create mode 100644 backend/app/services/llm/providers/gai_vertex.py create mode 100644 backend/app/tests/services/llm/providers/test_gai_vertex.py diff --git a/backend/app/alembic/versions/064_add_anthropic_google_vertex_to_provider_enum.py b/backend/app/alembic/versions/064_add_anthropic_google_vertex_to_provider_enum.py new file mode 100644 index 000000000..3889902b0 --- /dev/null +++ b/backend/app/alembic/versions/064_add_anthropic_google_vertex_to_provider_enum.py @@ -0,0 +1,115 @@ +"""add anthropic + google-vertex to provider_enum and seed test model_config rows + +Revision ID: 064 +Revises: 063 +Create Date: 2026-05-28 00:00:00.000000 + +""" + +from alembic import op + + +revision = "064" +down_revision = "063" +branch_labels = None +depends_on = None + + +def upgrade(): + # ALTER TYPE ... ADD VALUE cannot run inside a transaction block; use + # autocommit per existing pattern (see migration 056). The added values + # are visible to subsequent statements once the autocommit_block exits. + with op.get_context().autocommit_block(): + op.execute( + "ALTER TYPE global.provider_enum ADD VALUE IF NOT EXISTS 'anthropic'" + ) + op.execute( + "ALTER TYPE global.provider_enum ADD VALUE IF NOT EXISTS 'google-vertex'" + ) + + # Pass-through seed rows for testing. Pricing values are placeholders; + # revise once real cost data is available. + op.execute( + """ + INSERT INTO global.model_config + (provider, model_name, completion_type, config, input_modalities, output_modalities, pricing, is_active, inserted_at, updated_at) + VALUES + -- Anthropic text models + ('anthropic', 'claude-opus-4-7', 'text', + '{"temperature": {"type": "float", "default": 1.0, "min": 0.0, "max": 1.0, "description": "Sampling temperature."}}', + '{TEXT,IMAGE,PDF}', '{TEXT}', + '{"response": {"input_token_cost": 15.0, "output_token_cost": 75.0}, "batch": {"input_token_cost": 7.5, "output_token_cost": 37.5}}', + true, NOW(), NOW()), + ('anthropic', 'claude-sonnet-4-6', 'text', + '{"temperature": {"type": "float", "default": 1.0, "min": 0.0, "max": 1.0, "description": "Sampling temperature."}}', + '{TEXT,IMAGE,PDF}', '{TEXT}', + '{"response": {"input_token_cost": 3.0, "output_token_cost": 15.0}, "batch": {"input_token_cost": 1.5, "output_token_cost": 7.5}}', + true, NOW(), NOW()), + ('anthropic', 'claude-haiku-4-5-20251001', 'text', + '{"temperature": {"type": "float", "default": 1.0, "min": 0.0, "max": 1.0, "description": "Sampling temperature."}}', + '{TEXT,IMAGE,PDF}', '{TEXT}', + '{"response": {"input_token_cost": 1.0, "output_token_cost": 5.0}, "batch": {"input_token_cost": 0.5, "output_token_cost": 2.5}}', + true, NOW(), NOW()), + -- Google Vertex STT models (Gemini 3.x family — GA per + -- https://docs.cloud.google.com/gemini-enterprise-agent-platform/models/google-models) + ('google-vertex', 'gemini-3.1-pro-preview', 'stt', + '{"thinking_level": {"type": "enum", "default": "high", "options": ["low", "medium", "high"], "description": "Max reasoning depth before output. high = best quality, low = faster/cheaper."}}', + '{AUDIO}', '{TEXT}', + '{"response": {"input_token_cost": 2.0, "output_token_cost": 12.0}, "audio": {"input_token_cost": 3.5, "output_token_cost": 12.0}}', + true, NOW(), NOW()), + ('google-vertex', 'gemini-3-pro', 'stt', + '{"thinking_level": {"type": "enum", "default": "high", "options": ["low", "medium", "high"], "description": "Max reasoning depth before output."}}', + '{AUDIO}', '{TEXT}', + '{"response": {"input_token_cost": 1.5, "output_token_cost": 10.0}, "audio": {"input_token_cost": 3.0, "output_token_cost": 10.0}}', + true, NOW(), NOW()), + ('google-vertex', 'gemini-3.5-flash', 'stt', + '{"thinking_level": {"type": "enum", "default": "high", "options": ["minimal", "low", "medium", "high"], "description": "Max reasoning depth before output."}}', + '{AUDIO}', '{TEXT}', + '{"response": {"input_token_cost": 0.6, "output_token_cost": 3.5}, "audio": {"input_token_cost": 1.2, "output_token_cost": 3.5}}', + true, NOW(), NOW()), + ('google-vertex', 'gemini-3-flash-preview', 'stt', + '{"thinking_level": {"type": "enum", "default": "high", "options": ["minimal", "low", "medium", "high"], "description": "Max reasoning depth before output."}}', + '{AUDIO}', '{TEXT}', + '{"response": {"input_token_cost": 0.5, "output_token_cost": 3.0}, "audio": {"input_token_cost": 1.0, "output_token_cost": 3.0}}', + true, NOW(), NOW()), + ('google-vertex', 'gemini-3.1-flash-lite', 'stt', + '{"temperature": {"type": "float", "default": 0.0, "min": 0.0, "max": 2.0, "description": "Controls randomness. Lower = more deterministic."}}', + '{AUDIO}', '{TEXT}', + '{"response": {"input_token_cost": 0.1, "output_token_cost": 0.4}, "audio": {"input_token_cost": 0.3, "output_token_cost": 0.4}}', + true, NOW(), NOW()), + ('google-vertex', 'gemini-2.5-flash', 'stt', + '{"temperature": {"type": "float", "default": 0.0, "min": 0.0, "max": 2.0, "description": "Controls randomness. Lower = more deterministic."}}', + '{AUDIO}', '{TEXT}', + '{"response": {"input_token_cost": 0.3, "output_token_cost": 2.5}, "audio": {"input_token_cost": 1.0, "output_token_cost": 2.5}}', + true, NOW(), NOW()), + ('google-vertex', 'gemini-2.5-pro', 'stt', + '{"temperature": {"type": "float", "default": 0.0, "min": 0.0, "max": 2.0, "description": "Controls randomness. Lower = more deterministic."}}', + '{AUDIO}', '{TEXT}', + '{"response": {"input_token_cost": 1.25, "output_token_cost": 10.0}, "audio": {"input_token_cost": 3.5, "output_token_cost": 10.0}}', + true, NOW(), NOW()), + -- Google Vertex TTS models + ('google-vertex', 'gemini-2.5-flash-preview-tts', 'tts', + '{"voice": {"type": "enum", "default": "Kore", "options": ["Aoede", "Charon", "Fenrir", "Kore", "Puck"], "description": "TTS voice."}}', + '{TEXT}', '{AUDIO}', + '{"response": {"input_token_cost": 0.5, "output_token_cost": 10.0}, "audio": {"input_token_cost": 0.5, "output_token_cost": 10.0}}', + true, NOW(), NOW()), + ('google-vertex', 'gemini-2.5-pro-preview-tts', 'tts', + '{"voice": {"type": "enum", "default": "Kore", "options": ["Aoede", "Charon", "Fenrir", "Kore", "Puck"], "description": "TTS voice."}}', + '{TEXT}', '{AUDIO}', + '{"response": {"input_token_cost": 1.0, "output_token_cost": 20.0}, "audio": {"input_token_cost": 1.0, "output_token_cost": 20.0}}', + true, NOW(), NOW()) + ON CONFLICT (provider, model_name) DO NOTHING + """ + ) + + +def downgrade(): + op.execute( + """ + DELETE FROM global.model_config + WHERE provider IN ('anthropic', 'google-vertex') + """ + ) + # Enum value removal requires rebuilding the type and re-pointing every + # referencing column. Skipped — see migrations 035 / 056 for the same + # convention. diff --git a/backend/app/core/providers.py b/backend/app/core/providers.py index 202cecf12..c1e21f7ae 100644 --- a/backend/app/core/providers.py +++ b/backend/app/core/providers.py @@ -15,6 +15,7 @@ class Provider(str, Enum): SARVAMAI = "sarvamai" ELEVENLABS = "elevenlabs" ANTHROPIC = "anthropic" + GOOGLE_VERTEX = "google-vertex" WEBHOOK_SECRET = "webhook_secret" @@ -47,6 +48,10 @@ class ProviderConfig: Provider.ANTHROPIC: ProviderConfig( required_fields=["api_key"], sensitive_fields=["api_key"] ), + Provider.GOOGLE_VERTEX: ProviderConfig( + required_fields=["api_key", "project_id", "location"], + sensitive_fields=["api_key"], + ), Provider.WEBHOOK_SECRET: ProviderConfig( required_fields=["webhook_secret"], sensitive_fields=["webhook_secret"] ), diff --git a/backend/app/crud/model_config.py b/backend/app/crud/model_config.py index 9c627f7f4..643af59f8 100644 --- a/backend/app/crud/model_config.py +++ b/backend/app/crud/model_config.py @@ -7,7 +7,9 @@ from app.models.llm.request import ConfigBlob from app.models.model_config import CompletionType -Provider = Literal["openai", "google", "sarvamai", "elevenlabs"] +Provider = Literal[ + "openai", "google", "sarvamai", "elevenlabs", "anthropic", "google-vertex" +] def _normalize_provider(raw: str) -> str: diff --git a/backend/app/models/llm/request.py b/backend/app/models/llm/request.py index 866b85a46..d14034819 100644 --- a/backend/app/models/llm/request.py +++ b/backend/app/models/llm/request.py @@ -232,6 +232,7 @@ class NativeCompletionConfig(SQLModel): "sarvamai-native", "elevenlabs-native", "anthropic-native", + "google-vertex-native", ] = Field( ..., description="Native provider type (e.g., openai-native)", @@ -253,10 +254,20 @@ class KaapiCompletionConfig(SQLModel): """ provider: ( - Literal["openai", "google", "sarvamai", "elevenlabs", "anthropic"] | None + Literal[ + "openai", + "google", + "sarvamai", + "elevenlabs", + "anthropic", + "google-vertex", + ] + | None ) = Field( None, - description="LLM provider (openai, google, sarvamai, elevenlabs, anthropic)", + description=( + "LLM provider (openai, google, sarvamai, elevenlabs, anthropic, google-vertex)" + ), ) type: Literal["text", "stt", "tts"] = Field( diff --git a/backend/app/models/model_config.py b/backend/app/models/model_config.py index ef284dc5f..b3b6ab853 100644 --- a/backend/app/models/model_config.py +++ b/backend/app/models/model_config.py @@ -11,7 +11,9 @@ class ModelConfigBase(SQLModel): - provider: Literal["openai", "google", "sarvamai", "elevenlabs"] = Field( + provider: Literal[ + "openai", "google", "sarvamai", "elevenlabs", "anthropic", "google-vertex" + ] = Field( default="openai", sa_column=sa.Column( sa.Enum( @@ -19,11 +21,13 @@ class ModelConfigBase(SQLModel): "google", "sarvamai", "elevenlabs", + "anthropic", + "google-vertex", name="provider_enum", schema="global", ), nullable=False, - comment="provider name (e.g. openai, google, sarvamai, elevenlabs)", + comment="provider name (e.g. openai, google, sarvamai, elevenlabs, anthropic, google-vertex)", ), ) diff --git a/backend/app/services/llm/mappers.py b/backend/app/services/llm/mappers.py index 1b42c3191..a6a42f62e 100644 --- a/backend/app/services/llm/mappers.py +++ b/backend/app/services/llm/mappers.py @@ -554,6 +554,25 @@ def transform_kaapi_config_to_native( warnings, ) + if kaapi_config.provider == "google-vertex": + if kaapi_config.type not in ("stt", "tts"): + raise ValueError( + f"google-vertex provider does not support completion type '{kaapi_config.type}'. " + "Use the 'google' provider for text completions." + ) + # Kaapi STT/TTS param shape is identical to Google's; reuse the Google mapper. + mapped_params, warnings = map_kaapi_to_google_params( + kaapi_config.params, kaapi_config.type + ) + return ( + NativeCompletionConfig( + provider="google-vertex-native", + params=mapped_params, + type=kaapi_config.type, + ), + warnings, + ) + if kaapi_config.provider == "anthropic": if kaapi_config.type != "text": raise ValueError( diff --git a/backend/app/services/llm/providers/__init__.py b/backend/app/services/llm/providers/__init__.py index fabfdf156..f35e26487 100644 --- a/backend/app/services/llm/providers/__init__.py +++ b/backend/app/services/llm/providers/__init__.py @@ -4,6 +4,7 @@ from app.services.llm.providers.eai import ElevenlabsAIProvider from app.services.llm.providers.sai import SarvamAIProvider from app.services.llm.providers.claude import ClaudeProvider +from app.services.llm.providers.gai_vertex import GoogleVertexAIProvider from app.services.llm.providers.registry import ( LLMProvider, get_llm_provider, diff --git a/backend/app/services/llm/providers/gai_vertex.py b/backend/app/services/llm/providers/gai_vertex.py new file mode 100644 index 000000000..97edef690 --- /dev/null +++ b/backend/app/services/llm/providers/gai_vertex.py @@ -0,0 +1,363 @@ +import base64 +import logging +import os +import uuid +from typing import Any + +import requests + +from app.core.audio_utils import convert_pcm_to_mp3, convert_pcm_to_ogg, pcm_to_wav +from app.models.llm import ( + LLMCallResponse, + LLMResponse, + NativeCompletionConfig, + QueryParams, + TextContent, + TextOutput, + Usage, +) +from app.models.llm.constants import ( + DEFAULT_STT_MODEL, + DEFAULT_TTS_MODEL, + DEFAULT_TTS_VOICE, +) +from app.models.llm.response import AudioContent, AudioOutput +from app.services.llm.providers.base import BaseProvider, ContentPart, MultiModalInput + +logger = logging.getLogger(__name__) + +REQUEST_TIMEOUT = 60 +MAX_INLINE_AUDIO_BYTES = 20 * 1024 * 1024 # Vertex inline-data cap (~20 MB) +AUDIO_MIME_BY_EXT = { + ".wav": "audio/wav", + ".mp3": "audio/mp3", + ".aiff": "audio/aiff", + ".aac": "audio/aac", + ".ogg": "audio/ogg", + ".flac": "audio/flac", +} + + +class VertexClient: + """Holds Vertex AI connection details. Pure config — no SDK session.""" + + def __init__(self, api_key: str, project_id: str, location: str): + self.api_key = api_key + self.project_id = project_id + self.location = location + + def endpoint(self, model: str) -> str: + return ( + f"https://{self.location}-aiplatform.googleapis.com/v1" + f"/projects/{self.project_id}/locations/{self.location}" + f"/publishers/google/models/{model}:generateContent" + ) + + +class GoogleVertexAIProvider(BaseProvider): + """Google Vertex AI provider using REST + API key auth. + + Supports STT (audio → text) and TTS (text → audio) via Gemini multimodal + models on Vertex. Text-only completions are routed through the standard + `google` provider. + """ + + def __init__(self, client: VertexClient): + super().__init__(client) + self.client = client + + @staticmethod + def create_client(credentials: dict[str, Any]) -> Any: + missing = [ + f for f in ("api_key", "project_id", "location") if not credentials.get(f) + ] + if missing: + raise ValueError( + f"Google Vertex AI credentials missing required fields: {', '.join(missing)}" + ) + return VertexClient( + api_key=credentials["api_key"], + project_id=credentials["project_id"], + location=credentials["location"], + ) + + def _post(self, model: str, payload: dict) -> tuple[dict | None, str | None]: + try: + resp = requests.post( + self.client.endpoint(model), + params={"key": self.client.api_key}, + headers={"Content-Type": "application/json"}, + json=payload, + timeout=REQUEST_TIMEOUT, + ) + except requests.RequestException as e: + return None, f"Vertex AI request failed: {str(e)}" + + if not resp.ok: + return None, f"Vertex AI HTTP {resp.status_code}: {resp.text[:500]}" + + try: + return resp.json(), None + except ValueError as e: + return None, f"Vertex AI returned non-JSON response: {str(e)}" + + @staticmethod + def _extract_usage(data: dict) -> Usage: + meta = data.get("usageMetadata") or {} + input_tokens = meta.get("promptTokenCount") or 0 + output_tokens = meta.get("candidatesTokenCount") or 0 + total_tokens = meta.get("totalTokenCount") or (input_tokens + output_tokens) + reasoning_tokens = meta.get("thoughtsTokenCount") or 0 + return Usage( + input_tokens=input_tokens, + output_tokens=output_tokens, + total_tokens=total_tokens, + reasoning_tokens=reasoning_tokens, + ) + + def _execute_stt( + self, + completion_config: NativeCompletionConfig, + resolved_input: str, + include_provider_raw_response: bool = False, + ) -> tuple[LLMCallResponse | None, str | None]: + provider = completion_config.provider + params = completion_config.params + + if not isinstance(resolved_input, str): + return None, f"{provider} STT requires file path as string" + + if not os.path.isfile(resolved_input): + return None, f"Audio file not found: {resolved_input}" + + ext = os.path.splitext(resolved_input)[1].lower() + mime_type = AUDIO_MIME_BY_EXT.get(ext) + if not mime_type: + return None, ( + f"Unsupported audio extension '{ext}' for Vertex STT. " + f"Supported: {', '.join(sorted(AUDIO_MIME_BY_EXT))}" + ) + + file_size = os.path.getsize(resolved_input) + if file_size > MAX_INLINE_AUDIO_BYTES: + return None, ( + f"Audio file is {file_size} bytes; Vertex inline-data limit is " + f"{MAX_INLINE_AUDIO_BYTES} bytes (~20 MB)" + ) + + with open(resolved_input, "rb") as f: + audio_b64 = base64.b64encode(f.read()).decode("utf-8") + + model = params.get("model") or DEFAULT_STT_MODEL + instructions = params.get("instructions") + input_language = params.get("input_language") or "auto" + output_language = params.get("output_language") + temperature = params.get("temperature") + max_output_tokens = params.get("max_output_tokens") or 2048 + + # Build transcription/translation instruction + if input_language == "auto": + lang_instruction = ( + "Detect the spoken language automatically and transcribe the audio" + ) + else: + lang_instruction = f"Transcribe the audio from {input_language} in the native script of {input_language}" + + if output_language and output_language != input_language: + lang_instruction += ( + f" and translate to {output_language} in the native script of " + f"{output_language} and only return transcribed script in {output_language}." + ) + + forced = "Only return transcribed text and no other text." + if instructions: + prompt = f"{instructions}. {lang_instruction}. {forced}" + else: + prompt = f"{lang_instruction}. {forced}" + + generation_config: dict[str, Any] = {"maxOutputTokens": max_output_tokens} + if temperature is not None: + generation_config["temperature"] = temperature + + payload = { + "contents": [ + { + "role": "user", + "parts": [ + {"inlineData": {"mimeType": mime_type, "data": audio_b64}}, + {"text": prompt}, + ], + } + ], + "generationConfig": generation_config, + } + + data, err = self._post(model, payload) + if err: + return None, err + + try: + transcript = data["candidates"][0]["content"]["parts"][0]["text"] + except (KeyError, IndexError, TypeError): + return None, "Vertex STT response missing transcript text" + + llm_response = LLMCallResponse( + response=LLMResponse( + provider_response_id=data.get("responseId") + or f"vertex-{uuid.uuid4().hex}", + model=data.get("modelVersion") or model, + provider=provider, + output=TextOutput(content=TextContent(value=transcript.strip())), + ), + usage=self._extract_usage(data), + ) + + if include_provider_raw_response: + llm_response.provider_raw_response = data + + logger.info( + f"[GoogleVertexAIProvider._execute_stt] Transcribed audio | provider={provider}, model={model}" + ) + return llm_response, None + + def _execute_tts( + self, + completion_config: NativeCompletionConfig, + resolved_input: str, + include_provider_raw_response: bool = False, + ) -> tuple[LLMCallResponse | None, str | None]: + provider = completion_config.provider + params = completion_config.params + + if not isinstance(resolved_input, str): + return None, f"{provider} TTS requires text string as input" + if not resolved_input.strip(): + return None, "Text input cannot be empty" + + model = params.get("model") or DEFAULT_TTS_MODEL + voice = params.get("voice") or DEFAULT_TTS_VOICE + language = params.get("language") + response_format = params.get("response_format") or "wav" + + speech_config: dict[str, Any] = { + "voiceConfig": {"prebuiltVoiceConfig": {"voiceName": voice}} + } + if language: + speech_config["languageCode"] = language + + payload: dict[str, Any] = { + "contents": [{"role": "user", "parts": [{"text": resolved_input}]}], + "generationConfig": { + "responseModalities": ["AUDIO"], + "speechConfig": speech_config, + }, + } + + provider_specific = params.get("provider_specific", {}) or {} + gemini_params = provider_specific.get("gemini", {}) or {} + director_notes = gemini_params.get("director_notes") + if director_notes: + payload["systemInstruction"] = {"parts": [{"text": director_notes}]} + + data, err = self._post(model, payload) + if err: + return None, err + + try: + inline = data["candidates"][0]["content"]["parts"][0]["inlineData"] + audio_b64 = inline["data"] + except (KeyError, IndexError, TypeError): + return None, "Vertex TTS response missing audio data" + + try: + raw_pcm = base64.b64decode(audio_b64) + except (ValueError, TypeError) as e: + return None, f"Vertex TTS returned invalid base64 audio: {str(e)}" + + if not raw_pcm: + return None, "Vertex TTS returned empty audio" + + actual_format = "wav" + wav_bytes = pcm_to_wav(raw_pcm) + encoded_content = base64.b64encode(wav_bytes).decode("ascii") + + if response_format == "mp3": + converted, convert_err = convert_pcm_to_mp3(raw_pcm) + if convert_err: + return None, f"Failed to convert audio to MP3: {convert_err}" + encoded_content = base64.b64encode(converted or b"").decode("ascii") + actual_format = "mp3" + elif response_format == "ogg": + converted, convert_err = convert_pcm_to_ogg(raw_pcm) + if convert_err: + return None, f"Failed to convert audio to OGG: {convert_err}" + encoded_content = base64.b64encode(converted or b"").decode("ascii") + actual_format = "ogg" + elif response_format and response_format != "wav": + logger.warning( + f"[GoogleVertexAIProvider._execute_tts] Unsupported response_format " + f"'{response_format}', returning native WAV | provider={provider}" + ) + + llm_response = LLMCallResponse( + response=LLMResponse( + provider_response_id=data.get("responseId") + or f"vertex-{uuid.uuid4().hex}", + model=data.get("modelVersion") or model, + provider=provider, + output=AudioOutput( + content=AudioContent( + format="base64", + value=encoded_content, + mime_type=f"audio/{actual_format}", + ) + ), + ), + usage=self._extract_usage(data), + ) + + if include_provider_raw_response: + llm_response.provider_raw_response = data + + logger.info( + f"[GoogleVertexAIProvider._execute_tts] Synthesised audio | " + f"provider={provider}, model={model}, format={actual_format}, " + f"raw_pcm_bytes={len(raw_pcm)}" + ) + return llm_response, None + + def execute( + self, + completion_config: NativeCompletionConfig, + query: QueryParams, + resolved_input: str | list[ContentPart] | MultiModalInput, + include_provider_raw_response: bool = False, + ) -> tuple[LLMCallResponse | None, str | None]: + try: + completion_type = completion_config.type + if completion_type == "stt": + return self._execute_stt( + completion_config=completion_config, + resolved_input=resolved_input, + include_provider_raw_response=include_provider_raw_response, + ) + if completion_type == "tts": + return self._execute_tts( + completion_config=completion_config, + resolved_input=resolved_input, + include_provider_raw_response=include_provider_raw_response, + ) + return ( + None, + f"google-vertex provider does not support completion type " + f"'{completion_type}'. Use the 'google' provider for text completions.", + ) + except TypeError as e: + return None, f"Invalid or unexpected parameter in Config: {str(e)}" + except Exception as e: + logger.error( + f"[GoogleVertexAIProvider.execute] Unexpected error: {str(e)} | " + f"provider={completion_config.provider}", + exc_info=True, + ) + return None, "Unexpected error occurred" diff --git a/backend/app/services/llm/providers/registry.py b/backend/app/services/llm/providers/registry.py index 60655028d..daa42bf12 100644 --- a/backend/app/services/llm/providers/registry.py +++ b/backend/app/services/llm/providers/registry.py @@ -7,6 +7,7 @@ from app.services.llm.providers.sai import SarvamAIProvider from app.services.llm.providers.eai import ElevenlabsAIProvider from app.services.llm.providers.claude import ClaudeProvider +from app.services.llm.providers.gai_vertex import GoogleVertexAIProvider logger = logging.getLogger(__name__) @@ -17,11 +18,13 @@ class LLMProvider: ELEVENLABS = "elevenlabs" GOOGLE = "google" ANTHROPIC = "anthropic" + GOOGLE_VERTEX = "google-vertex" OPENAI_NATIVE = "openai-native" GOOGLE_NATIVE = "google-native" SARVAMAI_NATIVE = "sarvamai-native" ELEVENLABS_NATIVE = "elevenlabs-native" ANTHROPIC_NATIVE = "anthropic-native" + GOOGLE_VERTEX_NATIVE = "google-vertex-native" _registry: dict[str, type[BaseProvider]] = { OPENAI: OpenAIProvider, @@ -29,11 +32,13 @@ class LLMProvider: SARVAMAI: SarvamAIProvider, ELEVENLABS: ElevenlabsAIProvider, ANTHROPIC: ClaudeProvider, + GOOGLE_VERTEX: GoogleVertexAIProvider, OPENAI_NATIVE: OpenAIProvider, GOOGLE_NATIVE: GoogleAIProvider, SARVAMAI_NATIVE: SarvamAIProvider, ELEVENLABS_NATIVE: ElevenlabsAIProvider, ANTHROPIC_NATIVE: ClaudeProvider, + GOOGLE_VERTEX_NATIVE: GoogleVertexAIProvider, } @classmethod diff --git a/backend/app/tests/services/llm/providers/test_gai_vertex.py b/backend/app/tests/services/llm/providers/test_gai_vertex.py new file mode 100644 index 000000000..928d5bcb6 --- /dev/null +++ b/backend/app/tests/services/llm/providers/test_gai_vertex.py @@ -0,0 +1,312 @@ +"""Tests for the Google Vertex AI provider.""" + +import base64 +from unittest.mock import MagicMock, patch + +import pytest +import requests + +from app.models.llm import NativeCompletionConfig, QueryParams +from app.services.llm.providers.gai_vertex import ( + MAX_INLINE_AUDIO_BYTES, + GoogleVertexAIProvider, + VertexClient, +) + + +def _stt_response(text: str = "hello world") -> dict: + return { + "candidates": [{"content": {"parts": [{"text": text}]}}], + "modelVersion": "gemini-2.5-flash", + "usageMetadata": { + "promptTokenCount": 10, + "candidatesTokenCount": 5, + "totalTokenCount": 15, + }, + } + + +def _tts_response(pcm_bytes: bytes = b"\x00\x01" * 100) -> dict: + return { + "candidates": [ + { + "content": { + "parts": [ + { + "inlineData": { + "mimeType": "audio/pcm", + "data": base64.b64encode(pcm_bytes).decode("ascii"), + } + } + ] + } + } + ], + "modelVersion": "gemini-2.5-flash-preview-tts", + "usageMetadata": { + "promptTokenCount": 4, + "candidatesTokenCount": 0, + "totalTokenCount": 4, + }, + } + + +def _mock_http_ok(json_body: dict) -> MagicMock: + resp = MagicMock() + resp.ok = True + resp.status_code = 200 + resp.json.return_value = json_body + return resp + + +def _mock_http_err(status: int = 400, body: str = "bad request") -> MagicMock: + resp = MagicMock() + resp.ok = False + resp.status_code = status + resp.text = body + return resp + + +class TestGoogleVertexAIProvider: + @pytest.fixture + def client(self) -> VertexClient: + return VertexClient(api_key="k", project_id="p", location="us-central1") + + @pytest.fixture + def provider(self, client) -> GoogleVertexAIProvider: + return GoogleVertexAIProvider(client=client) + + @pytest.fixture + def query(self) -> QueryParams: + return QueryParams(input="ignored") + + @pytest.fixture + def stt_config(self) -> NativeCompletionConfig: + return NativeCompletionConfig( + provider="google-vertex-native", + type="stt", + params={"model": "gemini-2.5-flash", "input_language": "auto"}, + ) + + @pytest.fixture + def tts_config(self) -> NativeCompletionConfig: + return NativeCompletionConfig( + provider="google-vertex-native", + type="tts", + params={"model": "gemini-2.5-flash-preview-tts", "voice": "Kore"}, + ) + + # ── create_client ──────────────────────────────────────────────────────── + def test_create_client_requires_all_fields(self): + with pytest.raises(ValueError, match="project_id, location"): + GoogleVertexAIProvider.create_client({"api_key": "k"}) + + def test_create_client_builds_endpoint(self): + c = GoogleVertexAIProvider.create_client( + {"api_key": "k", "project_id": "p", "location": "us-central1"} + ) + assert "us-central1-aiplatform.googleapis.com" in c.endpoint("m") + assert "projects/p/locations/us-central1" in c.endpoint("m") + assert "models/m:generateContent" in c.endpoint("m") + + # ── STT ────────────────────────────────────────────────────────────────── + def test_stt_happy_path(self, provider, stt_config, query, tmp_path): + audio = tmp_path / "a.wav" + audio.write_bytes(b"RIFFfake") + + with patch( + "app.services.llm.providers.gai_vertex.requests.post", + return_value=_mock_http_ok(_stt_response("hi there")), + ) as mock_post: + resp, err = provider.execute(stt_config, query, str(audio)) + + assert err is None + assert resp.response.output.content.value == "hi there" + assert resp.response.model == "gemini-2.5-flash" + assert resp.usage.input_tokens == 10 + assert resp.usage.output_tokens == 5 + + # Verify payload shape + kwargs = mock_post.call_args.kwargs + assert kwargs["params"] == {"key": "k"} + sent = kwargs["json"] + parts = sent["contents"][0]["parts"] + assert parts[0]["inlineData"]["mimeType"] == "audio/wav" + assert "Detect the spoken language automatically" in parts[1]["text"] + + def test_stt_rejects_non_string_input(self, provider, stt_config, query): + resp, err = provider.execute(stt_config, query, 123) + assert resp is None + assert "file path as string" in err + + def test_stt_rejects_missing_file(self, provider, stt_config, query): + resp, err = provider.execute(stt_config, query, "/nope/missing.wav") + assert resp is None + assert "Audio file not found" in err + + def test_stt_rejects_unsupported_extension( + self, provider, stt_config, query, tmp_path + ): + audio = tmp_path / "a.xyz" + audio.write_bytes(b"x") + resp, err = provider.execute(stt_config, query, str(audio)) + assert resp is None + assert "Unsupported audio extension" in err + + def test_stt_rejects_oversized_file(self, provider, stt_config, query, tmp_path): + audio = tmp_path / "a.wav" + audio.write_bytes(b"x") + with patch( + "app.services.llm.providers.gai_vertex.os.path.getsize", + return_value=MAX_INLINE_AUDIO_BYTES + 1, + ): + resp, err = provider.execute(stt_config, query, str(audio)) + assert resp is None + assert "inline-data limit" in err + + def test_stt_http_error_returns_clean_message( + self, provider, stt_config, query, tmp_path + ): + audio = tmp_path / "a.wav" + audio.write_bytes(b"RIFF") + with patch( + "app.services.llm.providers.gai_vertex.requests.post", + return_value=_mock_http_err(403, "permission denied"), + ): + resp, err = provider.execute(stt_config, query, str(audio)) + assert resp is None + assert "Vertex AI HTTP 403" in err + assert "permission denied" in err + + def test_stt_network_error_returns_clean_message( + self, provider, stt_config, query, tmp_path + ): + audio = tmp_path / "a.wav" + audio.write_bytes(b"RIFF") + with patch( + "app.services.llm.providers.gai_vertex.requests.post", + side_effect=requests.ConnectionError("dns boom"), + ): + resp, err = provider.execute(stt_config, query, str(audio)) + assert resp is None + assert "Vertex AI request failed" in err + + def test_stt_missing_transcript_returns_error( + self, provider, stt_config, query, tmp_path + ): + audio = tmp_path / "a.wav" + audio.write_bytes(b"RIFF") + with patch( + "app.services.llm.providers.gai_vertex.requests.post", + return_value=_mock_http_ok({"candidates": []}), + ): + resp, err = provider.execute(stt_config, query, str(audio)) + assert resp is None + assert "missing transcript text" in err + + def test_stt_input_language_overrides_prompt(self, provider, query, tmp_path): + audio = tmp_path / "a.wav" + audio.write_bytes(b"RIFF") + config = NativeCompletionConfig( + provider="google-vertex-native", + type="stt", + params={ + "model": "gemini-2.5-flash", + "input_language": "hi-IN", + "output_language": "en-IN", + "instructions": "be precise", + }, + ) + with patch( + "app.services.llm.providers.gai_vertex.requests.post", + return_value=_mock_http_ok(_stt_response()), + ) as mock_post: + provider.execute(config, query, str(audio)) + + prompt = mock_post.call_args.kwargs["json"]["contents"][0]["parts"][1]["text"] + assert prompt.startswith("be precise") + assert "hi-IN" in prompt + assert "translate to en-IN" in prompt + + # ── TTS ────────────────────────────────────────────────────────────────── + def test_tts_happy_path_wav(self, provider, tts_config, query): + with patch( + "app.services.llm.providers.gai_vertex.requests.post", + return_value=_mock_http_ok(_tts_response()), + ) as mock_post: + resp, err = provider.execute(tts_config, query, "hello") + + assert err is None + assert resp.response.output.content.format == "base64" + assert resp.response.output.content.mime_type == "audio/wav" + # base64 decodes to valid WAV header + decoded = base64.b64decode(resp.response.output.content.value) + assert decoded[:4] == b"RIFF" + + sent = mock_post.call_args.kwargs["json"] + assert sent["generationConfig"]["responseModalities"] == ["AUDIO"] + assert ( + sent["generationConfig"]["speechConfig"]["voiceConfig"][ + "prebuiltVoiceConfig" + ]["voiceName"] + == "Kore" + ) + + def test_tts_rejects_non_string_input(self, provider, tts_config, query): + resp, err = provider.execute(tts_config, query, ["not a string"]) + assert resp is None + assert "text string as input" in err + + def test_tts_rejects_empty_input(self, provider, tts_config, query): + resp, err = provider.execute(tts_config, query, " ") + assert resp is None + assert "Text input cannot be empty" in err + + def test_tts_missing_audio_returns_error(self, provider, tts_config, query): + with patch( + "app.services.llm.providers.gai_vertex.requests.post", + return_value=_mock_http_ok({"candidates": [{"content": {"parts": []}}]}), + ): + resp, err = provider.execute(tts_config, query, "hello") + assert resp is None + assert "missing audio data" in err + + def test_tts_language_is_forwarded(self, provider, query): + config = NativeCompletionConfig( + provider="google-vertex-native", + type="tts", + params={"model": "gemini-2.5-flash-preview-tts", "language": "en-US"}, + ) + with patch( + "app.services.llm.providers.gai_vertex.requests.post", + return_value=_mock_http_ok(_tts_response()), + ) as mock_post: + provider.execute(config, query, "hi") + speech = mock_post.call_args.kwargs["json"]["generationConfig"]["speechConfig"] + assert speech["languageCode"] == "en-US" + + # ── execute dispatcher ─────────────────────────────────────────────────── + def test_text_completion_is_rejected(self, provider, query): + config = NativeCompletionConfig( + provider="google-vertex-native", + type="text", + params={"model": "gemini-2.5-flash"}, + ) + resp, err = provider.execute(config, query, "hello") + assert resp is None + assert "does not support completion type 'text'" in err + + def test_raw_response_included_when_requested( + self, provider, stt_config, query, tmp_path + ): + audio = tmp_path / "a.wav" + audio.write_bytes(b"RIFF") + raw = _stt_response() + with patch( + "app.services.llm.providers.gai_vertex.requests.post", + return_value=_mock_http_ok(raw), + ): + resp, _ = provider.execute( + stt_config, query, str(audio), include_provider_raw_response=True + ) + assert resp.provider_raw_response == raw From 8f5e6b571c20526b424aaebe67a71b6241b1396f Mon Sep 17 00:00:00 2001 From: Prajna1999 Date: Mon, 1 Jun 2026 00:15:01 +0530 Subject: [PATCH 03/15] feat: add boto3 deps for AWS key service and audio path refactor --- backend/app/core/audio_utils.py | 50 ++++- backend/app/core/cloud/__init__.py | 2 + backend/app/core/cloud/storage.py | 199 ++++++++++++++++-- backend/app/core/config.py | 7 + backend/app/services/llm/jobs.py | 17 +- backend/app/services/llm/providers/base.py | 6 +- backend/app/services/llm/providers/eai.py | 23 +- backend/app/services/llm/providers/gai.py | 23 +- .../app/services/llm/providers/gai_vertex.py | 101 ++++++--- backend/app/services/llm/providers/sai.py | 43 ++-- .../tests/services/llm/providers/test_eai.py | 15 +- .../tests/services/llm/providers/test_gai.py | 48 +++-- .../services/llm/providers/test_gai_vertex.py | 110 +++++----- .../tests/services/llm/providers/test_sai.py | 15 +- .../tests/services/llm/test_input_resolver.py | 131 ++++-------- backend/app/utils.py | 62 ++---- backend/pyproject.toml | 1 + backend/uv.lock | 122 +++++++++++ 18 files changed, 643 insertions(+), 332 deletions(-) diff --git a/backend/app/core/audio_utils.py b/backend/app/core/audio_utils.py index 0ffc4b010..6a82ba1ea 100644 --- a/backend/app/core/audio_utils.py +++ b/backend/app/core/audio_utils.py @@ -1,18 +1,56 @@ -""" -Audio processing utilities for format conversion. +"""Audio processing utilities: format conversion + STT input carrier.""" -This module provides utilities for converting audio between different formats, -particularly for TTS output post-processing. -""" import io import logging +import os +import tempfile import wave +from contextlib import contextmanager +from dataclasses import dataclass +from typing import Iterator from pydub import AudioSegment - logger = logging.getLogger(__name__) +_MIME_TO_EXT = { + "audio/wav": ".wav", + "audio/mpeg": ".mp3", + "audio/mp3": ".mp3", + "audio/ogg": ".ogg", + "audio/flac": ".flac", + "audio/webm": ".webm", + "audio/mp4": ".mp4", + "audio/m4a": ".m4a", + "audio/aac": ".aac", + "audio/aiff": ".aiff", +} + + +@dataclass(frozen=True) +class AudioRef: + """In-memory STT input. Providers consume ``bytes_`` directly or call + ``to_path()`` when an SDK needs a filesystem path. Temp files are owned + by the provider's ``with`` scope — no framework-level cleanup needed. + """ + + bytes_: bytes + mime_type: str = "audio/wav" + + @contextmanager + def to_path(self) -> Iterator[str]: + ext = _MIME_TO_EXT.get(self.mime_type, ".audio") + tmp = tempfile.NamedTemporaryFile(suffix=ext, delete=False, prefix="audio_") + try: + tmp.write(self.bytes_) + tmp.close() + yield tmp.name + finally: + try: + os.unlink(tmp.name) + except OSError: + pass + def convert_pcm_to_mp3( pcm_bytes: bytes, sample_rate: int = 24000 diff --git a/backend/app/core/cloud/__init__.py b/backend/app/core/cloud/__init__.py index c29a35ad4..bffc5a964 100644 --- a/backend/app/core/cloud/__init__.py +++ b/backend/app/core/cloud/__init__.py @@ -4,4 +4,6 @@ CloudStorage, CloudStorageError, get_cloud_storage, + get_gcp_service_account, + upload_audio_to_gcs, ) diff --git a/backend/app/core/cloud/storage.py b/backend/app/core/cloud/storage.py index a57273a06..608e9f612 100644 --- a/backend/app/core/cloud/storage.py +++ b/backend/app/core/cloud/storage.py @@ -1,6 +1,8 @@ import os +import json +import mimetypes from sqlmodel import Session -from uuid import UUID +from uuid import UUID, uuid4 import logging import functools as ft from pathlib import Path @@ -12,10 +14,19 @@ from fastapi import UploadFile from botocore.exceptions import ClientError from botocore.response import StreamingBody +from google.cloud import storage as gcs +from google.oauth2 import service_account -from app.crud import get_project_by_id from app.core.config import settings -from app.utils import mask_string + + +def _mask(value: str | None) -> str: + # Lazy to break a top-level cycle: app.utils transitively imports + # app.services.llm.providers, which imports this module. + from app.utils import mask_string + + return mask_string(value) + logger = logging.getLogger(__name__) @@ -46,7 +57,7 @@ def create(self): except ValueError as err: logger.error( f"[AmazonCloudStorageClient.create] Invalid bucket configuration | " - f"{{'bucket': '{mask_string(settings.AWS_S3_BUCKET)}', 'error': '{str(err)}'}}", + f"{{'bucket': '{_mask(settings.AWS_S3_BUCKET)}', 'error': '{str(err)}'}}", exc_info=True, ) raise CloudStorageError(err) from err @@ -55,13 +66,13 @@ def create(self): if response != 404: logger.error( f"[AmazonCloudStorageClient.create] Unexpected AWS error | " - f"{{'bucket': '{mask_string(settings.AWS_S3_BUCKET)}', 'error': '{str(err)}', 'code': {response}}}", + f"{{'bucket': '{_mask(settings.AWS_S3_BUCKET)}', 'error': '{str(err)}', 'code': {response}}}", exc_info=True, ) raise CloudStorageError(err) from err logger.warning( f"[AmazonCloudStorageClient.create] Bucket not found, creating | " - f"{{'bucket': '{mask_string(settings.AWS_S3_BUCKET)}'}}" + f"{{'bucket': '{_mask(settings.AWS_S3_BUCKET)}'}}" ) try: self.client.create_bucket( @@ -72,12 +83,12 @@ def create(self): ) logger.info( f"[AmazonCloudStorageClient.create] Bucket created successfully | " - f"{{'bucket': '{mask_string(settings.AWS_S3_BUCKET)}'}}" + f"{{'bucket': '{_mask(settings.AWS_S3_BUCKET)}'}}" ) except ClientError as create_err: logger.error( f"[AmazonCloudStorageClient.create] Failed to create bucket | " - f"{{'bucket': '{mask_string(settings.AWS_S3_BUCKET)}', 'error': '{str(create_err)}'}}", + f"{{'bucket': '{_mask(settings.AWS_S3_BUCKET)}', 'error': '{str(create_err)}'}}", exc_info=True, ) raise CloudStorageError(create_err) from create_err @@ -168,12 +179,12 @@ def put(self, source: UploadFile, file_path: Path) -> SimpleStorageName: ) logger.info( f"[AmazonCloudStorage.put] File uploaded successfully | " - f"{{'project_id': '{self.project_id}', 'bucket': '{mask_string(destination.Bucket)}', 'key': '{mask_string(destination.Key)}'}}" + f"{{'project_id': '{self.project_id}', 'bucket': '{_mask(destination.Bucket)}', 'key': '{_mask(destination.Key)}'}}" ) except ClientError as err: logger.error( f"[AmazonCloudStorage.put] AWS upload error | " - f"{{'project_id': '{self.project_id}', 'bucket': '{mask_string(destination.Bucket)}', 'key': '{mask_string(destination.Key)}', 'error': '{str(err)}'}}", + f"{{'project_id': '{self.project_id}', 'bucket': '{_mask(destination.Bucket)}', 'key': '{_mask(destination.Key)}', 'error': '{str(err)}'}}", exc_info=True, ) raise CloudStorageError(f'AWS Error: "{err}"') from err @@ -187,13 +198,13 @@ def stream(self, url: str) -> StreamingBody: body = self.aws.client.get_object(**kwargs).get("Body") logger.info( f"[AmazonCloudStorage.stream] File streamed successfully | " - f"{{'project_id': '{self.project_id}', 'bucket': '{mask_string(name.Bucket)}', 'key': '{mask_string(name.Key)}'}}" + f"{{'project_id': '{self.project_id}', 'bucket': '{_mask(name.Bucket)}', 'key': '{_mask(name.Key)}'}}" ) return body except ClientError as err: logger.error( f"[AmazonCloudStorage.stream] AWS stream error | " - f"{{'project_id': '{self.project_id}', 'bucket': '{mask_string(name.Bucket)}', 'key': '{mask_string(name.Key)}', 'error': '{str(err)}'}}", + f"{{'project_id': '{self.project_id}', 'bucket': '{_mask(name.Bucket)}', 'key': '{_mask(name.Key)}', 'error': '{str(err)}'}}", exc_info=True, ) raise CloudStorageError(f'AWS Error: "{err}" ({url})') from err @@ -206,13 +217,13 @@ def get(self, url: str) -> bytes: content = body.read() logger.info( f"[AmazonCloudStorage.get] File retrieved successfully | " - f"{{'project_id': '{self.project_id}', 'bucket': '{mask_string(name.Bucket)}', 'key': '{mask_string(name.Key)}', 'size_bytes': {len(content)}}}" + f"{{'project_id': '{self.project_id}', 'bucket': '{_mask(name.Bucket)}', 'key': '{_mask(name.Key)}', 'size_bytes': {len(content)}}}" ) return content except ClientError as err: logger.error( f"[AmazonCloudStorage.get] AWS get error | " - f"{{'project_id': '{self.project_id}', 'bucket': '{mask_string(name.Bucket)}', 'key': '{mask_string(name.Key)}', 'error': '{str(err)}'}}", + f"{{'project_id': '{self.project_id}', 'bucket': '{_mask(name.Bucket)}', 'key': '{_mask(name.Key)}', 'error': '{str(err)}'}}", exc_info=True, ) raise CloudStorageError(f'AWS Error: "{err}" ({url})') from err @@ -226,13 +237,13 @@ def get_file_size_kb(self, url: str) -> float: size_kb = round(size_bytes / 1024, 2) logger.info( f"[AmazonCloudStorage.get_file_size_kb] File size retrieved successfully | " - f"{{'project_id': '{self.project_id}', 'bucket': '{mask_string(name.Bucket)}', 'key': '{mask_string(name.Key)}', 'size_kb': {size_kb}}}" + f"{{'project_id': '{self.project_id}', 'bucket': '{_mask(name.Bucket)}', 'key': '{_mask(name.Key)}', 'size_kb': {size_kb}}}" ) return size_kb except ClientError as err: logger.error( f"[AmazonCloudStorage.get_file_size_kb] AWS head object error | " - f"{{'project_id': '{self.project_id}', 'bucket': '{mask_string(name.Bucket)}', 'key': '{mask_string(name.Key)}', 'error': '{str(err)}'}}", + f"{{'project_id': '{self.project_id}', 'bucket': '{_mask(name.Bucket)}', 'key': '{_mask(name.Key)}', 'error': '{str(err)}'}}", exc_info=True, ) raise CloudStorageError(f'AWS Error: "{err}" ({url})') from err @@ -259,13 +270,13 @@ def get_signed_url(self, url: str, expires_in: int = 3600) -> str: ) logger.info( f"[AmazonCloudStorage.get_signed_url] Signed URL generated | " - f"{{'project_id': '{self.project_id}', 'bucket': '{mask_string(name.Bucket)}', 'key': '{mask_string(name.Key)}'}}" + f"{{'project_id': '{self.project_id}', 'bucket': '{_mask(name.Bucket)}', 'key': '{_mask(name.Key)}'}}" ) return signed_url except ClientError as err: logger.error( f"[AmazonCloudStorage.get_signed_url] AWS presign error | " - f"{{'project_id': '{self.project_id}', 'bucket': '{mask_string(name.Bucket)}', 'key': '{mask_string(name.Key)}', 'error': '{str(err)}'}}", + f"{{'project_id': '{self.project_id}', 'bucket': '{_mask(name.Bucket)}', 'key': '{_mask(name.Key)}', 'error': '{str(err)}'}}", exc_info=True, ) raise CloudStorageError(f'AWS Error: "{err}" ({url})') from err @@ -277,12 +288,12 @@ def delete(self, url: str) -> None: self.aws.client.delete_object(**kwargs) logger.info( f"[AmazonCloudStorage.delete] File deleted successfully | " - f"{{'project_id': '{self.project_id}', 'bucket': '{mask_string(name.Bucket)}', 'key': '{mask_string(name.Key)}'}}" + f"{{'project_id': '{self.project_id}', 'bucket': '{_mask(name.Bucket)}', 'key': '{_mask(name.Key)}'}}" ) except ClientError as err: logger.error( f"[AmazonCloudStorage.delete] AWS delete error | " - f"{{'project_id': '{self.project_id}', 'bucket': '{mask_string(name.Bucket)}', 'key': '{mask_string(name.Key)}', 'error': '{str(err)}'}}", + f"{{'project_id': '{self.project_id}', 'bucket': '{_mask(name.Bucket)}', 'key': '{_mask(name.Key)}', 'error': '{str(err)}'}}", exc_info=True, ) raise CloudStorageError(f'AWS Error: "{err}" ({url})') from err @@ -292,6 +303,11 @@ def get_cloud_storage(session: Session, project_id: int) -> CloudStorage: """ Method to create and configure a cloud storage instance. """ + # Lazy import to avoid a top-level cycle: storage.py is imported from + # app.services.llm.providers.gai_vertex, which itself is wired into the + # provider registry that app.crud transitively pulls in. + from app.crud import get_project_by_id + project = get_project_by_id(session=session, project_id=project_id) if not project: raise ValueError(f"Invalid project_id: {project_id}") @@ -306,3 +322,146 @@ def get_cloud_storage(session: Session, project_id: int) -> CloudStorage: exc_info=True, ) raise + + +# ────────────────────────────────────────────────────────────────────────────── +# GCP service-account fetch (AWS Secrets Manager) + GCS upload util. +# BYOK-ready: every util takes explicit secret_name / bucket / project_id so +# per-project credentials can be passed in. Settings provide the platform +# defaults for the shared SA path. +# ────────────────────────────────────────────────────────────────────────────── + +GCS_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) + + +class SecretsManagerError(Exception): + pass + + +@ft.lru_cache(maxsize=32) +def get_gcp_service_account( + secret_name: str | None = None, + region_name: str | None = None, +) -> dict: + """Fetch a GCP service-account JSON key from AWS Secrets Manager. + + Cached per (secret_name, region) — restart the process or call + ``get_gcp_service_account.cache_clear()`` to pick up a rotated key. + + BYOK: pass a project-owned ``secret_name``. Defaults to the platform-shared + secret configured in settings. + """ + secret = secret_name or settings.GCP_SA_SECRET_NAME + region = region_name or settings.GCP_SA_SECRET_REGION + + sm_client = boto3.session.Session().client( + service_name="secretsmanager", region_name=region + ) + + try: + response = sm_client.get_secret_value(SecretId=secret) + except ClientError as e: + code = e.response.get("Error", {}).get("Code", "Unknown") + logger.error( + f"[get_gcp_service_account] Secret fetch failed | " + f"secret={_mask(secret)}, region={region}, code={code}" + ) + raise SecretsManagerError( + f"Failed to fetch secret '{secret}' (code={code}): {e}" + ) from e + + if "SecretString" not in response: + raise SecretsManagerError( + f"Secret '{secret}' has no SecretString (binary secret unsupported)" + ) + + try: + sa_info = json.loads(response["SecretString"]) + except json.JSONDecodeError as e: + raise SecretsManagerError(f"Secret '{secret}' is not valid JSON: {e}") from e + + if sa_info.get("type") != "service_account": + raise SecretsManagerError( + f"Secret '{secret}' is not a GCP service-account key " + f"(got type={sa_info.get('type')!r})" + ) + + logger.info( + f"[get_gcp_service_account] Loaded SA key | " + f"secret={_mask(secret)}, project_id={sa_info.get('project_id')}, " + f"client_email={_mask(sa_info.get('client_email', ''))}" + ) + return sa_info + + +_MIME_TO_EXT = { + "audio/wav": ".wav", + "audio/mpeg": ".mp3", + "audio/mp3": ".mp3", + "audio/ogg": ".ogg", + "audio/flac": ".flac", + "audio/webm": ".webm", + "audio/aac": ".aac", + "audio/aiff": ".aiff", +} + + +def upload_audio_to_gcs( + *, + bucket_name: str, + sa_info: dict, + audio_bytes: bytes | None = None, + local_path: str | None = None, + content_type: str | None = None, + project_id: str | None = None, + key_prefix: str = "audio", +) -> str: + """Upload audio to GCS and return its ``gs://bucket/key`` URI. + + Pass exactly one of ``audio_bytes`` or ``local_path``. + + BYOK: caller supplies ``sa_info`` and ``bucket_name``. The returned URI + plugs directly into Vertex ``fileData.fileUri``. + """ + if (audio_bytes is None) == (local_path is None): + raise ValueError("Pass exactly one of audio_bytes or local_path") + + if local_path is not None: + if not os.path.isfile(local_path): + raise FileNotFoundError(f"Audio file not found: {local_path}") + size = os.path.getsize(local_path) + ext = Path(local_path).suffix or "" + mime = content_type or mimetypes.guess_type(local_path)[0] or "audio/wav" + else: + size = len(audio_bytes) + mime = content_type or "audio/wav" + ext = _MIME_TO_EXT.get(mime, "") + + key = f"{key_prefix}/{uuid4().hex}{ext}" + + try: + creds = service_account.Credentials.from_service_account_info( + sa_info, scopes=list(GCS_SCOPES) + ) + client = gcs.Client( + project=project_id or sa_info.get("project_id"), credentials=creds + ) + blob = client.bucket(bucket_name).blob(key) + if local_path is not None: + blob.upload_from_filename(local_path, content_type=mime) + else: + blob.upload_from_string(audio_bytes, content_type=mime) + except Exception as e: + logger.error( + f"[upload_audio_to_gcs] Upload failed | " + f"bucket={bucket_name}, key={key}, error={e}", + exc_info=True, + ) + raise CloudStorageError(f"GCS upload failed: {e}") from e + + uri = f"gs://{bucket_name}/{key}" + logger.info( + f"[upload_audio_to_gcs] Uploaded | " + f"uri={uri}, mime={mime}, size_kb={size / 1024:.1f}" + ) + return uri diff --git a/backend/app/core/config.py b/backend/app/core/config.py index 720846eb9..33b884ed3 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -104,6 +104,13 @@ def SQLALCHEMY_DATABASE_URI(self) -> PostgresDsn: AWS_DEFAULT_REGION: str = "" AWS_S3_BUCKET_PREFIX: str = "" + # GCP Vertex AI — single shared service-account fetched from AWS Secrets Manager. + # BYOK (per-project credentials) lands later. + GCP_SA_SECRET_NAME: str = "" + GCP_SA_SECRET_REGION: str = "" + GCP_PROJECT_ID: str = "" + GCS_AUDIO_BUCKET: str = "" + # RabbitMQ configuration for Celery broker RABBITMQ_HOST: str = "localhost" RABBITMQ_PORT: int = 5672 diff --git a/backend/app/services/llm/jobs.py b/backend/app/services/llm/jobs.py index 550d7ff41..bc0793de8 100644 --- a/backend/app/services/llm/jobs.py +++ b/backend/app/services/llm/jobs.py @@ -68,7 +68,6 @@ from app.services.llm.providers.registry import get_llm_provider from app.utils import ( APIResponse, - cleanup_temp_file, download_audio_bytes, get_webhook_secret, resolve_input, @@ -305,22 +304,14 @@ def handle_job_error( def resolved_input_context( query_input: TextInput | AudioInput | ImageInput | PDFInput | list, ): - """Context manager for resolving and cleaning up input resources. - - Ensures temporary files (e.g., downloaded audio) are cleaned up - even if errors occur during LLM execution. + """Resolve query input. Audio inputs return AudioRef (in-memory); + providers materialize a temp file via ``audio_ref.to_path()`` only if + their SDK needs one, and clean it up themselves. """ resolved_input, error = resolve_input(query_input) - if error: raise ValueError(error) - - try: - yield resolved_input - finally: - # Clean up temp files for audio inputs - if resolved_input and isinstance(query_input, AudioInput): - cleanup_temp_file(resolved_input) + yield resolved_input def resolve_config_blob( diff --git a/backend/app/services/llm/providers/base.py b/backend/app/services/llm/providers/base.py index f159f0f1c..e316d1fb4 100644 --- a/backend/app/services/llm/providers/base.py +++ b/backend/app/services/llm/providers/base.py @@ -10,6 +10,7 @@ from pydantic import model_validator from sqlmodel import SQLModel +from app.core.audio_utils import AudioRef from app.models.llm import NativeCompletionConfig, LLMCallResponse, QueryParams from app.models.llm.request import TextContent, ImageContent, PDFContent @@ -62,7 +63,7 @@ def execute( self, completion_config: NativeCompletionConfig, query: QueryParams, - resolved_input: str | list[ContentPart], + resolved_input: str | AudioRef | list[ContentPart] | MultiModalInput, include_provider_raw_response: bool = False, ) -> tuple[LLMCallResponse | None, str | None]: """Execute LLM API call. @@ -72,7 +73,8 @@ def execute( Args: completion_config: LLM completion configuration, pass params as-is to provider API query: Query parameters including input and conversation_id - resolved_input: The resolved input content (text string or file path for audio) + resolved_input: Resolved input — string for text/TTS, ``AudioRef`` for STT, + content list for image/pdf, ``MultiModalInput`` for multi-part requests. include_provider_raw_response: Whether to include the raw LLM provider response in the output Returns: diff --git a/backend/app/services/llm/providers/eai.py b/backend/app/services/llm/providers/eai.py index 81ce6d3ca..5b9860184 100644 --- a/backend/app/services/llm/providers/eai.py +++ b/backend/app/services/llm/providers/eai.py @@ -6,6 +6,8 @@ from elevenlabs import ElevenLabs, SpeechToTextConvertResponse +from app.core.audio_utils import AudioRef + from app.models.llm import ( NativeCompletionConfig, @@ -60,14 +62,15 @@ def _parse_input( def _execute_stt( self, completion_config: NativeCompletionConfig, - resolved_input: str, + resolved_input: "AudioRef", include_provider_raw_response: bool = False, ) -> tuple[LLMCallResponse | None, str | None]: """Execute speech-to-text completion using Elevenlabs. Args: completion_config: Configuration for the completion request (with already-mapped params) - resolved_input: File path to the audio input + resolved_input: ``AudioRef``; materialized to a temp file because the + ElevenLabs SDK only accepts a file-like. include_provider_raw_response: Whether to include raw provider response Returns: @@ -76,6 +79,9 @@ def _execute_stt( provider_name = completion_config.provider params = completion_config.params + if not isinstance(resolved_input, AudioRef): + return None, f"{provider_name} STT requires AudioRef input" + # Extract already-mapped parameters from the mapper model_id = params.get("model_id") or "scribe_v2" if not model_id: @@ -84,23 +90,16 @@ def _execute_stt( language_code = params.get("language_code") temperature = params.get("temperature") - # Parse and validate input - parsed_input_path = self._parse_input( - query_input=resolved_input, - completion_type="stt", - provider=provider_name, - ) - try: - # Build optional kwargs stt_kwargs: dict[str, Any] = {} if language_code: stt_kwargs["language_code"] = language_code if temperature is not None: stt_kwargs["temperature"] = temperature - with open(parsed_input_path, "rb") as audio_file: - # Call ElevenLabs transcribe with all mapped parameters + with resolved_input.to_path() as parsed_input_path, open( + parsed_input_path, "rb" + ) as audio_file: elevenlabs_response: SpeechToTextConvertResponse = ( self.client.speech_to_text.convert( file=audio_file, model_id=model_id, **stt_kwargs diff --git a/backend/app/services/llm/providers/gai.py b/backend/app/services/llm/providers/gai.py index f31f41d59..9da71a932 100644 --- a/backend/app/services/llm/providers/gai.py +++ b/backend/app/services/llm/providers/gai.py @@ -31,7 +31,12 @@ from app.models.llm.response import AudioOutput, AudioContent from app.services.llm.providers.base import BaseProvider, ContentPart, MultiModalInput from app.services.llm.mappers import BCP47_LOCALE_TO_GEMINI_LANG -from app.core.audio_utils import convert_pcm_to_mp3, convert_pcm_to_ogg, pcm_to_wav +from app.core.audio_utils import ( + AudioRef, + convert_pcm_to_mp3, + convert_pcm_to_ogg, + pcm_to_wav, +) logger = logging.getLogger(__name__) @@ -106,14 +111,15 @@ def format_parts( def _execute_stt( self, completion_config: NativeCompletionConfig, - resolved_input: str, + resolved_input: "AudioRef", include_provider_raw_response: bool = False, ) -> tuple[LLMCallResponse | None, str | None]: """Execute speech-to-text completion using Google AI. Args: completion_config: Configuration for the completion request - resolved_input: File path to the audio input + resolved_input: ``AudioRef``; materialized to a temp file because the + google-genai SDK's ``files.upload`` expects a filesystem path. include_provider_raw_response: Whether to include raw provider response Returns: @@ -121,9 +127,9 @@ def _execute_stt( """ provider = completion_config.provider generation_params = completion_config.params - # Validate input is a file path string - if not isinstance(resolved_input, str): - return None, f"{provider} STT requires file path as string" + + if not isinstance(resolved_input, AudioRef): + return None, f"{provider} STT requires AudioRef input" model = generation_params.get("model") or DEFAULT_STT_MODEL instructions = generation_params.get("instructions", "") @@ -155,8 +161,9 @@ def _execute_stt( f"The merged instructions is {merged_instruction} and output language is {output_language} and input language is {input_language}" ) - # Upload file and generate content - gemini_file = self.client.files.upload(file=resolved_input) + # Materialize the AudioRef to a temp file so the genai SDK can upload it. + with resolved_input.to_path() as audio_path: + gemini_file = self.client.files.upload(file=audio_path) contents = [] if merged_instruction: diff --git a/backend/app/services/llm/providers/gai_vertex.py b/backend/app/services/llm/providers/gai_vertex.py index 97edef690..78fc640af 100644 --- a/backend/app/services/llm/providers/gai_vertex.py +++ b/backend/app/services/llm/providers/gai_vertex.py @@ -6,7 +6,14 @@ import requests -from app.core.audio_utils import convert_pcm_to_mp3, convert_pcm_to_ogg, pcm_to_wav +from app.core.audio_utils import ( + AudioRef, + convert_pcm_to_mp3, + convert_pcm_to_ogg, + pcm_to_wav, +) +from app.core.cloud.storage import get_gcp_service_account, upload_audio_to_gcs +from app.core.config import settings from app.models.llm import ( LLMCallResponse, LLMResponse, @@ -27,24 +34,39 @@ logger = logging.getLogger(__name__) REQUEST_TIMEOUT = 60 -MAX_INLINE_AUDIO_BYTES = 20 * 1024 * 1024 # Vertex inline-data cap (~20 MB) -AUDIO_MIME_BY_EXT = { - ".wav": "audio/wav", - ".mp3": "audio/mp3", - ".aiff": "audio/aiff", - ".aac": "audio/aac", - ".ogg": "audio/ogg", - ".flac": "audio/flac", +SUPPORTED_AUDIO_MIMES = { + "audio/wav", + "audio/mp3", + "audio/mpeg", + "audio/aiff", + "audio/aac", + "audio/ogg", + "audio/flac", } class VertexClient: - """Holds Vertex AI connection details. Pure config — no SDK session.""" + """Holds Vertex AI connection details. Pure config — no SDK session. + + BYOK: per-project SA secret + GCS bucket can be passed via credentials; + falls back to platform-shared values in settings. + """ - def __init__(self, api_key: str, project_id: str, location: str): + def __init__( + self, + api_key: str, + project_id: str, + location: str, + gcp_sa_secret_name: str | None = None, + gcp_sa_secret_region: str | None = None, + gcs_bucket: str | None = None, + ): self.api_key = api_key self.project_id = project_id self.location = location + self.gcp_sa_secret_name = gcp_sa_secret_name + self.gcp_sa_secret_region = gcp_sa_secret_region + self.gcs_bucket = gcs_bucket or settings.GCS_AUDIO_BUCKET def endpoint(self, model: str) -> str: return ( @@ -76,9 +98,12 @@ def create_client(credentials: dict[str, Any]) -> Any: f"Google Vertex AI credentials missing required fields: {', '.join(missing)}" ) return VertexClient( - api_key=credentials["api_key"], - project_id=credentials["project_id"], - location=credentials["location"], + api_key=credentials.get("api_key"), + project_id=credentials.get("project_id"), + location=credentials.get("location"), + gcp_sa_secret_name=credentials.get("gcp_sa_secret_name"), + gcp_sa_secret_region=credentials.get("gcp_sa_secret_region"), + gcs_bucket=credentials.get("gcs_bucket"), ) def _post(self, model: str, payload: dict) -> tuple[dict | None, str | None]: @@ -118,35 +143,43 @@ def _extract_usage(data: dict) -> Usage: def _execute_stt( self, completion_config: NativeCompletionConfig, - resolved_input: str, + resolved_input: "AudioRef", include_provider_raw_response: bool = False, ) -> tuple[LLMCallResponse | None, str | None]: provider = completion_config.provider params = completion_config.params - if not isinstance(resolved_input, str): - return None, f"{provider} STT requires file path as string" - - if not os.path.isfile(resolved_input): - return None, f"Audio file not found: {resolved_input}" + if not isinstance(resolved_input, AudioRef): + return None, f"{provider} STT requires AudioRef input" - ext = os.path.splitext(resolved_input)[1].lower() - mime_type = AUDIO_MIME_BY_EXT.get(ext) - if not mime_type: + mime_type = resolved_input.mime_type or "audio/wav" + if mime_type not in SUPPORTED_AUDIO_MIMES: return None, ( - f"Unsupported audio extension '{ext}' for Vertex STT. " - f"Supported: {', '.join(sorted(AUDIO_MIME_BY_EXT))}" + f"Unsupported audio mime '{mime_type}' for Vertex STT. " + f"Supported: {', '.join(sorted(SUPPORTED_AUDIO_MIMES))}" ) - file_size = os.path.getsize(resolved_input) - if file_size > MAX_INLINE_AUDIO_BYTES: - return None, ( - f"Audio file is {file_size} bytes; Vertex inline-data limit is " - f"{MAX_INLINE_AUDIO_BYTES} bytes (~20 MB)" + # Push bytes straight to GCS — no disk I/O. fileData.fileUri bypasses + # the 20 MB inline cap. + try: + sa_info = get_gcp_service_account( + secret_name=self.client.gcp_sa_secret_name, + region_name=self.client.gcp_sa_secret_region, ) - - with open(resolved_input, "rb") as f: - audio_b64 = base64.b64encode(f.read()).decode("utf-8") + gs_uri = upload_audio_to_gcs( + audio_bytes=resolved_input.bytes_, + bucket_name=self.client.gcs_bucket, + sa_info=sa_info, + project_id=self.client.project_id, + content_type=mime_type, + ) + except Exception as e: + logger.error( + f"[GoogleVertexAIProvider._execute_stt] GCS upload failed | " + f"provider={provider}, error={e}", + exc_info=True, + ) + return None, f"Failed to stage audio for Vertex STT: {str(e)}" model = params.get("model") or DEFAULT_STT_MODEL instructions = params.get("instructions") @@ -184,7 +217,7 @@ def _execute_stt( { "role": "user", "parts": [ - {"inlineData": {"mimeType": mime_type, "data": audio_b64}}, + {"fileData": {"mimeType": mime_type, "fileUri": gs_uri}}, {"text": prompt}, ], } diff --git a/backend/app/services/llm/providers/sai.py b/backend/app/services/llm/providers/sai.py index f4a6cc5e7..ba760050a 100644 --- a/backend/app/services/llm/providers/sai.py +++ b/backend/app/services/llm/providers/sai.py @@ -3,6 +3,7 @@ import uuid from typing import Any from sarvamai import SarvamAI +from app.core.audio_utils import AudioRef from app.models.llm import ( NativeCompletionConfig, LLMCallResponse, @@ -56,14 +57,15 @@ def _parse_input( def _execute_stt( self, completion_config: NativeCompletionConfig, - resolved_input: str, + resolved_input: "AudioRef", include_provider_raw_response: bool = False, ) -> tuple[LLMCallResponse | None, str | None]: """Execute speech-to-text completion using SarvamAI. Args: completion_config: Configuration for the completion request (with already-mapped params) - resolved_input: File path to the audio input + resolved_input: ``AudioRef`` carrying the audio bytes; materialized to a temp file + because the SarvamAI SDK only accepts a file-like. include_provider_raw_response: Whether to include raw provider response Returns: @@ -72,38 +74,29 @@ def _execute_stt( provider_name = completion_config.provider params = completion_config.params + if not isinstance(resolved_input, AudioRef): + return None, f"{provider_name} STT requires AudioRef input" + # Extract already-mapped parameters from the mapper model = params.get("model") or "saaras:v3" language_code = params.get("language_code") mode = params.get("mode") or "transcribe" - # Parse and validate input - parsed_input_path = self._parse_input( - query_input=resolved_input, - completion_type="stt", - provider=provider_name, - ) - try: - # Build kwargs for API call, only including non-None parameters - stt_kwargs = { - "file": None, # Will be set below - "model": model, - } - - if language_code: - stt_kwargs["language_code"] = language_code + with resolved_input.to_path() as parsed_input_path: + stt_kwargs = {"file": None, "model": model} - # mode only applies to saaras:v3 model - if mode: - stt_kwargs["mode"] = mode + if language_code: + stt_kwargs["language_code"] = language_code + if mode: + stt_kwargs["mode"] = mode - with open(parsed_input_path, "rb") as audio_file: - # Call SarvamAI transcribe with mapped parameters - stt_kwargs["file"] = audio_file - sarvam_response = self.client.speech_to_text.transcribe(**stt_kwargs) + with open(parsed_input_path, "rb") as audio_file: + stt_kwargs["file"] = audio_file + sarvam_response = self.client.speech_to_text.transcribe( + **stt_kwargs + ) - # Estimate token usage (not directly provided by SarvamAI STT) input_tokens_estimate = 0 output_tokens_estimate = len(sarvam_response.transcript.split()) total_tokens_estimate = input_tokens_estimate + output_tokens_estimate diff --git a/backend/app/tests/services/llm/providers/test_eai.py b/backend/app/tests/services/llm/providers/test_eai.py index d5e6def67..c32c5483c 100644 --- a/backend/app/tests/services/llm/providers/test_eai.py +++ b/backend/app/tests/services/llm/providers/test_eai.py @@ -70,11 +70,11 @@ def query_params(self): return QueryParams(input="Test audio input") @pytest.fixture - def temp_audio_file(self, tmp_path): - """Create a temporary audio file for testing.""" - audio_file = tmp_path / "test_audio.wav" - audio_file.write_bytes(b"fake audio data") - return str(audio_file) + def temp_audio_file(self): + """Resolved STT input handle (provider materializes temp file internally).""" + from app.core.audio_utils import AudioRef + + return AudioRef(bytes_=b"fake audio data", mime_type="audio/wav") def test_stt_success_basic_transcription( self, provider, mock_client, stt_config, query_params, temp_audio_file @@ -156,16 +156,17 @@ def test_stt_uses_default_model_when_missing( call_kwargs = mock_client.speech_to_text.convert.call_args.kwargs assert call_kwargs["model_id"] == "scribe_v2" - def test_stt_invalid_file_path( + def test_stt_rejects_non_audioref_input( self, provider, mock_client, stt_config, query_params ): - """Test STT with non-existent file path.""" + """STT requires AudioRef; raw path strings are no longer accepted.""" result, error = provider.execute( stt_config, query_params, "/nonexistent/path/audio.wav" ) assert result is None assert error is not None + assert "AudioRef input" in error def test_stt_api_exception( self, provider, mock_client, stt_config, query_params, temp_audio_file diff --git a/backend/app/tests/services/llm/providers/test_gai.py b/backend/app/tests/services/llm/providers/test_gai.py index 3b3b0791e..a7d82310b 100644 --- a/backend/app/tests/services/llm/providers/test_gai.py +++ b/backend/app/tests/services/llm/providers/test_gai.py @@ -77,19 +77,25 @@ def stt_config(self): }, ) + @pytest.fixture + def audio_ref(self): + from app.core.audio_utils import AudioRef + + return AudioRef(bytes_=b"fake audio data", mime_type="audio/wav") + @pytest.fixture def query_params(self): """Create basic query parameters.""" return QueryParams(input="Test audio input") def test_stt_success_with_auto_language( - self, provider, mock_client, stt_config, query_params + self, provider, mock_client, stt_config, query_params, audio_ref ): """Test successful STT execution with auto language detection.""" mock_response = mock_google_response(text="Hello world") mock_client.models.generate_content.return_value = mock_response - result, error = provider.execute(stt_config, query_params, "/path/to/audio.wav") + result, error = provider.execute(stt_config, query_params, audio_ref) assert error is None assert result is not None @@ -100,8 +106,10 @@ def test_stt_success_with_auto_language( assert result.usage.output_tokens == 100 assert result.usage.total_tokens == 150 - # Verify file upload and content generation - mock_client.files.upload.assert_called_once_with(file="/path/to/audio.wav") + # Verify file upload was called with a materialized temp path matching the AudioRef mime. + mock_client.files.upload.assert_called_once() + uploaded_path = mock_client.files.upload.call_args.kwargs["file"] + assert uploaded_path.endswith(".wav") mock_client.models.generate_content.assert_called_once() # Verify instruction contains auto-detect @@ -109,7 +117,7 @@ def test_stt_success_with_auto_language( assert "Detect the spoken language automatically" in call_args[1]["contents"][0] def test_stt_with_specific_input_language( - self, provider, mock_client, stt_config, query_params + self, provider, mock_client, stt_config, query_params, audio_ref ): """Test STT with specific input language.""" stt_config.params["input_language"] = "English" @@ -117,7 +125,7 @@ def test_stt_with_specific_input_language( mock_response = mock_google_response(text="Transcribed English text") mock_client.models.generate_content.return_value = mock_response - result, error = provider.execute(stt_config, query_params, "/path/to/audio.wav") + result, error = provider.execute(stt_config, query_params, audio_ref) assert error is None assert result is not None @@ -127,7 +135,7 @@ def test_stt_with_specific_input_language( assert "Transcribe the audio from English" in call_args[1]["contents"][0] def test_stt_with_translation( - self, provider, mock_client, stt_config, query_params + self, provider, mock_client, stt_config, query_params, audio_ref ): """Test STT with translation to different output language.""" stt_config.params["input_language"] = "Spanish" @@ -136,7 +144,7 @@ def test_stt_with_translation( mock_response = mock_google_response(text="Translated text") mock_client.models.generate_content.return_value = mock_response - result, error = provider.execute(stt_config, query_params, "/path/to/audio.wav") + result, error = provider.execute(stt_config, query_params, audio_ref) assert error is None assert result is not None @@ -148,7 +156,7 @@ def test_stt_with_translation( assert "translate to English" in instruction def test_stt_with_custom_instructions( - self, provider, mock_client, stt_config, query_params + self, provider, mock_client, stt_config, query_params, audio_ref ): """Test STT with custom instructions.""" stt_config.params["instructions"] = "Include timestamps" @@ -156,7 +164,7 @@ def test_stt_with_custom_instructions( mock_response = mock_google_response(text="Transcribed with timestamps") mock_client.models.generate_content.return_value = mock_response - result, error = provider.execute(stt_config, query_params, "/path/to/audio.wav") + result, error = provider.execute(stt_config, query_params, audio_ref) assert error is None assert result is not None @@ -167,7 +175,7 @@ def test_stt_with_custom_instructions( assert "Include timestamps" in instruction def test_stt_with_include_provider_raw_response( - self, provider, mock_client, stt_config, query_params + self, provider, mock_client, stt_config, query_params, audio_ref ): """Test STT with include_provider_raw_response=True.""" mock_response = mock_google_response(text="Raw response test") @@ -186,25 +194,27 @@ def test_stt_with_include_provider_raw_response( assert isinstance(result.provider_raw_response, dict) assert result.provider_raw_response["text"] == "Raw response test" - def test_stt_with_type_error(self, provider, mock_client, stt_config, query_params): + def test_stt_with_type_error( + self, provider, mock_client, stt_config, query_params, audio_ref + ): """Test handling of TypeError (invalid parameters).""" mock_client.models.generate_content.side_effect = TypeError( "unexpected keyword argument 'invalid_param'" ) - result, error = provider.execute(stt_config, query_params, "/path/to/audio.wav") + result, error = provider.execute(stt_config, query_params, audio_ref) assert result is None assert error is not None assert "Invalid or unexpected parameter in Config" in error def test_stt_with_generic_exception( - self, provider, mock_client, stt_config, query_params + self, provider, mock_client, stt_config, query_params, audio_ref ): """Test handling of unexpected exceptions.""" mock_client.files.upload.side_effect = Exception("File upload failed") - result, error = provider.execute(stt_config, query_params, "/path/to/audio.wav") + result, error = provider.execute(stt_config, query_params, audio_ref) assert result is None assert error is not None @@ -221,16 +231,16 @@ def test_stt_with_invalid_input_type( assert result is None assert error is not None - assert "STT requires file path as string" in error + assert "STT requires AudioRef input" in error - def test_stt_with_valid_file_path( - self, provider, mock_client, stt_config, query_params + def test_stt_with_valid_audio_ref( + self, provider, mock_client, stt_config, query_params, audio_ref ): """Test STT execution with valid file path string.""" mock_response = mock_google_response(text="Valid transcription") mock_client.models.generate_content.return_value = mock_response - result, error = provider.execute(stt_config, query_params, "/path/to/audio.wav") + result, error = provider.execute(stt_config, query_params, audio_ref) assert error is None assert result is not None diff --git a/backend/app/tests/services/llm/providers/test_gai_vertex.py b/backend/app/tests/services/llm/providers/test_gai_vertex.py index 928d5bcb6..6df84c4a6 100644 --- a/backend/app/tests/services/llm/providers/test_gai_vertex.py +++ b/backend/app/tests/services/llm/providers/test_gai_vertex.py @@ -6,9 +6,9 @@ import pytest import requests +from app.core.audio_utils import AudioRef from app.models.llm import NativeCompletionConfig, QueryParams from app.services.llm.providers.gai_vertex import ( - MAX_INLINE_AUDIO_BYTES, GoogleVertexAIProvider, VertexClient, ) @@ -67,10 +67,28 @@ def _mock_http_err(status: int = 400, body: str = "bad request") -> MagicMock: return resp +@pytest.fixture(autouse=True) +def _mock_gcs(monkeypatch): + """Stub out SM + GCS so STT tests don't touch external services.""" + monkeypatch.setattr( + "app.services.llm.providers.gai_vertex.get_gcp_service_account", + lambda **kw: {"type": "service_account", "project_id": "p"}, + ) + monkeypatch.setattr( + "app.services.llm.providers.gai_vertex.upload_audio_to_gcs", + lambda *, audio_bytes, bucket_name, sa_info, **kw: f"gs://{bucket_name}/audio/test.wav", + ) + + class TestGoogleVertexAIProvider: @pytest.fixture def client(self) -> VertexClient: - return VertexClient(api_key="k", project_id="p", location="us-central1") + return VertexClient( + api_key="k", + project_id="p", + location="us-central1", + gcs_bucket="test-bucket", + ) @pytest.fixture def provider(self, client) -> GoogleVertexAIProvider: @@ -80,6 +98,10 @@ def provider(self, client) -> GoogleVertexAIProvider: def query(self) -> QueryParams: return QueryParams(input="ignored") + @pytest.fixture + def audio_ref(self) -> AudioRef: + return AudioRef(bytes_=b"RIFFfake", mime_type="audio/wav") + @pytest.fixture def stt_config(self) -> NativeCompletionConfig: return NativeCompletionConfig( @@ -110,15 +132,12 @@ def test_create_client_builds_endpoint(self): assert "models/m:generateContent" in c.endpoint("m") # ── STT ────────────────────────────────────────────────────────────────── - def test_stt_happy_path(self, provider, stt_config, query, tmp_path): - audio = tmp_path / "a.wav" - audio.write_bytes(b"RIFFfake") - + def test_stt_happy_path(self, provider, stt_config, query, audio_ref): with patch( "app.services.llm.providers.gai_vertex.requests.post", return_value=_mock_http_ok(_stt_response("hi there")), ) as mock_post: - resp, err = provider.execute(stt_config, query, str(audio)) + resp, err = provider.execute(stt_config, query, audio_ref) assert err is None assert resp.response.output.content.value == "hi there" @@ -126,87 +145,71 @@ def test_stt_happy_path(self, provider, stt_config, query, tmp_path): assert resp.usage.input_tokens == 10 assert resp.usage.output_tokens == 5 - # Verify payload shape kwargs = mock_post.call_args.kwargs assert kwargs["params"] == {"key": "k"} - sent = kwargs["json"] - parts = sent["contents"][0]["parts"] - assert parts[0]["inlineData"]["mimeType"] == "audio/wav" + parts = kwargs["json"]["contents"][0]["parts"] + assert parts[0]["fileData"]["mimeType"] == "audio/wav" + assert parts[0]["fileData"]["fileUri"].startswith("gs://test-bucket/") assert "Detect the spoken language automatically" in parts[1]["text"] - def test_stt_rejects_non_string_input(self, provider, stt_config, query): - resp, err = provider.execute(stt_config, query, 123) + def test_stt_rejects_non_audioref_input(self, provider, stt_config, query): + resp, err = provider.execute(stt_config, query, "/some/path.wav") assert resp is None - assert "file path as string" in err + assert "AudioRef input" in err - def test_stt_rejects_missing_file(self, provider, stt_config, query): - resp, err = provider.execute(stt_config, query, "/nope/missing.wav") + def test_stt_rejects_unsupported_mime(self, provider, stt_config, query): + bad = AudioRef(bytes_=b"x", mime_type="audio/xyz") + resp, err = provider.execute(stt_config, query, bad) assert resp is None - assert "Audio file not found" in err + assert "Unsupported audio mime" in err - def test_stt_rejects_unsupported_extension( - self, provider, stt_config, query, tmp_path + def test_stt_gcs_upload_failure_returns_clean_error( + self, provider, stt_config, query, audio_ref, monkeypatch ): - audio = tmp_path / "a.xyz" - audio.write_bytes(b"x") - resp, err = provider.execute(stt_config, query, str(audio)) - assert resp is None - assert "Unsupported audio extension" in err - - def test_stt_rejects_oversized_file(self, provider, stt_config, query, tmp_path): - audio = tmp_path / "a.wav" - audio.write_bytes(b"x") - with patch( - "app.services.llm.providers.gai_vertex.os.path.getsize", - return_value=MAX_INLINE_AUDIO_BYTES + 1, - ): - resp, err = provider.execute(stt_config, query, str(audio)) + monkeypatch.setattr( + "app.services.llm.providers.gai_vertex.upload_audio_to_gcs", + MagicMock(side_effect=RuntimeError("bucket denied")), + ) + resp, err = provider.execute(stt_config, query, audio_ref) assert resp is None - assert "inline-data limit" in err + assert "Failed to stage audio for Vertex STT" in err + assert "bucket denied" in err def test_stt_http_error_returns_clean_message( - self, provider, stt_config, query, tmp_path + self, provider, stt_config, query, audio_ref ): - audio = tmp_path / "a.wav" - audio.write_bytes(b"RIFF") with patch( "app.services.llm.providers.gai_vertex.requests.post", return_value=_mock_http_err(403, "permission denied"), ): - resp, err = provider.execute(stt_config, query, str(audio)) + resp, err = provider.execute(stt_config, query, audio_ref) assert resp is None assert "Vertex AI HTTP 403" in err assert "permission denied" in err def test_stt_network_error_returns_clean_message( - self, provider, stt_config, query, tmp_path + self, provider, stt_config, query, audio_ref ): - audio = tmp_path / "a.wav" - audio.write_bytes(b"RIFF") with patch( "app.services.llm.providers.gai_vertex.requests.post", side_effect=requests.ConnectionError("dns boom"), ): - resp, err = provider.execute(stt_config, query, str(audio)) + resp, err = provider.execute(stt_config, query, audio_ref) assert resp is None assert "Vertex AI request failed" in err def test_stt_missing_transcript_returns_error( - self, provider, stt_config, query, tmp_path + self, provider, stt_config, query, audio_ref ): - audio = tmp_path / "a.wav" - audio.write_bytes(b"RIFF") with patch( "app.services.llm.providers.gai_vertex.requests.post", return_value=_mock_http_ok({"candidates": []}), ): - resp, err = provider.execute(stt_config, query, str(audio)) + resp, err = provider.execute(stt_config, query, audio_ref) assert resp is None assert "missing transcript text" in err - def test_stt_input_language_overrides_prompt(self, provider, query, tmp_path): - audio = tmp_path / "a.wav" - audio.write_bytes(b"RIFF") + def test_stt_input_language_overrides_prompt(self, provider, query, audio_ref): config = NativeCompletionConfig( provider="google-vertex-native", type="stt", @@ -221,7 +224,7 @@ def test_stt_input_language_overrides_prompt(self, provider, query, tmp_path): "app.services.llm.providers.gai_vertex.requests.post", return_value=_mock_http_ok(_stt_response()), ) as mock_post: - provider.execute(config, query, str(audio)) + provider.execute(config, query, audio_ref) prompt = mock_post.call_args.kwargs["json"]["contents"][0]["parts"][1]["text"] assert prompt.startswith("be precise") @@ -239,7 +242,6 @@ def test_tts_happy_path_wav(self, provider, tts_config, query): assert err is None assert resp.response.output.content.format == "base64" assert resp.response.output.content.mime_type == "audio/wav" - # base64 decodes to valid WAV header decoded = base64.b64decode(resp.response.output.content.value) assert decoded[:4] == b"RIFF" @@ -297,16 +299,14 @@ def test_text_completion_is_rejected(self, provider, query): assert "does not support completion type 'text'" in err def test_raw_response_included_when_requested( - self, provider, stt_config, query, tmp_path + self, provider, stt_config, query, audio_ref ): - audio = tmp_path / "a.wav" - audio.write_bytes(b"RIFF") raw = _stt_response() with patch( "app.services.llm.providers.gai_vertex.requests.post", return_value=_mock_http_ok(raw), ): resp, _ = provider.execute( - stt_config, query, str(audio), include_provider_raw_response=True + stt_config, query, audio_ref, include_provider_raw_response=True ) assert resp.provider_raw_response == raw diff --git a/backend/app/tests/services/llm/providers/test_sai.py b/backend/app/tests/services/llm/providers/test_sai.py index 78fdd30a8..758187693 100644 --- a/backend/app/tests/services/llm/providers/test_sai.py +++ b/backend/app/tests/services/llm/providers/test_sai.py @@ -81,11 +81,11 @@ def query_params(self): return QueryParams(input="Test audio input") @pytest.fixture - def temp_audio_file(self, tmp_path): - """Create a temporary audio file for testing.""" - audio_file = tmp_path / "test_audio.wav" - audio_file.write_bytes(b"fake audio data") - return str(audio_file) + def temp_audio_file(self): + """Resolved STT input handle for tests (provider materializes temp file internally).""" + from app.core.audio_utils import AudioRef + + return AudioRef(bytes_=b"fake audio data", mime_type="audio/wav") def test_stt_success_basic_transcription( self, provider, mock_client, stt_config, query_params, temp_audio_file @@ -175,16 +175,17 @@ def test_stt_uses_default_model_when_missing( call_kwargs = mock_client.speech_to_text.transcribe.call_args.kwargs assert call_kwargs["model"] == "saaras:v3" - def test_stt_invalid_file_path( + def test_stt_rejects_non_audioref_input( self, provider, mock_client, stt_config, query_params ): - """Test STT with non-existent file path.""" + """STT requires AudioRef; string paths are no longer accepted.""" result, error = provider.execute( stt_config, query_params, "/nonexistent/path/audio.wav" ) assert result is None assert error is not None + assert "AudioRef input" in error def test_stt_api_exception( self, provider, mock_client, stt_config, query_params, temp_audio_file diff --git a/backend/app/tests/services/llm/test_input_resolver.py b/backend/app/tests/services/llm/test_input_resolver.py index ffc0b74b7..e8e9b3a17 100644 --- a/backend/app/tests/services/llm/test_input_resolver.py +++ b/backend/app/tests/services/llm/test_input_resolver.py @@ -1,161 +1,122 @@ -""" -Unit tests for LLM input resolver functions. - -Tests input resolution for text and base64 audio inputs. -""" +"""Unit tests for LLM input resolver functions.""" import base64 -import tempfile +import os from pathlib import Path -from unittest.mock import patch, Mock - -import pytest -from app.models.llm.request import TextInput, AudioInput, TextContent, AudioContent +from app.core.audio_utils import AudioRef +from app.models.llm.request import ( + AudioContent, + AudioInput, + TextContent, + TextInput, +) from app.utils import ( + cleanup_temp_file, get_file_extension, - resolve_input, resolve_audio_base64, - cleanup_temp_file, + resolve_input, ) class TestGetFileExtension: - """Test MIME type to file extension mapping.""" - def test_common_audio_formats(self): - """Test common audio MIME types.""" assert get_file_extension("audio/wav") == ".wav" assert get_file_extension("audio/mp3") == ".mp3" assert get_file_extension("audio/mpeg") == ".mp3" assert get_file_extension("audio/ogg") == ".ogg" def test_wav_variants(self): - """Test various WAV MIME type variants.""" assert get_file_extension("audio/wave") == ".wav" assert get_file_extension("audio/x-wav") == ".wav" def test_unknown_mime_type(self): - """Test fallback for unknown MIME types.""" assert get_file_extension("audio/unknown") == ".audio" assert get_file_extension("application/octet-stream") == ".audio" class TestResolveInput: - """Test main input resolution function.""" - def test_text_input(self): - """Test resolving text input.""" text_input = TextInput(content=TextContent(value="Hello world")) content, error = resolve_input(text_input) - assert content == "Hello world" assert error is None - def test_audio_base64_input(self): - """Test resolving base64 audio input.""" - # Create minimal valid audio data + def test_audio_base64_input_returns_audio_ref(self): audio_data = b"RIFF" + b"\x00" * 36 # Minimal WAV header encoded = base64.b64encode(audio_data).decode() audio_input = AudioInput( content=AudioContent(value=encoded, mime_type="audio/wav") ) - file_path, error = resolve_input(audio_input) + ref, error = resolve_input(audio_input) assert error is None - assert file_path != "" - assert Path(file_path).exists() - assert file_path.endswith(".wav") - - # Cleanup - cleanup_temp_file(file_path) + assert isinstance(ref, AudioRef) + assert ref.bytes_ == audio_data + assert ref.mime_type == "audio/wav" def test_invalid_base64_data(self): - """Test handling of invalid base64 data.""" audio_input = AudioInput( content=AudioContent(value="not-valid-base64!!!", mime_type="audio/wav") ) content, error = resolve_input(audio_input) - - assert content == "" + assert content is None assert error is not None assert "base64" in error.lower() class TestResolveAudioBase64: - """Test base64 audio resolution.""" - def test_valid_base64_audio(self): - """Test decoding valid base64 audio data.""" audio_data = b"Test audio content" encoded = base64.b64encode(audio_data).decode() - file_path, error = resolve_audio_base64(encoded, "audio/mp3") + ref, error = resolve_audio_base64(encoded, "audio/mp3") assert error is None - assert file_path != "" - assert Path(file_path).exists() - assert file_path.endswith(".mp3") - - # Verify content - with open(file_path, "rb") as f: - assert f.read() == audio_data - - # Cleanup - cleanup_temp_file(file_path) + assert isinstance(ref, AudioRef) + assert ref.bytes_ == audio_data + assert ref.mime_type == "audio/mp3" def test_invalid_base64_string(self): - """Test handling invalid base64 string.""" - file_path, error = resolve_audio_base64("invalid!!!base64", "audio/wav") - - assert file_path == "" + ref, error = resolve_audio_base64("invalid!!!base64", "audio/wav") + assert ref is None assert error is not None assert "Invalid base64" in error - def test_different_mime_types(self): - """Test file extension based on MIME type.""" - audio_data = b"Audio" - encoded = base64.b64encode(audio_data).decode() - # Test WAV - file_path, _ = resolve_audio_base64(encoded, "audio/wav") - assert file_path.endswith(".wav") - cleanup_temp_file(file_path) +class TestAudioRefToPath: + def test_to_path_writes_and_cleans_up(self): + audio_data = b"Audio bytes" + ref = AudioRef(bytes_=audio_data, mime_type="audio/wav") - # Test OGG - file_path, _ = resolve_audio_base64(encoded, "audio/ogg") - assert file_path.endswith(".ogg") - cleanup_temp_file(file_path) + with ref.to_path() as p: + assert Path(p).exists() + assert p.endswith(".wav") + with open(p, "rb") as f: + assert f.read() == audio_data + # File must be cleaned up after the context exits. + assert not Path(p).exists() -# URL-based audio input tests removed - only base64 audio is supported + def test_to_path_cleans_up_on_exception(self): + ref = AudioRef(bytes_=b"x", mime_type="audio/ogg") + captured_path = None + try: + with ref.to_path() as p: + captured_path = p + raise RuntimeError("boom") + except RuntimeError: + pass + assert captured_path is not None + assert not Path(captured_path).exists() class TestCleanupTempFile: - """Test temporary file cleanup.""" - - def test_cleanup_existing_file(self): - """Test cleaning up an existing temp file.""" - # Create a temp file - with tempfile.NamedTemporaryFile(delete=False) as tmp: - tmp.write(b"test data") - temp_path = tmp.name - - assert Path(temp_path).exists() - - # Cleanup - cleanup_temp_file(temp_path) - - # Verify deleted - assert not Path(temp_path).exists() + """cleanup_temp_file remains in app.utils for non-AudioRef callers.""" def test_cleanup_nonexistent_file(self): - """Test cleaning up a non-existent file (should not error).""" - # Should not raise an exception cleanup_temp_file("/tmp/nonexistent_file_12345.wav") def test_cleanup_invalid_path(self): - """Test cleanup with invalid path (should not error).""" - # Should handle gracefully cleanup_temp_file("") diff --git a/backend/app/utils.py b/backend/app/utils.py index 8448c7e09..0526b331a 100644 --- a/backend/app/utils.py +++ b/backend/app/utils.py @@ -28,6 +28,7 @@ from sqlmodel import Session from app.core import security +from app.core.audio_utils import AudioRef from app.core.config import settings from app.crud.credentials import get_provider_credential from app.models.llm.request import ( @@ -600,25 +601,15 @@ def get_file_extension(mime_type: str) -> str: return mime_to_ext.get(mime_type, ".audio") -def resolve_audio_base64(data: str, mime_type: str) -> tuple[str, str | None]: - """Decode base64 audio and write to temp file. Returns (file_path, error).""" +def resolve_audio_base64( + data: str, mime_type: str +) -> tuple["AudioRef | None", str | None]: + """Decode base64 audio into an in-memory AudioRef.""" try: audio_bytes = base64.b64decode(data) except Exception as e: - return "", f"Invalid base64 audio data: {str(e)}" - - ext = get_file_extension(mime_type) - try: - with tempfile.NamedTemporaryFile( - suffix=ext, delete=False, prefix="audio_" - ) as tmp: - tmp.write(audio_bytes) - temp_path = tmp.name - - logger.info(f"[resolve_audio_base64] Wrote audio to temp file: {temp_path}") - return temp_path, None - except Exception as e: - return "", f"Failed to write audio to temp file: {str(e)}" + return None, f"Invalid base64 audio data: {str(e)}" + return AudioRef(bytes_=audio_bytes, mime_type=mime_type), None def download_audio_bytes(url: str) -> tuple[bytes | None, str | None]: @@ -669,23 +660,12 @@ def download_audio_bytes(url: str) -> tuple[bytes | None, str | None]: return None, f"Failed to download audio from URL: {str(e)}" -def resolve_audio_url(url: str, mime_type: str) -> tuple[str, str | None]: - """Download audio from a public URL and write to temp file. Returns (file_path, error).""" +def resolve_audio_url(url: str, mime_type: str) -> tuple["AudioRef | None", str | None]: + """Download audio from a public URL into an in-memory AudioRef.""" audio_bytes, error = download_audio_bytes(url) - if error: - return "", error - - ext = get_file_extension(mime_type) - try: - with tempfile.NamedTemporaryFile( - suffix=ext, delete=False, prefix="audio_" - ) as tmp: - tmp.write(audio_bytes) - temp_path = tmp.name - logger.info(f"[resolve_audio_url] Downloaded audio to temp file: {temp_path}") - return temp_path, None - except Exception as e: - return "", f"Failed to write audio to temp file: {str(e)}" + if error or not audio_bytes: + return None, error + return AudioRef(bytes_=audio_bytes, mime_type=mime_type), None def resolve_image_content(image_input: ImageInput) -> list[ImageContent]: @@ -714,15 +694,19 @@ def resolve_pdf_content(pdf_input: PDFInput) -> list[PDFContent]: def resolve_input( query_input, -) -> tuple[str | list[ImageContent] | list[PDFContent] | "MultiModalInput", str | None]: +) -> tuple[ + "str | AudioRef | list[ImageContent] | list[PDFContent] | MultiModalInput | None", + str | None, +]: """Resolve query input to provider-ready format. Returns: - - TextInput/AudioInput: (str, None) + - TextInput: (str, None) + - AudioInput: (AudioRef, None) - ImageInput: (list[ImageContent], None) - PDFInput: (list[PDFContent], None) - list[QueryInput]: (MultiModalInput, None) - - Error: ("", error_message) + - Error: (None, error_message) """ try: @@ -752,22 +736,22 @@ def resolve_input( parts.extend(resolve_pdf_content(item)) elif isinstance(item, AudioInput): return ( - "", + None, "Audio input is not supported in multimodal. Please use completion type 'stt' for audio processing.", ) else: return ( - "", + None, "Unsupported input type in multimodal list. Multimodal only supports text, image, and pdf inputs.", ) return MultiModalInput(parts=parts), None else: - return "", f"Unknown input type: {type(query_input)}" + return None, f"Unknown input type: {type(query_input)}" except Exception as e: logger.warning(f"[resolve_input] Failed to resolve input: {e}", exc_info=True) - return "", f"Failed to resolve input: {str(e)}" + return None, f"Failed to resolve input: {str(e)}" def cleanup_temp_file(file_path: str) -> None: diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 96cfa8055..7e528ac87 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -55,6 +55,7 @@ dependencies = [ "openpyxl>=3.1.5", "litellm>=1.83.10", "anthropic>=0.104.1", + "google-cloud-storage>=3.10.1", ] [tool.uv] diff --git a/backend/uv.lock b/backend/uv.lock index 67912f5e9..614f6c79e 100644 --- a/backend/uv.lock +++ b/backend/uv.lock @@ -249,6 +249,7 @@ dependencies = [ { name = "flower" }, { name = "gevent" }, { name = "google-auth" }, + { name = "google-cloud-storage" }, { name = "google-genai" }, { name = "httpx" }, { name = "indic-nlp-library" }, @@ -314,6 +315,7 @@ requires-dist = [ { name = "flower", specifier = ">=2.0.1" }, { name = "gevent", specifier = ">=25.9.1" }, { name = "google-auth", specifier = ">=2.49.1" }, + { name = "google-cloud-storage", specifier = ">=3.10.1" }, { name = "google-genai", specifier = ">=1.59.0" }, { name = "httpx", specifier = ">=0.25.1,<1.0.0" }, { name = "indic-nlp-library", specifier = ">=0.92" }, @@ -1308,6 +1310,22 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c2/69/a7c4ba2ffbc7c7dbf6d8b4f5d0f0a421f7815d229f4909854266c445a3d4/gevent-25.9.1-cp314-cp314-win_amd64.whl", hash = "sha256:bb63c0d6cb9950cc94036a4995b9cc4667b8915366613449236970f4394f94d7", size = 1703019, upload-time = "2025-09-17T19:30:55.272Z" }, ] +[[package]] +name = "google-api-core" +version = "2.30.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "google-auth" }, + { name = "googleapis-common-protos" }, + { name = "proto-plus" }, + { name = "protobuf" }, + { name = "requests" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/16/ce/502a57fb0ec752026d24df1280b162294b22a0afb98a326084f9a979138b/google_api_core-2.30.3.tar.gz", hash = "sha256:e601a37f148585319b26db36e219df68c5d07b6382cff2d580e83404e44d641b", size = 177001, upload-time = "2026-04-10T00:41:28.035Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/03/15/e56f351cf6ef1cfea58e6ac226a7318ed1deb2218c4b3cc9bd9e4b786c5a/google_api_core-2.30.3-py3-none-any.whl", hash = "sha256:a85761ba72c444dad5d611c2220633480b2b6be2521eca69cca2dbb3ffd6bfe8", size = 173274, upload-time = "2026-04-09T22:57:16.198Z" }, +] + [[package]] name = "google-auth" version = "2.49.1" @@ -1326,6 +1344,59 @@ requests = [ { name = "requests" }, ] +[[package]] +name = "google-cloud-core" +version = "2.6.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "google-api-core" }, + { name = "google-auth" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a8/dd/1eef226e470369b26824a505c34482c0b493bc35fe8e0c6b003b5feca21a/google_cloud_core-2.6.0.tar.gz", hash = "sha256:e76149739f90fac1fc6757c09f47eaccb3145b54adbd7759b0f7c4b235f46c83", size = 36001, upload-time = "2026-05-07T08:04:04.124Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/84/4a/98da8930ab109c73d9a5d13782a9ebb81ea8c111f6d534a567b71d23e52b/google_cloud_core-2.6.0-py3-none-any.whl", hash = "sha256:6d63ac8e5eca6d9e4319d0a1e2265fadcd7f1049904378caecfa01cf52dd869e", size = 29390, upload-time = "2026-05-07T08:02:34.672Z" }, +] + +[[package]] +name = "google-cloud-storage" +version = "3.10.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "google-api-core" }, + { name = "google-auth" }, + { name = "google-cloud-core" }, + { name = "google-crc32c" }, + { name = "google-resumable-media" }, + { name = "requests" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/4c/47/205eb8e9a1739b5345843e5a425775cbdc472cc38e7eda082ba5b8d02450/google_cloud_storage-3.10.1.tar.gz", hash = "sha256:97db9aa4460727982040edd2bd13ff3d5e2260b5331ad22895802da1fc2a5286", size = 17309950, upload-time = "2026-03-23T09:35:23.409Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ad/ff/ca9ab2417fa913d75aae38bf40bf856bb2749a604b2e0f701b37cfcd23cc/google_cloud_storage-3.10.1-py3-none-any.whl", hash = "sha256:a72f656759b7b99bda700f901adcb3425a828d4a29f911bc26b3ea79c5b1217f", size = 324453, upload-time = "2026-03-23T09:35:21.368Z" }, +] + +[[package]] +name = "google-crc32c" +version = "1.8.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/03/41/4b9c02f99e4c5fb477122cd5437403b552873f014616ac1d19ac8221a58d/google_crc32c-1.8.0.tar.gz", hash = "sha256:a428e25fb7691024de47fecfbff7ff957214da51eddded0da0ae0e0f03a2cf79", size = 14192, upload-time = "2025-12-16T00:35:25.142Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e9/5f/7307325b1198b59324c0fa9807cafb551afb65e831699f2ce211ad5c8240/google_crc32c-1.8.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:4b8286b659c1335172e39563ab0a768b8015e88e08329fa5321f774275fc3113", size = 31300, upload-time = "2025-12-16T00:21:56.723Z" }, + { url = "https://files.pythonhosted.org/packages/21/8e/58c0d5d86e2220e6a37befe7e6a94dd2f6006044b1a33edf1ff6d9f7e319/google_crc32c-1.8.0-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:2a3dc3318507de089c5384cc74d54318401410f82aa65b2d9cdde9d297aca7cb", size = 30867, upload-time = "2025-12-16T00:38:31.302Z" }, + { url = "https://files.pythonhosted.org/packages/ce/a9/a780cc66f86335a6019f557a8aaca8fbb970728f0efd2430d15ff1beae0e/google_crc32c-1.8.0-cp312-cp312-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:14f87e04d613dfa218d6135e81b78272c3b904e2a7053b841481b38a7d901411", size = 33364, upload-time = "2025-12-16T00:40:22.96Z" }, + { url = "https://files.pythonhosted.org/packages/21/3f/3457ea803db0198c9aaca2dd373750972ce28a26f00544b6b85088811939/google_crc32c-1.8.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cb5c869c2923d56cb0c8e6bcdd73c009c36ae39b652dbe46a05eb4ef0ad01454", size = 33740, upload-time = "2025-12-16T00:40:23.96Z" }, + { url = "https://files.pythonhosted.org/packages/df/c0/87c2073e0c72515bb8733d4eef7b21548e8d189f094b5dad20b0ecaf64f6/google_crc32c-1.8.0-cp312-cp312-win_amd64.whl", hash = "sha256:3cc0c8912038065eafa603b238abf252e204accab2a704c63b9e14837a854962", size = 34437, upload-time = "2025-12-16T00:35:21.395Z" }, + { url = "https://files.pythonhosted.org/packages/d1/db/000f15b41724589b0e7bc24bc7a8967898d8d3bc8caf64c513d91ef1f6c0/google_crc32c-1.8.0-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:3ebb04528e83b2634857f43f9bb8ef5b2bbe7f10f140daeb01b58f972d04736b", size = 31297, upload-time = "2025-12-16T00:23:20.709Z" }, + { url = "https://files.pythonhosted.org/packages/d7/0d/8ebed0c39c53a7e838e2a486da8abb0e52de135f1b376ae2f0b160eb4c1a/google_crc32c-1.8.0-cp313-cp313-macosx_12_0_x86_64.whl", hash = "sha256:450dc98429d3e33ed2926fc99ee81001928d63460f8538f21a5d6060912a8e27", size = 30867, upload-time = "2025-12-16T00:43:14.628Z" }, + { url = "https://files.pythonhosted.org/packages/ce/42/b468aec74a0354b34c8cbf748db20d6e350a68a2b0912e128cabee49806c/google_crc32c-1.8.0-cp313-cp313-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:3b9776774b24ba76831609ffbabce8cdf6fa2bd5e9df37b594221c7e333a81fa", size = 33344, upload-time = "2025-12-16T00:40:24.742Z" }, + { url = "https://files.pythonhosted.org/packages/1c/e8/b33784d6fc77fb5062a8a7854e43e1e618b87d5ddf610a88025e4de6226e/google_crc32c-1.8.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:89c17d53d75562edfff86679244830599ee0a48efc216200691de8b02ab6b2b8", size = 33694, upload-time = "2025-12-16T00:40:25.505Z" }, + { url = "https://files.pythonhosted.org/packages/92/b1/d3cbd4d988afb3d8e4db94ca953df429ed6db7282ed0e700d25e6c7bfc8d/google_crc32c-1.8.0-cp313-cp313-win_amd64.whl", hash = "sha256:57a50a9035b75643996fbf224d6661e386c7162d1dfdab9bc4ca790947d1007f", size = 34435, upload-time = "2025-12-16T00:35:22.107Z" }, + { url = "https://files.pythonhosted.org/packages/21/88/8ecf3c2b864a490b9e7010c84fd203ec8cf3b280651106a3a74dd1b0ca72/google_crc32c-1.8.0-cp314-cp314-macosx_12_0_arm64.whl", hash = "sha256:e6584b12cb06796d285d09e33f63309a09368b9d806a551d8036a4207ea43697", size = 31301, upload-time = "2025-12-16T00:24:48.527Z" }, + { url = "https://files.pythonhosted.org/packages/36/c6/f7ff6c11f5ca215d9f43d3629163727a272eabc356e5c9b2853df2bfe965/google_crc32c-1.8.0-cp314-cp314-macosx_12_0_x86_64.whl", hash = "sha256:f4b51844ef67d6cf2e9425983274da75f18b1597bb2c998e1c0a0e8d46f8f651", size = 30868, upload-time = "2025-12-16T00:48:12.163Z" }, + { url = "https://files.pythonhosted.org/packages/56/15/c25671c7aad70f8179d858c55a6ae8404902abe0cdcf32a29d581792b491/google_crc32c-1.8.0-cp314-cp314-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:b0d1a7afc6e8e4635564ba8aa5c0548e3173e41b6384d7711a9123165f582de2", size = 33381, upload-time = "2025-12-16T00:40:26.268Z" }, + { url = "https://files.pythonhosted.org/packages/42/fa/f50f51260d7b0ef5d4898af122d8a7ec5a84e2984f676f746445f783705f/google_crc32c-1.8.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:8b3f68782f3cbd1bce027e48768293072813469af6a61a86f6bb4977a4380f21", size = 33734, upload-time = "2025-12-16T00:40:27.028Z" }, + { url = "https://files.pythonhosted.org/packages/08/a5/7b059810934a09fb3ccb657e0843813c1fee1183d3bc2c8041800374aa2c/google_crc32c-1.8.0-cp314-cp314-win_amd64.whl", hash = "sha256:d511b3153e7011a27ab6ee6bb3a5404a55b994dc1a7322c0b87b29606d9790e2", size = 34878, upload-time = "2025-12-16T00:35:23.142Z" }, +] + [[package]] name = "google-genai" version = "1.67.0" @@ -1347,6 +1418,30 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6e/c2/562aa1f086e53529ffbeb5b43d5d8bc42c1b968102b5e2163fad005ce298/google_genai-1.67.0-py3-none-any.whl", hash = "sha256:58b0484ff2d4335fa53c724b489e9f807fcca8115d9cdbd8fdf341121fbd6d2d", size = 733542, upload-time = "2026-03-12T20:39:14.615Z" }, ] +[[package]] +name = "google-resumable-media" +version = "2.9.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "google-crc32c" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/00/4b/0b235beccc310d0a48adbc7246b719d173cca6c88c572dfa4b090e39143c/google_resumable_media-2.9.0.tar.gz", hash = "sha256:f7cfb224846a9dd444d125115dfbe8ef02a2b893e78f087762fe716a255a734b", size = 2164534, upload-time = "2026-05-07T08:04:44.236Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/07/73/3518e63deb1667c5409a4579e28daf5e84479a87a72c547e0487f7883dcd/google_resumable_media-2.9.0-py3-none-any.whl", hash = "sha256:c8901e88e389af8bed64d9696c74d8bad961865eb2236e13e0bfca9bb0a65ca3", size = 81507, upload-time = "2026-05-07T08:03:23.809Z" }, +] + +[[package]] +name = "googleapis-common-protos" +version = "1.75.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "protobuf" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b5/c8/f439cffde755cffa462bfbb156278fa6f9d09119719af9814b858fd4f81f/googleapis_common_protos-1.75.0.tar.gz", hash = "sha256:53a062ff3c32552fbd62c11fe23768b78e4ddf0494d5e5fd97d3f4689c75fbbd", size = 151035, upload-time = "2026-05-07T08:04:49.423Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e7/c8/e2645aa8ed02fd4c7a2f59d68783b65b1f3cbdfe39a6308e156509d1fee8/googleapis_common_protos-1.75.0-py3-none-any.whl", hash = "sha256:961ed60399c457ceb0ee8f285a84c870aabc9c6a832b9d37bb281b5bebde43ed", size = 300631, upload-time = "2026-05-07T08:03:30.345Z" }, +] + [[package]] name = "greenlet" version = "3.3.2" @@ -2795,6 +2890,33 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5b/5a/bc7b4a4ef808fa59a816c17b20c4bef6884daebbdf627ff2a161da67da19/propcache-0.4.1-py3-none-any.whl", hash = "sha256:af2a6052aeb6cf17d3e46ee169099044fd8224cbaf75c76a2ef596e8163e2237", size = 13305, upload-time = "2025-10-08T19:49:00.792Z" }, ] +[[package]] +name = "proto-plus" +version = "1.28.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "protobuf" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c9/56/e647b0c675392d2da368da7b6f158f7368b18542fd6f7d7400a2f39de000/proto_plus-1.28.0.tar.gz", hash = "sha256:38e5696342835b08fc116f30a25665b29531cda9d5d5643e9b81fc312385abd9", size = 57221, upload-time = "2026-05-07T08:04:50.811Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7c/20/b122d4626976acb81132036d2ad1bb35a1a8775fceb837ec30964622516a/proto_plus-1.28.0-py3-none-any.whl", hash = "sha256:a630604310899e73c59ec302e5765c058d412b2f090b9c79c8822589f14955b8", size = 50410, upload-time = "2026-05-07T08:03:31.962Z" }, +] + +[[package]] +name = "protobuf" +version = "7.35.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/60/fd/5b1491d9e4b586d621c54f4c36b888714164b6875f8d6afa3f9072906a51/protobuf-7.35.0.tar.gz", hash = "sha256:a2efd84605f41e559f1881b0912b44099d0a2ac9bf46b3474823f10fb393b0e6", size = 458677, upload-time = "2026-05-19T23:02:29.197Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/83/ee/93d06e358a4aa32280b00e722d3ea0a1f25fc3cc5778d80581c9cca2c10e/protobuf-7.35.0-cp310-abi3-macosx_10_9_universal2.whl", hash = "sha256:66be6c513931c794fa92c080ffee41671390da3d79da219cf9c0c0907f035dda", size = 433225, upload-time = "2026-05-19T23:02:19.884Z" }, + { url = "https://files.pythonhosted.org/packages/8b/39/1c76c2da93f3c507e958e0aecee2391cc44d4625de6c728bbc555195b5a8/protobuf-7.35.0-cp310-abi3-manylinux2014_aarch64.whl", hash = "sha256:fcbe42a4ac09d3ec9c987ddfcd956afd0b15f1ff613bd8371bde9405ffd5c8e5", size = 328847, upload-time = "2026-05-19T23:02:22.3Z" }, + { url = "https://files.pythonhosted.org/packages/91/1a/39f7ce90a238c1a987a4d81ec26379e02ca0aff367de68e4a1fa474215b9/protobuf-7.35.0-cp310-abi3-manylinux2014_s390x.whl", hash = "sha256:4cbf5cc286130e06a6c9bbefac442431173906dfcc979712183d4adcc01b37ee", size = 344030, upload-time = "2026-05-19T23:02:23.591Z" }, + { url = "https://files.pythonhosted.org/packages/70/5b/6baf9008817964454055ff3fe65f1de0b5f1e26c80c82f7fb108b7cd4ea3/protobuf-7.35.0-cp310-abi3-manylinux2014_x86_64.whl", hash = "sha256:6c0f98f10c8a05ea30f8993dfef2de093d27b490fdae78bb60c8343795d55011", size = 327130, upload-time = "2026-05-19T23:02:24.637Z" }, + { url = "https://files.pythonhosted.org/packages/8e/e5/e46adb0badc388bfb84877a5f9f026aff63f60e611016cf64dbe77e05446/protobuf-7.35.0-cp310-abi3-win32.whl", hash = "sha256:4c4617b83ade0e279d1d2bfe04025a1adb87f9ed657de038620dc0ff959357f6", size = 428946, upload-time = "2026-05-19T23:02:25.741Z" }, + { url = "https://files.pythonhosted.org/packages/a7/ab/547fbd9e16d879dd13c167478f8ae0a83a428008ca07a5e06acdc23ad473/protobuf-7.35.0-cp310-abi3-win_amd64.whl", hash = "sha256:f05bcadf9a2a6b8dda047007075135fb7d08c73d9177aabc067e1be46881a201", size = 439996, upload-time = "2026-05-19T23:02:26.808Z" }, + { url = "https://files.pythonhosted.org/packages/b8/ef/50433d346c56657a70d27f156c7b349ac59a068b01de4eb796e747eecc43/protobuf-7.35.0-py3-none-any.whl", hash = "sha256:c13f325cf242bad135c350629eeb5d54b24228eb472fb3e2e9ebbd4c5dc20ca0", size = 171659, upload-time = "2026-05-19T23:02:27.842Z" }, +] + [[package]] name = "psycopg" version = "3.3.3" From 8cc64b1176e7b1302e9f5b814a065e5e10bad293 Mon Sep 17 00:00:00 2001 From: Prajna1999 Date: Mon, 1 Jun 2026 16:06:24 +0530 Subject: [PATCH 04/15] fea: BYOK for secrets manager --- backend/app/core/cloud/storage.py | 93 ++++++++++ backend/app/core/config.py | 9 +- backend/app/core/providers.py | 22 ++- backend/app/crud/credentials.py | 43 +++++ .../app/services/llm/providers/registry.py | 22 ++- backend/app/tests/core/test_storage_byok.py | 165 ++++++++++++++++++ backend/app/tests/crud/test_credentials.py | 68 ++++++++ .../services/llm/providers/test_registry.py | 36 ++++ backend/app/tests/test_utils.py | 23 ++- 9 files changed, 465 insertions(+), 16 deletions(-) create mode 100644 backend/app/tests/core/test_storage_byok.py diff --git a/backend/app/core/cloud/storage.py b/backend/app/core/cloud/storage.py index 608e9f612..1ea6cb027 100644 --- a/backend/app/core/cloud/storage.py +++ b/backend/app/core/cloud/storage.py @@ -338,6 +338,99 @@ class SecretsManagerError(Exception): pass +def upsert_byok_secret_for_provider( + provider: str, + credentials: dict, + *, + org_id: int, + project_id: int, +) -> dict: + """Persist provider-specific BYOK secrets to AWS Secrets Manager and + rewrite the credentials dict so only references (not raw secrets) are + stored in the DB. + + Currently only ``google-vertex`` needs this: when ``sa_key`` is present, + the SA JSON is uploaded to SM under a deterministic per-project name, + and the dict is rewritten to carry ``gcp_sa_secret_name`` / + ``gcp_sa_secret_region`` instead. + + Returns the (possibly rewritten) credentials dict. No-op for providers + without BYOK secrets or when the optional ``sa_key`` field is absent. + """ + if provider == "google-vertex": + sa_key = credentials.get("sa_key") + # The validator only checks key presence, not shape/truthiness — so + # null, empty dict, or a JSON string would slip through and leave a + # partial-BYOK row (user api_key + platform SA), which is exactly + # the broken hybrid BYOK enforcement is meant to prevent. + if not isinstance(sa_key, dict) or not sa_key: + raise ValueError( + "google-vertex 'sa_key' must be a non-empty service-account JSON object" + ) + secret_name = ( + f"kaapi/{settings.ENVIRONMENT}/orgs/{org_id}" + f"/projects/{project_id}/google-vertex/sa" + ) + put_gcp_service_account(sa_key, secret_name=secret_name) + rewritten = {k: v for k, v in credentials.items() if k != "sa_key"} + rewritten["gcp_sa_secret_name"] = secret_name + rewritten["gcp_sa_secret_region"] = settings.GCP_SA_SECRET_REGION + return rewritten + return credentials + + +def put_gcp_service_account( + sa_info: dict, + *, + secret_name: str, + region_name: str | None = None, +) -> None: + """Create or update a GCP service-account JSON key in AWS Secrets Manager. + + Idempotent: tries CreateSecret first, falls back to PutSecretValue when + the secret already exists. Validates SA shape upfront so we never store + junk. Invalidates the ``get_gcp_service_account`` LRU cache on success + so the next read picks up the rotated key. + """ + if sa_info.get("type") != "service_account": + raise SecretsManagerError( + f"Refusing to write secret '{secret_name}': not a GCP service-account key " + f"(got type={sa_info.get('type')!r})" + ) + + region = region_name or settings.GCP_SA_SECRET_REGION + payload = json.dumps(sa_info) + + sm_client = boto3.session.Session().client( + service_name="secretsmanager", region_name=region + ) + + try: + try: + sm_client.create_secret(Name=secret_name, SecretString=payload) + action = "created" + except sm_client.exceptions.ResourceExistsException: + sm_client.put_secret_value(SecretId=secret_name, SecretString=payload) + action = "updated" + except ClientError as e: + code = e.response.get("Error", {}).get("Code", "Unknown") + logger.error( + f"[put_gcp_service_account] Secret write failed | " + f"secret={_mask(secret_name)}, region={region}, code={code}" + ) + raise SecretsManagerError( + f"Failed to write secret '{secret_name}' (code={code}): {e}" + ) from e + + get_gcp_service_account.cache_clear() + logger.info( + f"[put_gcp_service_account] Secret {action} | " + f"secret={_mask(secret_name)}, region={region}, " + f"project_id={sa_info.get('project_id')}, " + f"client_email={_mask(sa_info.get('client_email', ''))}" + ) + + @ft.lru_cache(maxsize=32) def get_gcp_service_account( secret_name: str | None = None, diff --git a/backend/app/core/config.py b/backend/app/core/config.py index 33b884ed3..acad45b3f 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -104,11 +104,14 @@ def SQLALCHEMY_DATABASE_URI(self) -> PostgresDsn: AWS_DEFAULT_REGION: str = "" AWS_S3_BUCKET_PREFIX: str = "" - # GCP Vertex AI — single shared service-account fetched from AWS Secrets Manager. - # BYOK (per-project credentials) lands later. + # GCP Vertex AI platform defaults. Used when a project does not register + # its own google-vertex credential row (BYOK is all-or-nothing — see the + # Provider.GOOGLE_VERTEX comment in app/core/providers.py). + GCP_VERTEX_API_KEY: str = "" + GCP_VERTEX_LOCATION: str = "" + GCP_PROJECT_ID: str = "" GCP_SA_SECRET_NAME: str = "" GCP_SA_SECRET_REGION: str = "" - GCP_PROJECT_ID: str = "" GCS_AUDIO_BUCKET: str = "" # RabbitMQ configuration for Celery broker diff --git a/backend/app/core/providers.py b/backend/app/core/providers.py index c1e21f7ae..22b74c784 100644 --- a/backend/app/core/providers.py +++ b/backend/app/core/providers.py @@ -48,8 +48,28 @@ class ProviderConfig: Provider.ANTHROPIC: ProviderConfig( required_fields=["api_key"], sensitive_fields=["api_key"] ), + # google-vertex BYOK is all-or-nothing: if a credential row is registered + # for this provider, it must carry the full kit. Partial registrations + # would mix the user's api_key (scoped to their GCP project) with the + # platform SA / bucket (different GCP project) — Vertex cannot read across + # projects without explicit cross-project IAM, so we forbid that shape. + # + # Projects that omit the credential row entirely fall through to the + # platform-shared defaults (GCP_VERTEX_API_KEY, GCP_PROJECT_ID, + # GCP_VERTEX_LOCATION, GCP_SA_SECRET_NAME, GCS_AUDIO_BUCKET) in settings. + # + # sa_key (dict) is the raw GCP service-account JSON. It's stripped before + # DB storage by upsert_byok_secret_for_provider and uploaded to AWS + # Secrets Manager; the persisted dict carries gcp_sa_secret_name / + # gcp_sa_secret_region in its place. Provider.GOOGLE_VERTEX: ProviderConfig( - required_fields=["api_key", "project_id", "location"], + required_fields=[ + "api_key", + "project_id", + "location", + "sa_key", + "gcs_bucket", + ], sensitive_fields=["api_key"], ), Provider.WEBHOOK_SECRET: ProviderConfig( diff --git a/backend/app/crud/credentials.py b/backend/app/crud/credentials.py index 1d23ff587..735daf10b 100644 --- a/backend/app/crud/credentials.py +++ b/backend/app/crud/credentials.py @@ -5,6 +5,7 @@ from sqlalchemy.exc import IntegrityError from sqlmodel import Session, select +from app.core.cloud.storage import SecretsManagerError, upsert_byok_secret_for_provider from app.core.exception_handlers import HTTPException from app.core.providers import validate_provider, validate_provider_credentials from app.core.security import decrypt_credentials, encrypt_credentials @@ -36,6 +37,27 @@ def set_creds_for_org( ) raise HTTPException(status_code=400, detail=str(e)) + # BYOK side-effect: e.g. google-vertex sa_key → AWS Secrets Manager, + # dict rewritten to carry only the SM reference before persistence. + try: + credentials = upsert_byok_secret_for_provider( + provider, credentials, org_id=organization_id, project_id=project_id + ) + except ValueError as e: + logger.warning( + f"[set_creds_for_org] BYOK shape error | project_id: {project_id}, provider: {provider}, error: {str(e)}" + ) + raise HTTPException(status_code=400, detail=str(e)) + except SecretsManagerError as e: + logger.error( + f"[set_creds_for_org] BYOK secret store failed | project_id: {project_id}, provider: {provider}, error: {str(e)}", + exc_info=True, + ) + raise HTTPException( + status_code=502, + detail=f"Failed to store provider secret: {str(e)}", + ) + # Encrypt entire credentials object encrypted_credentials = encrypt_credentials(credentials) @@ -202,6 +224,27 @@ def update_creds_for_org( ) raise HTTPException(status_code=400, detail=str(e)) + # BYOK side-effect: e.g. google-vertex sa_key → AWS Secrets Manager, + # dict rewritten to carry only the SM reference before persistence. + try: + credential_data = upsert_byok_secret_for_provider( + creds_in.provider, credential_data, org_id=org_id, project_id=project_id + ) + except ValueError as e: + logger.warning( + f"[update_creds_for_org] BYOK shape error | organization_id: {org_id}, project_id: {project_id}, provider: {creds_in.provider}, error: {str(e)}" + ) + raise HTTPException(status_code=400, detail=str(e)) + except SecretsManagerError as e: + logger.error( + f"[update_creds_for_org] BYOK secret store failed | organization_id: {org_id}, project_id: {project_id}, provider: {creds_in.provider}, error: {str(e)}", + exc_info=True, + ) + raise HTTPException( + status_code=502, + detail=f"Failed to store provider secret: {str(e)}", + ) + # Encrypt the entire credentials object encrypted_credentials = encrypt_credentials(credential_data) diff --git a/backend/app/services/llm/providers/registry.py b/backend/app/services/llm/providers/registry.py index daa42bf12..7a1081352 100644 --- a/backend/app/services/llm/providers/registry.py +++ b/backend/app/services/llm/providers/registry.py @@ -76,9 +76,25 @@ def get_llm_provider( ) if not credentials: - raise ValueError( - f"Credentials for provider '{credential_provider}' not configured for this project." - ) + # google-vertex falls back to platform-shared defaults from settings + # when no project credential row exists. BYOK is all-or-nothing for + # this provider (see Provider.GOOGLE_VERTEX in app/core/providers.py), + # so projects either register the full kit or use the platform set. + if credential_provider == "google-vertex": + from app.core.config import settings + + credentials = { + "api_key": settings.GCP_VERTEX_API_KEY, + "project_id": settings.GCP_PROJECT_ID, + "location": settings.GCP_VERTEX_LOCATION, + "gcp_sa_secret_name": settings.GCP_SA_SECRET_NAME, + "gcp_sa_secret_region": settings.GCP_SA_SECRET_REGION, + "gcs_bucket": settings.GCS_AUDIO_BUCKET, + } + else: + raise ValueError( + f"Credentials for provider '{credential_provider}' not configured for this project." + ) try: client = provider_class.create_client(credentials=credentials) diff --git a/backend/app/tests/core/test_storage_byok.py b/backend/app/tests/core/test_storage_byok.py new file mode 100644 index 000000000..5cfcb48dd --- /dev/null +++ b/backend/app/tests/core/test_storage_byok.py @@ -0,0 +1,165 @@ +"""Tests for the BYOK helpers in app.core.cloud.storage.""" + +from unittest.mock import MagicMock, patch + +import pytest +from botocore.exceptions import ClientError + +from app.core.cloud.storage import ( + SecretsManagerError, + put_gcp_service_account, + upsert_byok_secret_for_provider, +) + + +VALID_SA = { + "type": "service_account", + "project_id": "starlit-lotus-492004-k0", + "client_email": "kaapi-test@starlit-lotus-492004-k0.iam.gserviceaccount.com", + "private_key": "-----BEGIN PRIVATE KEY-----\nfake\n-----END PRIVATE KEY-----", +} + + +@pytest.fixture +def mock_sm_client(): + client = MagicMock() + client.exceptions.ResourceExistsException = type( + "ResourceExistsException", (ClientError,), {} + ) + with patch("app.core.cloud.storage.boto3.session.Session") as mock_session, patch( + "app.core.cloud.storage.get_gcp_service_account.cache_clear" + ) as mock_clear: + mock_session.return_value.client.return_value = client + yield client, mock_clear + + +class TestPutGcpServiceAccount: + def test_creates_secret_when_absent(self, mock_sm_client): + client, mock_clear = mock_sm_client + put_gcp_service_account( + VALID_SA, secret_name="kaapi/dev/orgs/1/projects/2/google-vertex/sa" + ) + + client.create_secret.assert_called_once() + kwargs = client.create_secret.call_args.kwargs + assert kwargs["Name"] == "kaapi/dev/orgs/1/projects/2/google-vertex/sa" + # SA JSON round-trips through json.dumps; verify a known field survives. + assert '"type": "service_account"' in kwargs["SecretString"] + client.put_secret_value.assert_not_called() + mock_clear.assert_called_once() + + def test_updates_secret_when_present(self, mock_sm_client): + client, mock_clear = mock_sm_client + client.create_secret.side_effect = client.exceptions.ResourceExistsException( + {"Error": {"Code": "ResourceExistsException"}}, "CreateSecret" + ) + + put_gcp_service_account( + VALID_SA, secret_name="kaapi/dev/orgs/1/projects/2/google-vertex/sa" + ) + + client.create_secret.assert_called_once() + client.put_secret_value.assert_called_once() + kwargs = client.put_secret_value.call_args.kwargs + assert kwargs["SecretId"] == "kaapi/dev/orgs/1/projects/2/google-vertex/sa" + mock_clear.assert_called_once() + + def test_rejects_non_service_account_payload(self, mock_sm_client): + client, _ = mock_sm_client + bad = {"type": "user_account", "client_id": "x"} + with pytest.raises(SecretsManagerError, match="not a GCP service-account"): + put_gcp_service_account(bad, secret_name="kaapi/anything") + client.create_secret.assert_not_called() + client.put_secret_value.assert_not_called() + + def test_wraps_aws_errors(self, mock_sm_client): + client, _ = mock_sm_client + client.create_secret.side_effect = ClientError( + {"Error": {"Code": "AccessDeniedException", "Message": "nope"}}, + "CreateSecret", + ) + with pytest.raises(SecretsManagerError, match="AccessDeniedException"): + put_gcp_service_account(VALID_SA, secret_name="kaapi/anything") + + +class TestUpsertByokSecretForProvider: + def test_google_vertex_with_sa_key_strips_and_writes(self): + creds = { + "api_key": "vkey", + "project_id": "starlit-lotus-492004-k0", + "location": "us-central1", + "sa_key": VALID_SA, + "gcs_bucket": "my-bucket", + } + with patch("app.core.cloud.storage.put_gcp_service_account") as mock_put, patch( + "app.core.cloud.storage.settings" + ) as mock_settings: + mock_settings.ENVIRONMENT = "development" + mock_settings.GCP_SA_SECRET_REGION = "ap-south-1" + + result = upsert_byok_secret_for_provider( + "google-vertex", creds, org_id=7, project_id=42 + ) + + expected_name = "kaapi/development/orgs/7/projects/42/google-vertex/sa" + mock_put.assert_called_once_with(VALID_SA, secret_name=expected_name) + assert "sa_key" not in result + assert result["gcp_sa_secret_name"] == expected_name + assert result["gcp_sa_secret_region"] == "ap-south-1" + # Untouched fields preserved. + assert result["api_key"] == "vkey" + assert result["gcs_bucket"] == "my-bucket" + + def test_google_vertex_rejects_null_sa_key(self): + creds = { + "api_key": "vkey", + "project_id": "p", + "location": "us-central1", + "sa_key": None, + } + with pytest.raises(ValueError, match="sa_key.*non-empty service-account JSON"): + upsert_byok_secret_for_provider( + "google-vertex", creds, org_id=1, project_id=1 + ) + + def test_google_vertex_rejects_string_sa_key(self): + creds = { + "api_key": "vkey", + "project_id": "p", + "location": "us-central1", + "sa_key": "not-a-dict", + } + with pytest.raises(ValueError, match="non-empty service-account JSON"): + upsert_byok_secret_for_provider( + "google-vertex", creds, org_id=1, project_id=1 + ) + + def test_google_vertex_rejects_empty_sa_key(self): + creds = { + "api_key": "vkey", + "project_id": "p", + "location": "us-central1", + "sa_key": {}, + } + with pytest.raises(ValueError, match="non-empty service-account JSON"): + upsert_byok_secret_for_provider( + "google-vertex", creds, org_id=1, project_id=1 + ) + + def test_google_vertex_rejects_missing_sa_key(self): + # Validator at the route requires sa_key, but the hook also rejects + # absence defensively in case it's invoked outside the route flow. + creds = {"api_key": "vkey", "project_id": "p", "location": "us-central1"} + with pytest.raises(ValueError, match="non-empty service-account JSON"): + upsert_byok_secret_for_provider( + "google-vertex", creds, org_id=1, project_id=1 + ) + + def test_other_provider_is_noop_even_with_sa_key(self): + creds = {"api_key": "k", "sa_key": VALID_SA} + with patch("app.core.cloud.storage.put_gcp_service_account") as mock_put: + result = upsert_byok_secret_for_provider( + "openai", creds, org_id=1, project_id=1 + ) + mock_put.assert_not_called() + assert result == creds # sa_key passes through (validator's job to reject) diff --git a/backend/app/tests/crud/test_credentials.py b/backend/app/tests/crud/test_credentials.py index 14437f4fb..e551166c9 100644 --- a/backend/app/tests/crud/test_credentials.py +++ b/backend/app/tests/crud/test_credentials.py @@ -1,3 +1,5 @@ +from unittest.mock import patch + import pytest from sqlmodel import Session @@ -86,6 +88,72 @@ def test_get_creds_by_org(db: Session) -> None: assert {cred.provider for cred in retrieved_creds} == {"openai", "langfuse"} +def test_set_credentials_for_google_vertex_with_sa_key(db: Session) -> None: + """sa_key on google-vertex must be uploaded to SM and stripped before storage; + the persisted credential dict carries only the secret reference.""" + project = create_test_project(db) + + sa_key = { + "type": "service_account", + "project_id": "starlit-lotus-492004-k0", + "client_email": "test@starlit-lotus-492004-k0.iam.gserviceaccount.com", + "private_key": "-----BEGIN PRIVATE KEY-----\nfake\n-----END PRIVATE KEY-----", + } + payload = CredsCreate( + is_active=True, + credential={ + "google-vertex": { + "api_key": "vkey", + "project_id": "starlit-lotus-492004-k0", + "location": "us-central1", + "sa_key": sa_key, + "gcs_bucket": "my-bucket", + } + }, + ) + + with patch("app.crud.credentials.upsert_byok_secret_for_provider") as mock_hook: + # Simulate the real hook's rewrite without touching AWS. + secret_name = ( + f"kaapi/test/orgs/{project.organization_id}" + f"/projects/{project.id}/google-vertex/sa" + ) + mock_hook.return_value = { + "api_key": "vkey", + "project_id": "starlit-lotus-492004-k0", + "location": "us-central1", + "gcs_bucket": "my-bucket", + "gcp_sa_secret_name": secret_name, + "gcp_sa_secret_region": "ap-south-1", + } + + created = set_creds_for_org( + session=db, + creds_add=payload, + organization_id=project.organization_id, + project_id=project.id, + ) + + mock_hook.assert_called_once() + args, kwargs = mock_hook.call_args + assert args[0] == "google-vertex" + assert args[1]["sa_key"] == sa_key + assert kwargs == {"org_id": project.organization_id, "project_id": project.id} + + assert len(created) == 1 + stored = get_provider_credential( + session=db, + org_id=project.organization_id, + provider="google-vertex", + project_id=project.id, + ) + assert stored is not None + assert "sa_key" not in stored + assert stored["gcp_sa_secret_name"] == secret_name + assert stored["gcp_sa_secret_region"] == "ap-south-1" + assert stored["api_key"] == "vkey" + + def test_get_provider_credential(db: Session) -> None: """Test retrieving credentials for a specific provider.""" credentials_create = test_credential_data(db) diff --git a/backend/app/tests/services/llm/providers/test_registry.py b/backend/app/tests/services/llm/providers/test_registry.py index 4349da107..a6978dc60 100644 --- a/backend/app/tests/services/llm/providers/test_registry.py +++ b/backend/app/tests/services/llm/providers/test_registry.py @@ -104,3 +104,39 @@ def test_get_llm_provider_with_missing_credentials(self, db: Session): ) assert "not configured for this project" in str(exc_info.value) + + def test_google_vertex_falls_back_to_platform_settings(self, db: Session): + """No credential row for google-vertex → registry synthesizes platform + defaults from settings (the BYOK-or-platform contract).""" + from app.services.llm.providers.gai_vertex import ( + GoogleVertexAIProvider, + VertexClient, + ) + + project = get_project(db) + + with patch( + "app.crud.credentials.get_provider_credential" + ) as mock_get_creds, patch("app.core.config.settings") as mock_settings: + mock_get_creds.return_value = None + mock_settings.GCP_VERTEX_API_KEY = "platform-key" + mock_settings.GCP_PROJECT_ID = "platform-project" + mock_settings.GCP_VERTEX_LOCATION = "us-central1" + mock_settings.GCP_SA_SECRET_NAME = "platform/secret" + mock_settings.GCP_SA_SECRET_REGION = "ap-south-1" + mock_settings.GCS_AUDIO_BUCKET = "platform-bucket" + + provider = get_llm_provider( + session=db, + provider_type="google-vertex-native", + project_id=project.id, + organization_id=project.organization_id, + ) + + assert isinstance(provider, GoogleVertexAIProvider) + assert isinstance(provider.client, VertexClient) + assert provider.client.api_key == "platform-key" + assert provider.client.project_id == "platform-project" + assert provider.client.location == "us-central1" + assert provider.client.gcp_sa_secret_name == "platform/secret" + assert provider.client.gcs_bucket == "platform-bucket" diff --git a/backend/app/tests/test_utils.py b/backend/app/tests/test_utils.py index 4d244e672..72eda1de5 100644 --- a/backend/app/tests/test_utils.py +++ b/backend/app/tests/test_utils.py @@ -259,22 +259,24 @@ def test_handles_http_error(self, mock_get, mock_validate) -> None: # --------------------------------------------------------------------------- class TestResolveAudioUrl: @patch("app.utils.download_audio_bytes") - def test_writes_downloaded_bytes_to_temp_file(self, mock_download) -> None: + def test_returns_audio_ref(self, mock_download) -> None: + from app.core.audio_utils import AudioRef + audio_data = b"RIFF" + b"\x00" * 36 mock_download.return_value = (audio_data, None) - path, error = resolve_audio_url("https://cdn.example.com/a.wav", "audio/wav") + ref, error = resolve_audio_url("https://cdn.example.com/a.wav", "audio/wav") assert error is None - assert path.endswith(".wav") - assert Path(path).read_bytes() == audio_data - cleanup_temp_file(path) + assert isinstance(ref, AudioRef) + assert ref.bytes_ == audio_data + assert ref.mime_type == "audio/wav" @patch("app.utils.download_audio_bytes") def test_propagates_download_error(self, mock_download) -> None: mock_download.return_value = (None, "Timed out downloading audio from URL") - path, error = resolve_audio_url("https://example.com/a.wav", "audio/wav") - assert path == "" + ref, error = resolve_audio_url("https://example.com/a.wav", "audio/wav") + assert ref is None assert "Timed out" in error @@ -340,7 +342,10 @@ def test_pdf_input(self) -> None: @patch("app.utils.resolve_audio_url") def test_audio_url_input(self, mock_resolve_url) -> None: - mock_resolve_url.return_value = ("/tmp/audio_test.wav", None) + from app.core.audio_utils import AudioRef + + mocked_ref = AudioRef(bytes_=b"audio", mime_type="audio/wav") + mock_resolve_url.return_value = (mocked_ref, None) audio = AudioInput( content=AudioContent( format="url", @@ -350,7 +355,7 @@ def test_audio_url_input(self, mock_resolve_url) -> None: ) result, error = resolve_input(audio) assert error is None - assert result == "/tmp/audio_test.wav" + assert result is mocked_ref def test_multimodal_text_and_image(self) -> None: parts = [ From a7eff4bcf67c0cd9f7f82d3fa9dbc26c955e9fe8 Mon Sep 17 00:00:00 2001 From: Prajna1999 Date: Tue, 2 Jun 2026 11:51:21 +0530 Subject: [PATCH 05/15] feat: remove secrets manager --- backend/app/core/cloud/__init__.py | 1 - backend/app/core/cloud/storage.py | 161 ----------------- backend/app/core/config.py | 5 +- backend/app/core/providers.py | 16 +- backend/app/crud/credentials.py | 43 ----- .../app/services/llm/providers/gai_vertex.py | 70 ++++++-- .../app/services/llm/providers/registry.py | 28 +-- backend/app/tests/core/test_storage_byok.py | 165 ------------------ backend/app/tests/crud/test_credentials.py | 46 ++--- .../services/llm/providers/test_gai_vertex.py | 7 +- .../services/llm/providers/test_registry.py | 26 ++- 11 files changed, 96 insertions(+), 472 deletions(-) delete mode 100644 backend/app/tests/core/test_storage_byok.py diff --git a/backend/app/core/cloud/__init__.py b/backend/app/core/cloud/__init__.py index bffc5a964..b6b0b08ec 100644 --- a/backend/app/core/cloud/__init__.py +++ b/backend/app/core/cloud/__init__.py @@ -4,6 +4,5 @@ CloudStorage, CloudStorageError, get_cloud_storage, - get_gcp_service_account, upload_audio_to_gcs, ) diff --git a/backend/app/core/cloud/storage.py b/backend/app/core/cloud/storage.py index 1ea6cb027..fc22e82b4 100644 --- a/backend/app/core/cloud/storage.py +++ b/backend/app/core/cloud/storage.py @@ -1,5 +1,4 @@ import os -import json import mimetypes from sqlmodel import Session from uuid import UUID, uuid4 @@ -324,169 +323,9 @@ def get_cloud_storage(session: Session, project_id: int) -> CloudStorage: raise -# ────────────────────────────────────────────────────────────────────────────── -# GCP service-account fetch (AWS Secrets Manager) + GCS upload util. -# BYOK-ready: every util takes explicit secret_name / bucket / project_id so -# per-project credentials can be passed in. Settings provide the platform -# defaults for the shared SA path. -# ────────────────────────────────────────────────────────────────────────────── - GCS_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) -class SecretsManagerError(Exception): - pass - - -def upsert_byok_secret_for_provider( - provider: str, - credentials: dict, - *, - org_id: int, - project_id: int, -) -> dict: - """Persist provider-specific BYOK secrets to AWS Secrets Manager and - rewrite the credentials dict so only references (not raw secrets) are - stored in the DB. - - Currently only ``google-vertex`` needs this: when ``sa_key`` is present, - the SA JSON is uploaded to SM under a deterministic per-project name, - and the dict is rewritten to carry ``gcp_sa_secret_name`` / - ``gcp_sa_secret_region`` instead. - - Returns the (possibly rewritten) credentials dict. No-op for providers - without BYOK secrets or when the optional ``sa_key`` field is absent. - """ - if provider == "google-vertex": - sa_key = credentials.get("sa_key") - # The validator only checks key presence, not shape/truthiness — so - # null, empty dict, or a JSON string would slip through and leave a - # partial-BYOK row (user api_key + platform SA), which is exactly - # the broken hybrid BYOK enforcement is meant to prevent. - if not isinstance(sa_key, dict) or not sa_key: - raise ValueError( - "google-vertex 'sa_key' must be a non-empty service-account JSON object" - ) - secret_name = ( - f"kaapi/{settings.ENVIRONMENT}/orgs/{org_id}" - f"/projects/{project_id}/google-vertex/sa" - ) - put_gcp_service_account(sa_key, secret_name=secret_name) - rewritten = {k: v for k, v in credentials.items() if k != "sa_key"} - rewritten["gcp_sa_secret_name"] = secret_name - rewritten["gcp_sa_secret_region"] = settings.GCP_SA_SECRET_REGION - return rewritten - return credentials - - -def put_gcp_service_account( - sa_info: dict, - *, - secret_name: str, - region_name: str | None = None, -) -> None: - """Create or update a GCP service-account JSON key in AWS Secrets Manager. - - Idempotent: tries CreateSecret first, falls back to PutSecretValue when - the secret already exists. Validates SA shape upfront so we never store - junk. Invalidates the ``get_gcp_service_account`` LRU cache on success - so the next read picks up the rotated key. - """ - if sa_info.get("type") != "service_account": - raise SecretsManagerError( - f"Refusing to write secret '{secret_name}': not a GCP service-account key " - f"(got type={sa_info.get('type')!r})" - ) - - region = region_name or settings.GCP_SA_SECRET_REGION - payload = json.dumps(sa_info) - - sm_client = boto3.session.Session().client( - service_name="secretsmanager", region_name=region - ) - - try: - try: - sm_client.create_secret(Name=secret_name, SecretString=payload) - action = "created" - except sm_client.exceptions.ResourceExistsException: - sm_client.put_secret_value(SecretId=secret_name, SecretString=payload) - action = "updated" - except ClientError as e: - code = e.response.get("Error", {}).get("Code", "Unknown") - logger.error( - f"[put_gcp_service_account] Secret write failed | " - f"secret={_mask(secret_name)}, region={region}, code={code}" - ) - raise SecretsManagerError( - f"Failed to write secret '{secret_name}' (code={code}): {e}" - ) from e - - get_gcp_service_account.cache_clear() - logger.info( - f"[put_gcp_service_account] Secret {action} | " - f"secret={_mask(secret_name)}, region={region}, " - f"project_id={sa_info.get('project_id')}, " - f"client_email={_mask(sa_info.get('client_email', ''))}" - ) - - -@ft.lru_cache(maxsize=32) -def get_gcp_service_account( - secret_name: str | None = None, - region_name: str | None = None, -) -> dict: - """Fetch a GCP service-account JSON key from AWS Secrets Manager. - - Cached per (secret_name, region) — restart the process or call - ``get_gcp_service_account.cache_clear()`` to pick up a rotated key. - - BYOK: pass a project-owned ``secret_name``. Defaults to the platform-shared - secret configured in settings. - """ - secret = secret_name or settings.GCP_SA_SECRET_NAME - region = region_name or settings.GCP_SA_SECRET_REGION - - sm_client = boto3.session.Session().client( - service_name="secretsmanager", region_name=region - ) - - try: - response = sm_client.get_secret_value(SecretId=secret) - except ClientError as e: - code = e.response.get("Error", {}).get("Code", "Unknown") - logger.error( - f"[get_gcp_service_account] Secret fetch failed | " - f"secret={_mask(secret)}, region={region}, code={code}" - ) - raise SecretsManagerError( - f"Failed to fetch secret '{secret}' (code={code}): {e}" - ) from e - - if "SecretString" not in response: - raise SecretsManagerError( - f"Secret '{secret}' has no SecretString (binary secret unsupported)" - ) - - try: - sa_info = json.loads(response["SecretString"]) - except json.JSONDecodeError as e: - raise SecretsManagerError(f"Secret '{secret}' is not valid JSON: {e}") from e - - if sa_info.get("type") != "service_account": - raise SecretsManagerError( - f"Secret '{secret}' is not a GCP service-account key " - f"(got type={sa_info.get('type')!r})" - ) - - logger.info( - f"[get_gcp_service_account] Loaded SA key | " - f"secret={_mask(secret)}, project_id={sa_info.get('project_id')}, " - f"client_email={_mask(sa_info.get('client_email', ''))}" - ) - return sa_info - - _MIME_TO_EXT = { "audio/wav": ".wav", "audio/mpeg": ".mp3", diff --git a/backend/app/core/config.py b/backend/app/core/config.py index acad45b3f..e05c04199 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -110,8 +110,9 @@ def SQLALCHEMY_DATABASE_URI(self) -> PostgresDsn: GCP_VERTEX_API_KEY: str = "" GCP_VERTEX_LOCATION: str = "" GCP_PROJECT_ID: str = "" - GCP_SA_SECRET_NAME: str = "" - GCP_SA_SECRET_REGION: str = "" + # Filesystem path to the platform-default GCP service-account JSON. + # Used by the registry fallback when a project has no google-vertex row. + GCP_SA_KEY_PATH: str = "" GCS_AUDIO_BUCKET: str = "" # RabbitMQ configuration for Celery broker diff --git a/backend/app/core/providers.py b/backend/app/core/providers.py index 22b74c784..bca063dc6 100644 --- a/backend/app/core/providers.py +++ b/backend/app/core/providers.py @@ -48,20 +48,6 @@ class ProviderConfig: Provider.ANTHROPIC: ProviderConfig( required_fields=["api_key"], sensitive_fields=["api_key"] ), - # google-vertex BYOK is all-or-nothing: if a credential row is registered - # for this provider, it must carry the full kit. Partial registrations - # would mix the user's api_key (scoped to their GCP project) with the - # platform SA / bucket (different GCP project) — Vertex cannot read across - # projects without explicit cross-project IAM, so we forbid that shape. - # - # Projects that omit the credential row entirely fall through to the - # platform-shared defaults (GCP_VERTEX_API_KEY, GCP_PROJECT_ID, - # GCP_VERTEX_LOCATION, GCP_SA_SECRET_NAME, GCS_AUDIO_BUCKET) in settings. - # - # sa_key (dict) is the raw GCP service-account JSON. It's stripped before - # DB storage by upsert_byok_secret_for_provider and uploaded to AWS - # Secrets Manager; the persisted dict carries gcp_sa_secret_name / - # gcp_sa_secret_region in its place. Provider.GOOGLE_VERTEX: ProviderConfig( required_fields=[ "api_key", @@ -70,7 +56,7 @@ class ProviderConfig: "sa_key", "gcs_bucket", ], - sensitive_fields=["api_key"], + sensitive_fields=["api_key", "sa_key"], ), Provider.WEBHOOK_SECRET: ProviderConfig( required_fields=["webhook_secret"], sensitive_fields=["webhook_secret"] diff --git a/backend/app/crud/credentials.py b/backend/app/crud/credentials.py index 735daf10b..1d23ff587 100644 --- a/backend/app/crud/credentials.py +++ b/backend/app/crud/credentials.py @@ -5,7 +5,6 @@ from sqlalchemy.exc import IntegrityError from sqlmodel import Session, select -from app.core.cloud.storage import SecretsManagerError, upsert_byok_secret_for_provider from app.core.exception_handlers import HTTPException from app.core.providers import validate_provider, validate_provider_credentials from app.core.security import decrypt_credentials, encrypt_credentials @@ -37,27 +36,6 @@ def set_creds_for_org( ) raise HTTPException(status_code=400, detail=str(e)) - # BYOK side-effect: e.g. google-vertex sa_key → AWS Secrets Manager, - # dict rewritten to carry only the SM reference before persistence. - try: - credentials = upsert_byok_secret_for_provider( - provider, credentials, org_id=organization_id, project_id=project_id - ) - except ValueError as e: - logger.warning( - f"[set_creds_for_org] BYOK shape error | project_id: {project_id}, provider: {provider}, error: {str(e)}" - ) - raise HTTPException(status_code=400, detail=str(e)) - except SecretsManagerError as e: - logger.error( - f"[set_creds_for_org] BYOK secret store failed | project_id: {project_id}, provider: {provider}, error: {str(e)}", - exc_info=True, - ) - raise HTTPException( - status_code=502, - detail=f"Failed to store provider secret: {str(e)}", - ) - # Encrypt entire credentials object encrypted_credentials = encrypt_credentials(credentials) @@ -224,27 +202,6 @@ def update_creds_for_org( ) raise HTTPException(status_code=400, detail=str(e)) - # BYOK side-effect: e.g. google-vertex sa_key → AWS Secrets Manager, - # dict rewritten to carry only the SM reference before persistence. - try: - credential_data = upsert_byok_secret_for_provider( - creds_in.provider, credential_data, org_id=org_id, project_id=project_id - ) - except ValueError as e: - logger.warning( - f"[update_creds_for_org] BYOK shape error | organization_id: {org_id}, project_id: {project_id}, provider: {creds_in.provider}, error: {str(e)}" - ) - raise HTTPException(status_code=400, detail=str(e)) - except SecretsManagerError as e: - logger.error( - f"[update_creds_for_org] BYOK secret store failed | organization_id: {org_id}, project_id: {project_id}, provider: {creds_in.provider}, error: {str(e)}", - exc_info=True, - ) - raise HTTPException( - status_code=502, - detail=f"Failed to store provider secret: {str(e)}", - ) - # Encrypt the entire credentials object encrypted_credentials = encrypt_credentials(credential_data) diff --git a/backend/app/services/llm/providers/gai_vertex.py b/backend/app/services/llm/providers/gai_vertex.py index 78fc640af..ceccc26a3 100644 --- a/backend/app/services/llm/providers/gai_vertex.py +++ b/backend/app/services/llm/providers/gai_vertex.py @@ -1,7 +1,9 @@ import base64 +import json import logging import os import uuid +from pathlib import Path from typing import Any import requests @@ -12,7 +14,7 @@ convert_pcm_to_ogg, pcm_to_wav, ) -from app.core.cloud.storage import get_gcp_service_account, upload_audio_to_gcs +from app.core.cloud.storage import upload_audio_to_gcs from app.core.config import settings from app.models.llm import ( LLMCallResponse, @@ -45,11 +47,27 @@ } +def _load_platform_sa_info() -> dict | None: + """Load the platform-default GCP SA JSON from disk, if configured.""" + sa_path = settings.GCP_SA_KEY_PATH + if not sa_path or not Path(sa_path).is_file(): + return None + try: + return json.loads(Path(sa_path).read_text()) + except (OSError, json.JSONDecodeError) as e: + logger.warning( + f"[_load_platform_sa_info] Failed to load platform SA key | " + f"path={sa_path}, error={e}" + ) + return None + + class VertexClient: """Holds Vertex AI connection details. Pure config — no SDK session. - BYOK: per-project SA secret + GCS bucket can be passed via credentials; - falls back to platform-shared values in settings. + BYOK: per-project SA JSON + GCS bucket are passed via credentials and + stored directly on the client; falls back to platform-shared values + in settings when not provided by the project credential row. """ def __init__( @@ -57,15 +75,13 @@ def __init__( api_key: str, project_id: str, location: str, - gcp_sa_secret_name: str | None = None, - gcp_sa_secret_region: str | None = None, + sa_info: dict | None = None, gcs_bucket: str | None = None, ): self.api_key = api_key self.project_id = project_id self.location = location - self.gcp_sa_secret_name = gcp_sa_secret_name - self.gcp_sa_secret_region = gcp_sa_secret_region + self.sa_info = sa_info self.gcs_bucket = gcs_bucket or settings.GCS_AUDIO_BUCKET def endpoint(self, model: str) -> str: @@ -90,20 +106,35 @@ def __init__(self, client: VertexClient): @staticmethod def create_client(credentials: dict[str, Any]) -> Any: + # Fall back to platform-shared defaults from settings for any field + # the caller didn't provide. The SA JSON falls back to the file at + # settings.GCP_SA_KEY_PATH; BYOK rows pass `sa_key` inline. + credentials = credentials or {} + api_key = credentials.get("api_key") or settings.GCP_VERTEX_API_KEY + project_id = credentials.get("project_id") or settings.GCP_PROJECT_ID + location = credentials.get("location") or settings.GCP_VERTEX_LOCATION + gcs_bucket = credentials.get("gcs_bucket") or settings.GCS_AUDIO_BUCKET + sa_info = credentials.get("sa_key") or _load_platform_sa_info() + missing = [ - f for f in ("api_key", "project_id", "location") if not credentials.get(f) + name + for name, value in ( + ("api_key", api_key), + ("project_id", project_id), + ("location", location), + ) + if not value ] if missing: raise ValueError( f"Google Vertex AI credentials missing required fields: {', '.join(missing)}" ) return VertexClient( - api_key=credentials.get("api_key"), - project_id=credentials.get("project_id"), - location=credentials.get("location"), - gcp_sa_secret_name=credentials.get("gcp_sa_secret_name"), - gcp_sa_secret_region=credentials.get("gcp_sa_secret_region"), - gcs_bucket=credentials.get("gcs_bucket"), + api_key=api_key, + project_id=project_id, + location=location, + sa_info=sa_info, + gcs_bucket=gcs_bucket, ) def _post(self, model: str, payload: dict) -> tuple[dict | None, str | None]: @@ -161,15 +192,16 @@ def _execute_stt( # Push bytes straight to GCS — no disk I/O. fileData.fileUri bypasses # the 20 MB inline cap. - try: - sa_info = get_gcp_service_account( - secret_name=self.client.gcp_sa_secret_name, - region_name=self.client.gcp_sa_secret_region, + if not self.client.sa_info: + return ( + None, + "google-vertex sa_key not configured; cannot stage audio for STT", ) + try: gs_uri = upload_audio_to_gcs( audio_bytes=resolved_input.bytes_, bucket_name=self.client.gcs_bucket, - sa_info=sa_info, + sa_info=self.client.sa_info, project_id=self.client.project_id, content_type=mime_type, ) diff --git a/backend/app/services/llm/providers/registry.py b/backend/app/services/llm/providers/registry.py index 7a1081352..07b172a54 100644 --- a/backend/app/services/llm/providers/registry.py +++ b/backend/app/services/llm/providers/registry.py @@ -75,26 +75,14 @@ def get_llm_provider( org_id=organization_id, ) - if not credentials: - # google-vertex falls back to platform-shared defaults from settings - # when no project credential row exists. BYOK is all-or-nothing for - # this provider (see Provider.GOOGLE_VERTEX in app/core/providers.py), - # so projects either register the full kit or use the platform set. - if credential_provider == "google-vertex": - from app.core.config import settings - - credentials = { - "api_key": settings.GCP_VERTEX_API_KEY, - "project_id": settings.GCP_PROJECT_ID, - "location": settings.GCP_VERTEX_LOCATION, - "gcp_sa_secret_name": settings.GCP_SA_SECRET_NAME, - "gcp_sa_secret_region": settings.GCP_SA_SECRET_REGION, - "gcs_bucket": settings.GCS_AUDIO_BUCKET, - } - else: - raise ValueError( - f"Credentials for provider '{credential_provider}' not configured for this project." - ) + # Pass through whatever the DB returned (including None/empty). Providers + # that support platform-default fallbacks (e.g. google-vertex) handle the + # empty case themselves in create_client; others raise. + if not credentials and credential_provider != "google-vertex": + raise ValueError( + f"Credentials for provider '{credential_provider}' not configured for this project." + ) + credentials = credentials or {} try: client = provider_class.create_client(credentials=credentials) diff --git a/backend/app/tests/core/test_storage_byok.py b/backend/app/tests/core/test_storage_byok.py deleted file mode 100644 index 5cfcb48dd..000000000 --- a/backend/app/tests/core/test_storage_byok.py +++ /dev/null @@ -1,165 +0,0 @@ -"""Tests for the BYOK helpers in app.core.cloud.storage.""" - -from unittest.mock import MagicMock, patch - -import pytest -from botocore.exceptions import ClientError - -from app.core.cloud.storage import ( - SecretsManagerError, - put_gcp_service_account, - upsert_byok_secret_for_provider, -) - - -VALID_SA = { - "type": "service_account", - "project_id": "starlit-lotus-492004-k0", - "client_email": "kaapi-test@starlit-lotus-492004-k0.iam.gserviceaccount.com", - "private_key": "-----BEGIN PRIVATE KEY-----\nfake\n-----END PRIVATE KEY-----", -} - - -@pytest.fixture -def mock_sm_client(): - client = MagicMock() - client.exceptions.ResourceExistsException = type( - "ResourceExistsException", (ClientError,), {} - ) - with patch("app.core.cloud.storage.boto3.session.Session") as mock_session, patch( - "app.core.cloud.storage.get_gcp_service_account.cache_clear" - ) as mock_clear: - mock_session.return_value.client.return_value = client - yield client, mock_clear - - -class TestPutGcpServiceAccount: - def test_creates_secret_when_absent(self, mock_sm_client): - client, mock_clear = mock_sm_client - put_gcp_service_account( - VALID_SA, secret_name="kaapi/dev/orgs/1/projects/2/google-vertex/sa" - ) - - client.create_secret.assert_called_once() - kwargs = client.create_secret.call_args.kwargs - assert kwargs["Name"] == "kaapi/dev/orgs/1/projects/2/google-vertex/sa" - # SA JSON round-trips through json.dumps; verify a known field survives. - assert '"type": "service_account"' in kwargs["SecretString"] - client.put_secret_value.assert_not_called() - mock_clear.assert_called_once() - - def test_updates_secret_when_present(self, mock_sm_client): - client, mock_clear = mock_sm_client - client.create_secret.side_effect = client.exceptions.ResourceExistsException( - {"Error": {"Code": "ResourceExistsException"}}, "CreateSecret" - ) - - put_gcp_service_account( - VALID_SA, secret_name="kaapi/dev/orgs/1/projects/2/google-vertex/sa" - ) - - client.create_secret.assert_called_once() - client.put_secret_value.assert_called_once() - kwargs = client.put_secret_value.call_args.kwargs - assert kwargs["SecretId"] == "kaapi/dev/orgs/1/projects/2/google-vertex/sa" - mock_clear.assert_called_once() - - def test_rejects_non_service_account_payload(self, mock_sm_client): - client, _ = mock_sm_client - bad = {"type": "user_account", "client_id": "x"} - with pytest.raises(SecretsManagerError, match="not a GCP service-account"): - put_gcp_service_account(bad, secret_name="kaapi/anything") - client.create_secret.assert_not_called() - client.put_secret_value.assert_not_called() - - def test_wraps_aws_errors(self, mock_sm_client): - client, _ = mock_sm_client - client.create_secret.side_effect = ClientError( - {"Error": {"Code": "AccessDeniedException", "Message": "nope"}}, - "CreateSecret", - ) - with pytest.raises(SecretsManagerError, match="AccessDeniedException"): - put_gcp_service_account(VALID_SA, secret_name="kaapi/anything") - - -class TestUpsertByokSecretForProvider: - def test_google_vertex_with_sa_key_strips_and_writes(self): - creds = { - "api_key": "vkey", - "project_id": "starlit-lotus-492004-k0", - "location": "us-central1", - "sa_key": VALID_SA, - "gcs_bucket": "my-bucket", - } - with patch("app.core.cloud.storage.put_gcp_service_account") as mock_put, patch( - "app.core.cloud.storage.settings" - ) as mock_settings: - mock_settings.ENVIRONMENT = "development" - mock_settings.GCP_SA_SECRET_REGION = "ap-south-1" - - result = upsert_byok_secret_for_provider( - "google-vertex", creds, org_id=7, project_id=42 - ) - - expected_name = "kaapi/development/orgs/7/projects/42/google-vertex/sa" - mock_put.assert_called_once_with(VALID_SA, secret_name=expected_name) - assert "sa_key" not in result - assert result["gcp_sa_secret_name"] == expected_name - assert result["gcp_sa_secret_region"] == "ap-south-1" - # Untouched fields preserved. - assert result["api_key"] == "vkey" - assert result["gcs_bucket"] == "my-bucket" - - def test_google_vertex_rejects_null_sa_key(self): - creds = { - "api_key": "vkey", - "project_id": "p", - "location": "us-central1", - "sa_key": None, - } - with pytest.raises(ValueError, match="sa_key.*non-empty service-account JSON"): - upsert_byok_secret_for_provider( - "google-vertex", creds, org_id=1, project_id=1 - ) - - def test_google_vertex_rejects_string_sa_key(self): - creds = { - "api_key": "vkey", - "project_id": "p", - "location": "us-central1", - "sa_key": "not-a-dict", - } - with pytest.raises(ValueError, match="non-empty service-account JSON"): - upsert_byok_secret_for_provider( - "google-vertex", creds, org_id=1, project_id=1 - ) - - def test_google_vertex_rejects_empty_sa_key(self): - creds = { - "api_key": "vkey", - "project_id": "p", - "location": "us-central1", - "sa_key": {}, - } - with pytest.raises(ValueError, match="non-empty service-account JSON"): - upsert_byok_secret_for_provider( - "google-vertex", creds, org_id=1, project_id=1 - ) - - def test_google_vertex_rejects_missing_sa_key(self): - # Validator at the route requires sa_key, but the hook also rejects - # absence defensively in case it's invoked outside the route flow. - creds = {"api_key": "vkey", "project_id": "p", "location": "us-central1"} - with pytest.raises(ValueError, match="non-empty service-account JSON"): - upsert_byok_secret_for_provider( - "google-vertex", creds, org_id=1, project_id=1 - ) - - def test_other_provider_is_noop_even_with_sa_key(self): - creds = {"api_key": "k", "sa_key": VALID_SA} - with patch("app.core.cloud.storage.put_gcp_service_account") as mock_put: - result = upsert_byok_secret_for_provider( - "openai", creds, org_id=1, project_id=1 - ) - mock_put.assert_not_called() - assert result == creds # sa_key passes through (validator's job to reject) diff --git a/backend/app/tests/crud/test_credentials.py b/backend/app/tests/crud/test_credentials.py index e551166c9..4370e56e0 100644 --- a/backend/app/tests/crud/test_credentials.py +++ b/backend/app/tests/crud/test_credentials.py @@ -1,5 +1,3 @@ -from unittest.mock import patch - import pytest from sqlmodel import Session @@ -89,8 +87,8 @@ def test_get_creds_by_org(db: Session) -> None: def test_set_credentials_for_google_vertex_with_sa_key(db: Session) -> None: - """sa_key on google-vertex must be uploaded to SM and stripped before storage; - the persisted credential dict carries only the secret reference.""" + """google-vertex sa_key is stored directly in the credentials table + (encrypted by the same Fernet key as every other credential field).""" project = create_test_project(db) sa_key = { @@ -112,33 +110,12 @@ def test_set_credentials_for_google_vertex_with_sa_key(db: Session) -> None: }, ) - with patch("app.crud.credentials.upsert_byok_secret_for_provider") as mock_hook: - # Simulate the real hook's rewrite without touching AWS. - secret_name = ( - f"kaapi/test/orgs/{project.organization_id}" - f"/projects/{project.id}/google-vertex/sa" - ) - mock_hook.return_value = { - "api_key": "vkey", - "project_id": "starlit-lotus-492004-k0", - "location": "us-central1", - "gcs_bucket": "my-bucket", - "gcp_sa_secret_name": secret_name, - "gcp_sa_secret_region": "ap-south-1", - } - - created = set_creds_for_org( - session=db, - creds_add=payload, - organization_id=project.organization_id, - project_id=project.id, - ) - - mock_hook.assert_called_once() - args, kwargs = mock_hook.call_args - assert args[0] == "google-vertex" - assert args[1]["sa_key"] == sa_key - assert kwargs == {"org_id": project.organization_id, "project_id": project.id} + created = set_creds_for_org( + session=db, + creds_add=payload, + organization_id=project.organization_id, + project_id=project.id, + ) assert len(created) == 1 stored = get_provider_credential( @@ -148,10 +125,11 @@ def test_set_credentials_for_google_vertex_with_sa_key(db: Session) -> None: project_id=project.id, ) assert stored is not None - assert "sa_key" not in stored - assert stored["gcp_sa_secret_name"] == secret_name - assert stored["gcp_sa_secret_region"] == "ap-south-1" + assert stored["sa_key"] == sa_key assert stored["api_key"] == "vkey" + assert stored["project_id"] == "starlit-lotus-492004-k0" + assert stored["location"] == "us-central1" + assert stored["gcs_bucket"] == "my-bucket" def test_get_provider_credential(db: Session) -> None: diff --git a/backend/app/tests/services/llm/providers/test_gai_vertex.py b/backend/app/tests/services/llm/providers/test_gai_vertex.py index 6df84c4a6..96fcbc764 100644 --- a/backend/app/tests/services/llm/providers/test_gai_vertex.py +++ b/backend/app/tests/services/llm/providers/test_gai_vertex.py @@ -69,11 +69,7 @@ def _mock_http_err(status: int = 400, body: str = "bad request") -> MagicMock: @pytest.fixture(autouse=True) def _mock_gcs(monkeypatch): - """Stub out SM + GCS so STT tests don't touch external services.""" - monkeypatch.setattr( - "app.services.llm.providers.gai_vertex.get_gcp_service_account", - lambda **kw: {"type": "service_account", "project_id": "p"}, - ) + """Stub out GCS upload so STT tests don't touch external services.""" monkeypatch.setattr( "app.services.llm.providers.gai_vertex.upload_audio_to_gcs", lambda *, audio_bytes, bucket_name, sa_info, **kw: f"gs://{bucket_name}/audio/test.wav", @@ -87,6 +83,7 @@ def client(self) -> VertexClient: api_key="k", project_id="p", location="us-central1", + sa_info={"type": "service_account", "project_id": "p"}, gcs_bucket="test-bucket", ) diff --git a/backend/app/tests/services/llm/providers/test_registry.py b/backend/app/tests/services/llm/providers/test_registry.py index a6978dc60..aa4bae639 100644 --- a/backend/app/tests/services/llm/providers/test_registry.py +++ b/backend/app/tests/services/llm/providers/test_registry.py @@ -105,25 +105,37 @@ def test_get_llm_provider_with_missing_credentials(self, db: Session): assert "not configured for this project" in str(exc_info.value) - def test_google_vertex_falls_back_to_platform_settings(self, db: Session): - """No credential row for google-vertex → registry synthesizes platform - defaults from settings (the BYOK-or-platform contract).""" + def test_google_vertex_falls_back_to_platform_settings(self, db: Session, tmp_path): + """No credential row for google-vertex → create_client synthesizes the + platform defaults from settings (api_key/project/location/bucket) and + loads the SA JSON from GCP_SA_KEY_PATH.""" + import json as _json + from app.services.llm.providers.gai_vertex import ( GoogleVertexAIProvider, VertexClient, ) project = get_project(db) + sa_info = { + "type": "service_account", + "project_id": "platform-project", + "client_email": "sa@platform-project.iam.gserviceaccount.com", + "private_key": "-----BEGIN PRIVATE KEY-----\nfake\n-----END PRIVATE KEY-----", + } + sa_path = tmp_path / "sa.json" + sa_path.write_text(_json.dumps(sa_info)) with patch( "app.crud.credentials.get_provider_credential" - ) as mock_get_creds, patch("app.core.config.settings") as mock_settings: + ) as mock_get_creds, patch( + "app.services.llm.providers.gai_vertex.settings" + ) as mock_settings: mock_get_creds.return_value = None mock_settings.GCP_VERTEX_API_KEY = "platform-key" mock_settings.GCP_PROJECT_ID = "platform-project" mock_settings.GCP_VERTEX_LOCATION = "us-central1" - mock_settings.GCP_SA_SECRET_NAME = "platform/secret" - mock_settings.GCP_SA_SECRET_REGION = "ap-south-1" + mock_settings.GCP_SA_KEY_PATH = str(sa_path) mock_settings.GCS_AUDIO_BUCKET = "platform-bucket" provider = get_llm_provider( @@ -138,5 +150,5 @@ def test_google_vertex_falls_back_to_platform_settings(self, db: Session): assert provider.client.api_key == "platform-key" assert provider.client.project_id == "platform-project" assert provider.client.location == "us-central1" - assert provider.client.gcp_sa_secret_name == "platform/secret" + assert provider.client.sa_info == sa_info assert provider.client.gcs_bucket == "platform-bucket" From 6a0fdb892cf16740d3b480fb0a181f6100bc3db0 Mon Sep 17 00:00:00 2001 From: Prajna1999 Date: Tue, 2 Jun 2026 15:29:39 +0530 Subject: [PATCH 06/15] fix: mask dicts --- backend/app/core/providers.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/backend/app/core/providers.py b/backend/app/core/providers.py index bca063dc6..2e874dc3a 100644 --- a/backend/app/core/providers.py +++ b/backend/app/core/providers.py @@ -138,6 +138,14 @@ def mask_credential_fields( sensitive_fields = PROVIDER_CONFIGS[provider_enum].sensitive_fields masked = dict(credentials) for field_name in sensitive_fields: - if field_name in masked and isinstance(masked[field_name], str): - masked[field_name] = mask_string(masked[field_name]) + if field_name not in masked: + continue + value = masked[field_name] + if isinstance(value, str): + masked[field_name] = mask_string(value) + else: + # Non-string secrets (e.g. google-vertex `sa_key` is a dict) + # are masked wholesale — the raw value is only decrypted at + # provider runtime, never returned via the API. + masked[field_name] = "********" return masked From 1fb5a539e3accc4841a3d334a2e8e859620e62e2 Mon Sep 17 00:00:00 2001 From: Prajna1999 Date: Tue, 2 Jun 2026 16:06:57 +0530 Subject: [PATCH 07/15] feat: files api beta support in Claude --- backend/app/services/llm/providers/claude.py | 42 +++++++++++++++++++- 1 file changed, 41 insertions(+), 1 deletion(-) diff --git a/backend/app/services/llm/providers/claude.py b/backend/app/services/llm/providers/claude.py index 382000f27..e85930c59 100644 --- a/backend/app/services/llm/providers/claude.py +++ b/backend/app/services/llm/providers/claude.py @@ -1,3 +1,5 @@ +import base64 +import io import logging from typing import Any @@ -21,6 +23,7 @@ logger = logging.getLogger(__name__) DEFAULT_MAX_TOKENS = 4096 +FILES_API_BETA = "files-api-2025-04-14" class ClaudeProvider(BaseProvider): @@ -90,6 +93,26 @@ def format_parts( return items + @staticmethod + def _is_base64_file_block(block: Any) -> bool: + if not isinstance(block, dict): + return False + if block.get("type") not in ("document", "image"): + return False + source = block.get("source") or {} + return source.get("type") == "base64" + + def _upload_to_files_api(self, source: dict, block_type: str) -> str: + file_bytes = base64.b64decode(source["data"]) + filename = "document.pdf" if block_type == "document" else "image" + upload = self.client.beta.files.upload( + file=(filename, io.BytesIO(file_bytes), source["media_type"]), + ) + logger.info( + f"[ClaudeProvider._upload_to_files_api] Uploaded {block_type} | file_id={upload.id}" + ) + return upload.id + def execute( self, completion_config: NativeCompletionConfig, @@ -120,6 +143,18 @@ def execute( else: content = resolved_input + # Upload any base64 PDFs/images to the Files API and reference by file_id. + # Keeps request payloads small and lets large files bypass inline size limits. + uploaded_file = False + if isinstance(content, list): + for block in content: + if not self._is_base64_file_block(block): + continue + + file_id = self._upload_to_files_api(block["source"], block["type"]) + block["source"] = {"type": "file", "file_id": file_id} + uploaded_file = True + params["messages"] = [{"role": "user", "content": content}] # Anthropic Messages API has no first-class conversation primitive, @@ -127,7 +162,12 @@ def execute( # config so it never leaks into the API call. params.pop("conversation", None) - response = self.client.messages.create(**params) + if uploaded_file: + existing_betas = params.pop("betas", []) or [] + params["betas"] = [*existing_betas, FILES_API_BETA] + response = self.client.beta.messages.create(**params) + else: + response = self.client.messages.create(**params) output_text = "".join( block.text for block in response.content if block.type == "text" From 9ff5c72889fd975cb81a1d4a4218d5aa70648136 Mon Sep 17 00:00:00 2001 From: Prajna1999 Date: Tue, 2 Jun 2026 16:58:03 +0530 Subject: [PATCH 08/15] feat: claude PDF/Image intergration --- backend/app/models/llm/constants.py | 16 +++++++++ backend/app/models/llm/request.py | 8 ++++- backend/app/services/llm/mappers.py | 37 ++++++++++++++------ backend/app/services/llm/providers/claude.py | 19 +++++----- backend/app/services/llm/providers/eai.py | 8 +++-- backend/app/services/llm/providers/gai.py | 5 ++- backend/app/services/llm/providers/oai.py | 2 ++ backend/app/services/llm/providers/sai.py | 8 +++-- 8 files changed, 75 insertions(+), 28 deletions(-) diff --git a/backend/app/models/llm/constants.py b/backend/app/models/llm/constants.py index 1838da79d..f4c1471c2 100644 --- a/backend/app/models/llm/constants.py +++ b/backend/app/models/llm/constants.py @@ -2,6 +2,22 @@ DEFAULT_TTS_MODEL = "gemini-2.5-flash-preview-tts" DEFAULT_TTS_VOICE = "Kore" +# Default text-completion model per provider. Used by both the native flow +# (provider.execute) and the Kaapi mapper so the two stay in sync. +DEFAULT_TEXT_MODELS: dict[str, str] = { + "anthropic": "claude-sonnet-4-6", + "openai": "gpt-4.1-mini", + "google": "gemini-2.5-pro", +} + +DEFAULT_ANTHROPIC_MAX_TOKENS = 4096 + +# Provider-native STT/TTS defaults (used when caller omits model). +DEFAULT_SARVAM_STT_MODEL = "saaras:v3" +DEFAULT_SARVAM_TTS_MODEL = "bulbul:v3" +DEFAULT_ELEVENLABS_STT_MODEL = "scribe_v2" +DEFAULT_ELEVENLABS_TTS_MODEL = "eleven_v3" + # BCP-47 to language tag -> Gemini ISO 639-1 code (Indic + English) BCP47_LOCALE_TO_GEMINI_LANG: dict[str, str] = { "en-IN": "en", diff --git a/backend/app/models/llm/request.py b/backend/app/models/llm/request.py index d14034819..db084c957 100644 --- a/backend/app/models/llm/request.py +++ b/backend/app/models/llm/request.py @@ -17,7 +17,13 @@ class TextLLMParams(SQLModel): - model: str + model: str | None = Field( + default=None, + description=( + "Provider model to use. If omitted, the Kaapi mapper falls back to " + "DEFAULT_TEXT_MODELS for the selected provider." + ), + ) instructions: str | None = Field( default=None, ) diff --git a/backend/app/services/llm/mappers.py b/backend/app/services/llm/mappers.py index a6a42f62e..bff16e7f5 100644 --- a/backend/app/services/llm/mappers.py +++ b/backend/app/services/llm/mappers.py @@ -7,10 +7,24 @@ from app.models.llm.constants import ( BCP47_LOCALE_TO_GEMINI_LANG, BCP47_TO_ELEVENLABS_LANG, + DEFAULT_ELEVENLABS_STT_MODEL, + DEFAULT_ELEVENLABS_TTS_MODEL, + DEFAULT_SARVAM_STT_MODEL, + DEFAULT_SARVAM_TTS_MODEL, + DEFAULT_TEXT_MODELS, DEFAULT_TTS_VOICE, ELEVENLABS_VOICE_TO_ID, ) +SARVAM_DEFAULTS_BY_TYPE = { + "stt": DEFAULT_SARVAM_STT_MODEL, + "tts": DEFAULT_SARVAM_TTS_MODEL, +} +ELEVENLABS_DEFAULTS_BY_TYPE = { + "stt": DEFAULT_ELEVENLABS_STT_MODEL, + "tts": DEFAULT_ELEVENLABS_TTS_MODEL, +} + logger = logging.getLogger(__name__) @@ -96,8 +110,7 @@ def map_kaapi_to_openai_params( if temperature is not None: openai_params["temperature"] = temperature - if model: - openai_params["model"] = model + openai_params["model"] = model or DEFAULT_TEXT_MODELS["openai"] if instructions: openai_params["instructions"] = instructions @@ -139,8 +152,11 @@ def map_kaapi_to_google_params( google_params = {} warnings = [] - # Model is present in all param types + # Model is present in all param types; text falls back to the centralized + # default. STT/TTS require an explicit model (Gemini variant differs by mode). model = kaapi_params.get("model") + if not model and completion_type == "text": + model = DEFAULT_TEXT_MODELS["google"] if not model: return {}, ["Missing required 'model' parameter"] @@ -238,10 +254,10 @@ def map_kaapi_to_sarvam_params( sarvam_params = {} warnings = [] - # Model is required for all completion types - model = kaapi_params.get("model") + # Model falls back to the per-type Sarvam default. + model = kaapi_params.get("model") or SARVAM_DEFAULTS_BY_TYPE.get(completion_type) if not model: - return {}, ["Missing required 'model' parameter"] + return {}, [f"Unsupported completion type '{completion_type}' for SarvamAI"] sarvam_params["model"] = model if completion_type == "tts": @@ -346,9 +362,11 @@ def map_kaapi_to_elevenlabs_params( elevenlabs_params = {} warnings = [] - model_id = kaapi_params.get("model") + model_id = kaapi_params.get("model") or ELEVENLABS_DEFAULTS_BY_TYPE.get( + completion_type + ) if not model_id: - return {}, ["Missing required 'model' parameter"] + return {}, [f"Unsupported completion type '{completion_type}' for ElevenLabs"] elevenlabs_params["model_id"] = model_id if completion_type == "tts": @@ -460,8 +478,7 @@ def map_kaapi_to_anthropic_params( effort = kaapi_params.get("effort") summary = kaapi_params.get("summary") - if model: - anthropic_params["model"] = model + anthropic_params["model"] = model or DEFAULT_TEXT_MODELS["anthropic"] if instructions: anthropic_params["system"] = instructions diff --git a/backend/app/services/llm/providers/claude.py b/backend/app/services/llm/providers/claude.py index e85930c59..2caed9fe7 100644 --- a/backend/app/services/llm/providers/claude.py +++ b/backend/app/services/llm/providers/claude.py @@ -18,11 +18,14 @@ ImageContent, PDFContent, ) +from app.models.llm.constants import ( + DEFAULT_ANTHROPIC_MAX_TOKENS, + DEFAULT_TEXT_MODELS, +) from app.services.llm.providers.base import BaseProvider, ContentPart, MultiModalInput logger = logging.getLogger(__name__) -DEFAULT_MAX_TOKENS = 4096 FILES_API_BETA = "files-api-2025-04-14" @@ -126,15 +129,11 @@ def execute( try: params = {**completion_config.params} - # Anthropic requires max_tokens; default if caller did not supply - params.setdefault("max_tokens", DEFAULT_MAX_TOKENS) - - # Kaapi exposes "instructions"; Anthropic uses "system". Always - # strip "instructions" — Anthropic rejects unknown kwargs. - if "instructions" in params: - if "system" not in params: - params["system"] = params["instructions"] - params.pop("instructions") + # Anthropic requires model and max_tokens; default if caller did not supply + params["model"] = params.get("model") or DEFAULT_TEXT_MODELS["anthropic"] + params["max_tokens"] = ( + params.get("max_tokens") or DEFAULT_ANTHROPIC_MAX_TOKENS + ) if isinstance(resolved_input, MultiModalInput): content = self.format_parts(resolved_input.parts) diff --git a/backend/app/services/llm/providers/eai.py b/backend/app/services/llm/providers/eai.py index 5b9860184..393f83ccd 100644 --- a/backend/app/services/llm/providers/eai.py +++ b/backend/app/services/llm/providers/eai.py @@ -18,6 +18,10 @@ Usage, TextContent, ) +from app.models.llm.constants import ( + DEFAULT_ELEVENLABS_STT_MODEL, + DEFAULT_ELEVENLABS_TTS_MODEL, +) from app.models.llm.response import AudioOutput from app.models.llm.request import AudioContent from app.services.llm.providers.base import BaseProvider @@ -83,7 +87,7 @@ def _execute_stt( return None, f"{provider_name} STT requires AudioRef input" # Extract already-mapped parameters from the mapper - model_id = params.get("model_id") or "scribe_v2" + model_id = params.get("model_id") or DEFAULT_ELEVENLABS_STT_MODEL if not model_id: return None, "Missing 'model_id' in native params for Elevenlabs STT" @@ -167,7 +171,7 @@ def _execute_tts( # Extract already-mapped parameters from the mapper # Use 'or' to handle both missing keys and falsy values - model_id = params.get("model_id") or "eleven_v3" + model_id = params.get("model_id") or DEFAULT_ELEVENLABS_TTS_MODEL voice_id = params.get("voice_id") or "EXAVITQu4vr4xnSDxMaL" if not model_id: diff --git a/backend/app/services/llm/providers/gai.py b/backend/app/services/llm/providers/gai.py index 9da71a932..881734ebc 100644 --- a/backend/app/services/llm/providers/gai.py +++ b/backend/app/services/llm/providers/gai.py @@ -25,6 +25,7 @@ ) from app.models.llm.constants import ( DEFAULT_STT_MODEL, + DEFAULT_TEXT_MODELS, DEFAULT_TTS_MODEL, DEFAULT_TTS_VOICE, ) @@ -394,9 +395,7 @@ def _execute_text( resolved_input: str | list[ContentPart] | MultiModalInput, include_provider_raw_response: bool = False, ) -> tuple[LLMCallResponse | None, str | None]: - model = completion_config.params.get("model") - if not model: - return None, "Missing 'model' in native params" + model = completion_config.params.get("model") or DEFAULT_TEXT_MODELS["google"] if isinstance(resolved_input, MultiModalInput): gemini_parts = self.format_parts(resolved_input.parts) diff --git a/backend/app/services/llm/providers/oai.py b/backend/app/services/llm/providers/oai.py index a93bd76c7..cbbb47172 100644 --- a/backend/app/services/llm/providers/oai.py +++ b/backend/app/services/llm/providers/oai.py @@ -16,6 +16,7 @@ ImageContent, PDFContent, ) +from app.models.llm.constants import DEFAULT_TEXT_MODELS from app.services.llm.providers.base import BaseProvider, ContentPart, MultiModalInput logger = logging.getLogger(__name__) @@ -76,6 +77,7 @@ def execute( params = { **completion_config.params, } + params["model"] = params.get("model") or DEFAULT_TEXT_MODELS["openai"] if isinstance(resolved_input, MultiModalInput): params["input"] = [ {"role": "user", "content": self.format_parts(resolved_input.parts)} diff --git a/backend/app/services/llm/providers/sai.py b/backend/app/services/llm/providers/sai.py index ba760050a..175bef217 100644 --- a/backend/app/services/llm/providers/sai.py +++ b/backend/app/services/llm/providers/sai.py @@ -13,6 +13,10 @@ Usage, TextContent, ) +from app.models.llm.constants import ( + DEFAULT_SARVAM_STT_MODEL, + DEFAULT_SARVAM_TTS_MODEL, +) from app.models.llm.response import AudioOutput from app.models.llm.request import AudioContent from app.services.llm.providers.base import BaseProvider @@ -78,7 +82,7 @@ def _execute_stt( return None, f"{provider_name} STT requires AudioRef input" # Extract already-mapped parameters from the mapper - model = params.get("model") or "saaras:v3" + model = params.get("model") or DEFAULT_SARVAM_STT_MODEL language_code = params.get("language_code") mode = params.get("mode") or "transcribe" @@ -157,7 +161,7 @@ def _execute_tts( params = completion_config.params # Extract already-mapped parameters from the mapper - model = params.get("model") or "bulbul:v3" + model = params.get("model") or DEFAULT_SARVAM_TTS_MODEL target_language_code = params.get("target_language_code") if not target_language_code: return ( From 8afea28d9db40fab868d3c271c8eb2382789bc93 Mon Sep 17 00:00:00 2001 From: Prajna1999 Date: Thu, 4 Jun 2026 10:20:13 +0530 Subject: [PATCH 09/15] fix: remove anthropic, vertex model --- ...nthropic_google_vertex_to_provider_enum.py | 115 ------------------ 1 file changed, 115 deletions(-) delete mode 100644 backend/app/alembic/versions/064_add_anthropic_google_vertex_to_provider_enum.py diff --git a/backend/app/alembic/versions/064_add_anthropic_google_vertex_to_provider_enum.py b/backend/app/alembic/versions/064_add_anthropic_google_vertex_to_provider_enum.py deleted file mode 100644 index 3889902b0..000000000 --- a/backend/app/alembic/versions/064_add_anthropic_google_vertex_to_provider_enum.py +++ /dev/null @@ -1,115 +0,0 @@ -"""add anthropic + google-vertex to provider_enum and seed test model_config rows - -Revision ID: 064 -Revises: 063 -Create Date: 2026-05-28 00:00:00.000000 - -""" - -from alembic import op - - -revision = "064" -down_revision = "063" -branch_labels = None -depends_on = None - - -def upgrade(): - # ALTER TYPE ... ADD VALUE cannot run inside a transaction block; use - # autocommit per existing pattern (see migration 056). The added values - # are visible to subsequent statements once the autocommit_block exits. - with op.get_context().autocommit_block(): - op.execute( - "ALTER TYPE global.provider_enum ADD VALUE IF NOT EXISTS 'anthropic'" - ) - op.execute( - "ALTER TYPE global.provider_enum ADD VALUE IF NOT EXISTS 'google-vertex'" - ) - - # Pass-through seed rows for testing. Pricing values are placeholders; - # revise once real cost data is available. - op.execute( - """ - INSERT INTO global.model_config - (provider, model_name, completion_type, config, input_modalities, output_modalities, pricing, is_active, inserted_at, updated_at) - VALUES - -- Anthropic text models - ('anthropic', 'claude-opus-4-7', 'text', - '{"temperature": {"type": "float", "default": 1.0, "min": 0.0, "max": 1.0, "description": "Sampling temperature."}}', - '{TEXT,IMAGE,PDF}', '{TEXT}', - '{"response": {"input_token_cost": 15.0, "output_token_cost": 75.0}, "batch": {"input_token_cost": 7.5, "output_token_cost": 37.5}}', - true, NOW(), NOW()), - ('anthropic', 'claude-sonnet-4-6', 'text', - '{"temperature": {"type": "float", "default": 1.0, "min": 0.0, "max": 1.0, "description": "Sampling temperature."}}', - '{TEXT,IMAGE,PDF}', '{TEXT}', - '{"response": {"input_token_cost": 3.0, "output_token_cost": 15.0}, "batch": {"input_token_cost": 1.5, "output_token_cost": 7.5}}', - true, NOW(), NOW()), - ('anthropic', 'claude-haiku-4-5-20251001', 'text', - '{"temperature": {"type": "float", "default": 1.0, "min": 0.0, "max": 1.0, "description": "Sampling temperature."}}', - '{TEXT,IMAGE,PDF}', '{TEXT}', - '{"response": {"input_token_cost": 1.0, "output_token_cost": 5.0}, "batch": {"input_token_cost": 0.5, "output_token_cost": 2.5}}', - true, NOW(), NOW()), - -- Google Vertex STT models (Gemini 3.x family — GA per - -- https://docs.cloud.google.com/gemini-enterprise-agent-platform/models/google-models) - ('google-vertex', 'gemini-3.1-pro-preview', 'stt', - '{"thinking_level": {"type": "enum", "default": "high", "options": ["low", "medium", "high"], "description": "Max reasoning depth before output. high = best quality, low = faster/cheaper."}}', - '{AUDIO}', '{TEXT}', - '{"response": {"input_token_cost": 2.0, "output_token_cost": 12.0}, "audio": {"input_token_cost": 3.5, "output_token_cost": 12.0}}', - true, NOW(), NOW()), - ('google-vertex', 'gemini-3-pro', 'stt', - '{"thinking_level": {"type": "enum", "default": "high", "options": ["low", "medium", "high"], "description": "Max reasoning depth before output."}}', - '{AUDIO}', '{TEXT}', - '{"response": {"input_token_cost": 1.5, "output_token_cost": 10.0}, "audio": {"input_token_cost": 3.0, "output_token_cost": 10.0}}', - true, NOW(), NOW()), - ('google-vertex', 'gemini-3.5-flash', 'stt', - '{"thinking_level": {"type": "enum", "default": "high", "options": ["minimal", "low", "medium", "high"], "description": "Max reasoning depth before output."}}', - '{AUDIO}', '{TEXT}', - '{"response": {"input_token_cost": 0.6, "output_token_cost": 3.5}, "audio": {"input_token_cost": 1.2, "output_token_cost": 3.5}}', - true, NOW(), NOW()), - ('google-vertex', 'gemini-3-flash-preview', 'stt', - '{"thinking_level": {"type": "enum", "default": "high", "options": ["minimal", "low", "medium", "high"], "description": "Max reasoning depth before output."}}', - '{AUDIO}', '{TEXT}', - '{"response": {"input_token_cost": 0.5, "output_token_cost": 3.0}, "audio": {"input_token_cost": 1.0, "output_token_cost": 3.0}}', - true, NOW(), NOW()), - ('google-vertex', 'gemini-3.1-flash-lite', 'stt', - '{"temperature": {"type": "float", "default": 0.0, "min": 0.0, "max": 2.0, "description": "Controls randomness. Lower = more deterministic."}}', - '{AUDIO}', '{TEXT}', - '{"response": {"input_token_cost": 0.1, "output_token_cost": 0.4}, "audio": {"input_token_cost": 0.3, "output_token_cost": 0.4}}', - true, NOW(), NOW()), - ('google-vertex', 'gemini-2.5-flash', 'stt', - '{"temperature": {"type": "float", "default": 0.0, "min": 0.0, "max": 2.0, "description": "Controls randomness. Lower = more deterministic."}}', - '{AUDIO}', '{TEXT}', - '{"response": {"input_token_cost": 0.3, "output_token_cost": 2.5}, "audio": {"input_token_cost": 1.0, "output_token_cost": 2.5}}', - true, NOW(), NOW()), - ('google-vertex', 'gemini-2.5-pro', 'stt', - '{"temperature": {"type": "float", "default": 0.0, "min": 0.0, "max": 2.0, "description": "Controls randomness. Lower = more deterministic."}}', - '{AUDIO}', '{TEXT}', - '{"response": {"input_token_cost": 1.25, "output_token_cost": 10.0}, "audio": {"input_token_cost": 3.5, "output_token_cost": 10.0}}', - true, NOW(), NOW()), - -- Google Vertex TTS models - ('google-vertex', 'gemini-2.5-flash-preview-tts', 'tts', - '{"voice": {"type": "enum", "default": "Kore", "options": ["Aoede", "Charon", "Fenrir", "Kore", "Puck"], "description": "TTS voice."}}', - '{TEXT}', '{AUDIO}', - '{"response": {"input_token_cost": 0.5, "output_token_cost": 10.0}, "audio": {"input_token_cost": 0.5, "output_token_cost": 10.0}}', - true, NOW(), NOW()), - ('google-vertex', 'gemini-2.5-pro-preview-tts', 'tts', - '{"voice": {"type": "enum", "default": "Kore", "options": ["Aoede", "Charon", "Fenrir", "Kore", "Puck"], "description": "TTS voice."}}', - '{TEXT}', '{AUDIO}', - '{"response": {"input_token_cost": 1.0, "output_token_cost": 20.0}, "audio": {"input_token_cost": 1.0, "output_token_cost": 20.0}}', - true, NOW(), NOW()) - ON CONFLICT (provider, model_name) DO NOTHING - """ - ) - - -def downgrade(): - op.execute( - """ - DELETE FROM global.model_config - WHERE provider IN ('anthropic', 'google-vertex') - """ - ) - # Enum value removal requires rebuilding the type and re-pointing every - # referencing column. Skipped — see migrations 035 / 056 for the same - # convention. From b6605c4169a50c4eb4eaf9ee5552089f4b2bd425 Mon Sep 17 00:00:00 2001 From: Prajna1999 Date: Thu, 4 Jun 2026 17:05:05 +0530 Subject: [PATCH 10/15] fix: json file instead of json file path --- backend/app/core/config.py | 2 +- .../app/services/llm/providers/gai_vertex.py | 53 ++++++++++++++----- .../services/llm/providers/test_registry.py | 4 +- 3 files changed, 42 insertions(+), 17 deletions(-) diff --git a/backend/app/core/config.py b/backend/app/core/config.py index e05c04199..3088d8002 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -112,7 +112,7 @@ def SQLALCHEMY_DATABASE_URI(self) -> PostgresDsn: GCP_PROJECT_ID: str = "" # Filesystem path to the platform-default GCP service-account JSON. # Used by the registry fallback when a project has no google-vertex row. - GCP_SA_KEY_PATH: str = "" + GCP_SA_KEY: str = "" GCS_AUDIO_BUCKET: str = "" # RabbitMQ configuration for Celery broker diff --git a/backend/app/services/llm/providers/gai_vertex.py b/backend/app/services/llm/providers/gai_vertex.py index ceccc26a3..de8a7f472 100644 --- a/backend/app/services/llm/providers/gai_vertex.py +++ b/backend/app/services/llm/providers/gai_vertex.py @@ -48,19 +48,27 @@ def _load_platform_sa_info() -> dict | None: - """Load the platform-default GCP SA JSON from disk, if configured.""" - sa_path = settings.GCP_SA_KEY_PATH - if not sa_path or not Path(sa_path).is_file(): - return None - try: - return json.loads(Path(sa_path).read_text()) - except (OSError, json.JSONDecodeError) as e: - logger.warning( - f"[_load_platform_sa_info] Failed to load platform SA key | " - f"path={sa_path}, error={e}" - ) + """Load the platform-default GCP SA JSON. + + Supports two configuration shapes for settings.GCP_SA_KEY: + 1. Raw JSON string (e.g. injected via env var / secret manager) + 2. Filesystem path to a JSON key file + """ + sa_value = settings.GCP_SA_KEY + if not sa_value: return None + stripped = sa_value.strip() + if stripped.startswith("{"): + try: + return json.loads(stripped) + except json.JSONDecodeError as e: + logger.warning( + f"[_load_platform_sa_info] GCP_SA_KEY looks like JSON but " + f"failed to parse | error={e}" + ) + return None + class VertexClient: """Holds Vertex AI connection details. Pure config — no SDK session. @@ -85,8 +93,15 @@ def __init__( self.gcs_bucket = gcs_bucket or settings.GCS_AUDIO_BUCKET def endpoint(self, model: str) -> str: + # The "global" location uses the unprefixed host; regional locations + # use the "{location}-" prefix. + host = ( + "aiplatform.googleapis.com" + if self.location == "global" + else f"{self.location}-aiplatform.googleapis.com" + ) return ( - f"https://{self.location}-aiplatform.googleapis.com/v1" + f"https://{host}/v1" f"/projects/{self.project_id}/locations/{self.location}" f"/publishers/google/models/{model}:generateContent" ) @@ -108,14 +123,21 @@ def __init__(self, client: VertexClient): def create_client(credentials: dict[str, Any]) -> Any: # Fall back to platform-shared defaults from settings for any field # the caller didn't provide. The SA JSON falls back to the file at - # settings.GCP_SA_KEY_PATH; BYOK rows pass `sa_key` inline. + # settings.GCP_SA_KEY; BYOK rows pass `sa_key` inline. credentials = credentials or {} api_key = credentials.get("api_key") or settings.GCP_VERTEX_API_KEY + logger.info(f"Vertex API Key {api_key}") project_id = credentials.get("project_id") or settings.GCP_PROJECT_ID location = credentials.get("location") or settings.GCP_VERTEX_LOCATION gcs_bucket = credentials.get("gcs_bucket") or settings.GCS_AUDIO_BUCKET sa_info = credentials.get("sa_key") or _load_platform_sa_info() + source = "byok" if credentials.get("api_key") else "platform" + logger.info( + f"[create_client] vertex creds | source={source}, " + f"project_id={project_id}, location={location}" + ) + missing = [ name for name, value in ( @@ -138,9 +160,11 @@ def create_client(credentials: dict[str, Any]) -> Any: ) def _post(self, model: str, payload: dict) -> tuple[dict | None, str | None]: + url = self.client.endpoint(model) + logger.debug(f"[_post] vertex url={url}") try: resp = requests.post( - self.client.endpoint(model), + url, params={"key": self.client.api_key}, headers={"Content-Type": "application/json"}, json=payload, @@ -258,6 +282,7 @@ def _execute_stt( } data, err = self._post(model, payload) + logger.error(f"[_execute_stt] Error post making the call to Vertes is {err}") if err: return None, err diff --git a/backend/app/tests/services/llm/providers/test_registry.py b/backend/app/tests/services/llm/providers/test_registry.py index aa4bae639..c68eb527e 100644 --- a/backend/app/tests/services/llm/providers/test_registry.py +++ b/backend/app/tests/services/llm/providers/test_registry.py @@ -108,7 +108,7 @@ def test_get_llm_provider_with_missing_credentials(self, db: Session): def test_google_vertex_falls_back_to_platform_settings(self, db: Session, tmp_path): """No credential row for google-vertex → create_client synthesizes the platform defaults from settings (api_key/project/location/bucket) and - loads the SA JSON from GCP_SA_KEY_PATH.""" + loads the SA JSON from GCP_SA_KEY.""" import json as _json from app.services.llm.providers.gai_vertex import ( @@ -135,7 +135,7 @@ def test_google_vertex_falls_back_to_platform_settings(self, db: Session, tmp_pa mock_settings.GCP_VERTEX_API_KEY = "platform-key" mock_settings.GCP_PROJECT_ID = "platform-project" mock_settings.GCP_VERTEX_LOCATION = "us-central1" - mock_settings.GCP_SA_KEY_PATH = str(sa_path) + mock_settings.GCP_SA_KEY = str(sa_path) mock_settings.GCS_AUDIO_BUCKET = "platform-bucket" provider = get_llm_provider( From 3c387e770a2aa93d0b57b02e38687d5772326f80 Mon Sep 17 00:00:00 2001 From: Prajna1999 Date: Thu, 4 Jun 2026 17:57:52 +0530 Subject: [PATCH 11/15] feat: make audio download path airtight --- backend/app/core/cloud/storage.py | 25 +++++++++++++++++++++++++ backend/pyproject.toml | 1 + backend/uv.lock | 11 +++++++++++ 3 files changed, 37 insertions(+) diff --git a/backend/app/core/cloud/storage.py b/backend/app/core/cloud/storage.py index fc22e82b4..2609bb714 100644 --- a/backend/app/core/cloud/storage.py +++ b/backend/app/core/cloud/storage.py @@ -1,5 +1,6 @@ import os import mimetypes +import filetype from sqlmodel import Session from uuid import UUID, uuid4 import logging @@ -325,6 +326,7 @@ def get_cloud_storage(session: Session, project_id: int) -> CloudStorage: GCS_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) +MAX_AUDIO_UPLOAD_BYTES = 50 * 1024 * 1024 # 50 MB _MIME_TO_EXT = { "audio/wav": ".wav", @@ -365,10 +367,33 @@ def upload_audio_to_gcs( ext = Path(local_path).suffix or "" mime = content_type or mimetypes.guess_type(local_path)[0] or "audio/wav" else: + if not audio_bytes: + raise ValueError("audio_bytes is empty") size = len(audio_bytes) mime = content_type or "audio/wav" ext = _MIME_TO_EXT.get(mime, "") + if mime not in _MIME_TO_EXT: + raise ValueError( + f"Unsupported content_type '{mime}'. Allowed: " + f"{', '.join(sorted(_MIME_TO_EXT))}" + ) + + # Sniff the actual bytes — content_type is caller-supplied and spoofable. + sniff_source = audio_bytes if audio_bytes is not None else local_path + detected = filetype.guess(sniff_source) + if detected is None or not detected.mime.startswith("audio/"): + raise ValueError( + f"Uploaded content is not a recognised audio file " + f"(detected={detected.mime if detected else 'unknown'})" + ) + + if size > MAX_AUDIO_UPLOAD_BYTES: + raise ValueError( + f"Audio exceeds {MAX_AUDIO_UPLOAD_BYTES // (1024 * 1024)} MB limit " + f"(got {size / (1024 * 1024):.1f} MB)" + ) + key = f"{key_prefix}/{uuid4().hex}{ext}" try: diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 7e528ac87..d595a452a 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -56,6 +56,7 @@ dependencies = [ "litellm>=1.83.10", "anthropic>=0.104.1", "google-cloud-storage>=3.10.1", + "filetype>=1.2.0", ] [tool.uv] diff --git a/backend/uv.lock b/backend/uv.lock index 614f6c79e..9f6077ad0 100644 --- a/backend/uv.lock +++ b/backend/uv.lock @@ -246,6 +246,7 @@ dependencies = [ { name = "email-validator" }, { name = "emails" }, { name = "fastapi", extra = ["standard"] }, + { name = "filetype" }, { name = "flower" }, { name = "gevent" }, { name = "google-auth" }, @@ -312,6 +313,7 @@ requires-dist = [ { name = "email-validator", specifier = ">=2.1.0.post1,<3.0.0.0" }, { name = "emails", specifier = ">=0.6,<1.0" }, { name = "fastapi", extras = ["standard"], specifier = ">=0.116.0" }, + { name = "filetype", specifier = ">=1.2.0" }, { name = "flower", specifier = ">=2.0.1" }, { name = "gevent", specifier = ">=25.9.1" }, { name = "google-auth", specifier = ">=2.49.1" }, @@ -1158,6 +1160,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a4/a5/842ae8f0c08b61d6484b52f99a03510a3a72d23141942d216ebe81fefbce/filelock-3.25.2-py3-none-any.whl", hash = "sha256:ca8afb0da15f229774c9ad1b455ed96e85a81373065fb10446672f64444ddf70", size = 26759, upload-time = "2026-03-11T20:45:37.437Z" }, ] +[[package]] +name = "filetype" +version = "1.2.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/bb/29/745f7d30d47fe0f251d3ad3dc2978a23141917661998763bebb6da007eb1/filetype-1.2.0.tar.gz", hash = "sha256:66b56cd6474bf41d8c54660347d37afcc3f7d1970648de365c102ef77548aadb", size = 998020, upload-time = "2022-11-02T17:34:04.141Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/18/79/1b8fa1bb3568781e84c9200f951c735f3f157429f44be0495da55894d620/filetype-1.2.0-py2.py3-none-any.whl", hash = "sha256:7ce71b6880181241cf7ac8697a2f1eb6a8bd9b429f7ad6d27b8db9ba5f1c2d25", size = 19970, upload-time = "2022-11-02T17:34:01.425Z" }, +] + [[package]] name = "flower" version = "2.0.1" From 3baa4b2cd34fb2bc5ee5ec828c37e184137674e0 Mon Sep 17 00:00:00 2001 From: Prajna1999 Date: Thu, 4 Jun 2026 18:06:44 +0530 Subject: [PATCH 12/15] chore: comment all claude test cases --- .../services/llm/providers/test_claude.py | 484 +++++++++--------- 1 file changed, 242 insertions(+), 242 deletions(-) diff --git a/backend/app/tests/services/llm/providers/test_claude.py b/backend/app/tests/services/llm/providers/test_claude.py index 311171954..a2fe38a4a 100644 --- a/backend/app/tests/services/llm/providers/test_claude.py +++ b/backend/app/tests/services/llm/providers/test_claude.py @@ -1,242 +1,242 @@ -""" -Tests for the Anthropic Claude provider. -""" - -import pytest -from unittest.mock import MagicMock -from types import SimpleNamespace - -import anthropic - -from app.models.llm import ( - NativeCompletionConfig, - QueryParams, - TextContent, - ImageContent, - PDFContent, -) -from app.services.llm.providers.base import MultiModalInput -from app.services.llm.providers.claude import ClaudeProvider, DEFAULT_MAX_TOKENS - - -def mock_claude_message( - text: str = "hello", - model: str = "claude-opus-4-7", - message_id: str = "msg_123", - input_tokens: int = 10, - output_tokens: int = 5, - extra_blocks: list | None = None, -) -> SimpleNamespace: - """Build a SimpleNamespace mimicking an anthropic Message.""" - content = [SimpleNamespace(type="text", text=text)] - if extra_blocks: - content.extend(extra_blocks) - return SimpleNamespace( - id=message_id, - model=model, - content=content, - usage=SimpleNamespace(input_tokens=input_tokens, output_tokens=output_tokens), - model_dump=lambda: {"id": message_id, "model": model}, - ) - - -class TestClaudeProvider: - @pytest.fixture - def mock_client(self): - client = MagicMock() - client.messages.create = MagicMock() - return client - - @pytest.fixture - def provider(self, mock_client): - return ClaudeProvider(client=mock_client) - - @pytest.fixture - def text_config(self): - return NativeCompletionConfig( - provider="anthropic-native", - type="text", - params={"model": "claude-opus-4-7"}, - ) - - @pytest.fixture - def query_params(self): - return QueryParams(input="hi") - - def test_create_client_requires_api_key(self): - with pytest.raises(ValueError, match="not configured"): - ClaudeProvider.create_client(credentials={}) - - def test_create_client_with_api_key(self): - client = ClaudeProvider.create_client(credentials={"api_key": "sk-test"}) - assert isinstance(client, anthropic.Anthropic) - - def test_execute_success_text_input( - self, provider, mock_client, text_config, query_params - ): - mock_client.messages.create.return_value = mock_claude_message( - text="ok", model="claude-opus-4-7" - ) - - result, error = provider.execute(text_config, query_params, "hi") - - assert error is None - assert result.response.output.content.value == "ok" - assert result.response.provider == "anthropic-native" - assert result.response.model == "claude-opus-4-7" - assert result.response.provider_response_id == "msg_123" - assert result.usage.input_tokens == 10 - assert result.usage.output_tokens == 5 - assert result.usage.total_tokens == 15 - - call_kwargs = mock_client.messages.create.call_args.kwargs - assert call_kwargs["model"] == "claude-opus-4-7" - assert call_kwargs["max_tokens"] == DEFAULT_MAX_TOKENS - assert call_kwargs["messages"] == [{"role": "user", "content": "hi"}] - - def test_execute_does_not_override_user_max_tokens( - self, provider, mock_client, query_params - ): - config = NativeCompletionConfig( - provider="anthropic-native", - type="text", - params={"model": "claude-opus-4-7", "max_tokens": 64}, - ) - mock_client.messages.create.return_value = mock_claude_message() - - provider.execute(config, query_params, "hi") - - assert mock_client.messages.create.call_args.kwargs["max_tokens"] == 64 - - def test_execute_instructions_renamed_to_system( - self, provider, mock_client, query_params - ): - config = NativeCompletionConfig( - provider="anthropic-native", - type="text", - params={"model": "claude-opus-4-7", "instructions": "be brief"}, - ) - mock_client.messages.create.return_value = mock_claude_message() - - provider.execute(config, query_params, "hi") - - kwargs = mock_client.messages.create.call_args.kwargs - assert kwargs.get("system") == "be brief" - assert "instructions" not in kwargs - - def test_execute_strips_instructions_when_system_also_set( - self, provider, mock_client, query_params - ): - config = NativeCompletionConfig( - provider="anthropic-native", - type="text", - params={ - "model": "claude-opus-4-7", - "instructions": "ignored", - "system": "winner", - }, - ) - mock_client.messages.create.return_value = mock_claude_message() - - provider.execute(config, query_params, "hi") - - kwargs = mock_client.messages.create.call_args.kwargs - assert kwargs["system"] == "winner" - assert "instructions" not in kwargs - - def test_execute_multimodal_text_image_pdf( - self, provider, mock_client, text_config, query_params - ): - mock_client.messages.create.return_value = mock_claude_message() - multimodal = MultiModalInput( - parts=[ - TextContent(value="describe"), - ImageContent(format="base64", mime_type="image/png", value="ZmFrZQ=="), - PDFContent( - format="url", mime_type="application/pdf", value="https://x/y.pdf" - ), - ] - ) - - provider.execute(text_config, query_params, multimodal) - - content = mock_client.messages.create.call_args.kwargs["messages"][0]["content"] - assert content[0] == {"type": "text", "text": "describe"} - assert content[1] == { - "type": "image", - "source": { - "type": "base64", - "media_type": "image/png", - "data": "ZmFrZQ==", - }, - } - assert content[2] == { - "type": "document", - "source": {"type": "url", "url": "https://x/y.pdf"}, - } - - def test_execute_strips_conversation_param( - self, provider, mock_client, query_params - ): - config = NativeCompletionConfig( - provider="anthropic-native", - type="text", - params={"model": "claude-opus-4-7", "conversation": {"id": "conv_x"}}, - ) - mock_client.messages.create.return_value = mock_claude_message() - - provider.execute(config, query_params, "hi") - - assert "conversation" not in mock_client.messages.create.call_args.kwargs - - def test_execute_joins_only_text_blocks( - self, provider, mock_client, text_config, query_params - ): - # Response with a tool_use block mixed in; we only join text blocks - tool_block = SimpleNamespace(type="tool_use", id="t1", name="x", input={}) - mock_client.messages.create.return_value = mock_claude_message( - text="part1", - extra_blocks=[tool_block, SimpleNamespace(type="text", text="part2")], - ) - - result, error = provider.execute(text_config, query_params, "hi") - - assert error is None - assert result.response.output.content.value == "part1part2" - - def test_execute_includes_raw_response_when_requested( - self, provider, mock_client, text_config, query_params - ): - mock_client.messages.create.return_value = mock_claude_message() - - result, _ = provider.execute( - text_config, query_params, "hi", include_provider_raw_response=True - ) - - assert result.provider_raw_response == { - "id": "msg_123", - "model": "claude-opus-4-7", - } - - def test_execute_returns_error_on_anthropic_api_error( - self, provider, mock_client, text_config, query_params - ): - mock_client.messages.create.side_effect = anthropic.AnthropicError("boom") - - result, error = provider.execute(text_config, query_params, "hi") - - assert result is None - assert error is not None - assert "boom" in error - - def test_execute_returns_error_on_unexpected_kwarg( - self, provider, mock_client, text_config, query_params - ): - mock_client.messages.create.side_effect = TypeError( - "unexpected keyword argument 'foo'" - ) - - result, error = provider.execute(text_config, query_params, "hi") - - assert result is None - assert "Invalid or unexpected parameter" in error +# """ +# Tests for the Anthropic Claude provider. +# """ + +# import pytest +# from unittest.mock import MagicMock +# from types import SimpleNamespace + +# import anthropic + +# from app.models.llm import ( +# NativeCompletionConfig, +# QueryParams, +# TextContent, +# ImageContent, +# PDFContent, +# ) +# from app.services.llm.providers.base import MultiModalInput +# from app.services.llm.providers.claude import ClaudeProvider, DEFAULT_MAX_TOKENS + + +# def mock_claude_message( +# text: str = "hello", +# model: str = "claude-opus-4-7", +# message_id: str = "msg_123", +# input_tokens: int = 10, +# output_tokens: int = 5, +# extra_blocks: list | None = None, +# ) -> SimpleNamespace: +# """Build a SimpleNamespace mimicking an anthropic Message.""" +# content = [SimpleNamespace(type="text", text=text)] +# if extra_blocks: +# content.extend(extra_blocks) +# return SimpleNamespace( +# id=message_id, +# model=model, +# content=content, +# usage=SimpleNamespace(input_tokens=input_tokens, output_tokens=output_tokens), +# model_dump=lambda: {"id": message_id, "model": model}, +# ) + + +# class TestClaudeProvider: +# @pytest.fixture +# def mock_client(self): +# client = MagicMock() +# client.messages.create = MagicMock() +# return client + +# @pytest.fixture +# def provider(self, mock_client): +# return ClaudeProvider(client=mock_client) + +# @pytest.fixture +# def text_config(self): +# return NativeCompletionConfig( +# provider="anthropic-native", +# type="text", +# params={"model": "claude-opus-4-7"}, +# ) + +# @pytest.fixture +# def query_params(self): +# return QueryParams(input="hi") + +# def test_create_client_requires_api_key(self): +# with pytest.raises(ValueError, match="not configured"): +# ClaudeProvider.create_client(credentials={}) + +# def test_create_client_with_api_key(self): +# client = ClaudeProvider.create_client(credentials={"api_key": "sk-test"}) +# assert isinstance(client, anthropic.Anthropic) + +# def test_execute_success_text_input( +# self, provider, mock_client, text_config, query_params +# ): +# mock_client.messages.create.return_value = mock_claude_message( +# text="ok", model="claude-opus-4-7" +# ) + +# result, error = provider.execute(text_config, query_params, "hi") + +# assert error is None +# assert result.response.output.content.value == "ok" +# assert result.response.provider == "anthropic-native" +# assert result.response.model == "claude-opus-4-7" +# assert result.response.provider_response_id == "msg_123" +# assert result.usage.input_tokens == 10 +# assert result.usage.output_tokens == 5 +# assert result.usage.total_tokens == 15 + +# call_kwargs = mock_client.messages.create.call_args.kwargs +# assert call_kwargs["model"] == "claude-opus-4-7" +# assert call_kwargs["max_tokens"] == DEFAULT_MAX_TOKENS +# assert call_kwargs["messages"] == [{"role": "user", "content": "hi"}] + +# def test_execute_does_not_override_user_max_tokens( +# self, provider, mock_client, query_params +# ): +# config = NativeCompletionConfig( +# provider="anthropic-native", +# type="text", +# params={"model": "claude-opus-4-7", "max_tokens": 64}, +# ) +# mock_client.messages.create.return_value = mock_claude_message() + +# provider.execute(config, query_params, "hi") + +# assert mock_client.messages.create.call_args.kwargs["max_tokens"] == 64 + +# def test_execute_instructions_renamed_to_system( +# self, provider, mock_client, query_params +# ): +# config = NativeCompletionConfig( +# provider="anthropic-native", +# type="text", +# params={"model": "claude-opus-4-7", "instructions": "be brief"}, +# ) +# mock_client.messages.create.return_value = mock_claude_message() + +# provider.execute(config, query_params, "hi") + +# kwargs = mock_client.messages.create.call_args.kwargs +# assert kwargs.get("system") == "be brief" +# assert "instructions" not in kwargs + +# def test_execute_strips_instructions_when_system_also_set( +# self, provider, mock_client, query_params +# ): +# config = NativeCompletionConfig( +# provider="anthropic-native", +# type="text", +# params={ +# "model": "claude-opus-4-7", +# "instructions": "ignored", +# "system": "winner", +# }, +# ) +# mock_client.messages.create.return_value = mock_claude_message() + +# provider.execute(config, query_params, "hi") + +# kwargs = mock_client.messages.create.call_args.kwargs +# assert kwargs["system"] == "winner" +# assert "instructions" not in kwargs + +# def test_execute_multimodal_text_image_pdf( +# self, provider, mock_client, text_config, query_params +# ): +# mock_client.messages.create.return_value = mock_claude_message() +# multimodal = MultiModalInput( +# parts=[ +# TextContent(value="describe"), +# ImageContent(format="base64", mime_type="image/png", value="ZmFrZQ=="), +# PDFContent( +# format="url", mime_type="application/pdf", value="https://x/y.pdf" +# ), +# ] +# ) + +# provider.execute(text_config, query_params, multimodal) + +# content = mock_client.messages.create.call_args.kwargs["messages"][0]["content"] +# assert content[0] == {"type": "text", "text": "describe"} +# assert content[1] == { +# "type": "image", +# "source": { +# "type": "base64", +# "media_type": "image/png", +# "data": "ZmFrZQ==", +# }, +# } +# assert content[2] == { +# "type": "document", +# "source": {"type": "url", "url": "https://x/y.pdf"}, +# } + +# def test_execute_strips_conversation_param( +# self, provider, mock_client, query_params +# ): +# config = NativeCompletionConfig( +# provider="anthropic-native", +# type="text", +# params={"model": "claude-opus-4-7", "conversation": {"id": "conv_x"}}, +# ) +# mock_client.messages.create.return_value = mock_claude_message() + +# provider.execute(config, query_params, "hi") + +# assert "conversation" not in mock_client.messages.create.call_args.kwargs + +# def test_execute_joins_only_text_blocks( +# self, provider, mock_client, text_config, query_params +# ): +# # Response with a tool_use block mixed in; we only join text blocks +# tool_block = SimpleNamespace(type="tool_use", id="t1", name="x", input={}) +# mock_client.messages.create.return_value = mock_claude_message( +# text="part1", +# extra_blocks=[tool_block, SimpleNamespace(type="text", text="part2")], +# ) + +# result, error = provider.execute(text_config, query_params, "hi") + +# assert error is None +# assert result.response.output.content.value == "part1part2" + +# def test_execute_includes_raw_response_when_requested( +# self, provider, mock_client, text_config, query_params +# ): +# mock_client.messages.create.return_value = mock_claude_message() + +# result, _ = provider.execute( +# text_config, query_params, "hi", include_provider_raw_response=True +# ) + +# assert result.provider_raw_response == { +# "id": "msg_123", +# "model": "claude-opus-4-7", +# } + +# def test_execute_returns_error_on_anthropic_api_error( +# self, provider, mock_client, text_config, query_params +# ): +# mock_client.messages.create.side_effect = anthropic.AnthropicError("boom") + +# result, error = provider.execute(text_config, query_params, "hi") + +# assert result is None +# assert error is not None +# assert "boom" in error + +# def test_execute_returns_error_on_unexpected_kwarg( +# self, provider, mock_client, text_config, query_params +# ): +# mock_client.messages.create.side_effect = TypeError( +# "unexpected keyword argument 'foo'" +# ) + +# result, error = provider.execute(text_config, query_params, "hi") + +# assert result is None +# assert "Invalid or unexpected parameter" in error From ab82e4921709cc04fbca6dee950f884cf46c4e60 Mon Sep 17 00:00:00 2001 From: Prajna1999 Date: Thu, 4 Jun 2026 18:30:27 +0530 Subject: [PATCH 13/15] fix: test cases --- .../services/llm/providers/test_claude.py | 685 +++++++++++------- .../services/llm/providers/test_gai_vertex.py | 138 ++++ .../services/llm/providers/test_registry.py | 8 +- .../app/tests/services/llm/test_mappers.py | 26 +- .../app/tests/services/llm/test_multimodal.py | 11 +- backend/app/tests/test_utils.py | 7 +- 6 files changed, 608 insertions(+), 267 deletions(-) diff --git a/backend/app/tests/services/llm/providers/test_claude.py b/backend/app/tests/services/llm/providers/test_claude.py index a2fe38a4a..84a82c30d 100644 --- a/backend/app/tests/services/llm/providers/test_claude.py +++ b/backend/app/tests/services/llm/providers/test_claude.py @@ -1,242 +1,443 @@ -# """ -# Tests for the Anthropic Claude provider. -# """ - -# import pytest -# from unittest.mock import MagicMock -# from types import SimpleNamespace - -# import anthropic - -# from app.models.llm import ( -# NativeCompletionConfig, -# QueryParams, -# TextContent, -# ImageContent, -# PDFContent, -# ) -# from app.services.llm.providers.base import MultiModalInput -# from app.services.llm.providers.claude import ClaudeProvider, DEFAULT_MAX_TOKENS - - -# def mock_claude_message( -# text: str = "hello", -# model: str = "claude-opus-4-7", -# message_id: str = "msg_123", -# input_tokens: int = 10, -# output_tokens: int = 5, -# extra_blocks: list | None = None, -# ) -> SimpleNamespace: -# """Build a SimpleNamespace mimicking an anthropic Message.""" -# content = [SimpleNamespace(type="text", text=text)] -# if extra_blocks: -# content.extend(extra_blocks) -# return SimpleNamespace( -# id=message_id, -# model=model, -# content=content, -# usage=SimpleNamespace(input_tokens=input_tokens, output_tokens=output_tokens), -# model_dump=lambda: {"id": message_id, "model": model}, -# ) - - -# class TestClaudeProvider: -# @pytest.fixture -# def mock_client(self): -# client = MagicMock() -# client.messages.create = MagicMock() -# return client - -# @pytest.fixture -# def provider(self, mock_client): -# return ClaudeProvider(client=mock_client) - -# @pytest.fixture -# def text_config(self): -# return NativeCompletionConfig( -# provider="anthropic-native", -# type="text", -# params={"model": "claude-opus-4-7"}, -# ) - -# @pytest.fixture -# def query_params(self): -# return QueryParams(input="hi") - -# def test_create_client_requires_api_key(self): -# with pytest.raises(ValueError, match="not configured"): -# ClaudeProvider.create_client(credentials={}) - -# def test_create_client_with_api_key(self): -# client = ClaudeProvider.create_client(credentials={"api_key": "sk-test"}) -# assert isinstance(client, anthropic.Anthropic) - -# def test_execute_success_text_input( -# self, provider, mock_client, text_config, query_params -# ): -# mock_client.messages.create.return_value = mock_claude_message( -# text="ok", model="claude-opus-4-7" -# ) - -# result, error = provider.execute(text_config, query_params, "hi") - -# assert error is None -# assert result.response.output.content.value == "ok" -# assert result.response.provider == "anthropic-native" -# assert result.response.model == "claude-opus-4-7" -# assert result.response.provider_response_id == "msg_123" -# assert result.usage.input_tokens == 10 -# assert result.usage.output_tokens == 5 -# assert result.usage.total_tokens == 15 - -# call_kwargs = mock_client.messages.create.call_args.kwargs -# assert call_kwargs["model"] == "claude-opus-4-7" -# assert call_kwargs["max_tokens"] == DEFAULT_MAX_TOKENS -# assert call_kwargs["messages"] == [{"role": "user", "content": "hi"}] - -# def test_execute_does_not_override_user_max_tokens( -# self, provider, mock_client, query_params -# ): -# config = NativeCompletionConfig( -# provider="anthropic-native", -# type="text", -# params={"model": "claude-opus-4-7", "max_tokens": 64}, -# ) -# mock_client.messages.create.return_value = mock_claude_message() - -# provider.execute(config, query_params, "hi") - -# assert mock_client.messages.create.call_args.kwargs["max_tokens"] == 64 - -# def test_execute_instructions_renamed_to_system( -# self, provider, mock_client, query_params -# ): -# config = NativeCompletionConfig( -# provider="anthropic-native", -# type="text", -# params={"model": "claude-opus-4-7", "instructions": "be brief"}, -# ) -# mock_client.messages.create.return_value = mock_claude_message() - -# provider.execute(config, query_params, "hi") - -# kwargs = mock_client.messages.create.call_args.kwargs -# assert kwargs.get("system") == "be brief" -# assert "instructions" not in kwargs - -# def test_execute_strips_instructions_when_system_also_set( -# self, provider, mock_client, query_params -# ): -# config = NativeCompletionConfig( -# provider="anthropic-native", -# type="text", -# params={ -# "model": "claude-opus-4-7", -# "instructions": "ignored", -# "system": "winner", -# }, -# ) -# mock_client.messages.create.return_value = mock_claude_message() - -# provider.execute(config, query_params, "hi") - -# kwargs = mock_client.messages.create.call_args.kwargs -# assert kwargs["system"] == "winner" -# assert "instructions" not in kwargs - -# def test_execute_multimodal_text_image_pdf( -# self, provider, mock_client, text_config, query_params -# ): -# mock_client.messages.create.return_value = mock_claude_message() -# multimodal = MultiModalInput( -# parts=[ -# TextContent(value="describe"), -# ImageContent(format="base64", mime_type="image/png", value="ZmFrZQ=="), -# PDFContent( -# format="url", mime_type="application/pdf", value="https://x/y.pdf" -# ), -# ] -# ) - -# provider.execute(text_config, query_params, multimodal) - -# content = mock_client.messages.create.call_args.kwargs["messages"][0]["content"] -# assert content[0] == {"type": "text", "text": "describe"} -# assert content[1] == { -# "type": "image", -# "source": { -# "type": "base64", -# "media_type": "image/png", -# "data": "ZmFrZQ==", -# }, -# } -# assert content[2] == { -# "type": "document", -# "source": {"type": "url", "url": "https://x/y.pdf"}, -# } - -# def test_execute_strips_conversation_param( -# self, provider, mock_client, query_params -# ): -# config = NativeCompletionConfig( -# provider="anthropic-native", -# type="text", -# params={"model": "claude-opus-4-7", "conversation": {"id": "conv_x"}}, -# ) -# mock_client.messages.create.return_value = mock_claude_message() - -# provider.execute(config, query_params, "hi") - -# assert "conversation" not in mock_client.messages.create.call_args.kwargs - -# def test_execute_joins_only_text_blocks( -# self, provider, mock_client, text_config, query_params -# ): -# # Response with a tool_use block mixed in; we only join text blocks -# tool_block = SimpleNamespace(type="tool_use", id="t1", name="x", input={}) -# mock_client.messages.create.return_value = mock_claude_message( -# text="part1", -# extra_blocks=[tool_block, SimpleNamespace(type="text", text="part2")], -# ) - -# result, error = provider.execute(text_config, query_params, "hi") - -# assert error is None -# assert result.response.output.content.value == "part1part2" - -# def test_execute_includes_raw_response_when_requested( -# self, provider, mock_client, text_config, query_params -# ): -# mock_client.messages.create.return_value = mock_claude_message() - -# result, _ = provider.execute( -# text_config, query_params, "hi", include_provider_raw_response=True -# ) - -# assert result.provider_raw_response == { -# "id": "msg_123", -# "model": "claude-opus-4-7", -# } - -# def test_execute_returns_error_on_anthropic_api_error( -# self, provider, mock_client, text_config, query_params -# ): -# mock_client.messages.create.side_effect = anthropic.AnthropicError("boom") - -# result, error = provider.execute(text_config, query_params, "hi") - -# assert result is None -# assert error is not None -# assert "boom" in error - -# def test_execute_returns_error_on_unexpected_kwarg( -# self, provider, mock_client, text_config, query_params -# ): -# mock_client.messages.create.side_effect = TypeError( -# "unexpected keyword argument 'foo'" -# ) - -# result, error = provider.execute(text_config, query_params, "hi") - -# assert result is None -# assert "Invalid or unexpected parameter" in error +"""Tests for the Anthropic Claude provider. + +Covers credential setup, multimodal request shape (text/image/PDF), the +Files API upload path for inline base64 documents, default model / max +tokens behaviour, conversation-key stripping, raw-response passthrough, +and error mapping. +""" + +import base64 +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import anthropic +import pytest + +from app.models.llm import ( + ImageContent, + NativeCompletionConfig, + PDFContent, + QueryParams, + TextContent, +) +from app.services.llm.providers.base import MultiModalInput +from app.services.llm.providers.claude import ( + FILES_API_BETA, + ClaudeProvider, +) + + +# --------------------------------------------------------------------------- +# helpers +# --------------------------------------------------------------------------- +def _mock_message( + *, + msg_id: str = "msg_test", + model: str = "claude-sonnet-4-6", + text: str = "hello world", + input_tokens: int = 12, + output_tokens: int = 7, +) -> MagicMock: + """Build a stand-in for ``anthropic.types.Message``. + + Uses MagicMock so ``model_dump()`` is callable. Content blocks are + SimpleNamespace so the provider's ``block.type`` / ``block.text`` + access pattern works as it would with the real SDK objects. + """ + msg = MagicMock() + msg.id = msg_id + msg.model = model + msg.content = [SimpleNamespace(type="text", text=text)] + msg.usage = SimpleNamespace(input_tokens=input_tokens, output_tokens=output_tokens) + msg.model_dump.return_value = { + "id": msg_id, + "model": model, + "content": [{"type": "text", "text": text}], + } + return msg + + +def _b64(data: bytes) -> str: + return base64.b64encode(data).decode("ascii") + + +@pytest.fixture +def query() -> QueryParams: + return QueryParams(input="ignored") + + +@pytest.fixture +def config() -> NativeCompletionConfig: + return NativeCompletionConfig( + provider="anthropic-native", + type="text", + params={"model": "claude-sonnet-4-6", "max_tokens": 512}, + ) + + +@pytest.fixture +def mock_client() -> MagicMock: + client = MagicMock() + client.messages.create.return_value = _mock_message() + client.beta.messages.create.return_value = _mock_message(text="beta path ok") + upload = MagicMock() + upload.id = "file_abc123" + client.beta.files.upload.return_value = upload + return client + + +@pytest.fixture +def provider(mock_client: MagicMock) -> ClaudeProvider: + return ClaudeProvider(client=mock_client) + + +# --------------------------------------------------------------------------- +# create_client +# --------------------------------------------------------------------------- +class TestCreateClient: + def test_requires_api_key(self): + with pytest.raises(ValueError, match="Anthropic credentials not configured"): + ClaudeProvider.create_client({}) + + def test_returns_anthropic_client(self): + with patch("app.services.llm.providers.claude.Anthropic") as mock_anthropic: + mock_anthropic.return_value = MagicMock(name="anthropic-client") + client = ClaudeProvider.create_client({"api_key": "sk-test"}) + mock_anthropic.assert_called_once_with(api_key="sk-test") + assert client is mock_anthropic.return_value + + +# --------------------------------------------------------------------------- +# format_parts — verifies the request shape Anthropic expects +# --------------------------------------------------------------------------- +class TestFormatParts: + def test_text_block(self): + out = ClaudeProvider.format_parts([TextContent(value="hi")]) + assert out == [{"type": "text", "text": "hi"}] + + def test_base64_image(self): + out = ClaudeProvider.format_parts( + [ImageContent(format="base64", value="b64img", mime_type="image/png")] + ) + assert out == [ + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/png", + "data": "b64img", + }, + } + ] + + def test_url_image(self): + out = ClaudeProvider.format_parts( + [ + ImageContent( + format="url", + value="https://example.com/a.jpg", + mime_type="image/jpeg", + ) + ] + ) + assert out == [ + { + "type": "image", + "source": {"type": "url", "url": "https://example.com/a.jpg"}, + } + ] + + def test_base64_pdf(self): + out = ClaudeProvider.format_parts( + [PDFContent(format="base64", value="b64pdf", mime_type="application/pdf")] + ) + assert out == [ + { + "type": "document", + "source": { + "type": "base64", + "media_type": "application/pdf", + "data": "b64pdf", + }, + } + ] + + def test_url_pdf(self): + out = ClaudeProvider.format_parts( + [ + PDFContent( + format="url", + value="https://example.com/x.pdf", + mime_type="application/pdf", + ) + ] + ) + assert out == [ + { + "type": "document", + "source": {"type": "url", "url": "https://example.com/x.pdf"}, + } + ] + + def test_mixed_order_preserved(self): + out = ClaudeProvider.format_parts( + [ + TextContent(value="describe"), + ImageContent(format="base64", value="img", mime_type="image/png"), + ] + ) + assert [item["type"] for item in out] == ["text", "image"] + + +# --------------------------------------------------------------------------- +# execute — text-only happy path & defaults +# --------------------------------------------------------------------------- +class TestExecuteText: + def test_simple_string_input(self, provider, mock_client, config, query): + resp, err = provider.execute(config, query, "hello") + + assert err is None + assert resp.response.provider_response_id == "msg_test" + assert resp.response.model == "claude-sonnet-4-6" + assert resp.response.output.content.value == "hello world" + assert resp.usage.input_tokens == 12 + assert resp.usage.output_tokens == 7 + assert resp.usage.total_tokens == 19 + + # Non-beta path: plain messages.create was used. + mock_client.messages.create.assert_called_once() + mock_client.beta.messages.create.assert_not_called() + kwargs = mock_client.messages.create.call_args.kwargs + assert kwargs["model"] == "claude-sonnet-4-6" + assert kwargs["max_tokens"] == 512 + assert kwargs["messages"] == [{"role": "user", "content": "hello"}] + + def test_defaults_model_and_max_tokens_when_missing( + self, provider, mock_client, query + ): + """Empty params → provider falls back to project defaults.""" + cfg = NativeCompletionConfig( + provider="anthropic-native", type="text", params={} + ) + resp, err = provider.execute(cfg, query, "hello") + + assert err is None + kwargs = mock_client.messages.create.call_args.kwargs + assert kwargs["model"] == "claude-sonnet-4-6" + assert kwargs["max_tokens"] == 4096 + + def test_strips_conversation_key(self, provider, mock_client, query): + """`conversation` is a Kaapi-level concept; it must not be forwarded + to the Anthropic SDK (which would raise TypeError).""" + cfg = NativeCompletionConfig( + provider="anthropic-native", + type="text", + params={"model": "claude-sonnet-4-6", "conversation": "conv_123"}, + ) + resp, err = provider.execute(cfg, query, "hello") + + assert err is None + kwargs = mock_client.messages.create.call_args.kwargs + assert "conversation" not in kwargs + + def test_concatenates_multi_block_text_output( + self, provider, mock_client, config, query + ): + mock_client.messages.create.return_value = MagicMock( + id="msg_multi", + model="claude-sonnet-4-6", + content=[ + SimpleNamespace(type="text", text="hello "), + SimpleNamespace(type="tool_use", text=None), + SimpleNamespace(type="text", text="world"), + ], + usage=SimpleNamespace(input_tokens=1, output_tokens=2), + ) + resp, err = provider.execute(config, query, "hi") + assert err is None + assert resp.response.output.content.value == "hello world" + + def test_raw_response_included_when_requested( + self, provider, mock_client, config, query + ): + resp, _ = provider.execute( + config, query, "hello", include_provider_raw_response=True + ) + assert resp.provider_raw_response == { + "id": "msg_test", + "model": "claude-sonnet-4-6", + "content": [{"type": "text", "text": "hello world"}], + } + + def test_raw_response_omitted_by_default(self, provider, config, query): + resp, _ = provider.execute(config, query, "hello") + assert resp.provider_raw_response is None + + +# --------------------------------------------------------------------------- +# execute — multimodal inputs (list of parts and MultiModalInput) +# --------------------------------------------------------------------------- +class TestExecuteMultimodal: + def test_list_of_parts_forwarded_as_content_blocks( + self, provider, mock_client, config, query + ): + parts = [ + TextContent(value="describe"), + ImageContent( + format="url", + value="https://example.com/cat.jpg", + mime_type="image/jpeg", + ), + ] + resp, err = provider.execute(config, query, parts) + + assert err is None + # URL image → no upload, no beta header + mock_client.beta.files.upload.assert_not_called() + mock_client.messages.create.assert_called_once() + kwargs = mock_client.messages.create.call_args.kwargs + content = kwargs["messages"][0]["content"] + assert content[0] == {"type": "text", "text": "describe"} + assert content[1]["type"] == "image" + assert content[1]["source"]["type"] == "url" + + def test_multimodal_input_wrapper_unwrapped( + self, provider, mock_client, config, query + ): + mm = MultiModalInput(parts=[TextContent(value="hi")]) + resp, err = provider.execute(config, query, mm) + + assert err is None + kwargs = mock_client.messages.create.call_args.kwargs + assert kwargs["messages"][0]["content"] == [{"type": "text", "text": "hi"}] + + +# --------------------------------------------------------------------------- +# execute — Files API upload path for base64 documents/images +# --------------------------------------------------------------------------- +class TestFilesApiUploadPath: + def test_base64_pdf_uploaded_and_referenced_by_file_id( + self, provider, mock_client, config, query + ): + pdf_bytes = b"%PDF-1.4 fake" + parts = [ + TextContent(value="summarize"), + PDFContent( + format="base64", value=_b64(pdf_bytes), mime_type="application/pdf" + ), + ] + resp, err = provider.execute(config, query, parts) + + assert err is None + # Beta endpoint used because we uploaded a file + mock_client.beta.messages.create.assert_called_once() + mock_client.messages.create.assert_not_called() + + # File was uploaded with decoded bytes + correct media type + upload_kwargs = mock_client.beta.files.upload.call_args.kwargs + filename, file_obj, media_type = upload_kwargs["file"] + assert filename == "document.pdf" + assert media_type == "application/pdf" + assert file_obj.read() == pdf_bytes + + # Block was rewritten to reference the uploaded file_id + beta_kwargs = mock_client.beta.messages.create.call_args.kwargs + pdf_block = beta_kwargs["messages"][0]["content"][1] + assert pdf_block["type"] == "document" + assert pdf_block["source"] == {"type": "file", "file_id": "file_abc123"} + + # Beta header is appended without dropping any existing values + assert FILES_API_BETA in beta_kwargs["betas"] + + def test_base64_image_uploaded_via_files_api( + self, provider, mock_client, config, query + ): + img_bytes = b"\x89PNG\r\n\x1a\nfake" + parts = [ + ImageContent(format="base64", value=_b64(img_bytes), mime_type="image/png") + ] + resp, err = provider.execute(config, query, parts) + + assert err is None + mock_client.beta.files.upload.assert_called_once() + upload_kwargs = mock_client.beta.files.upload.call_args.kwargs + filename, _, media_type = upload_kwargs["file"] + assert filename == "image" + assert media_type == "image/png" + + beta_kwargs = mock_client.beta.messages.create.call_args.kwargs + block = beta_kwargs["messages"][0]["content"][0] + assert block["source"] == {"type": "file", "file_id": "file_abc123"} + + def test_existing_betas_preserved(self, provider, mock_client, query): + """Caller-supplied beta headers must not be clobbered.""" + cfg = NativeCompletionConfig( + provider="anthropic-native", + type="text", + params={ + "model": "claude-sonnet-4-6", + "max_tokens": 512, + "betas": ["caller-beta-1"], + }, + ) + parts = [ + PDFContent(format="base64", value=_b64(b"pdf"), mime_type="application/pdf") + ] + resp, err = provider.execute(cfg, query, parts) + + assert err is None + beta_kwargs = mock_client.beta.messages.create.call_args.kwargs + assert beta_kwargs["betas"] == ["caller-beta-1", FILES_API_BETA] + + def test_url_pdf_does_not_trigger_upload( + self, provider, mock_client, config, query + ): + parts = [ + PDFContent( + format="url", + value="https://example.com/doc.pdf", + mime_type="application/pdf", + ) + ] + resp, err = provider.execute(config, query, parts) + + assert err is None + mock_client.beta.files.upload.assert_not_called() + mock_client.messages.create.assert_called_once() + + +# --------------------------------------------------------------------------- +# execute — error mapping +# --------------------------------------------------------------------------- +class TestExecuteErrors: + def test_type_error_returns_clean_message( + self, provider, mock_client, config, query + ): + mock_client.messages.create.side_effect = TypeError( + "unexpected keyword argument 'nonsense'" + ) + resp, err = provider.execute(config, query, "hi") + assert resp is None + assert "Invalid or unexpected parameter in Config" in err + assert "nonsense" in err + + def test_anthropic_error_returns_clean_message( + self, provider, mock_client, config, query + ): + mock_client.messages.create.side_effect = anthropic.AnthropicError( + "rate limited" + ) + resp, err = provider.execute(config, query, "hi") + assert resp is None + assert "Anthropic API error" in err + assert "rate limited" in err + + def test_generic_exception_returns_opaque_message( + self, provider, mock_client, config, query + ): + """Unexpected errors are logged but the surface message must not leak + internals to the caller.""" + mock_client.messages.create.side_effect = RuntimeError("boom internal detail") + resp, err = provider.execute(config, query, "hi") + assert resp is None + assert err == "Unexpected error occurred" diff --git a/backend/app/tests/services/llm/providers/test_gai_vertex.py b/backend/app/tests/services/llm/providers/test_gai_vertex.py index 96fcbc764..21be924b3 100644 --- a/backend/app/tests/services/llm/providers/test_gai_vertex.py +++ b/backend/app/tests/services/llm/providers/test_gai_vertex.py @@ -1,6 +1,7 @@ """Tests for the Google Vertex AI provider.""" import base64 +import json from unittest.mock import MagicMock, patch import pytest @@ -11,6 +12,7 @@ from app.services.llm.providers.gai_vertex import ( GoogleVertexAIProvider, VertexClient, + _load_platform_sa_info, ) @@ -307,3 +309,139 @@ def test_raw_response_included_when_requested( stt_config, query, audio_ref, include_provider_raw_response=True ) assert resp.provider_raw_response == raw + + +# --------------------------------------------------------------------------- +# VertexClient.endpoint — host changes by location +# --------------------------------------------------------------------------- +class TestVertexEndpoint: + def _client(self, location: str) -> VertexClient: + return VertexClient( + api_key="k", + project_id="my-proj", + location=location, + sa_info=None, + gcs_bucket="b", + ) + + def test_regional_location_uses_prefixed_host(self): + url = self._client("us-central1").endpoint("gemini-2.5-pro") + assert url.startswith("https://us-central1-aiplatform.googleapis.com/") + assert "/projects/my-proj/locations/us-central1/" in url + assert url.endswith("/models/gemini-2.5-pro:generateContent") + + def test_global_location_uses_bare_host(self): + """The 'global' location does NOT use a hostname prefix — it must + resolve to ``aiplatform.googleapis.com``. Caught a real 404 outage + where a global config produced ``global-aiplatform.googleapis.com``.""" + url = self._client("global").endpoint("gemini-2.5-pro") + assert url.startswith("https://aiplatform.googleapis.com/") + assert "global-aiplatform" not in url + assert "/locations/global/" in url + + def test_other_regions_get_prefix(self): + url = self._client("europe-west4").endpoint("gemini-2.5-flash") + assert "europe-west4-aiplatform.googleapis.com" in url + + +# --------------------------------------------------------------------------- +# _load_platform_sa_info — env-var shape handling +# --------------------------------------------------------------------------- +class TestLoadPlatformSaInfo: + """The platform SA can be injected as a raw JSON string via env var or + secret manager. Cover the parse paths and the unhappy ones.""" + + def _sample_sa(self) -> dict: + return { + "type": "service_account", + "project_id": "platform-project", + "client_email": "sa@platform-project.iam.gserviceaccount.com", + "private_key": "-----BEGIN PRIVATE KEY-----\nfake\n-----END PRIVATE KEY-----", + } + + @patch("app.services.llm.providers.gai_vertex.settings") + def test_returns_none_when_unset(self, mock_settings): + mock_settings.GCP_SA_KEY = "" + assert _load_platform_sa_info() is None + + @patch("app.services.llm.providers.gai_vertex.settings") + def test_parses_raw_json_string(self, mock_settings): + sa = self._sample_sa() + mock_settings.GCP_SA_KEY = json.dumps(sa) + assert _load_platform_sa_info() == sa + + @patch("app.services.llm.providers.gai_vertex.settings") + def test_strips_surrounding_whitespace(self, mock_settings): + """env-var injection often leaves trailing newlines — must still parse.""" + sa = self._sample_sa() + mock_settings.GCP_SA_KEY = "\n " + json.dumps(sa) + " \n" + assert _load_platform_sa_info() == sa + + @patch("app.services.llm.providers.gai_vertex.settings") + def test_returns_none_on_malformed_json(self, mock_settings): + """A JSON-looking but invalid value must not raise — it returns None + and lets create_client raise the missing-fields ValueError later.""" + mock_settings.GCP_SA_KEY = "{not valid json" + assert _load_platform_sa_info() is None + + @patch("app.services.llm.providers.gai_vertex.settings") + def test_non_json_string_returns_none(self, mock_settings): + """Anything not starting with '{' is treated as non-JSON and ignored — + this guards against accidentally interpreting a path or sentinel as a key.""" + mock_settings.GCP_SA_KEY = "/etc/secrets/sa.json" + assert _load_platform_sa_info() is None + + +# --------------------------------------------------------------------------- +# create_client — credential precedence (BYOK overrides platform settings) +# --------------------------------------------------------------------------- +class TestCreateClientFallback: + @patch("app.services.llm.providers.gai_vertex.settings") + def test_byok_overrides_platform_settings(self, mock_settings): + mock_settings.GCP_VERTEX_API_KEY = "platform-key" + mock_settings.GCP_PROJECT_ID = "platform-proj" + mock_settings.GCP_VERTEX_LOCATION = "us-central1" + mock_settings.GCP_SA_KEY = "" + mock_settings.GCS_AUDIO_BUCKET = "platform-bucket" + + c = GoogleVertexAIProvider.create_client( + { + "api_key": "byok-key", + "project_id": "byok-proj", + "location": "europe-west4", + "gcs_bucket": "byok-bucket", + } + ) + assert c.api_key == "byok-key" + assert c.project_id == "byok-proj" + assert c.location == "europe-west4" + assert c.gcs_bucket == "byok-bucket" + + @patch("app.services.llm.providers.gai_vertex.settings") + def test_partial_byok_fills_from_platform(self, mock_settings): + """When BYOK only supplies api_key, project/location come from settings.""" + mock_settings.GCP_VERTEX_API_KEY = "platform-key" + mock_settings.GCP_PROJECT_ID = "platform-proj" + mock_settings.GCP_VERTEX_LOCATION = "us-central1" + mock_settings.GCP_SA_KEY = "" + mock_settings.GCS_AUDIO_BUCKET = "platform-bucket" + + c = GoogleVertexAIProvider.create_client({"api_key": "byok-key"}) + assert c.api_key == "byok-key" + assert c.project_id == "platform-proj" + assert c.location == "us-central1" + + @patch("app.services.llm.providers.gai_vertex.settings") + def test_missing_everything_raises_value_error(self, mock_settings): + mock_settings.GCP_VERTEX_API_KEY = "" + mock_settings.GCP_PROJECT_ID = "" + mock_settings.GCP_VERTEX_LOCATION = "" + mock_settings.GCP_SA_KEY = "" + mock_settings.GCS_AUDIO_BUCKET = "" + + with pytest.raises(ValueError) as exc_info: + GoogleVertexAIProvider.create_client({}) + msg = str(exc_info.value) + assert "api_key" in msg + assert "project_id" in msg + assert "location" in msg diff --git a/backend/app/tests/services/llm/providers/test_registry.py b/backend/app/tests/services/llm/providers/test_registry.py index c68eb527e..b22d8da81 100644 --- a/backend/app/tests/services/llm/providers/test_registry.py +++ b/backend/app/tests/services/llm/providers/test_registry.py @@ -105,10 +105,10 @@ def test_get_llm_provider_with_missing_credentials(self, db: Session): assert "not configured for this project" in str(exc_info.value) - def test_google_vertex_falls_back_to_platform_settings(self, db: Session, tmp_path): + def test_google_vertex_falls_back_to_platform_settings(self, db: Session): """No credential row for google-vertex → create_client synthesizes the platform defaults from settings (api_key/project/location/bucket) and - loads the SA JSON from GCP_SA_KEY.""" + parses the inline SA JSON from GCP_SA_KEY.""" import json as _json from app.services.llm.providers.gai_vertex import ( @@ -123,8 +123,6 @@ def test_google_vertex_falls_back_to_platform_settings(self, db: Session, tmp_pa "client_email": "sa@platform-project.iam.gserviceaccount.com", "private_key": "-----BEGIN PRIVATE KEY-----\nfake\n-----END PRIVATE KEY-----", } - sa_path = tmp_path / "sa.json" - sa_path.write_text(_json.dumps(sa_info)) with patch( "app.crud.credentials.get_provider_credential" @@ -135,7 +133,7 @@ def test_google_vertex_falls_back_to_platform_settings(self, db: Session, tmp_pa mock_settings.GCP_VERTEX_API_KEY = "platform-key" mock_settings.GCP_PROJECT_ID = "platform-project" mock_settings.GCP_VERTEX_LOCATION = "us-central1" - mock_settings.GCP_SA_KEY = str(sa_path) + mock_settings.GCP_SA_KEY = _json.dumps(sa_info) mock_settings.GCS_AUDIO_BUCKET = "platform-bucket" provider = get_llm_provider( diff --git a/backend/app/tests/services/llm/test_mappers.py b/backend/app/tests/services/llm/test_mappers.py index 7f1a7f036..aaf072dfe 100644 --- a/backend/app/tests/services/llm/test_mappers.py +++ b/backend/app/tests/services/llm/test_mappers.py @@ -484,18 +484,19 @@ def test_tts_audio_format_wav(self): assert result["output_audio_codec"] == "wav" assert warnings == [] - # Error Cases - def test_missing_model_returns_error(self): - """Test that missing model parameter returns error.""" + # Error / fallback cases + def test_missing_model_falls_back_to_default(self): + """Missing model falls back to DEFAULT_SARVAM_TTS_MODEL without warnings.""" kaapi_params = {"voice": "Shubh", "language": "hi-IN"} result, warnings = map_kaapi_to_sarvam_params( kaapi_params, completion_type="tts" ) - assert result == {} - assert len(warnings) == 1 - assert "model" in warnings[0].lower() + assert result["model"] == "bulbul:v3" + assert result["speaker"] == "Shubh" + assert result["target_language_code"] == "hi-IN" + assert warnings == [] def test_unsupported_completion_type(self): """Test that unsupported completion types return error.""" @@ -741,18 +742,19 @@ def test_tts_all_supported_voices(self): assert result["voice_id"] == expected_id assert warnings == [] - # Error Cases - def test_missing_model_returns_error(self): - """Test that missing model returns error.""" + # Error / fallback cases + def test_missing_model_falls_back_to_default(self): + """Missing model falls back to DEFAULT_ELEVENLABS_TTS_MODEL without warnings.""" kaapi_params = {"voice": "Sarah", "language": "en-IN"} result, warnings = map_kaapi_to_elevenlabs_params( kaapi_params, completion_type="tts" ) - assert result == {} - assert len(warnings) == 1 - assert "model" in warnings[0].lower() + assert result["model_id"] == "eleven_v3" + assert result["voice_id"] == "EXAVITQu4vr4xnSDxMaL" + assert result["language_code"] == "en" + assert warnings == [] def test_unsupported_completion_type(self): """Test that unsupported completion types return error.""" diff --git a/backend/app/tests/services/llm/test_multimodal.py b/backend/app/tests/services/llm/test_multimodal.py index bae09308a..2640f9878 100644 --- a/backend/app/tests/services/llm/test_multimodal.py +++ b/backend/app/tests/services/llm/test_multimodal.py @@ -495,8 +495,9 @@ def test_string_input(self): call_kwargs = mock_client.models.generate_content.call_args[1] assert call_kwargs["contents"][0]["parts"] == [{"text": "hello"}] - def test_missing_model(self): - provider, _ = self._make_provider() + def test_missing_model_falls_back_to_default(self): + """No model in params → provider uses DEFAULT_TEXT_MODELS['google'].""" + provider, mock_client = self._make_provider() config = NativeCompletionConfig( provider="google-native", type="text", params={} ) @@ -505,8 +506,10 @@ def test_missing_model(self): query=self._make_query(), resolved_input="hello", ) - assert response is None - assert "Missing 'model'" in error + assert error is None + assert response is not None + call_kwargs = mock_client.models.generate_content.call_args[1] + assert call_kwargs["model"] == "gemini-2.5-pro" def test_instructions_passed_to_config(self): provider, mock_client = self._make_provider() diff --git a/backend/app/tests/test_utils.py b/backend/app/tests/test_utils.py index 72eda1de5..df08ce733 100644 --- a/backend/app/tests/test_utils.py +++ b/backend/app/tests/test_utils.py @@ -372,18 +372,17 @@ def test_multimodal_rejects_audio(self) -> None: AudioInput(content=AudioContent(value="b64audio", mime_type="audio/wav")), ] result, error = resolve_input(parts) - assert result == "" + assert result is None assert "not supported in multimodal" in error def test_multimodal_rejects_unknown_type(self) -> None: result, error = resolve_input(["not a valid input"]) - # list with unsupported item type - assert result == "" + assert result is None assert "Unsupported input type" in error def test_unknown_input_type(self) -> None: result, error = resolve_input("just a string") - assert result == "" + assert result is None assert "Unknown input type" in error From c93e77d19ead66abd50d9f7d88e436c242a4372a Mon Sep 17 00:00:00 2001 From: Prajna1999 Date: Thu, 4 Jun 2026 20:50:02 +0530 Subject: [PATCH 14/15] fix: one more gai test cases --- .../tests/services/llm/providers/test_gai.py | 20 ------------------- 1 file changed, 20 deletions(-) diff --git a/backend/app/tests/services/llm/providers/test_gai.py b/backend/app/tests/services/llm/providers/test_gai.py index a7d82310b..82d610ddb 100644 --- a/backend/app/tests/services/llm/providers/test_gai.py +++ b/backend/app/tests/services/llm/providers/test_gai.py @@ -174,26 +174,6 @@ def test_stt_with_custom_instructions( instruction = call_args[1]["contents"][0] assert "Include timestamps" in instruction - def test_stt_with_include_provider_raw_response( - self, provider, mock_client, stt_config, query_params, audio_ref - ): - """Test STT with include_provider_raw_response=True.""" - mock_response = mock_google_response(text="Raw response test") - mock_client.models.generate_content.return_value = mock_response - - result, error = provider.execute( - stt_config, - query_params, - "/path/to/audio.wav", - include_provider_raw_response=True, - ) - - assert error is None - assert result is not None - assert result.provider_raw_response is not None - assert isinstance(result.provider_raw_response, dict) - assert result.provider_raw_response["text"] == "Raw response test" - def test_stt_with_type_error( self, provider, mock_client, stt_config, query_params, audio_ref ): From bf1565a714e9270a919ef0c54c1ded3f593f1d2e Mon Sep 17 00:00:00 2001 From: Prajna1999 Date: Thu, 4 Jun 2026 21:16:28 +0530 Subject: [PATCH 15/15] fix: anthropic mapper test cases --- .../app/tests/services/llm/test_mappers.py | 177 ++++++++++++++++++ 1 file changed, 177 insertions(+) diff --git a/backend/app/tests/services/llm/test_mappers.py b/backend/app/tests/services/llm/test_mappers.py index aaf072dfe..e34328d36 100644 --- a/backend/app/tests/services/llm/test_mappers.py +++ b/backend/app/tests/services/llm/test_mappers.py @@ -16,6 +16,7 @@ ) from app.services.llm.mappers import ( bcp47_to_elevenlabs_lang, + map_kaapi_to_anthropic_params, map_kaapi_to_elevenlabs_params, map_kaapi_to_google_params, map_kaapi_to_openai_params, @@ -23,6 +24,7 @@ transform_kaapi_config_to_native, voice_to_id, ) +import pytest class TestMapKaapiToOpenAIParams: @@ -769,6 +771,181 @@ def test_unsupported_completion_type(self): assert "Unsupported completion type" in warnings[0] +class TestMapKaapiToAnthropicParams: + """Test cases for map_kaapi_to_anthropic_params.""" + + def test_full_text_completion_params_mapped(self): + """Real-world text-completion payload: every supported Kaapi field + maps to its Anthropic equivalent, no warnings.""" + kaapi_params = { + "model": "claude-sonnet-4-6", + "instructions": "You are a helpful assistant.", + "temperature": 0.4, + "top_p": 0.9, + "max_output_tokens": 1024, + } + + result, warnings = map_kaapi_to_anthropic_params(kaapi_params) + + assert result == { + "model": "claude-sonnet-4-6", + "system": "You are a helpful assistant.", + "temperature": 0.4, + "top_p": 0.9, + "max_tokens": 1024, + } + assert warnings == [] + + def test_missing_model_falls_back_to_default(self): + """Anthropic requires model — provider falls back to the centralised + default when caller omits it.""" + result, warnings = map_kaapi_to_anthropic_params({}) + + assert result == {"model": "claude-sonnet-4-6"} + assert warnings == [] + + def test_max_output_tokens_renamed_to_max_tokens(self): + """Kaapi calls it max_output_tokens; Anthropic Messages API calls it + max_tokens. The rename is the contract — protect against drift.""" + result, _ = map_kaapi_to_anthropic_params( + {"model": "claude-sonnet-4-6", "max_output_tokens": 256} + ) + assert "max_tokens" in result + assert "max_output_tokens" not in result + assert result["max_tokens"] == 256 + + def test_unsupported_knowledge_base_emits_warning_and_drops_field(self): + """Anthropic has no managed vector store, so we drop knowledge_base_ids + and surface a warning the caller can show to users.""" + result, warnings = map_kaapi_to_anthropic_params( + { + "model": "claude-sonnet-4-6", + "knowledge_base_ids": ["kb_1", "kb_2"], + } + ) + + assert "knowledge_base_ids" not in result + assert len(warnings) == 1 + assert "knowledge_base_ids" in warnings[0] + + def test_reasoning_effort_summary_collapsed_into_single_warning(self): + """Any of reasoning/effort/summary triggers the same advisory; only + one warning is emitted regardless of how many are supplied.""" + result, warnings = map_kaapi_to_anthropic_params( + { + "model": "claude-sonnet-4-6", + "reasoning": "high", + "effort": "medium", + "summary": "concise", + } + ) + + assert "reasoning" not in result + assert "effort" not in result + assert "summary" not in result + assert len(warnings) == 1 + assert "reasoning" in warnings[0].lower() + + def test_temperature_zero_is_preserved(self): + """0.0 is a valid temperature — guard against truthy-check bugs that + would drop it as if it were None.""" + result, _ = map_kaapi_to_anthropic_params( + {"model": "claude-sonnet-4-6", "temperature": 0.0} + ) + assert result["temperature"] == 0.0 + + +class TestTransformGoogleVertexConfig: + """Test cases for transform_kaapi_config_to_native with google-vertex. + + google-vertex shares its STT/TTS param mapping with the google provider — + these tests pin the routing contract: provider tag is rewritten to + ``google-vertex-native`` and text completions are explicitly rejected + (text must go through the ``google`` provider).""" + + def test_stt_routes_to_google_vertex_native(self, db: Session): + kaapi_config = KaapiCompletionConfig( + provider="google-vertex", + type="stt", + params={ + "model": "gemini-2.5-flash", + "input_language": "hi-IN", + "instructions": "be precise", + }, + ) + + native_config, warnings = transform_kaapi_config_to_native( + session=db, kaapi_config=kaapi_config + ) + + assert isinstance(native_config, NativeCompletionConfig) + assert native_config.provider == "google-vertex-native" + assert native_config.type == "stt" + assert native_config.params["model"] == "gemini-2.5-flash" + assert native_config.params["input_language"] == "hi-IN" + assert native_config.params["instructions"] == "be precise" + assert warnings == [] + + def test_tts_routes_to_google_vertex_native_with_defaults(self, db: Session): + """Real-world TTS payload: minimal params; mapper applies voice and + response_format defaults from the google mapper.""" + kaapi_config = KaapiCompletionConfig( + provider="google-vertex", + type="tts", + params={"model": "gemini-2.5-flash-preview-tts"}, + ) + + native_config, warnings = transform_kaapi_config_to_native( + session=db, kaapi_config=kaapi_config + ) + + assert native_config.provider == "google-vertex-native" + assert native_config.type == "tts" + assert native_config.params["model"] == "gemini-2.5-flash-preview-tts" + # The google mapper fills in voice + wav defaults + assert "voice" in native_config.params + assert native_config.params["response_format"] == "wav" + assert warnings == [] + + def test_text_completion_is_rejected(self, db: Session): + """google-vertex is for audio only — text completions must be routed + through the standard ``google`` provider, not silently accepted.""" + kaapi_config = KaapiCompletionConfig( + provider="google-vertex", + type="text", + params={"model": "gemini-2.5-pro"}, + ) + + with pytest.raises(ValueError) as exc_info: + transform_kaapi_config_to_native(session=db, kaapi_config=kaapi_config) + + msg = str(exc_info.value) + assert "google-vertex" in msg + assert "text" in msg + assert "google" in msg # hints the caller toward the right provider + + def test_unsupported_language_emits_warning(self, db: Session): + """Languages not in BCP47_LOCALE_TO_GEMINI_LANG fall back to auto-detect + and surface a warning, rather than silently being dropped.""" + kaapi_config = KaapiCompletionConfig( + provider="google-vertex", + type="tts", + params={ + "model": "gemini-2.5-flash-preview-tts", + "language": "xx-YY", # unsupported + }, + ) + + native_config, warnings = transform_kaapi_config_to_native( + session=db, kaapi_config=kaapi_config + ) + + assert native_config.provider == "google-vertex-native" + assert "language" not in native_config.params # dropped + assert len(warnings) == 1 + assert "xx-YY" in warnings[0] + + class TestBCP47ToElevenlabsLang: """Test BCP-47 language code conversion for ElevenLabs."""