Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions docs/guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,21 @@ variables. Configure `A2A_CLIENT_BEARER_TOKEN` or `A2A_CLIENT_BASIC_AUTH` when
the remote agent protects its runtime surface. CLI outbound calls follow the
same environment-only model.

`A2AClient.send()` returns the latest response event and keeps the default
stream-first behavior. If a peer returns a non-terminal task snapshot and
expects follow-up `tasks/get` polling, enable the optional facade fallback
with:

- `A2A_CLIENT_POLLING_FALLBACK_ENABLED=true`
- `A2A_CLIENT_POLLING_FALLBACK_INITIAL_INTERVAL_SECONDS`
- `A2A_CLIENT_POLLING_FALLBACK_MAX_INTERVAL_SECONDS`
- `A2A_CLIENT_POLLING_FALLBACK_BACKOFF_MULTIPLIER`
- `A2A_CLIENT_POLLING_FALLBACK_TIMEOUT_SECONDS`

The fallback only applies to `send()`, keeps `send_message()` as a thin event
stream wrapper, and stops polling once the task reaches a terminal state or a
caller-intervention state such as `input-required` or `auth-required`.

Execution-boundary metadata is intentionally declarative deployment metadata:
it is published through `RuntimeProfile`, Agent Card, OpenAPI, and `/health`,
and should not be interpreted as a live per-request privilege snapshot or a
Expand Down
81 changes: 78 additions & 3 deletions src/opencode_a2a/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@
map_agent_card_error,
map_operation_error,
)
from .errors import A2AUnsupportedBindingError
from .errors import A2ATimeoutError, A2AUnsupportedBindingError
from .payload_text import extract_text as extract_text_from_payload
from .polling import PollingFallbackPolicy
from .request_context import build_call_context, build_client_interceptors, split_request_metadata


