Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,16 @@

```python
# Simple boolean and dict
toolset = SendA2uiToClientToolset(a2ui_enabled=True, a2ui_schema=MY_SCHEMA)
toolset = SendA2uiToClientToolset(a2ui_enabled=True, a2ui_catalog=MY_CATALOG)

# Async providers
async def check_enabled(ctx: ReadonlyContext) -> bool:
return await some_condition(ctx)

async def get_schema(ctx: ReadonlyContext) -> dict[str, Any]:
return await fetch_schema(ctx)
async def get_catalog(ctx: ReadonlyContext) -> A2uiCatalog:
return await fetch_catalog(ctx)

toolset = SendA2uiToClientToolset(a2ui_enabled=check_enabled, a2ui_schema=get_schema)
toolset = SendA2uiToClientToolset(a2ui_enabled=check_enabled, a2ui_catalog=get_catalog)
```

2. Integration with Agent:
Expand All @@ -60,7 +60,7 @@ async def get_schema(ctx: ReadonlyContext) -> dict[str, Any]:
tools=[
SendA2uiToClientToolset(
a2ui_enabled=True,
a2ui_schema=MY_SCHEMA
a2ui_catalog=MY_CATALOG
)
]
)
Expand All @@ -86,7 +86,7 @@ async def get_schema(ctx: ReadonlyContext) -> dict[str, Any]:

from a2a import types as a2a_types
from a2ui.extension.a2ui_extension import create_a2ui_part
from a2ui.extension.a2ui_schema_utils import wrap_as_json_array
from a2ui.inference.schema.catalog import A2uiCatalog
from google.adk.a2a.converters import part_converter
from google.adk.agents.readonly_context import ReadonlyContext
from google.adk.models import LlmRequest
Expand All @@ -101,8 +101,11 @@ async def get_schema(ctx: ReadonlyContext) -> dict[str, Any]:
A2uiEnabledProvider: TypeAlias = Callable[
[ReadonlyContext], Union[bool, Awaitable[bool]]
]
A2uiSchemaProvider: TypeAlias = Callable[
[ReadonlyContext], Union[dict[str, Any], Awaitable[dict[str, Any]]]
A2uiCatalogProvider: TypeAlias = Callable[
[ReadonlyContext], Union[A2uiCatalog, Awaitable[A2uiCatalog]]
]
A2uiExamplesProvider: TypeAlias = Callable[
[ReadonlyContext], Union[str, Awaitable[str]]
]


Expand All @@ -113,11 +116,12 @@ class SendA2uiToClientToolset(base_toolset.BaseToolset):
def __init__(
self,
a2ui_enabled: Union[bool, A2uiEnabledProvider],
a2ui_schema: Union[dict[str, Any], A2uiSchemaProvider],
a2ui_catalog: Union[A2uiCatalog, A2uiCatalogProvider],
a2ui_examples: Union[str, A2uiExamplesProvider],
):
super().__init__()
self._a2ui_enabled = a2ui_enabled
self._ui_tools = [self._SendA2uiJsonToClientTool(a2ui_schema)]
self._ui_tools = [self._SendA2uiJsonToClientTool(a2ui_catalog, a2ui_examples)]

async def _resolve_a2ui_enabled(self, ctx: ReadonlyContext) -> bool:
"""The resolved self.a2ui_enabled field to construct instruction for this agent.
Expand Down Expand Up @@ -164,8 +168,13 @@ class _SendA2uiJsonToClientTool(BaseTool):
A2UI_JSON_ARG_NAME = "a2ui_json"
TOOL_ERROR_KEY = "error"

def __init__(self, a2ui_schema: Union[dict[str, Any], A2uiSchemaProvider]):
self._a2ui_schema = a2ui_schema
def __init__(
self,
a2ui_catalog: Union[A2uiCatalog, A2uiCatalogProvider],
a2ui_examples: Union[str, A2uiExamplesProvider],
):
self._a2ui_catalog = a2ui_catalog
self._a2ui_examples = a2ui_examples
super().__init__(
name=self.TOOL_NAME,
description=(
Expand Down Expand Up @@ -195,34 +204,39 @@ def _get_declaration(self) -> genai_types.FunctionDeclaration | None:
),
)

async def _resolve_a2ui_schema(self, ctx: ReadonlyContext) -> dict[str, Any]:
"""The resolved self.a2ui_schema field to construct instruction for this agent.
async def _resolve_a2ui_examples(self, ctx: ReadonlyContext) -> str:
"""The resolved self.a2ui_examples field to construct instruction for this agent.

Args:
ctx: The ReadonlyContext to resolve the provider with.

Returns:
The A2UI schema to send to the client.
The A2UI examples string.
"""
if isinstance(self._a2ui_schema, dict):
return self._a2ui_schema
if isinstance(self._a2ui_examples, str):
return self._a2ui_examples
else:
a2ui_schema = self._a2ui_schema(ctx)
if inspect.isawaitable(a2ui_schema):
a2ui_schema = await a2ui_schema
return a2ui_schema
a2ui_examples = self._a2ui_examples(ctx)
if inspect.isawaitable(a2ui_examples):
a2ui_examples = await a2ui_examples
return a2ui_examples

async def get_a2ui_schema(self, ctx: ReadonlyContext) -> dict[str, Any]:
"""Retrieves and wraps the A2UI schema.
async def _resolve_a2ui_catalog(self, ctx: ReadonlyContext) -> A2uiCatalog:
"""The resolved self.a2ui_catalog field to construct instruction for this agent.

Args:
ctx: The ReadonlyContext for resolving the schema.
ctx: The ReadonlyContext to resolve the provider with.

Returns:
The wrapped A2UI schema.
The A2UI catalog object.
"""
a2ui_schema = await self._resolve_a2ui_schema(ctx)
return wrap_as_json_array(a2ui_schema)
if isinstance(self._a2ui_catalog, A2uiCatalog):
return self._a2ui_catalog
else:
a2ui_catalog = self._a2ui_catalog(ctx)
if inspect.isawaitable(a2ui_catalog):
a2ui_catalog = await a2ui_catalog
return a2ui_catalog

async def process_llm_request(
self, *, tool_context: ToolContext, llm_request: LlmRequest
Expand All @@ -231,15 +245,14 @@ async def process_llm_request(
tool_context=tool_context, llm_request=llm_request
)

a2ui_schema = await self.get_a2ui_schema(tool_context)
a2ui_catalog = await self._resolve_a2ui_catalog(tool_context)

instruction = a2ui_catalog.render_as_llm_instructions()
examples = await self._resolve_a2ui_examples(tool_context)

llm_request.append_instructions([f"""
---BEGIN A2UI JSON SCHEMA---
{json.dumps(a2ui_schema)}
---END A2UI JSON SCHEMA---
"""])
llm_request.append_instructions([instruction, examples])

logger.info("Added a2ui_schema to system instructions")
logger.info("Added A2UI schema and examples to system instructions")

async def run_async(
self, *, args: dict[str, Any], tool_context: ToolContext
Expand All @@ -261,8 +274,9 @@ async def run_async(
)
a2ui_json_payload = [a2ui_json_payload]

a2ui_schema = await self.get_a2ui_schema(tool_context)
jsonschema.validate(instance=a2ui_json_payload, schema=a2ui_schema)
a2ui_catalog = await self._resolve_a2ui_catalog(tool_context)

a2ui_catalog.validator.validate(a2ui_json_payload)

logger.info(
f"Validated call to tool {self.TOOL_NAME} with {self.A2UI_JSON_ARG_NAME}"
Expand Down
Loading