diff --git a/py/packages/genkit/src/genkit/_ai/_aio.py b/py/packages/genkit/src/genkit/_ai/_aio.py index 8668102eb0..7851b960e5 100644 --- a/py/packages/genkit/src/genkit/_ai/_aio.py +++ b/py/packages/genkit/src/genkit/_ai/_aio.py @@ -298,8 +298,8 @@ def define_middleware( def middleware( self, - name: str, *, + name: str, description: str | None = None, ) -> Callable[[type[MiddlewareT]], type[MiddlewareT]]: """Decorator that registers a custom middleware on this app's registry.""" diff --git a/py/packages/genkit/src/genkit/_core/_error.py b/py/packages/genkit/src/genkit/_core/_error.py index 71f26f4a3d..82a6fb9135 100644 --- a/py/packages/genkit/src/genkit/_core/_error.py +++ b/py/packages/genkit/src/genkit/_core/_error.py @@ -195,8 +195,14 @@ def __init__( self.status: StatusName = temp_status self.http_code: int = http_status_code(temp_status) + # When this error wraps another (the common shape — the action runtime + # catches the underlying failure and re-raises as ``GenkitError(..., + # cause=original)``), surface the cause in the default string form so + # downstream consumers (logs, model-facing tool error messages, the Dev + # UI) see the real reason instead of the bare wrapper text. source_prefix = f'{source}: ' if source else '' - super().__init__(f'{source_prefix}{self.status}: {message}') + cause_suffix = f': {cause}' if cause else '' + super().__init__(f'{source_prefix}{self.status}: {message}{cause_suffix}') self.original_message: str = message if not details: diff --git a/py/packages/genkit/tests/genkit/core/error_test.py b/py/packages/genkit/tests/genkit/core/error_test.py index 5a9ed00dbd..c2dc92a8ab 100644 --- a/py/packages/genkit/tests/genkit/core/error_test.py +++ b/py/packages/genkit/tests/genkit/core/error_test.py @@ -43,6 +43,16 @@ def test_genkit_error() -> None: error_no_source = GenkitError(status='INTERNAL', message='Test message 2') assert str(error_no_source) == 'INTERNAL: Test message 2' + # When wrapping another exception the cause should appear in str(...) too, + # so the model and any plain ``f"{e}"`` log line see the real reason. + wrapped = GenkitError( + status='INTERNAL', + message='Error while running action read_file', + cause=ValueError("File not found: 'workspace/foo.py'"), + ) + assert str(wrapped) == ("INTERNAL: Error while running action read_file: File not found: 'workspace/foo.py'") + assert wrapped.original_message == 'Error while running action read_file' + def test_genkit_error_to_json() -> None: # NOT_FOUND is a valid gRPC-style status (maps to HTTP 404). diff --git a/py/plugins/google-genai/src/genkit/plugins/google_genai/models/utils.py b/py/plugins/google-genai/src/genkit/plugins/google_genai/models/utils.py index 41ff00d739..cd0b59963f 100644 --- a/py/plugins/google-genai/src/genkit/plugins/google_genai/models/utils.py +++ b/py/plugins/google-genai/src/genkit/plugins/google_genai/models/utils.py @@ -183,11 +183,17 @@ async def to_gemini(cls, part: Part | DocumentPart) -> genai.types.Part | list[g if extra_parts: tool_output = clean_output + # Gemini's FunctionResponse requires a dict-shaped ``response``, + # but a tool can legitimately hand back any JSON value (string, + # list, int, None, ...). Envelope it as ``{name, content}`` so + # the wire payload is always a dict; the inbound converter + # unwraps the same envelope so callers see the original value. + gemini_tool_name = tool_response.name.replace('/', '__') fn_part = genai.types.Part( function_response=genai.types.FunctionResponse( id=tool_response.ref, - name=tool_response.name.replace('/', '__'), - response=tool_output, + name=gemini_tool_name, + response={'name': gemini_tool_name, 'content': tool_output}, ) ) if extra_parts: @@ -315,13 +321,19 @@ def from_gemini(cls, part: genai.types.Part, ref: str | None = None) -> Part: ) ) if part.function_response: + # If the model echoes back the ``{name, content}`` envelope we + # used on the outbound side, peel it off so the caller sees the + # original tool output. + output = part.function_response.response + if isinstance(output, dict) and 'name' in output and 'content' in output: + output = output['content'] return Part( root=ToolResponsePart( tool_response=ToolResponse( ref=getattr(part.function_response, 'id', None), # restore slashes name=(part.function_response.name or '').replace('__', '/'), - output=part.function_response.response, + output=output, ) ) ) diff --git a/py/plugins/middleware/README.md b/py/plugins/middleware/README.md new file mode 100644 index 0000000000..88f2b5d7c5 --- /dev/null +++ b/py/plugins/middleware/README.md @@ -0,0 +1,197 @@ +# Genkit Middleware Plugin + +A collection of middleware implementations for Firebase Genkit Python. + +## Overview + +This plugin provides five concrete middleware implementations for common use cases: + +- **Retry**: Retries model API calls on transient errors with exponential backoff +- **Fallback**: Falls back to alternative models when the primary model fails +- **ToolApproval**: Requires explicit approval before executing tool calls +- **Skills**: Exposes a library of skills as system prompts and tools +- **Filesystem**: Provides sandboxed filesystem operations + +## Quick start + +Import the middleware classes you need and pass instances directly into `use=[]`: + +```python +from genkit import Genkit +from genkit.plugins.middleware import Retry, Fallback, Middleware + +ai = Genkit(plugins=[Middleware()]) + +response = await ai.generate( + model='googleai/gemini-flash-latest', + prompt='Hello!', + use=[ + Retry(max_retries=5), + Fallback(models=['googleai/gemini-2.5-pro']), + ], +) +``` + +These pre-packaged middlewares will be available to play with in the Dev UI by default. + +## Installation + +```bash +pip install genkit-plugin-middleware +``` + +## Usage + +### Retry + +Automatically retries model calls on transient failures with configurable exponential backoff: + +```python +from genkit.plugins.middleware import Retry + +retry = Retry( + max_retries=3, + statuses=['UNAVAILABLE', 'DEADLINE_EXCEEDED', 'RESOURCE_EXHAUSTED'], + initial_delay_ms=1000, + max_delay_ms=60000, + backoff_factor=2.0, + jitter=True, # set False for deterministic backoff (tests) +) + +response = await ai.generate( + model='googleai/gemini-2.5-flash', + prompt='Hello!', + use=[retry], +) +``` + +### Fallback + +Falls back to alternative models on retryable errors: + +```python +from genkit.plugins.middleware import Fallback + +fallback = Fallback( + models=['googleai/gemini-2.5-pro', 'googleai/gemini-2.5-flash'], + statuses=['UNAVAILABLE', 'DEADLINE_EXCEEDED'], +) + +response = await ai.generate( + model='googleai/gemini-2.5-ultra', + prompt='Hello!', + use=[fallback], +) +``` + +### ToolApproval + +Requires approval before executing tools (useful for sensitive operations): + +```python +from genkit.plugins.middleware import ToolApproval + +approval = ToolApproval( + allowed_tools=['get_weather', 'search'], # These tools run without approval +) + +response = await ai.generate( + model='googleai/gemini-2.5-flash', + prompt='Delete the database', + tools=[delete_database_tool], + use=[approval], +) +``` + +When a non-allowed tool is called, execution is interrupted. Approve and re-run the +tool by restarting it with ``resumed_metadata`` that includes ``toolApproved`` +(the middleware only treats explicit dict metadata as approval): + +```python +first = await ai.generate( + model='googleai/gemini-flash-latest', + prompt='Delete the database', + tools=[delete_database_tool], + use=[approval], +) + +response = await ai.generate( + model='googleai/gemini-flash-latest', + prompt='Delete the database', + messages=list(first.messages), + tools=[delete_database_tool], + use=[approval], + resume_restart=delete_database_tool.restart( + None, + interrupt=first.interrupts[0], + resumed_metadata={'toolApproved': True}, + ), +) +``` + +### Skills + +Scans directories for SKILL.md files and exposes them as loadable instructions: + +```python +from genkit.plugins.middleware import Skills + +skills = Skills( + skill_paths=['skills', 'prompts/skills'], +) + +response = await ai.generate( + model='googleai/gemini-flash-latest', + prompt='Help me with Python', + use=[skills], +) +``` + +Skills are discovered by scanning for directories containing `SKILL.md` files. Each `SKILL.md` can have optional YAML frontmatter: + +```markdown +--- +name: python-expert +description: Expert Python programming assistance +--- + +You are an expert Python programmer... +``` + +### Filesystem + +Provides sandboxed file operations confined to a root directory: + +```python +from genkit.plugins.middleware import Filesystem + +fs = Filesystem( + root_dir='./workspace', + allow_write_access=True, + tool_name_prefix='', +) + +response = await ai.generate( + model='googleai/gemini-flash-latest', + prompt='List files in the current directory', + use=[fs], +) +``` + +Provides four tools: +- `list_files`: List files in a directory +- `read_file`: Read file content +- `write_file`: Write to a file (requires `allow_write_access=True`) +- `edit_file`: Edit file with string replacements (requires `allow_write_access=True`) + +## Development + +```bash +cd py/plugins/middleware +pip install -e ".[dev]" +pytest tests/ +``` + +## License + +Apache 2.0 diff --git a/py/plugins/middleware/pyproject.toml b/py/plugins/middleware/pyproject.toml new file mode 100644 index 0000000000..c24144185f --- /dev/null +++ b/py/plugins/middleware/pyproject.toml @@ -0,0 +1,79 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +[project] +authors = [ + { name = "Google" }, +] +classifiers = [ + "Development Status :: 3 - Alpha", + "Environment :: Console", + "Environment :: Web Environment", + "Framework :: AsyncIO", + "Framework :: Pydantic", + "Framework :: Pydantic :: 2", + "Intended Audience :: Developers", + "Operating System :: OS Independent", + "Programming Language :: Python", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development :: Libraries", + "Typing :: Typed", + "License :: OSI Approved :: Apache Software License", +] +dependencies = [ + "genkit>=0.5.2", + "pyyaml>=6.0", +] +description = "A collection of middleware implementations for Genkit." +keywords = [ + "genkit", + "ai", + "llm", + "middleware", +] +license = "Apache-2.0" +name = "genkit-plugin-middleware" +readme = "README.md" +requires-python = ">=3.10" +version = "0.5.2" + +[project.optional-dependencies] +dev = [ + "pytest>=8.3.4", + "pytest-asyncio>=0.25.2", + "pytest-cov>=6.0.0", + "pytest-xdist>=3.6.1", +] + +[project.urls] +"Bug Tracker" = "https://github.com/genkit-ai/genkit/issues" +"Documentation" = "https://firebase.google.com/docs/genkit" +"Homepage" = "https://github.com/genkit-ai/genkit" +"Repository" = "https://github.com/genkit-ai/genkit/tree/main/py" + +[build-system] +build-backend = "hatchling.build" +requires = ["hatchling"] + +[tool.hatch.build.targets.wheel] +only-include = ["src/genkit/plugins/middleware"] +sources = ["src"] diff --git a/py/plugins/middleware/src/genkit/plugins/middleware/__init__.py b/py/plugins/middleware/src/genkit/plugins/middleware/__init__.py new file mode 100644 index 0000000000..2f1148a8c1 --- /dev/null +++ b/py/plugins/middleware/src/genkit/plugins/middleware/__init__.py @@ -0,0 +1,127 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Genkit middleware plugin. + +Provides concrete middleware implementations: + +* ``Retry`` — retries model calls on transient errors with exponential + backoff. +* ``Fallback`` — falls back to alternative models on failure. +* ``ToolApproval`` — requires approval before executing tools. +* ``Skills`` — exposes a ``SKILL.md`` library as system prompts plus a + ``use_skill`` tool. +* ``Filesystem`` — sandboxed filesystem operations (list / read / write / + edit). + +Import the classes you need and pass instances into ``use=[...]``: + + from genkit.plugins.middleware import Retry, Fallback + + response = await ai.generate( + prompt='Hello', + use=[ + Retry(max_retries=5), + Fallback(models=['googleai/gemini-2.5-pro']), + ], + ) + +Or register all five with the ``Middleware`` plugin so they appear in +the Dev UI. +""" + +from genkit.middleware import GenerateMiddleware +from genkit.plugin_api import Action, ActionKind, ActionMetadata, Plugin, new_middleware +from genkit.plugins.middleware._fallback import Fallback +from genkit.plugins.middleware._filesystem import Filesystem +from genkit.plugins.middleware._retry import Retry +from genkit.plugins.middleware._skills import Skills +from genkit.plugins.middleware._tool_approval import ToolApproval + + +class Middleware(Plugin): + """Plugin that registers Retry, Fallback, ToolApproval, Skills, and Filesystem. + + Registers all five middleware descriptors so they show up in the Dev + UI. + + ``Filesystem`` has no default root — supply ``root_dir`` when + constructing an instance, for example + ``Filesystem(root_dir='./workspace')``. + + Usage: + from genkit.plugins.middleware import Middleware, Retry, Skills + + ai = Genkit(plugins=[GoogleAI(), Middleware()]) + await ai.generate( + prompt='Hello', + use=[Retry(max_retries=5), Skills(skill_paths=['skills'])], + ) + """ + + name = 'genkit-middleware' + + async def init(self) -> list[Action]: + """No actions to register; this plugin only contributes middleware.""" + return [] + + async def resolve(self, action_type: ActionKind, name: str) -> Action | None: + """No dynamic actions to resolve.""" + return None + + async def list_actions(self) -> list[ActionMetadata]: + """No actions to list.""" + return [] + + def list_middleware(self) -> list[GenerateMiddleware]: + """Return descriptors for all middleware exposed by this plugin.""" + return [ + new_middleware( + Retry, + name='retry', + description='Retries model calls on transient failures with exponential backoff', + ), + new_middleware( + Fallback, + name='fallback', + description='Falls back to alternative models on failure', + ), + new_middleware( + ToolApproval, + name='tool_approval', + description='Requires approval before executing tools', + ), + new_middleware( + Skills, + name='skills', + description='Provides access to skill library for specialized instructions', + ), + new_middleware( + Filesystem, + name='filesystem', + description='Sandboxed filesystem operations', + ), + ] + + +__all__ = [ + 'Fallback', + 'Filesystem', + 'Middleware', + 'Retry', + 'Skills', + 'ToolApproval', +] diff --git a/py/plugins/middleware/src/genkit/plugins/middleware/_fallback.py b/py/plugins/middleware/src/genkit/plugins/middleware/_fallback.py new file mode 100644 index 0000000000..c0d03444e3 --- /dev/null +++ b/py/plugins/middleware/src/genkit/plugins/middleware/_fallback.py @@ -0,0 +1,97 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Fallback middleware for Genkit model calls.""" + +from __future__ import annotations + +from collections.abc import Awaitable, Callable +from typing import Any + +from pydantic import BaseModel, Field + +from genkit import GenkitError +from genkit._core._action import Action, ActionKind +from genkit._core._model import ModelResponse +from genkit.middleware import BaseMiddleware, GenerateMiddlewareContext, ModelHookParams + +_DEFAULT_FALLBACK_STATUSES: list[str] = [ + 'UNAVAILABLE', + 'DEADLINE_EXCEEDED', + 'RESOURCE_EXHAUSTED', + 'ABORTED', + 'INTERNAL', + 'NOT_FOUND', + 'UNIMPLEMENTED', +] + + +class FallbackConfig(BaseModel): + """Models and statuses that trigger fallback.""" + + models: list[str] = Field(default_factory=list) + statuses: list[str] = Field(default_factory=lambda: list(_DEFAULT_FALLBACK_STATUSES)) + + +class Fallback(BaseMiddleware[FallbackConfig]): + """Fallback middleware to try alternative models on failure.""" + + async def _resolve_fallback_model( + self, + ctx: GenerateMiddlewareContext, + model_name: str, + ) -> Action[Any, Any, Any]: + """Look up a fallback model on the per-call registry.""" + action = await ctx.registry.resolve_action(ActionKind.MODEL, model_name) + if action is None: + raise GenkitError( + status='NOT_FOUND', + message=f'No model named "{model_name}" is registered on this app.', + ) + return action + + async def wrap_model( + self, + params: ModelHookParams, + next_fn: Callable[[ModelHookParams], Awaitable[ModelResponse]], + ctx: GenerateMiddlewareContext, + ) -> ModelResponse: + """Try the primary model, then fall back to alternates on retryable errors.""" + last_error: Exception | None = None + try: + return await next_fn(params) + except Exception as exc: + if not isinstance(exc, GenkitError) or exc.status not in self.config.statuses: + raise + last_error = exc + + assert last_error is not None # noqa: S101 + on_chunk = ctx.on_chunk + for model_name in self.config.models: + fallback_action = await self._resolve_fallback_model(ctx, model_name) + try: + result = await fallback_action.run( + input=params.request, + context=ctx.custom_context, + on_chunk=on_chunk, + ) + return result.response # type: ignore[return-value] + except Exception as e2: + last_error = e2 + if not isinstance(e2, GenkitError) or e2.status not in self.config.statuses: + raise + + raise last_error diff --git a/py/plugins/middleware/src/genkit/plugins/middleware/_filesystem.py b/py/plugins/middleware/src/genkit/plugins/middleware/_filesystem.py new file mode 100644 index 0000000000..815825bc50 --- /dev/null +++ b/py/plugins/middleware/src/genkit/plugins/middleware/_filesystem.py @@ -0,0 +1,353 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Filesystem middleware for Genkit. + +Provides sandboxed file operations — ``list_files``, ``read_file``, +``write_file``, ``edit_file`` — confined to a configurable root directory. + +``read_file`` queues file content as user messages so the tool response stays +small. Tool errors are queued the same way so the model can self-correct on +the next turn. + +Each ``generate()`` gets a fresh middleware instance with its own message +queue; ``wrap_generate`` drains queued messages into the request before the +next model call. +""" + +from __future__ import annotations + +import asyncio +import base64 +import mimetypes +import os +from collections.abc import Awaitable, Callable +from pathlib import Path +from typing import Any + +from pydantic import BaseModel as PydanticBaseModel + +from genkit._ai._tools import Interrupt, define_tool +from genkit._core._model import Message, ModelResponse, ModelResponseChunk +from genkit._core._registry import Registry +from genkit._core._typing import ( + Media, + MediaPart, + Part, + Role, + TextPart, +) +from genkit.middleware import ( + BaseMiddleware, + GenerateHookParams, + GenerateMiddlewareContext, + MultipartToolResponse, + ToolHookParams, +) + +# --------------------------------------------------------------------------- +# Tool input schemas (module-level so Pydantic can resolve annotations) +# --------------------------------------------------------------------------- + + +class _ListFilesInput(PydanticBaseModel): + """Input for list_files tool.""" + + dir_path: str = '' + recursive: bool = False + + +class _ReadFileInput(PydanticBaseModel): + """Input for read_file tool.""" + + file_path: str + offset: int = 0 + limit: int = 0 + + +class _WriteFileInput(PydanticBaseModel): + """Input for write_file tool.""" + + file_path: str + content: str + + +class _EditSpec(PydanticBaseModel): + """A single string-replacement edit.""" + + old_string: str + new_string: str + replace_all: bool = False + + +class _EditFileInput(PydanticBaseModel): + """Input for edit_file tool.""" + + file_path: str + edits: list[_EditSpec] + + +_MAX_FILE_SIZE_BYTES = 10 * 1024 * 1024 # 10 MB — absolute ceiling for reading +_MAX_READ_SLICE_BYTES = 256 * 1024 # 256 KB — max bytes returned per slice + + +class FilesystemConfig(PydanticBaseModel): + """Sandbox root and write/tool naming options.""" + + root_dir: str + allow_write_access: bool = False + tool_name_prefix: str = '' + + +class Filesystem(BaseMiddleware[FilesystemConfig]): + """Filesystem middleware with sandboxed file operations. + + Contributes ``list_files``, ``read_file``, and optionally ``write_file`` + and ``edit_file``. Tool errors are queued as user messages so the model + can self-correct on the next turn. + """ + + def __init__(self, **kwargs: Any) -> None: # noqa: ANN401 + super().__init__(**kwargs) + if not self.config.root_dir or not self.config.root_dir.strip(): + raise ValueError('Filesystem.root_dir must not be empty.') + # One queue per generate() — the engine copies middleware per call. + self._message_queue: list[Message] = [] + + @property + def _root_abs(self) -> str: + return str(Path(self.config.root_dir).resolve()) + + def _tool_name(self, base: str) -> str: + return f'{self.config.tool_name_prefix}{base}' + + def _filesystem_tool_names(self) -> frozenset[str]: + names = {self._tool_name('list_files'), self._tool_name('read_file')} + if self.config.allow_write_access: + names |= {self._tool_name('write_file'), self._tool_name('edit_file')} + return frozenset(names) + + def _resolve_safe(self, rel: str) -> str: + """Resolve ``rel`` to an absolute path, raising ValueError if it escapes root.""" + rel = rel.strip().lstrip('/').lstrip('\\') + if not rel: + rel = '.' + candidate = os.path.realpath(os.path.join(self._root_abs, rel)) + c_norm = os.path.normcase(candidate) + root_norm = os.path.normcase(self._root_abs) + if c_norm != root_norm and not c_norm.startswith(root_norm + os.sep): + raise ValueError(f'Path {rel!r} escapes the root directory.') + return candidate + + def _enqueue_parts(self, parts: list[Part]) -> None: + """Append parts to the pending user message for the next model turn.""" + if self._message_queue and self._message_queue[-1].role == Role.USER: + self._message_queue[-1].content.extend(parts) + else: + self._message_queue.append(Message(role=Role.USER, content=list(parts))) + + def _list_files(self, dir_path: str = '', recursive: bool = False) -> list[dict[str, Any]]: + """List files and directories under ``dir_path`` (relative to root).""" + abs_dir = self._resolve_safe(dir_path) + if not os.path.isdir(abs_dir): + raise ValueError(f'Not a directory: {dir_path!r}') + + results: list[dict[str, Any]] = [] + if recursive: + for root, dirs, files in os.walk(abs_dir): + dirs[:] = sorted(d for d in dirs if not d.startswith('.')) + for name in sorted(files): + abs_path = os.path.join(root, name) + try: + stat = os.stat(abs_path) + rel = os.path.relpath(abs_path, abs_dir) + results.append({'path': rel, 'is_directory': False, 'size_bytes': stat.st_size}) + except OSError: + continue + for name in dirs: + rel = os.path.relpath(os.path.join(root, name), abs_dir) + results.append({'path': rel, 'is_directory': True, 'size_bytes': 0}) + else: + for name in sorted(os.listdir(abs_dir)): + abs_path = os.path.join(abs_dir, name) + try: + stat = os.stat(abs_path) + is_dir = os.path.isdir(abs_path) + results.append({'path': name, 'is_directory': is_dir, 'size_bytes': 0 if is_dir else stat.st_size}) + except OSError: + continue + + return results + + def _read_file_impl(self, file_path: str, offset: int, limit: int) -> str: + """Read a file and enqueue its content as a user message.""" + abs_path = self._resolve_safe(file_path) + if not os.path.isfile(abs_path): + raise ValueError(f'File not found: {file_path!r}') + + stat = os.stat(abs_path) + if stat.st_size > _MAX_FILE_SIZE_BYTES: + raise ValueError(f'File too large ({stat.st_size:,} bytes; max {_MAX_FILE_SIZE_BYTES:,}).') + + mime_type, _ = mimetypes.guess_type(abs_path) + is_image = bool(mime_type and mime_type.startswith('image/')) + + if is_image: + with open(abs_path, 'rb') as fh: + raw = fh.read() + if len(raw) > _MAX_READ_SLICE_BYTES: + raise ValueError(f'Image too large ({len(raw):,} bytes; max {_MAX_READ_SLICE_BYTES:,}).') + b64 = base64.b64encode(raw).decode('ascii') + data_uri = f'data:{mime_type};base64,{b64}' + self._enqueue_parts([Part(root=MediaPart(media=Media(url=data_uri, content_type=mime_type)))]) + return f'Image {file_path} queued as media part.' + + with open(abs_path, encoding='utf-8', errors='replace') as fh: + lines = fh.readlines() + + total = len(lines) + start = max(0, offset - 1) if offset > 0 else 0 + end = total if limit == 0 else min(total, start + limit) + sliced = ''.join(lines[start:end]) + + if len(sliced.encode()) > _MAX_READ_SLICE_BYTES: + raise ValueError(f'Slice too large ({len(sliced):,} chars). Use offset/limit to read smaller sections.') + + if offset > 0 or limit > 0: + wrapped = f'\n{sliced}\n' + else: + wrapped = f'\n{sliced}\n' + + self._enqueue_parts([Part(root=TextPart(text=wrapped))]) + return f'File {file_path} read successfully. Content queued as user message.' + + def _write_file_impl(self, file_path: str, content: str) -> str: + abs_path = self._resolve_safe(file_path) + os.makedirs(os.path.dirname(abs_path) or '.', exist_ok=True) + with open(abs_path, 'w', encoding='utf-8') as fh: + fh.write(content) + return f'File {file_path} written successfully.' + + def _edit_file_impl(self, file_path: str, edits: list[dict[str, Any]]) -> str: + abs_path = self._resolve_safe(file_path) + if not os.path.isfile(abs_path): + raise ValueError(f'File not found: {file_path!r}') + + with open(abs_path, encoding='utf-8', errors='replace') as fh: + content = fh.read() + + for spec in edits: + old = spec.get('old_string', '') + new = spec.get('new_string', '') + replace_all = spec.get('replace_all', False) + if not old: + raise ValueError('old_string must be non-empty.') + if old == new: + raise ValueError('old_string and new_string must differ.') + count = content.count(old) + if count == 0: + raise ValueError(f'old_string not found in file: {old!r}') + if not replace_all and count > 1: + raise ValueError(f'old_string matches {count} times but replace_all=False.') + content = content.replace(old, new) if replace_all else content.replace(old, new, 1) + + with open(abs_path, 'w', encoding='utf-8') as fh: + fh.write(content) + return f'File {file_path} edited successfully.' + + def tools(self, ctx: GenerateMiddlewareContext) -> list[Any]: + """Return filesystem tool actions for this generate() call.""" + scratch = Registry() + + async def list_files(input: _ListFilesInput) -> list[dict[str, Any]]: + return await asyncio.to_thread(self._list_files, input.dir_path, input.recursive) + + async def read_file(input: _ReadFileInput) -> str: + return await asyncio.to_thread( + self._read_file_impl, + input.file_path, + input.offset, + input.limit, + ) + + t_list = define_tool(scratch, list_files, name=self._tool_name('list_files')) + t_read = define_tool(scratch, read_file, name=self._tool_name('read_file')) + tools_out = [t_list.action(), t_read.action()] + + if self.config.allow_write_access: + + async def write_file(input: _WriteFileInput) -> str: + return await asyncio.to_thread(self._write_file_impl, input.file_path, input.content) + + async def edit_file(input: _EditFileInput) -> str: + return await asyncio.to_thread( + self._edit_file_impl, + input.file_path, + [e.model_dump() for e in input.edits], + ) + + t_write = define_tool(scratch, write_file, name=self._tool_name('write_file')) + t_edit = define_tool(scratch, edit_file, name=self._tool_name('edit_file')) + tools_out += [t_write.action(), t_edit.action()] + + return tools_out + + async def wrap_generate( + self, + params: GenerateHookParams, + next_fn: Callable[[GenerateHookParams], Awaitable[ModelResponse]], + ctx: GenerateMiddlewareContext, + ) -> ModelResponse: + """Drain queued user messages into the request before the next model turn.""" + if not self._message_queue: + return await next_fn(params) + + message_index = params.message_index + if ctx.on_chunk: + for msg in self._message_queue: + ctx.send_chunk(ModelResponseChunk(role=msg.role, content=msg.content, index=message_index)) + message_index += 1 + + new_request = params.request.model_copy() + new_request.messages = [*params.request.messages, *self._message_queue] + self._message_queue.clear() + + params = params.model_copy( + update={ + 'request': new_request, + 'message_index': message_index, + } + ) + return await next_fn(params) + + async def wrap_tool( + self, + params: ToolHookParams, + next_fn: Callable[[ToolHookParams], Awaitable[MultipartToolResponse]], + ctx: GenerateMiddlewareContext, + ) -> MultipartToolResponse: + """Catch filesystem tool errors and enqueue them as user messages.""" + if params.tool.name not in self._filesystem_tool_names(): + return await next_fn(params) + + try: + return await next_fn(params) + except Interrupt: + raise + except Exception as exc: + error_msg = f'Tool "{params.tool.name}" failed: {exc}' + self._enqueue_parts([Part(root=TextPart(text=error_msg))]) + return MultipartToolResponse(output='Tool call failed; see user message below for details.') diff --git a/py/plugins/middleware/src/genkit/plugins/middleware/_retry.py b/py/plugins/middleware/src/genkit/plugins/middleware/_retry.py new file mode 100644 index 0000000000..5ce122be70 --- /dev/null +++ b/py/plugins/middleware/src/genkit/plugins/middleware/_retry.py @@ -0,0 +1,82 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Retry middleware for Genkit model calls.""" + +from __future__ import annotations + +import asyncio +import math +import random +from collections.abc import Awaitable, Callable + +from pydantic import BaseModel, Field + +from genkit import GenkitError +from genkit._core._model import ModelResponse +from genkit.middleware import BaseMiddleware, GenerateMiddlewareContext, ModelHookParams + +_DEFAULT_RETRY_STATUSES: list[str] = [ + 'UNAVAILABLE', + 'DEADLINE_EXCEEDED', + 'RESOURCE_EXHAUSTED', + 'ABORTED', + 'INTERNAL', +] + + +class RetryConfig(BaseModel): + """Knobs for retry backoff and which error statuses are retried.""" + + max_retries: int = Field(default=3, ge=0) + statuses: list[str] = Field(default_factory=lambda: list(_DEFAULT_RETRY_STATUSES)) + initial_delay_ms: int = 1000 + max_delay_ms: int = 60000 + backoff_factor: float = 2.0 + jitter: bool = True + + +class Retry(BaseMiddleware[RetryConfig]): + """Retry middleware with exponential backoff for transient failures.""" + + async def wrap_model( + self, + params: ModelHookParams, + next_fn: Callable[[ModelHookParams], Awaitable[ModelResponse]], + ctx: GenerateMiddlewareContext, + ) -> ModelResponse: + """Retry the model call up to max_retries times on transient failures.""" + current_delay_ms = float(self.config.initial_delay_ms) + + for attempt in range(self.config.max_retries + 1): + try: + return await next_fn(params) + except Exception as e: + if attempt == self.config.max_retries: + raise + + if isinstance(e, GenkitError) and e.status not in self.config.statuses: + raise + + delay_ms = current_delay_ms + if self.config.jitter: + delay_ms += 1000.0 * math.pow(2, attempt) * random.random() + delay_ms = min(delay_ms, self.config.max_delay_ms) + + await asyncio.sleep(delay_ms / 1000.0) + current_delay_ms = min(current_delay_ms * self.config.backoff_factor, self.config.max_delay_ms) + + raise AssertionError('Retry loop exited without returning or raising') # noqa: EM101 diff --git a/py/plugins/middleware/src/genkit/plugins/middleware/_skills.py b/py/plugins/middleware/src/genkit/plugins/middleware/_skills.py new file mode 100644 index 0000000000..f1a05e5b56 --- /dev/null +++ b/py/plugins/middleware/src/genkit/plugins/middleware/_skills.py @@ -0,0 +1,179 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Skills middleware for Genkit.""" + +from __future__ import annotations + +import asyncio +from collections.abc import Awaitable, Callable +from pathlib import Path +from typing import Any + +import yaml +from pydantic import BaseModel as PydanticBaseModel, Field + +from genkit._ai._model import Message +from genkit._ai._tools import define_tool +from genkit._core._model import ModelRequest, ModelResponse +from genkit._core._registry import Registry +from genkit._core._typing import Part, Role, TextPart +from genkit.middleware import BaseMiddleware, GenerateHookParams, GenerateMiddlewareContext + +_SKILLS_MARKER = 'skills-instructions' +_MISSING_DESCRIPTION = 'No description provided.' + + +class _UseSkillInput(PydanticBaseModel): + """Input for the ``use_skill`` tool.""" + + skill_name: str = Field(description='The name of the skill to load (as listed in the system prompt).') + + +class SkillsConfig(PydanticBaseModel): + """Directories to scan for skill folders containing ``SKILL.md``.""" + + skill_paths: list[str] = Field(default_factory=lambda: ['skills']) + + +class Skills(BaseMiddleware[SkillsConfig]): + """Skills middleware that exposes ``SKILL.md`` files as loadable instructions.""" + + def _scan_skills(self) -> dict[str, dict[str, str]]: + skills: dict[str, dict[str, str]] = {} + for path_str in self.config.skill_paths: + path = Path(path_str).resolve() + if not path.is_dir(): + continue + for subdir in sorted(path.iterdir()): + if not subdir.is_dir() or subdir.name.startswith('.'): + continue + skill_file = subdir / 'SKILL.md' + if not skill_file.is_file(): + continue + name, description = self._parse_skill_file(skill_file) + if not name: + name = subdir.name + skills[name] = { + 'path': str(skill_file), + 'description': description or '', + } + return skills + + def _parse_skill_file(self, path: Path) -> tuple[str, str]: + try: + content = path.read_text(encoding='utf-8').lstrip('\ufeff') + except Exception: + return '', '' + if not content.startswith('---\n'): + return '', '' + end_idx = content.find('\n---', 4) + if end_idx == -1: + return '', '' + try: + data = yaml.safe_load(content[4:end_idx]) + if not isinstance(data, dict): + return '', '' + return data.get('name', ''), data.get('description', '') + except Exception: + return '', '' + + def _build_skills_prompt(self, skills: dict[str, dict[str, str]]) -> str: + if not skills: + return '' + lines = [ + '', + 'You have access to a library of skills that serve as specialized instructions/personas.', + 'Strongly prefer to use them when working on anything related to them.', + 'Only use them once to load the context.', + 'Here are the available skills:', + ] + for skill_name in sorted(skills.keys()): + desc = skills[skill_name]['description'] + if desc and desc != _MISSING_DESCRIPTION: + lines.append(f' - {skill_name} - {desc}') + else: + lines.append(f' - {skill_name}') + lines.append('') + return '\n'.join(lines) + + def _inject_skills_prompt(self, request: ModelRequest, prompt_text: str) -> ModelRequest: + messages = list(request.messages) + system_idx: int | None = None + for i, msg in enumerate(messages): + if msg.role == Role.SYSTEM: + system_idx = i + break + + marker_meta: dict[str, Any] = {_SKILLS_MARKER: True} + new_part = Part(root=TextPart(text=prompt_text, metadata=marker_meta)) + + if system_idx is not None: + msg = messages[system_idx] + new_content = [] + replaced = False + for part in msg.content: + meta = part.root.metadata if isinstance(part.root, TextPart) else None + if isinstance(meta, dict) and meta.get(_SKILLS_MARKER): + new_content.append(new_part) + replaced = True + else: + new_content.append(part) + if not replaced: + new_content.append(new_part) + messages[system_idx] = Message(role=Role.SYSTEM, content=new_content) + else: + messages.insert(0, Message(role=Role.SYSTEM, content=[new_part])) + + new_request = request.model_copy() + new_request.messages = messages + return new_request + + def tools(self, ctx: GenerateMiddlewareContext) -> list[Any]: + if not self._scan_skills(): + return [] + + scratch = Registry() + + async def use_skill(input: _UseSkillInput) -> str: + skill_name = input.skill_name + skills = await asyncio.to_thread(self._scan_skills) + info = skills.get(skill_name) + if info is None: + available = ', '.join(sorted(skills.keys())) + return f'Unknown skill "{skill_name}". Available skills: {available}' + try: + skill_path = Path(info['path']) + return await asyncio.to_thread(skill_path.read_text, encoding='utf-8') + except Exception as exc: + return f'Failed to read skill "{skill_name}": {exc}' + + t = define_tool(scratch, use_skill, name='use_skill') + return [t.action()] + + async def wrap_generate( + self, + params: GenerateHookParams, + next_fn: Callable[[GenerateHookParams], Awaitable[ModelResponse]], + ctx: GenerateMiddlewareContext, + ) -> ModelResponse: + skills = await asyncio.to_thread(self._scan_skills) + if skills: + prompt_text = self._build_skills_prompt(skills) + if prompt_text: + params = params.model_copy() + params.request = self._inject_skills_prompt(params.request, prompt_text) + return await next_fn(params) diff --git a/py/plugins/middleware/src/genkit/plugins/middleware/_tool_approval.py b/py/plugins/middleware/src/genkit/plugins/middleware/_tool_approval.py new file mode 100644 index 0000000000..51e0bdbcb9 --- /dev/null +++ b/py/plugins/middleware/src/genkit/plugins/middleware/_tool_approval.py @@ -0,0 +1,64 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Tool approval middleware for Genkit.""" + +from __future__ import annotations + +import json +from collections.abc import Awaitable, Callable + +from pydantic import BaseModel, Field + +from genkit._ai._tools import Interrupt +from genkit._core._tracing import SpanMetadata, run_in_new_span +from genkit.middleware import BaseMiddleware, GenerateMiddlewareContext, MultipartToolResponse, ToolHookParams + + +class ToolApprovalConfig(BaseModel): + """Tools that may run without an approval interrupt.""" + + allowed_tools: list[str] = Field(default_factory=list) + + +class ToolApproval(BaseMiddleware[ToolApprovalConfig]): + """Tool approval middleware that interrupts execution for non-allowed tools.""" + + async def wrap_tool( + self, + params: ToolHookParams, + next_fn: Callable[[ToolHookParams], Awaitable[MultipartToolResponse]], + ctx: GenerateMiddlewareContext, + ) -> MultipartToolResponse: + """Intercept tool execution and require approval if not in allowed list.""" + tool_name = params.tool.name + + if tool_name in self.config.allowed_tools: + return await next_fn(params) + + metadata = params.tool_request_part.metadata or {} + resumed = metadata.get('resumed') + if isinstance(resumed, dict) and resumed.get('toolApproved'): + return await next_fn(params) + + tool_input = params.tool_request_part.tool_request.input + with run_in_new_span( + SpanMetadata(name=tool_name, type='action', subtype='tool', input=tool_input), + ) as span: + if tool_input is not None: + inp_json = tool_input.model_dump_json() if isinstance(tool_input, BaseModel) else json.dumps(tool_input) + span.set_attribute('genkit:input', inp_json) + raise Interrupt({'message': f'Tool not in approved list: {tool_name}'}) diff --git a/py/plugins/middleware/tests/conftest.py b/py/plugins/middleware/tests/conftest.py new file mode 100644 index 0000000000..424158e387 --- /dev/null +++ b/py/plugins/middleware/tests/conftest.py @@ -0,0 +1,14 @@ +# Copyright 2025 Google LLC +# SPDX-License-Identifier: Apache-2.0 + +"""Pytest fixtures for middleware plugin unit tests.""" + +import pytest + +from genkit._core._registry import Registry +from genkit.middleware import GenerateMiddlewareContext + + +@pytest.fixture +def ctx() -> GenerateMiddlewareContext: + return GenerateMiddlewareContext(registry=Registry()) diff --git a/py/plugins/middleware/tests/fallback_test.py b/py/plugins/middleware/tests/fallback_test.py new file mode 100644 index 0000000000..7b6e98fc7c --- /dev/null +++ b/py/plugins/middleware/tests/fallback_test.py @@ -0,0 +1,82 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for Fallback middleware.""" + +from typing import NoReturn + +import pytest + +from genkit import ModelRequest, ModelResponse +from genkit._core._error import GenkitError +from genkit.middleware import ModelHookParams +from genkit.plugins.middleware import Fallback + + +def _make_params() -> ModelHookParams: + return ModelHookParams(request=ModelRequest(messages=[])) + + +def _make_fallback(**kwargs) -> Fallback: + return Fallback(**kwargs) + + +@pytest.mark.asyncio +async def test_fallback_success_on_first_model(ctx) -> None: + """Test that successful primary model calls pass through.""" + fallback = _make_fallback(models=['model2', 'model3']) + + async def next_fn(params): + return ModelResponse(message=None) + + result = await fallback.wrap_model(_make_params(), next_fn, ctx) + assert result is not None + + +@pytest.mark.asyncio +async def test_fallback_on_retryable_error(ctx) -> None: + """Test that retryable errors are classified correctly.""" + fallback = _make_fallback(models=['model2']) + + async def next_fn(params) -> NoReturn: + raise GenkitError(message='Service unavailable', status='UNAVAILABLE') + + with pytest.raises(GenkitError): + await fallback.wrap_model(_make_params(), next_fn, ctx) + + +@pytest.mark.asyncio +async def test_fallback_non_retryable_error(ctx) -> None: + """Test that non-retryable errors fail immediately.""" + fallback = _make_fallback(models=['model2']) + + async def next_fn(params) -> NoReturn: + raise GenkitError(message='Invalid argument', status='INVALID_ARGUMENT') + + with pytest.raises(GenkitError): + await fallback.wrap_model(_make_params(), next_fn, ctx) + + +@pytest.mark.asyncio +async def test_fallback_non_genkit_error(ctx) -> None: + """Test that non-GenkitError exceptions fail immediately.""" + fallback = _make_fallback(models=['model2']) + + async def next_fn(params) -> NoReturn: + raise ConnectionError('Network failure') + + with pytest.raises(ConnectionError): + await fallback.wrap_model(_make_params(), next_fn, ctx) diff --git a/py/plugins/middleware/tests/filesystem_test.py b/py/plugins/middleware/tests/filesystem_test.py new file mode 100644 index 0000000000..10f87f7d6f --- /dev/null +++ b/py/plugins/middleware/tests/filesystem_test.py @@ -0,0 +1,184 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for Filesystem middleware.""" + +import tempfile +from pathlib import Path + +import pytest + +from genkit.plugins.middleware import Filesystem + +# --------------------------------------------------------------------------- +# Construction / validation +# --------------------------------------------------------------------------- + + +def test_filesystem_validates_root_dir() -> None: + """Filesystem must reject an empty root_dir.""" + with pytest.raises(ValueError, match='root_dir'): + Filesystem(root_dir='') + + +def test_filesystem_resolves_root() -> None: + """root_dir is resolved to an absolute path.""" + with tempfile.TemporaryDirectory() as tmpdir: + fs = Filesystem(root_dir=tmpdir) + assert fs._root_abs == str(Path(tmpdir).resolve()) + + +# --------------------------------------------------------------------------- +# _resolve_safe +# --------------------------------------------------------------------------- + + +def test_resolve_safe_allows_root() -> None: + with tempfile.TemporaryDirectory() as tmpdir: + fs = Filesystem(root_dir=tmpdir) + assert fs._resolve_safe('') == fs._root_abs + + +def test_resolve_safe_allows_child() -> None: + with tempfile.TemporaryDirectory() as tmpdir: + fs = Filesystem(root_dir=tmpdir) + child = Path(tmpdir) / 'sub' / 'file.txt' + child.parent.mkdir(parents=True) + assert fs._resolve_safe('sub/file.txt').endswith('sub/file.txt') + + +def test_resolve_safe_blocks_escape() -> None: + with tempfile.TemporaryDirectory() as tmpdir: + fs = Filesystem(root_dir=tmpdir) + with pytest.raises(ValueError, match='escapes'): + fs._resolve_safe('../../../etc/passwd') + + +# --------------------------------------------------------------------------- +# _list_files +# --------------------------------------------------------------------------- + + +def test_list_files_returns_paths_relative_to_queried_dir() -> None: + """list_files paths should be relative to the requested sub-dir, not root.""" + with tempfile.TemporaryDirectory() as tmpdir: + sub = Path(tmpdir) / 'docs' + sub.mkdir() + (sub / 'api.md').write_text('hello') + fs = Filesystem(root_dir=tmpdir) + entries = fs._list_files('docs') + names = [e['path'] for e in entries] + assert 'api.md' in names + assert 'docs/api.md' not in names + + +def test_list_files_root() -> None: + with tempfile.TemporaryDirectory() as tmpdir: + (Path(tmpdir) / 'a.txt').write_text('a') + (Path(tmpdir) / 'b.txt').write_text('b') + fs = Filesystem(root_dir=tmpdir) + entries = fs._list_files() + names = {e['path'] for e in entries} + assert 'a.txt' in names + assert 'b.txt' in names + + +# --------------------------------------------------------------------------- +# _read_file_impl (text files) +# --------------------------------------------------------------------------- + + +def test_read_file_queues_content() -> None: + with tempfile.TemporaryDirectory() as tmpdir: + f = Path(tmpdir) / 'hello.txt' + f.write_text('hello world\n') + fs = Filesystem(root_dir=tmpdir) + result = fs._read_file_impl('hello.txt', 0, 0) + assert 'queued' in result.lower() or 'read' in result.lower() + assert len(fs._message_queue) == 1 + assert len(fs._message_queue[0].content) == 1 + + +def test_read_file_rereads_each_time() -> None: + """No dedup cache — each read queues content again.""" + with tempfile.TemporaryDirectory() as tmpdir: + f = Path(tmpdir) / 'hello.txt' + f.write_text('hello world\n') + fs = Filesystem(root_dir=tmpdir) + fs._read_file_impl('hello.txt', 0, 0) + fs._message_queue.clear() + result = fs._read_file_impl('hello.txt', 0, 0) + assert 'read' in result.lower() + assert len(fs._message_queue) == 1 + + +# --------------------------------------------------------------------------- +# _write_file_impl and _edit_file_impl +# --------------------------------------------------------------------------- + + +def test_write_file_overwrites_without_prior_read() -> None: + with tempfile.TemporaryDirectory() as tmpdir: + f = Path(tmpdir) / 'existing.txt' + f.write_text('original\n') + fs = Filesystem(root_dir=tmpdir, allow_write_access=True) + result = fs._write_file_impl('existing.txt', 'new content\n') + assert 'written' in result.lower() + assert f.read_text() == 'new content\n' + + +def test_write_new_file_succeeds() -> None: + with tempfile.TemporaryDirectory() as tmpdir: + fs = Filesystem(root_dir=tmpdir, allow_write_access=True) + result = fs._write_file_impl('new.txt', 'content\n') + assert 'written' in result.lower() + assert (Path(tmpdir) / 'new.txt').read_text() == 'content\n' + + +def test_edit_file_reads_from_disk() -> None: + with tempfile.TemporaryDirectory() as tmpdir: + f = Path(tmpdir) / 'edit_me.txt' + f.write_text('hello world\n') + fs = Filesystem(root_dir=tmpdir, allow_write_access=True) + result = fs._edit_file_impl('edit_me.txt', [{'old_string': 'hello', 'new_string': 'hi'}]) + assert 'edited' in result.lower() + assert f.read_text() == 'hi world\n' + + +# --------------------------------------------------------------------------- +# tools() — dynamic tool registration +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_tools_returns_read_and_list(ctx) -> None: + with tempfile.TemporaryDirectory() as tmpdir: + fs = Filesystem(root_dir=tmpdir) + tool_actions = fs.tools(ctx) + names = {t.name for t in tool_actions} + assert 'list_files' in names + assert 'read_file' in names + assert 'write_file' not in names + + +@pytest.mark.asyncio +async def test_tools_returns_write_when_allowed(ctx) -> None: + with tempfile.TemporaryDirectory() as tmpdir: + fs = Filesystem(root_dir=tmpdir, allow_write_access=True) + tool_actions = fs.tools(ctx) + names = {t.name for t in tool_actions} + assert 'write_file' in names + assert 'edit_file' in names diff --git a/py/plugins/middleware/tests/retry_test.py b/py/plugins/middleware/tests/retry_test.py new file mode 100644 index 0000000000..8bde3ba198 --- /dev/null +++ b/py/plugins/middleware/tests/retry_test.py @@ -0,0 +1,121 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for Retry middleware.""" + +from typing import NoReturn + +import pytest +from pydantic import ValidationError + +from genkit import ModelRequest, ModelResponse +from genkit._core._error import GenkitError +from genkit.middleware import GenerateMiddlewareContext, ModelHookParams +from genkit.plugins.middleware import Retry + + +def _make_params() -> ModelHookParams: + return ModelHookParams(request=ModelRequest(messages=[])) + + +@pytest.mark.asyncio +async def test_retry_success_on_first_attempt(ctx: GenerateMiddlewareContext) -> None: + """Test that successful calls pass through without retry.""" + retry = Retry(max_retries=3) + + async def next_fn(params): + return ModelResponse(message=None) + + result = await retry.wrap_model(_make_params(), next_fn, ctx) + assert result is not None + + +@pytest.mark.asyncio +async def test_retry_on_retryable_error(ctx: GenerateMiddlewareContext) -> None: + """Test that retryable errors trigger retry.""" + retry = Retry(max_retries=2, initial_delay_ms=10, jitter=False) + + call_count = 0 + + async def next_fn(params): + nonlocal call_count + call_count += 1 + if call_count < 2: + raise GenkitError(message='Service unavailable', status='UNAVAILABLE') + return ModelResponse(message=None) + + result = await retry.wrap_model(_make_params(), next_fn, ctx) + assert result is not None + assert call_count == 2 + + +@pytest.mark.asyncio +async def test_retry_exhausted(ctx: GenerateMiddlewareContext) -> None: + """Test that errors are raised after max retries.""" + retry = Retry(max_retries=1, initial_delay_ms=10, jitter=False) + + async def next_fn(params) -> NoReturn: + raise GenkitError(message='Service unavailable', status='UNAVAILABLE') + + with pytest.raises(GenkitError): + await retry.wrap_model(_make_params(), next_fn, ctx) + + +@pytest.mark.asyncio +async def test_retry_non_retryable_error(ctx: GenerateMiddlewareContext) -> None: + """Test that non-retryable errors fail immediately.""" + retry = Retry(max_retries=3) + + call_count = 0 + + async def next_fn(params) -> NoReturn: + nonlocal call_count + call_count += 1 + raise GenkitError(message='Invalid argument', status='INVALID_ARGUMENT') + + with pytest.raises(GenkitError): + await retry.wrap_model(_make_params(), next_fn, ctx) + assert call_count == 1 + + +def test_retry_rejects_negative_max_retries() -> None: + """``max_retries`` must be non-negative; the wrap_model fall-through is unreachable. + + Regression: without the ``Field(ge=0)`` constraint, ``max_retries=-1`` would + skip the for-loop entirely and trip the defensive ``AssertionError`` at the + end of ``wrap_model``. + """ + with pytest.raises(ValidationError): + Retry(max_retries=-1) + + +@pytest.mark.asyncio +async def test_retry_non_genkit_error(ctx: GenerateMiddlewareContext) -> None: + """Test that non-GenkitError exceptions are retried.""" + retry = Retry(max_retries=2, initial_delay_ms=10, jitter=False) + + call_count = 0 + + async def next_fn(params): + nonlocal call_count + call_count += 1 + if call_count < 2: + raise ConnectionError('Network failure') + return ModelResponse(message=None) + + result = await retry.wrap_model(_make_params(), next_fn, ctx) + assert result is not None + assert call_count == 2 diff --git a/py/plugins/middleware/tests/skills_test.py b/py/plugins/middleware/tests/skills_test.py new file mode 100644 index 0000000000..2cf68e9d2a --- /dev/null +++ b/py/plugins/middleware/tests/skills_test.py @@ -0,0 +1,142 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for Skills middleware.""" + +import tempfile +from pathlib import Path + +import pytest + +from genkit import ModelRequest, ModelResponse +from genkit._core._model import GenerateActionOptions +from genkit.middleware import GenerateHookParams, GenerateMiddlewareContext +from genkit.plugins.middleware import Skills + + +def _make_params() -> GenerateHookParams: + return GenerateHookParams( + options=GenerateActionOptions(messages=[]), + request=ModelRequest(messages=[]), + iteration=0, + ) + + +@pytest.mark.asyncio +async def test_skills_no_paths(ctx: GenerateMiddlewareContext) -> None: + """Test that middleware works with no skill paths.""" + skills = Skills(skill_paths=[]) + + async def next_fn(params): + return ModelResponse(message=None) + + result = await skills.wrap_generate(_make_params(), next_fn, ctx) + assert result is not None + + +@pytest.mark.asyncio +async def test_skills_nonexistent_path(ctx: GenerateMiddlewareContext) -> None: + """Test that nonexistent paths are silently skipped.""" + skills = Skills(skill_paths=['/nonexistent/path']) + + async def next_fn(params): + return ModelResponse(message=None) + + result = await skills.wrap_generate(_make_params(), next_fn, ctx) + assert result is not None + + +@pytest.mark.asyncio +async def test_skills_scan_with_skill(ctx: GenerateMiddlewareContext) -> None: + """Test that skills are scanned and injected into system message.""" + with tempfile.TemporaryDirectory() as tmpdir: + skill_dir = Path(tmpdir) / 'test-skill' + skill_dir.mkdir() + skill_file = skill_dir / 'SKILL.md' + skill_file.write_text("""--- +name: test-skill +description: A test skill +--- +You are a test assistant. +""") + + skills = Skills(skill_paths=[tmpdir]) + + async def next_fn(params): + # Check that skills prompt was injected + assert len(params.request.messages) > 0 + return ModelResponse(message=None) + + result = await skills.wrap_generate(_make_params(), next_fn, ctx) + assert result is not None + + +@pytest.mark.asyncio +async def test_skills_parse_frontmatter() -> None: + """Test that YAML frontmatter is parsed correctly.""" + with tempfile.TemporaryDirectory() as tmpdir: + skill_dir = Path(tmpdir) / 'python-expert' + skill_dir.mkdir() + skill_file = skill_dir / 'SKILL.md' + skill_file.write_text("""--- +name: python-expert +description: Expert Python programming assistance +--- +You are an expert Python programmer. +""") + + skills = Skills(skill_paths=[tmpdir]) + info = skills._scan_skills() + + assert 'python-expert' in info + assert info['python-expert']['description'] == 'Expert Python programming assistance' + + +def test_skills_parse_no_frontmatter() -> None: + """Test that files without frontmatter use directory name; description is empty.""" + with tempfile.TemporaryDirectory() as tmpdir: + skill_dir = Path(tmpdir) / 'test-skill' + skill_dir.mkdir() + skill_file = skill_dir / 'SKILL.md' + skill_file.write_text('You are a test assistant.') + + skills = Skills(skill_paths=[tmpdir]) + info = skills._scan_skills() + + assert 'test-skill' in info + # No frontmatter → empty description (displayed without placeholder in the prompt) + assert info['test-skill']['description'] == '' + + +def test_skills_placeholder_description_not_shown_in_prompt() -> None: + """Frontmatter that uses the placeholder sentence lists the skill name only.""" + with tempfile.TemporaryDirectory() as tmpdir: + skill_dir = Path(tmpdir) / 'bare-skill' + skill_dir.mkdir() + skill_file = skill_dir / 'SKILL.md' + skill_file.write_text("""--- +name: bare-skill +description: No description provided. +--- +Skill body. +""") + + skills = Skills(skill_paths=[tmpdir]) + scanned = skills._scan_skills() + prompt = skills._build_skills_prompt(scanned) + + assert ' - bare-skill\n' in prompt + assert 'No description provided' not in prompt diff --git a/py/plugins/middleware/tests/tool_approval_test.py b/py/plugins/middleware/tests/tool_approval_test.py new file mode 100644 index 0000000000..76fca5399f --- /dev/null +++ b/py/plugins/middleware/tests/tool_approval_test.py @@ -0,0 +1,107 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for ToolApproval middleware.""" + +import pytest + +from genkit._ai._tools import Interrupt, define_tool +from genkit._core._registry import Registry +from genkit._core._typing import ToolRequest, ToolRequestPart +from genkit.middleware import GenerateMiddlewareContext, MultipartToolResponse, ToolHookParams +from genkit.plugins.middleware import ToolApproval + + +def _make_tool(name: str): + """Create a minimal Action with the given name via define_tool.""" + scratch = Registry() + + async def fn() -> str: + return '' + + return define_tool(scratch, fn, name=name).action() + + +@pytest.mark.asyncio +async def test_tool_approval_allowed_tool(ctx: GenerateMiddlewareContext) -> None: + """Test that allowed tools pass through without approval.""" + approval = ToolApproval(allowed_tools=['get_weather']) + + async def next_fn(params): + return MultipartToolResponse(output='sunny') + + tool = _make_tool('get_weather') + tool_request = ToolRequest(name='get_weather', input={}) + tool_request_part = ToolRequestPart(tool_request=tool_request) + params = ToolHookParams(tool_request_part=tool_request_part, tool=tool) + + result = await approval.wrap_tool(params, next_fn, ctx) + assert result is not None + + +@pytest.mark.asyncio +async def test_tool_approval_non_allowed_tool(ctx: GenerateMiddlewareContext) -> None: + """Test that non-allowed tools raise Interrupt.""" + approval = ToolApproval(allowed_tools=['get_weather']) + + async def next_fn(params): + return MultipartToolResponse(output=None) + + tool = _make_tool('delete_database') + tool_request = ToolRequest(name='delete_database', input={}) + tool_request_part = ToolRequestPart(tool_request=tool_request) + params = ToolHookParams(tool_request_part=tool_request_part, tool=tool) + + with pytest.raises(Interrupt) as exc_info: + await approval.wrap_tool(params, next_fn, ctx) + assert 'delete_database' in exc_info.value.metadata['message'] + + +@pytest.mark.asyncio +async def test_tool_approval_resumed_with_approval(ctx: GenerateMiddlewareContext) -> None: + """Test that resumed tools with approval metadata pass through.""" + approval = ToolApproval(allowed_tools=[]) + + async def next_fn(params): + return MultipartToolResponse(output='approved') + + tool = _make_tool('some_tool') + tool_request = ToolRequest(name='some_tool', input={}) + tool_request_part = ToolRequestPart( + tool_request=tool_request, + metadata={'resumed': {'toolApproved': True}}, + ) + params = ToolHookParams(tool_request_part=tool_request_part, tool=tool) + + result = await approval.wrap_tool(params, next_fn, ctx) + assert result is not None + + +@pytest.mark.asyncio +async def test_tool_approval_empty_allowed_list(ctx: GenerateMiddlewareContext) -> None: + """Test that empty allowed list requires approval for all tools.""" + approval = ToolApproval(allowed_tools=[]) + + async def next_fn(params): + return MultipartToolResponse(output=None) + + tool = _make_tool('any_tool') + tool_request = ToolRequest(name='any_tool', input={}) + tool_request_part = ToolRequestPart(tool_request=tool_request) + params = ToolHookParams(tool_request_part=tool_request_part, tool=tool) + + with pytest.raises(Interrupt): + await approval.wrap_tool(params, next_fn, ctx) diff --git a/py/pyproject.toml b/py/pyproject.toml index abd97f3a31..9a00e18128 100644 --- a/py/pyproject.toml +++ b/py/pyproject.toml @@ -27,6 +27,7 @@ dependencies = [ "genkit-plugin-google-genai", "genkit-plugin-ollama", "genkit-plugin-evaluators", + "genkit-plugin-middleware", "genkit-plugin-vertex-ai", # Internal tools (private, not published) "liccheck>=0.9.2", @@ -154,6 +155,8 @@ flask-hello = { workspace = true } gemini-code-execution = { workspace = true } gemini-context-caching = { workspace = true } google-genai-media = { workspace = true } +middleware = { workspace = true } +middleware-coding-agent = { workspace = true } output-formats = { workspace = true } prompts = { workspace = true } tool-interrupts = { workspace = true } @@ -169,6 +172,7 @@ genkit-plugin-fastapi = { workspace = true } genkit-plugin-flask = { workspace = true } genkit-plugin-google-cloud = { workspace = true } genkit-plugin-google-genai = { workspace = true } +genkit-plugin-middleware = { workspace = true } genkit-plugin-ollama = { workspace = true } genkit-plugin-evaluators = { workspace = true } genkit-plugin-vertex-ai = { workspace = true } @@ -263,6 +267,7 @@ select = [ "packages/genkit/src/genkit/_core/trace/__init__.py" = ["F401"] # Test files don't need docstrings; test method names are self-documenting "**/tests/**/*.py" = ["D", "ANN401"] +"plugins/middleware/tests/*_test.py" = ["D", "ANN"] "packages/genkit/tests/typing/*.py" = ["D", "F821", "F841", "B018", "ANN"] # Validation/provenance/safe-defaults test methods are self-documenting via names; # S108 false positives on Path('/tmp/x') used as dummy non-existent paths in tests. diff --git a/py/samples/middleware-coding-agent/.gitignore b/py/samples/middleware-coding-agent/.gitignore new file mode 100644 index 0000000000..b41636103a --- /dev/null +++ b/py/samples/middleware-coding-agent/.gitignore @@ -0,0 +1,4 @@ +__pycache__/ + +# The agent edits files here in place; outputs are reproducible by re-running. +workspace/ diff --git a/py/samples/middleware-coding-agent/README.md b/py/samples/middleware-coding-agent/README.md new file mode 100644 index 0000000000..fada7edbd2 --- /dev/null +++ b/py/samples/middleware-coding-agent/README.md @@ -0,0 +1,58 @@ +# middleware-coding-agent + +Interactive coding-agent REPL that wires up the +[`Filesystem`](../../plugins/middleware/src/genkit/plugins/middleware/_filesystem.py), +[`Skills`](../../plugins/middleware/src/genkit/plugins/middleware/_skills.py), +and [`ToolApproval`](../../plugins/middleware/src/genkit/plugins/middleware/_tool_approval.py) +middleware against a sandboxed workspace. + +Python port of the JavaScript example at +[`js/plugins/middleware/examples/coding_agent.ts`](../../../js/plugins/middleware/examples/coding_agent.ts). + +## What's here + +``` +middleware-coding-agent/ +├── src/main.py # interactive REPL +├── skills/ +│ ├── python-expert/SKILL.md # house style for editing Python +│ └── test-writer/SKILL.md # house style for writing pytest tests +└── workspace/ # sandbox the agent reads, writes, edits in + # (created on first run; contents gitignored) +``` + +The model gets: + +- the contents of `workspace/` via `Filesystem(root_dir=…, allow_write_access=True)` — + `list_files`, `read_file`, `write_file`, `edit_file`, all confined to that + directory. +- a system prompt listing the two skills, plus a `use_skill` tool it calls + to pull in the full `SKILL.md` content on demand. +- `ToolApproval(allowed_tools=['read_file', 'list_files', 'use_skill'])` — + read-only tools run without prompting; anything that can mutate the + workspace (`write_file`, `edit_file`) interrupts and waits for your + `y/N` from the CLI before resuming. + +## Run it + +```bash +cd py/samples/middleware-coding-agent +GEMINI_API_KEY=... genkit start -- uv run src/main.py +``` + +Type a request at the REPL prompt in your terminal (e.g. `build a tiny +priority queue module with push/pop/peek and pytest tests`), hit enter, +and approve each write the agent proposes. Conversation history persists +across turns until you type `exit`. + +If you want the agent to fix or extend an existing file instead of +starting from scratch, drop the file into `workspace/` first and reference +it by name in your prompt. + +## Resetting between runs + +The agent edits `workspace/` in place. To start over: + +```bash +rm -rf py/samples/middleware-coding-agent/workspace/* +``` diff --git a/py/samples/middleware-coding-agent/pyproject.toml b/py/samples/middleware-coding-agent/pyproject.toml new file mode 100644 index 0000000000..8b3a52860a --- /dev/null +++ b/py/samples/middleware-coding-agent/pyproject.toml @@ -0,0 +1,19 @@ +[project] +name = "middleware-coding-agent" +version = "0.1.0" +requires-python = ">=3.10" +dependencies = [ + "genkit", + "genkit-plugin-google-genai", + "genkit-plugin-middleware", + "pydantic>=2.10.5", + "structlog>=25.2.0", + "uvloop>=0.21.0", +] + +[build-system] +build-backend = "hatchling.build" +requires = ["hatchling"] + +[tool.hatch.build.targets.wheel] +packages = ["src"] diff --git a/py/samples/middleware-coding-agent/skills/python-expert/SKILL.md b/py/samples/middleware-coding-agent/skills/python-expert/SKILL.md new file mode 100644 index 0000000000..220f2f1734 --- /dev/null +++ b/py/samples/middleware-coding-agent/skills/python-expert/SKILL.md @@ -0,0 +1,16 @@ +--- +name: python-expert +description: Conventions for clean, idiomatic Python. Load whenever you read, edit, or write Python source files. +--- + +# Python expert + +When working with Python in this workspace, follow these conventions: + +- **Type-hint everything.** Parameters, returns, attributes, locals where the type isn't obvious. +- **Prefer dataclasses** for simple data containers over hand-written `__init__`s. +- **Raise specific exceptions** (`ValueError`, `KeyError`, `LookupError`) with informative messages. Avoid bare `Exception`. +- **Don't swallow errors.** Don't `except Exception: pass`. Let unexpected errors propagate. +- **Match the surrounding style.** If the file uses single quotes and 4-space indent, match it. Don't reformat unrelated lines. +- **Comments explain why, not what.** Skip narration like `# loop over items`; only comment non-obvious intent. +- **Small, focused edits.** When fixing a bug, change only what's necessary. Leave the rest of the file untouched so the diff stays readable. diff --git a/py/samples/middleware-coding-agent/skills/test-writer/SKILL.md b/py/samples/middleware-coding-agent/skills/test-writer/SKILL.md new file mode 100644 index 0000000000..17d387842e --- /dev/null +++ b/py/samples/middleware-coding-agent/skills/test-writer/SKILL.md @@ -0,0 +1,16 @@ +--- +name: test-writer +description: How to write pytest tests for modules in this workspace. Load whenever you are about to write or extend tests. +--- + +# Test writer + +When writing pytest tests in this workspace: + +- **One test file per module.** `foo.py` lives next to `foo_test.py` (suffix, not prefix). +- **Cover the happy path AND at least one edge case.** Empty input, duplicates, boundary values — pick what matters for the unit under test. +- **Use `pytest.mark.parametrize`** when the same assertion runs over a small table of inputs. Keep IDs descriptive. +- **Name tests `test___`.** Examples: `test_total_empty_cart_returns_zero`, `test_add_duplicate_item_merges_quantities`. +- **Arrange / Act / Assert.** Three clear blocks. No setup hidden in fixtures unless it's reused across at least two tests. +- **Assert behavior, not implementation.** Don't reach into private attributes or count function calls; check the observable result. +- **Imports at module top.** Don't import inside test functions. diff --git a/py/samples/middleware-coding-agent/src/main.py b/py/samples/middleware-coding-agent/src/main.py new file mode 100644 index 0000000000..7cd84e18d7 --- /dev/null +++ b/py/samples/middleware-coding-agent/src/main.py @@ -0,0 +1,148 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Agentic coding REPL — Filesystem + Skills + ToolApproval middleware. + +An interactive coding agent that reads, edits, and writes files inside a +sandboxed ``workspace/`` directory. Read-only tools (``read_file``, +``list_files``, ``use_skill``) run automatically; everything that can +mutate the workspace (``write_file``, ``edit_file``) is gated by +``ToolApproval``, so the CLI pauses and asks ``y/N`` before each write. + +The agent state — middleware instances and message history — is owned by a +``CodingAgent`` session object built once per ``main()`` invocation. +That ties every ``ai.generate()`` and resume in this REPL to the same +middleware stack. ``Filesystem`` itself keeps no cross-call cache; file +content reaches the model through enqueued messages inside each call. + +Re-running cleanly: + +* The agent mutates files in ``workspace/`` directly. To start over, + ``rm -rf workspace/*`` — the directory itself is recreated on next run. +""" + +from pathlib import Path + +from genkit import Genkit, Message, ModelResponse, Part, Role, TextPart, Tool, ToolRequestPart, restart_tool +from genkit._ai._generate import resolve_tool +from genkit.plugins.google_genai import GoogleAI +from genkit.plugins.middleware import Filesystem, Middleware, Skills, ToolApproval + +_HERE = Path(__file__).resolve().parent.parent +_WORKSPACE = _HERE / 'workspace' +_SKILLS = _HERE / 'skills' + +ai = Genkit( + plugins=[GoogleAI(), Middleware()], + model='googleai/gemini-flash-latest', +) + + +SYSTEM_PROMPT = ( + 'You are a helpful coding agent. Very terse but thoughtful and careful.\n' + f'Your working directory is {_WORKSPACE}, you are not allowed to access anything outside it.\n' + 'Use plain filenames relative to the workspace root (e.g. ``foo.py``, not ``./foo.py`` ' + 'or absolute paths). You must ``read_file`` an existing file before you can ``write_file`` ' + 'or ``edit_file`` it — new files do not need a prior read.\n' + 'Use skills. ALWAYS start by analyzing the current state of the workspace, ' + 'there might be something already there.' +) + + +class CodingAgent: + """One agent session: owns the middleware stack and the running conversation.""" + + def __init__(self) -> None: + self.middleware = [ + ToolApproval(allowed_tools=['read_file', 'list_files', 'use_skill']), + Skills(skill_paths=[str(_SKILLS)]), + Filesystem(root_dir=str(_WORKSPACE), allow_write_access=True), + ] + self.messages: list[Message] = [ + Message(role=Role.SYSTEM, content=[Part(root=TextPart(text=SYSTEM_PROMPT))]), + ] + + async def turn(self, user_input: str) -> ModelResponse: + """Drive one user turn to completion across any number of approval prompts.""" + restart: list[ToolRequestPart] | None = None + while True: + response = await ai.generate( + prompt=user_input if restart is None else None, + messages=self.messages, + resume_restart=restart, + max_turns=20, + use=self.middleware, + ) + if not response.interrupts: + self.messages = response.messages + return response + + approved = await _ask_for_approvals(response.interrupts) + if not approved: + print('Tool denied.') # noqa: T201 + self.messages = response.messages + return response + + print('Resuming...') # noqa: T201 + restart = approved + self.messages = response.messages + + +async def _ask_for_approvals(interrupts: list[ToolRequestPart]) -> list[ToolRequestPart]: + """Prompt the user y/N for each pending interrupt; return the approved restart parts.""" + approved: list[ToolRequestPart] = [] + for trp in interrupts: + print('\n*** Tool Approval Required ***') # noqa: T201 + print(f'Tool: {trp.tool_request.name}') # noqa: T201 + print(f'Input: {trp.tool_request.input}') # noqa: T201 + if input('Approve? (y/N): ').strip().lower() in ('y', 'yes'): + tool = Tool(await resolve_tool(ai.registry, trp.tool_request.name)) + approved.append( + restart_tool(tool=tool, interrupt=trp, resumed_metadata={'toolApproved': True}), + ) + return approved + + +async def main() -> None: + """Interactive REPL — one ``CodingAgent`` per process, one ``turn()`` per user line.""" + _WORKSPACE.mkdir(parents=True, exist_ok=True) + + print('--- Coding Agent ---') # noqa: T201 + print('Type your request. To exit, type "exit".') # noqa: T201 + + agent = CodingAgent() + + while True: + try: + user_input = input('\n> ').strip() + except EOFError: + break + if user_input.lower() == 'exit': + break + if not user_input: + continue + + try: + response = await agent.turn(user_input) + except Exception as e: # noqa: BLE001 - top-level REPL: surface, don't crash + print(f'Error during generation: {e}') # noqa: T201 + continue + + print(f'\nAI Response:\n{response.text}') # noqa: T201 + + +if __name__ == '__main__': + ai.run_main(main()) diff --git a/py/samples/middleware/README.md b/py/samples/middleware/README.md new file mode 100644 index 0000000000..083fd9a815 --- /dev/null +++ b/py/samples/middleware/README.md @@ -0,0 +1,17 @@ +# Middleware + +Intercept or modify model requests with `use=` on `ai.generate()`. + +```bash +export GEMINI_API_KEY=your-api-key +uv sync +uv run src/main.py +``` + +To inspect the flows in Dev UI instead: + +```bash +genkit start -- uv run src/main.py +``` + +Try `logging_demo` and `request_modifier_demo`. diff --git a/py/samples/middleware/pyproject.toml b/py/samples/middleware/pyproject.toml new file mode 100644 index 0000000000..e7aca883f6 --- /dev/null +++ b/py/samples/middleware/pyproject.toml @@ -0,0 +1,19 @@ +[project] +name = "middleware" +version = "0.2.0" +requires-python = ">=3.10" +dependencies = [ + "genkit", + "genkit-plugin-google-genai", + "genkit-plugin-middleware", + "pydantic>=2.0.0", + "structlog>=24.0.0", + "uvloop>=0.21.0", +] + +[build-system] +build-backend = "hatchling.build" +requires = ["hatchling"] + +[tool.hatch.build.targets.wheel] +packages = ["src"] diff --git a/py/samples/middleware/src/main.py b/py/samples/middleware/src/main.py new file mode 100644 index 0000000000..0af40996ea --- /dev/null +++ b/py/samples/middleware/src/main.py @@ -0,0 +1,105 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Middleware - inspect or modify requests before they reach the model.""" + +import structlog +from pydantic import BaseModel, Field + +from genkit import Genkit, Message, Part, Role, TextPart +from genkit.middleware import BaseMiddleware, GenerateMiddlewareContext +from genkit.plugins.google_genai import GoogleAI +from genkit.plugins.middleware import Middleware + +logger = structlog.get_logger(__name__) + + +class PromptInput(BaseModel): + """Input shared by middleware flows.""" + + prompt: str = Field( + default='Explain recursion simply.', + description='Prompt to send to the model', + ) + + +ai = Genkit( + plugins=[GoogleAI(), Middleware()], + model='googleai/gemini-2.5-flash', +) + + +class LoggingMiddleware(BaseMiddleware): + """Log request/response details without changing behavior.""" + + async def wrap_model(self, params, next_fn, ctx: GenerateMiddlewareContext): + await logger.ainfo('middleware saw request', message_count=len(params.request.messages)) + response = await next_fn(params) + await logger.ainfo('middleware saw response', finish_reason=response.finish_reason) + return response + + +class ConciseReplyConfig(BaseModel): + """Per-call system instruction for ConciseReplyMiddleware.""" + + instruction: str = 'Answer in one short paragraph.' + + +@ai.middleware(name='concise_reply_mw') +class ConciseReplyMiddleware(BaseMiddleware[ConciseReplyConfig]): + """Prepend a short system instruction before the model call. + + Each call can supply its own value by constructing a fresh instance: + ``ConciseReplyMiddleware(instruction=...)``. + """ + + async def wrap_model(self, params, next_fn, ctx: GenerateMiddlewareContext): + system_message = Message( + role=Role.SYSTEM, + content=[Part(root=TextPart(text=self.config.instruction))], + ) + params.request = params.request.model_copy() + params.request.messages = [system_message, *params.request.messages] + return await next_fn(params) + + +@ai.flow() +async def logging_demo(input: PromptInput) -> str: + """Pass a ``BaseMiddleware`` instance directly: no registration needed in-process.""" + + response = await ai.generate(prompt=input.prompt, use=[LoggingMiddleware()]) + return response.text + + +@ai.flow() +async def request_modifier_demo(input: PromptInput) -> str: + """Pass a configured middleware instance with a per-call override of ``instruction``.""" + + response = await ai.generate( + prompt=input.prompt, + use=[ConciseReplyMiddleware(instruction='Answer in a single haiku.')], + ) + return response.text + + +async def main() -> None: + """Run both middleware demos once.""" + print(await logging_demo(PromptInput())) # noqa: T201 + print(await request_modifier_demo(PromptInput(prompt='Write a haiku about recursion.'))) # noqa: T201 + + +if __name__ == '__main__': + ai.run_main(main()) diff --git a/py/uv.lock b/py/uv.lock index f95e41a86f..6fcc756042 100644 --- a/py/uv.lock +++ b/py/uv.lock @@ -25,10 +25,13 @@ members = [ "genkit-plugin-flask", "genkit-plugin-google-cloud", "genkit-plugin-google-genai", + "genkit-plugin-middleware", "genkit-plugin-ollama", "genkit-plugin-vertex-ai", "genkit-workspace", "google-genai-media", + "middleware", + "middleware-coding-agent", "output-formats", "prompts", "tool-interrupts", @@ -1265,6 +1268,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8a/0e/97c33bf5009bdbac74fd2beace167cab3f978feb69cc36f1ef79360d6c4e/exceptiongroup-1.3.1-py3-none-any.whl", hash = "sha256:a7a39a3bd276781e98394987d3a5701d0c4edffb633bb7a5144577f82c773598", size = 16740, upload-time = "2025-11-21T23:01:53.443Z" }, ] +[[package]] +name = "execnet" +version = "2.1.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/bf/89/780e11f9588d9e7128a3f87788354c7946a9cbb1401ad38a48c4db9a4f07/execnet-2.1.2.tar.gz", hash = "sha256:63d83bfdd9a23e35b9c6a3261412324f964c2ec8dcd8d3c6916ee9373e0befcd", size = 166622, upload-time = "2025-11-12T09:56:37.75Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ab/84/02fc1827e8cdded4aa65baef11296a9bbe595c474f0d6d758af082d849fd/execnet-2.1.2-py3-none-any.whl", hash = "sha256:67fba928dd5a544b783f6056f449e5e3931a5c378b128bc18501f7ea79e296ec", size = 40708, upload-time = "2025-11-12T09:56:36.333Z" }, +] + [[package]] name = "executing" version = "2.2.1" @@ -1637,6 +1649,34 @@ requires-dist = [ { name = "structlog", specifier = ">=25.2.0" }, ] +[[package]] +name = "genkit-plugin-middleware" +version = "0.5.2" +source = { editable = "plugins/middleware" } +dependencies = [ + { name = "genkit" }, + { name = "pyyaml" }, +] + +[package.optional-dependencies] +dev = [ + { name = "pytest" }, + { name = "pytest-asyncio" }, + { name = "pytest-cov" }, + { name = "pytest-xdist" }, +] + +[package.metadata] +requires-dist = [ + { name = "genkit", editable = "packages/genkit" }, + { name = "pytest", marker = "extra == 'dev'", specifier = ">=8.3.4" }, + { name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=0.25.2" }, + { name = "pytest-cov", marker = "extra == 'dev'", specifier = ">=6.0.0" }, + { name = "pytest-xdist", marker = "extra == 'dev'", specifier = ">=3.6.1" }, + { name = "pyyaml", specifier = ">=6.0" }, +] +provides-extras = ["dev"] + [[package]] name = "genkit-plugin-ollama" version = "0.5.2" @@ -1697,6 +1737,7 @@ dependencies = [ { name = "genkit-plugin-flask" }, { name = "genkit-plugin-google-cloud" }, { name = "genkit-plugin-google-genai" }, + { name = "genkit-plugin-middleware" }, { name = "genkit-plugin-ollama" }, { name = "genkit-plugin-vertex-ai" }, { name = "liccheck" }, @@ -1767,6 +1808,7 @@ requires-dist = [ { name = "genkit-plugin-flask", editable = "plugins/flask" }, { name = "genkit-plugin-google-cloud", editable = "plugins/google-cloud" }, { name = "genkit-plugin-google-genai", editable = "plugins/google-genai" }, + { name = "genkit-plugin-middleware", editable = "plugins/middleware" }, { name = "genkit-plugin-ollama", editable = "plugins/ollama" }, { name = "genkit-plugin-vertex-ai", editable = "plugins/vertex-ai" }, { name = "liccheck", specifier = ">=0.9.2" }, @@ -3461,6 +3503,52 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8", size = 9979, upload-time = "2022-08-14T12:40:09.779Z" }, ] +[[package]] +name = "middleware" +version = "0.2.0" +source = { editable = "samples/middleware" } +dependencies = [ + { name = "genkit" }, + { name = "genkit-plugin-google-genai" }, + { name = "genkit-plugin-middleware" }, + { name = "pydantic" }, + { name = "structlog" }, + { name = "uvloop" }, +] + +[package.metadata] +requires-dist = [ + { name = "genkit", editable = "packages/genkit" }, + { name = "genkit-plugin-google-genai", editable = "plugins/google-genai" }, + { name = "genkit-plugin-middleware", editable = "plugins/middleware" }, + { name = "pydantic", specifier = ">=2.0.0" }, + { name = "structlog", specifier = ">=24.0.0" }, + { name = "uvloop", specifier = ">=0.21.0" }, +] + +[[package]] +name = "middleware-coding-agent" +version = "0.1.0" +source = { editable = "samples/middleware-coding-agent" } +dependencies = [ + { name = "genkit" }, + { name = "genkit-plugin-google-genai" }, + { name = "genkit-plugin-middleware" }, + { name = "pydantic" }, + { name = "structlog" }, + { name = "uvloop" }, +] + +[package.metadata] +requires-dist = [ + { name = "genkit", editable = "packages/genkit" }, + { name = "genkit-plugin-google-genai", editable = "plugins/google-genai" }, + { name = "genkit-plugin-middleware", editable = "plugins/middleware" }, + { name = "pydantic", specifier = ">=2.10.5" }, + { name = "structlog", specifier = ">=25.2.0" }, + { name = "uvloop", specifier = ">=0.21.0" }, +] + [[package]] name = "mistune" version = "3.2.1" @@ -5326,6 +5414,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fc/3f/172d73600ad2771774cda108efb813fc724fc345e5240a81a1085f1ade5d/pytest_watcher-0.6.3-py3-none-any.whl", hash = "sha256:83e7748c933087e8276edb6078663e6afa9926434b4fd8b85cf6b32b1d5bec89", size = 12431, upload-time = "2026-01-10T23:28:17.64Z" }, ] +[[package]] +name = "pytest-xdist" +version = "3.8.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "execnet" }, + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/78/b4/439b179d1ff526791eb921115fca8e44e596a13efeda518b9d845a619450/pytest_xdist-3.8.0.tar.gz", hash = "sha256:7e578125ec9bc6050861aa93f2d59f1d8d085595d6551c2c90b6f4fad8d3a9f1", size = 88069, upload-time = "2025-07-01T13:30:59.346Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ca/31/d4e37e9e550c2b92a9cbc2e4d0b7420a27224968580b5a447f420847c975/pytest_xdist-3.8.0-py3-none-any.whl", hash = "sha256:202ca578cfeb7370784a8c33d6d05bc6e13b4f25b5053c30a152269fd10f0b88", size = 46396, upload-time = "2025-07-01T13:30:56.632Z" }, +] + [[package]] name = "python-dateutil" version = "2.9.0.post0"