Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions icij-worker/icij_worker/cli/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from icij_worker import TaskState
from icij_worker.cli.utils import AsyncTyper, eprint
from icij_worker.http_ import TaskClient
from icij_worker.objects import READY_STATES, Task, TaskError
from icij_worker.objects import ErrorEvent, READY_STATES, Task

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -155,7 +155,8 @@ async def _handle_alive(
await _handle_ready(task, client)


def _format_error(error: TaskError) -> str:
def _format_error(error: ErrorEvent) -> str:
error = error.error
stack = StackSummary.from_list(
[FrameSummary(f.name, f.lineno, f.name) for f in error.stacktrace]
)
Expand Down
29 changes: 11 additions & 18 deletions icij-worker/icij_worker/event_publisher/amqp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,31 +4,27 @@
from contextlib import AsyncExitStack
from functools import cached_property

from aio_pika import (
Exchange as AioPikaExchange,
RobustChannel,
connect_robust,
)
from aio_pika import Exchange as AioPikaExchange, RobustChannel
from aio_pika.abc import AbstractRobustConnection

from icij_common.logging_utils import LogWithNameMixin
from icij_worker import ManagerEvent
from . import EventPublisher
from ..routing_strategy import Routing
from ..utils.amqp import AMQPMixin, RobustConnection
from ..utils.amqp import AMQPMixin


class AMQPPublisher(AMQPMixin, EventPublisher, LogWithNameMixin):
def __init__(
self,
logger: Optional[logging.Logger] = None,
logger: logging.Logger | None = None,
*,
broker_url: str,
connection_timeout_s: float = 1.0,
reconnection_wait_s: float = 5.0,
is_qpid: bool = False,
app_id: str | None = None,
connection: Optional[AbstractRobustConnection] = None,
connection: AbstractRobustConnection | None = None,
):
super().__init__(
broker_url,
Expand All @@ -42,8 +38,8 @@ def __init__(
self._app_id = app_id
self._broker_url = broker_url
self._connection_ = connection
self._channel_: Optional[RobustChannel] = None
self._manager_evt_x: Optional[AioPikaExchange] = None
self._channel_: RobustChannel | None = None
self._manager_evt_x: AioPikaExchange | None = None
self._connection_timeout_s = connection_timeout_s
self._reconnection_wait_s = reconnection_wait_s
self._exit_stack = AsyncExitStack()
Expand Down Expand Up @@ -73,14 +69,11 @@ async def _publish_event(self, event: ManagerEvent):

async def _connection_workflow(self):
self.debug("creating connection...")
if self._connection_ is None:
self._connection_ = await connect_robust(
self._broker_url,
timeout=self._connection_timeout_s,
reconnect_interval=self._reconnection_wait_s,
connection_class=RobustConnection,
)
await self._exit_stack.enter_async_context(self._connection)
try:
_ = self.connection
except ValueError:
await self._connect()
await self._exit_stack.enter_async_context(self.connection)
self.debug("creating channel...")
self._channel_ = await self._connection.channel(
publisher_confirms=self._publisher_confirms,
Expand Down
12 changes: 4 additions & 8 deletions icij-worker/icij_worker/task_manager/amqp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from functools import cached_property
from typing import TypeVar, cast

from aio_pika import connect_robust
from aio_pika.abc import AbstractExchange, AbstractQueueIterator
from aiormq import DeliveryError
from pydantic import Field
Expand All @@ -32,7 +31,6 @@
AMQPConfigMixin,
AMQPManagementClient,
AMQPMixin,
RobustConnection,
amqp_task_group_policy,
health_policy,
)
Expand Down Expand Up @@ -217,12 +215,10 @@ async def shutdown_workers(self):
async def _connection_workflow(self):
await self._exit_stack.enter_async_context(self._management_client)
logger.debug("creating connection...")
self._connection_ = await connect_robust(
self._broker_url,
timeout=self._connection_timeout_s,
reconnect_interval=self._reconnection_wait_s,
connection_class=RobustConnection,
)
try:
_ = self.connection
except ValueError:
await self._connect()
await self._exit_stack.enter_async_context(self._connection)
logger.debug("creating channel...")
self._channel_ = await self._connection.channel(
Expand Down
25 changes: 17 additions & 8 deletions icij-worker/icij_worker/utils/amqp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
from aio_pika import (
DeliveryMode,
Message as AioPikaMessage,
RobustChannel as RobustChannel_,
RobustConnection as RobustConnection_,
RobustChannel,
RobustConnection,
connect_robust,
)
from aio_pika.abc import (
AbstractExchange,
Expand Down Expand Up @@ -117,8 +118,7 @@ def broker_url(self) -> str:
if amqp_userinfo:
amqp_userinfo += "@"
amqp_authority = (
f"{amqp_userinfo or ''}{self.rabbitmq_host}"
f"{f':{self.rabbitmq_port}' or ''}"
f"{amqp_userinfo or ''}{self.rabbitmq_host}{f':{self.rabbitmq_port}' or ''}"
)
amqp_uri = f"amqp://{amqp_authority}"
if self.rabbitmq_vhost is not None:
Expand Down Expand Up @@ -216,7 +216,7 @@ def channel(self) -> AbstractRobustChannel:
return self._channel

@property
def connection(self) -> AbstractRobustChannel:
def connection(self) -> AbstractRobustConnection:
return self._connection

@classmethod
Expand Down Expand Up @@ -268,6 +268,15 @@ def health_routing(cls) -> Routing:
queue_name=AMQP_HEALTH_QUEUE,
)

async def _connect(self):
connection_class = QpidRobustConnection if self._is_qpid else RobustConnection
self._connection_ = await connect_robust(
self._broker_url,
timeout=self._connection_timeout_s,
reconnect_interval=self._reconnection_wait_s,
connection_class=connection_class,
)

async def _get_queue_iterator(
self,
routing: Routing,
Expand Down Expand Up @@ -431,7 +440,7 @@ def health_policy(routing: Routing) -> AMQPPolicy:
)


class RobustChannel(RobustChannel_):
class QpidRobustChannel(RobustChannel):
async def __close_callback(self, _: Any, exc: BaseException) -> None:
# pylint: disable=unused-private-member
timeout_exc = parse_consumer_timeout(exc)
Expand All @@ -440,8 +449,8 @@ async def __close_callback(self, _: Any, exc: BaseException) -> None:
raise timeout_exc from exc


class RobustConnection(RobustConnection_):
CHANNEL_CLASS: type[RobustChannel] = RobustChannel
class QpidRobustConnection(RobustConnection):
CHANNEL_CLASS: type[RobustChannel] = QpidRobustChannel

# Defined async context manager attributes to be able to enter and exit this
# in ExitStack
Expand Down
4 changes: 2 additions & 2 deletions icij-worker/icij_worker/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,9 @@ async def consume(self) -> Task:
self._started_task_consumption_evt.set()
task = await self._consume()
msg = 'Task(id="%s") locked'
if task.max_retries is not None:
if task.max_retries is not None and task.retries_left is not None:
msg += (
f", tentative ({task.max_retries - task.retries_left}"
f", retry ({task.max_retries - task.retries_left}"
f"/{task.max_retries})"
)
self.info(msg, task.id)
Expand Down
30 changes: 28 additions & 2 deletions icij-worker/tests/utils/test_amqp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,17 @@

import pytest
from aio_pika import Message, connect_robust
from aio_pika.abc import AbstractRobustQueue
from aio_pika.abc import AbstractRobustConnection, AbstractRobustQueue

from icij_worker.task_storage.postgres.postgres import logger
from icij_worker.utils.amqp import (
AMQPManagementClient,
AMQPMixin,
AMQPPolicy,
ApplyTo,
parse_consumer_timeout,
QpidRobustConnection,
RobustConnection,
parse_consumer_timeout,
worker_events_policy,
)

Expand Down Expand Up @@ -104,3 +105,28 @@ async def test_worker_events_policy():
assert policy == expected
worker_queue_name = "WORKER_EVENT-some-service"
assert re.match(policy.pattern, worker_queue_name)


@pytest.mark.parametrize(
"is_qpid_,expected_type", [(True, QpidRobustConnection), (False, RobustConnection)]
)
async def test_should_handle_qpid_when_creating_connection(
is_qpid_, expected_type: type[AbstractRobustConnection], rabbit_mq: str
):
# Given
class SomeClass(AMQPMixin):
def __init__(self, broker_url: str, *, is_qpid: bool):
super().__init__(broker_url=broker_url, is_qpid=is_qpid)

async def connect(self):
await self._connect()

# When
instance = SomeClass(rabbit_mq, is_qpid=is_qpid_)
await instance.connect()

# Then
assert (
type(instance.connection) # pylint: disable=unidiomatic-typecheck
is expected_type
)
Loading