diff --git a/icij-worker/icij_worker/cli/tasks.py b/icij-worker/icij_worker/cli/tasks.py index 51c7f6e..6870039 100644 --- a/icij-worker/icij_worker/cli/tasks.py +++ b/icij-worker/icij_worker/cli/tasks.py @@ -11,7 +11,7 @@ from icij_worker import TaskState from icij_worker.cli.utils import AsyncTyper, eprint from icij_worker.http_ import TaskClient -from icij_worker.objects import READY_STATES, Task, TaskError +from icij_worker.objects import ErrorEvent, READY_STATES, Task logger = logging.getLogger(__name__) @@ -155,7 +155,8 @@ async def _handle_alive( await _handle_ready(task, client) -def _format_error(error: TaskError) -> str: +def _format_error(error: ErrorEvent) -> str: + error = error.error stack = StackSummary.from_list( [FrameSummary(f.name, f.lineno, f.name) for f in error.stacktrace] ) diff --git a/icij-worker/icij_worker/event_publisher/amqp.py b/icij-worker/icij_worker/event_publisher/amqp.py index 7d053b9..5049ac2 100644 --- a/icij-worker/icij_worker/event_publisher/amqp.py +++ b/icij-worker/icij_worker/event_publisher/amqp.py @@ -4,31 +4,27 @@ from contextlib import AsyncExitStack from functools import cached_property -from aio_pika import ( - Exchange as AioPikaExchange, - RobustChannel, - connect_robust, -) +from aio_pika import Exchange as AioPikaExchange, RobustChannel from aio_pika.abc import AbstractRobustConnection from icij_common.logging_utils import LogWithNameMixin from icij_worker import ManagerEvent from . import EventPublisher from ..routing_strategy import Routing -from ..utils.amqp import AMQPMixin, RobustConnection +from ..utils.amqp import AMQPMixin class AMQPPublisher(AMQPMixin, EventPublisher, LogWithNameMixin): def __init__( self, - logger: Optional[logging.Logger] = None, + logger: logging.Logger | None = None, *, broker_url: str, connection_timeout_s: float = 1.0, reconnection_wait_s: float = 5.0, is_qpid: bool = False, app_id: str | None = None, - connection: Optional[AbstractRobustConnection] = None, + connection: AbstractRobustConnection | None = None, ): super().__init__( broker_url, @@ -42,8 +38,8 @@ def __init__( self._app_id = app_id self._broker_url = broker_url self._connection_ = connection - self._channel_: Optional[RobustChannel] = None - self._manager_evt_x: Optional[AioPikaExchange] = None + self._channel_: RobustChannel | None = None + self._manager_evt_x: AioPikaExchange | None = None self._connection_timeout_s = connection_timeout_s self._reconnection_wait_s = reconnection_wait_s self._exit_stack = AsyncExitStack() @@ -73,14 +69,11 @@ async def _publish_event(self, event: ManagerEvent): async def _connection_workflow(self): self.debug("creating connection...") - if self._connection_ is None: - self._connection_ = await connect_robust( - self._broker_url, - timeout=self._connection_timeout_s, - reconnect_interval=self._reconnection_wait_s, - connection_class=RobustConnection, - ) - await self._exit_stack.enter_async_context(self._connection) + try: + _ = self.connection + except ValueError: + await self._connect() + await self._exit_stack.enter_async_context(self.connection) self.debug("creating channel...") self._channel_ = await self._connection.channel( publisher_confirms=self._publisher_confirms, diff --git a/icij-worker/icij_worker/task_manager/amqp.py b/icij-worker/icij_worker/task_manager/amqp.py index fac29a5..6cfa938 100644 --- a/icij-worker/icij_worker/task_manager/amqp.py +++ b/icij-worker/icij_worker/task_manager/amqp.py @@ -5,7 +5,6 @@ from functools import cached_property from typing import TypeVar, cast -from aio_pika import connect_robust from aio_pika.abc import AbstractExchange, AbstractQueueIterator from aiormq import DeliveryError from pydantic import Field @@ -32,7 +31,6 @@ AMQPConfigMixin, AMQPManagementClient, AMQPMixin, - RobustConnection, amqp_task_group_policy, health_policy, ) @@ -217,12 +215,10 @@ async def shutdown_workers(self): async def _connection_workflow(self): await self._exit_stack.enter_async_context(self._management_client) logger.debug("creating connection...") - self._connection_ = await connect_robust( - self._broker_url, - timeout=self._connection_timeout_s, - reconnect_interval=self._reconnection_wait_s, - connection_class=RobustConnection, - ) + try: + _ = self.connection + except ValueError: + await self._connect() await self._exit_stack.enter_async_context(self._connection) logger.debug("creating channel...") self._channel_ = await self._connection.channel( diff --git a/icij-worker/icij_worker/utils/amqp.py b/icij-worker/icij_worker/utils/amqp.py index 69258a8..e1ad686 100644 --- a/icij-worker/icij_worker/utils/amqp.py +++ b/icij-worker/icij_worker/utils/amqp.py @@ -14,8 +14,9 @@ from aio_pika import ( DeliveryMode, Message as AioPikaMessage, - RobustChannel as RobustChannel_, - RobustConnection as RobustConnection_, + RobustChannel, + RobustConnection, + connect_robust, ) from aio_pika.abc import ( AbstractExchange, @@ -117,8 +118,7 @@ def broker_url(self) -> str: if amqp_userinfo: amqp_userinfo += "@" amqp_authority = ( - f"{amqp_userinfo or ''}{self.rabbitmq_host}" - f"{f':{self.rabbitmq_port}' or ''}" + f"{amqp_userinfo or ''}{self.rabbitmq_host}{f':{self.rabbitmq_port}' or ''}" ) amqp_uri = f"amqp://{amqp_authority}" if self.rabbitmq_vhost is not None: @@ -216,7 +216,7 @@ def channel(self) -> AbstractRobustChannel: return self._channel @property - def connection(self) -> AbstractRobustChannel: + def connection(self) -> AbstractRobustConnection: return self._connection @classmethod @@ -268,6 +268,15 @@ def health_routing(cls) -> Routing: queue_name=AMQP_HEALTH_QUEUE, ) + async def _connect(self): + connection_class = QpidRobustConnection if self._is_qpid else RobustConnection + self._connection_ = await connect_robust( + self._broker_url, + timeout=self._connection_timeout_s, + reconnect_interval=self._reconnection_wait_s, + connection_class=connection_class, + ) + async def _get_queue_iterator( self, routing: Routing, @@ -431,7 +440,7 @@ def health_policy(routing: Routing) -> AMQPPolicy: ) -class RobustChannel(RobustChannel_): +class QpidRobustChannel(RobustChannel): async def __close_callback(self, _: Any, exc: BaseException) -> None: # pylint: disable=unused-private-member timeout_exc = parse_consumer_timeout(exc) @@ -440,8 +449,8 @@ async def __close_callback(self, _: Any, exc: BaseException) -> None: raise timeout_exc from exc -class RobustConnection(RobustConnection_): - CHANNEL_CLASS: type[RobustChannel] = RobustChannel +class QpidRobustConnection(RobustConnection): + CHANNEL_CLASS: type[RobustChannel] = QpidRobustChannel # Defined async context manager attributes to be able to enter and exit this # in ExitStack diff --git a/icij-worker/icij_worker/worker/worker.py b/icij-worker/icij_worker/worker/worker.py index 02a043f..ee948e4 100644 --- a/icij-worker/icij_worker/worker/worker.py +++ b/icij-worker/icij_worker/worker/worker.py @@ -182,9 +182,9 @@ async def consume(self) -> Task: self._started_task_consumption_evt.set() task = await self._consume() msg = 'Task(id="%s") locked' - if task.max_retries is not None: + if task.max_retries is not None and task.retries_left is not None: msg += ( - f", tentative ({task.max_retries - task.retries_left}" + f", retry ({task.max_retries - task.retries_left}" f"/{task.max_retries})" ) self.info(msg, task.id) diff --git a/icij-worker/tests/utils/test_amqp.py b/icij-worker/tests/utils/test_amqp.py index 68a9f5c..3130052 100644 --- a/icij-worker/tests/utils/test_amqp.py +++ b/icij-worker/tests/utils/test_amqp.py @@ -4,7 +4,7 @@ import pytest from aio_pika import Message, connect_robust -from aio_pika.abc import AbstractRobustQueue +from aio_pika.abc import AbstractRobustConnection, AbstractRobustQueue from icij_worker.task_storage.postgres.postgres import logger from icij_worker.utils.amqp import ( @@ -12,8 +12,9 @@ AMQPMixin, AMQPPolicy, ApplyTo, - parse_consumer_timeout, + QpidRobustConnection, RobustConnection, + parse_consumer_timeout, worker_events_policy, ) @@ -104,3 +105,28 @@ async def test_worker_events_policy(): assert policy == expected worker_queue_name = "WORKER_EVENT-some-service" assert re.match(policy.pattern, worker_queue_name) + + +@pytest.mark.parametrize( + "is_qpid_,expected_type", [(True, QpidRobustConnection), (False, RobustConnection)] +) +async def test_should_handle_qpid_when_creating_connection( + is_qpid_, expected_type: type[AbstractRobustConnection], rabbit_mq: str +): + # Given + class SomeClass(AMQPMixin): + def __init__(self, broker_url: str, *, is_qpid: bool): + super().__init__(broker_url=broker_url, is_qpid=is_qpid) + + async def connect(self): + await self._connect() + + # When + instance = SomeClass(rabbit_mq, is_qpid=is_qpid_) + await instance.connect() + + # Then + assert ( + type(instance.connection) # pylint: disable=unidiomatic-typecheck + is expected_type + )