Expand All @@ -58,6 +59,13 @@ def __init__(
self._lock = asyncio.Lock()
self._request_lock = asyncio.Lock()
self._active_requests = 0
self._polling_fallback_policy = PollingFallbackPolicy(
enabled=self._settings.polling_fallback_enabled,
initial_interval_seconds=self._settings.polling_fallback_initial_interval_seconds,
max_interval_seconds=self._settings.polling_fallback_max_interval_seconds,
backoff_multiplier=self._settings.polling_fallback_backoff_multiplier,
timeout_seconds=self._settings.polling_fallback_timeout_seconds,
)

async def close(self) -> None:
"""Close cached client resources and owned HTTP transport."""
Expand Down Expand Up @@ -149,7 +157,12 @@ async def send(
metadata: Mapping[str, Any] | None = None,
extensions: list[str] | None = None,
) -> Message | tuple[Task, TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None] | None:
"""Send a message and return the terminal response/event."""
"""Send a message and return the latest response event.

When polling fallback is enabled, a non-terminal `(Task, None)` result may
be followed by bounded `tasks/get` polling until a terminal task snapshot
is observed.
"""
last_event: (
Message | tuple[Task, TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None] | None
) = None
Expand All @@ -162,7 +175,13 @@ async def send(
extensions=extensions,
):
last_event = event
return last_event
if not self._should_poll_after_send(last_event):
return last_event
terminal_task = await self._poll_task_until_terminal(
self._extract_task_from_client_event(last_event),
metadata=metadata,
)
return (terminal_task, None)

async def get_task(
self,
Expand Down Expand Up @@ -299,6 +318,62 @@ async def _release_operation(self) -> None:
if self._active_requests > 0:
self._active_requests -= 1

def _should_poll_after_send(
self,
event: Message | tuple[Task, TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None] | None,
) -> bool:
if not self._polling_fallback_policy.enabled:
return False
if event is None or isinstance(event, Message) or not isinstance(event, tuple):
return False
task, update = event
if update is not None:
return False
return self._polling_fallback_policy.should_poll_state(task.status.state)

def _extract_task_from_client_event(
self,
event: Message | tuple[Task, TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None] | None,
) -> Task:
task, _update = cast(
tuple[Task, TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None],
event,
)
return task

async def _poll_task_until_terminal(
self,
task: Task,
*,
metadata: Mapping[str, Any] | None = None,
) -> Task:
deadline = self._current_time() + self._polling_fallback_policy.timeout_seconds
interval = self._polling_fallback_policy.initial_interval_seconds
current_task = task

while True:
if self._polling_fallback_policy.is_terminal_state(current_task.status.state):
return current_task
if not self._polling_fallback_policy.should_poll_state(current_task.status.state):
return current_task

remaining = deadline - self._current_time()
if remaining <= 0:
raise A2ATimeoutError(
"Remote A2A peer did not reach a terminal task state "
"before polling fallback timed out"
)

await self._sleep(min(interval, remaining))
current_task = await self.get_task(current_task.id, metadata=metadata)
interval = self._polling_fallback_policy.next_interval_seconds(interval)

def _current_time(self) -> float:
return asyncio.get_running_loop().time()

async def _sleep(self, delay_seconds: float) -> None:
await asyncio.sleep(delay_seconds)

def _build_user_message(
self,
*,
Expand Down
81 changes: 81 additions & 0 deletions src/opencode_a2a/client/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Any

from .auth import validate_basic_auth
from .polling import PollingFallbackPolicy, validate_polling_fallback_policy


def _read_setting(
Expand Down Expand Up @@ -113,6 +114,11 @@ class A2AClientSettings:
"JSONRPC",
"HTTP+JSON",
)
polling_fallback_enabled: bool = False
polling_fallback_initial_interval_seconds: float = 0.5
polling_fallback_max_interval_seconds: float = 2.0
polling_fallback_backoff_multiplier: float = 2.0
polling_fallback_timeout_seconds: float = 10.0


def load_settings(raw_settings: Any) -> A2AClientSettings:
Expand Down Expand Up @@ -177,6 +183,76 @@ def load_settings(raw_settings: Any) -> A2AClientSettings:
),
default=("JSONRPC", "HTTP+JSON"),
)
polling_fallback_enabled = _coerce_bool(
"A2A_CLIENT_POLLING_FALLBACK_ENABLED",
_read_setting(
raw_settings,
keys=(
"A2A_CLIENT_POLLING_FALLBACK_ENABLED",
"a2a_client_polling_fallback_enabled",
),
default=False,
),
default=False,
)
polling_fallback_initial_interval_seconds = _coerce_float(
"A2A_CLIENT_POLLING_FALLBACK_INITIAL_INTERVAL_SECONDS",
_read_setting(
raw_settings,
keys=(
"A2A_CLIENT_POLLING_FALLBACK_INITIAL_INTERVAL_SECONDS",
"a2a_client_polling_fallback_initial_interval_seconds",
),
default=0.5,
),
default=0.5,
)
polling_fallback_max_interval_seconds = _coerce_float(
"A2A_CLIENT_POLLING_FALLBACK_MAX_INTERVAL_SECONDS",
_read_setting(
raw_settings,
keys=(
"A2A_CLIENT_POLLING_FALLBACK_MAX_INTERVAL_SECONDS",
"a2a_client_polling_fallback_max_interval_seconds",
),
default=2.0,
),
default=2.0,
)
polling_fallback_backoff_multiplier = _coerce_float(
"A2A_CLIENT_POLLING_FALLBACK_BACKOFF_MULTIPLIER",
_read_setting(
raw_settings,
keys=(
"A2A_CLIENT_POLLING_FALLBACK_BACKOFF_MULTIPLIER",
"a2a_client_polling_fallback_backoff_multiplier",
),
default=2.0,
),
default=2.0,
)
polling_fallback_timeout_seconds = _coerce_float(
"A2A_CLIENT_POLLING_FALLBACK_TIMEOUT_SECONDS",
_read_setting(
raw_settings,
keys=(
"A2A_CLIENT_POLLING_FALLBACK_TIMEOUT_SECONDS",
"a2a_client_polling_fallback_timeout_seconds",
),
default=10.0,
),
default=10.0,
)

validate_polling_fallback_policy(
PollingFallbackPolicy(
enabled=polling_fallback_enabled,
initial_interval_seconds=polling_fallback_initial_interval_seconds,
max_interval_seconds=polling_fallback_max_interval_seconds,
backoff_multiplier=polling_fallback_backoff_multiplier,
timeout_seconds=polling_fallback_timeout_seconds,
)
)

return A2AClientSettings(
default_timeout=default_timeout,
Expand All @@ -185,6 +261,11 @@ def load_settings(raw_settings: Any) -> A2AClientSettings:
bearer_token=bearer_token,
basic_auth=basic_auth,
supported_transports=supported_transports,
polling_fallback_enabled=polling_fallback_enabled,
polling_fallback_initial_interval_seconds=polling_fallback_initial_interval_seconds,
polling_fallback_max_interval_seconds=polling_fallback_max_interval_seconds,
polling_fallback_backoff_multiplier=polling_fallback_backoff_multiplier,
polling_fallback_timeout_seconds=polling_fallback_timeout_seconds,
)


Expand Down
68 changes: 68 additions & 0 deletions src/opencode_a2a/client/polling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""Polling fallback policy helpers for the A2A client facade."""

from __future__ import annotations

from dataclasses import dataclass

from a2a.types import TaskState

_TERMINAL_TASK_STATES = frozenset(
{
TaskState.completed,
TaskState.canceled,
TaskState.failed,
TaskState.rejected,
}
)
_AUTO_POLLING_TASK_STATES = frozenset(
{
TaskState.submitted,
TaskState.working,
TaskState.unknown,
}
)


@dataclass(frozen=True)
class PollingFallbackPolicy:
"""Encapsulates polling fallback configuration and task-state rules."""

enabled: bool = False
initial_interval_seconds: float = 0.5
max_interval_seconds: float = 2.0
backoff_multiplier: float = 2.0
timeout_seconds: float = 10.0

def should_poll_state(self, state: TaskState) -> bool:
return state in _AUTO_POLLING_TASK_STATES

def is_terminal_state(self, state: TaskState) -> bool:
return state in _TERMINAL_TASK_STATES

def next_interval_seconds(self, current_interval_seconds: float) -> float:
return min(
max(current_interval_seconds, 0.0) * self.backoff_multiplier,
self.max_interval_seconds,
)


def validate_polling_fallback_policy(policy: PollingFallbackPolicy) -> None:
"""Validate polling fallback settings before they are used at runtime."""
if policy.initial_interval_seconds <= 0:
raise ValueError("A2A_CLIENT_POLLING_FALLBACK_INITIAL_INTERVAL_SECONDS must be positive")
if policy.max_interval_seconds <= 0:
raise ValueError("A2A_CLIENT_POLLING_FALLBACK_MAX_INTERVAL_SECONDS must be positive")
if policy.backoff_multiplier < 1.0:
raise ValueError(
"A2A_CLIENT_POLLING_FALLBACK_BACKOFF_MULTIPLIER must be greater than or equal to 1"
)
if policy.timeout_seconds <= 0:
raise ValueError("A2A_CLIENT_POLLING_FALLBACK_TIMEOUT_SECONDS must be positive")
if policy.max_interval_seconds < policy.initial_interval_seconds:
raise ValueError(
"A2A_CLIENT_POLLING_FALLBACK_MAX_INTERVAL_SECONDS must be greater than or "
"equal to A2A_CLIENT_POLLING_FALLBACK_INITIAL_INTERVAL_SECONDS"
)


__all__ = ["PollingFallbackPolicy", "validate_polling_fallback_policy"]
20 changes: 20 additions & 0 deletions tests/client/test_client_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ def test_load_settings_from_mapping() -> None:
"A2A_CLIENT_BEARER_TOKEN": "peer-token",
"A2A_CLIENT_BASIC_AUTH": "user:pass",
"A2A_CLIENT_SUPPORTED_TRANSPORTS": "json-rpc,http-json",
"A2A_CLIENT_POLLING_FALLBACK_ENABLED": "true",
"A2A_CLIENT_POLLING_FALLBACK_INITIAL_INTERVAL_SECONDS": "0.75",
"A2A_CLIENT_POLLING_FALLBACK_MAX_INTERVAL_SECONDS": "3",
"A2A_CLIENT_POLLING_FALLBACK_BACKOFF_MULTIPLIER": "1.5",
"A2A_CLIENT_POLLING_FALLBACK_TIMEOUT_SECONDS": "12",
}

settings = load_settings(raw)
Expand All @@ -31,6 +36,11 @@ def test_load_settings_from_mapping() -> None:
assert settings.bearer_token == "peer-token"
assert settings.basic_auth == "user:pass"
assert settings.supported_transports == ("JSONRPC", "HTTP+JSON")
assert settings.polling_fallback_enabled is True
assert settings.polling_fallback_initial_interval_seconds == 0.75
assert settings.polling_fallback_max_interval_seconds == 3.0
assert settings.polling_fallback_backoff_multiplier == 1.5
assert settings.polling_fallback_timeout_seconds == 12.0


def test_load_settings_invalid_transport_raises() -> None:
Expand Down Expand Up @@ -59,3 +69,13 @@ def test_load_settings_accepts_base64_basic_auth() -> None:
def test_load_settings_invalid_basic_auth_raises() -> None:
with pytest.raises(ValueError, match="username:password"):
load_settings({"A2A_CLIENT_BASIC_AUTH": "not-basic-auth"})


def test_load_settings_invalid_polling_fallback_interval_raises() -> None:
with pytest.raises(ValueError, match="INITIAL_INTERVAL_SECONDS must be positive"):
load_settings({"A2A_CLIENT_POLLING_FALLBACK_INITIAL_INTERVAL_SECONDS": "0"})


def test_load_settings_invalid_polling_fallback_backoff_raises() -> None:
with pytest.raises(ValueError, match="BACKOFF_MULTIPLIER must be greater than or equal to 1"):
load_settings({"A2A_CLIENT_POLLING_FALLBACK_BACKOFF_MULTIPLIER": "0.5"})
Loading