Source code for aio_pika.robust_queue

import asyncio
import uuid
import warnings
from typing import Any, Awaitable, Callable, Dict, Optional, Tuple, Union

import aiormq
from aiormq import ChannelInvalidStateError
from pamqp.common import Arguments

from .abc import (
    AbstractChannel,
    AbstractExchange,
    AbstractIncomingMessage,
    AbstractQueueIterator,
    AbstractRobustQueue,
    ConsumerTag,
    TimeoutType,
)
from .exchange import ExchangeParamType
from .log import get_logger
from .queue import Queue, QueueIterator


log = get_logger(__name__)


[docs] class RobustQueue(Queue, AbstractRobustQueue): __slots__ = ("_consumers", "_bindings") _consumers: Dict[ConsumerTag, Dict[str, Any]] _bindings: Dict[Tuple[Union[AbstractExchange, str], str], Dict[str, Any]] def __init__( self, channel: AbstractChannel, name: Optional[str], durable: bool = False, exclusive: bool = False, auto_delete: bool = False, arguments: Arguments = None, passive: bool = False, ): super().__init__( channel=channel, name=name or f"amq_{uuid.uuid4().hex}", durable=durable, exclusive=exclusive, auto_delete=auto_delete, arguments=arguments, passive=passive, ) self._consumers = {} self._bindings = {} async def restore(self, channel: Any = None) -> None: if channel is not None: warnings.warn( "Channel argument will be ignored because you " "don't need to pass this anymore.", DeprecationWarning, ) await self.declare() bindings = tuple(self._bindings.items()) consumers = tuple(self._consumers.items()) for (exchange, routing_key), kwargs in bindings: await self.bind(exchange, routing_key, **kwargs) for consumer_tag, kwargs in consumers: await self.consume(consumer_tag=consumer_tag, **kwargs)
[docs] async def bind( self, exchange: ExchangeParamType, routing_key: Optional[str] = None, *, arguments: Arguments = None, timeout: TimeoutType = None, robust: bool = True, ) -> aiormq.spec.Queue.BindOk: if routing_key is None: routing_key = self.name result = await super().bind( exchange=exchange, routing_key=routing_key, arguments=arguments, timeout=timeout, ) if robust: self._bindings[(exchange, routing_key)] = dict( arguments=arguments, ) return result
[docs] async def unbind( self, exchange: ExchangeParamType, routing_key: Optional[str] = None, arguments: Arguments = None, timeout: TimeoutType = None, ) -> aiormq.spec.Queue.UnbindOk: if routing_key is None: routing_key = self.name result = await super().unbind( exchange, routing_key, arguments, timeout, ) self._bindings.pop((exchange, routing_key), None) return result
[docs] async def consume( self, callback: Callable[[AbstractIncomingMessage], Awaitable[Any]], no_ack: bool = False, exclusive: bool = False, arguments: Arguments = None, consumer_tag: Optional[ConsumerTag] = None, timeout: TimeoutType = None, robust: bool = True, ) -> ConsumerTag: consumer_tag = await super().consume( consumer_tag=consumer_tag, timeout=timeout, callback=callback, no_ack=no_ack, exclusive=exclusive, arguments=arguments, ) if robust: self._consumers[consumer_tag] = dict( callback=callback, no_ack=no_ack, exclusive=exclusive, arguments=arguments, ) return consumer_tag
[docs] async def cancel( self, consumer_tag: ConsumerTag, timeout: TimeoutType = None, nowait: bool = False, ) -> aiormq.spec.Basic.CancelOk: result = await super().cancel(consumer_tag, timeout, nowait) self._consumers.pop(consumer_tag, None) return result
[docs] def iterator(self, **kwargs: Any) -> AbstractQueueIterator: return RobustQueueIterator(self, **kwargs)
class RobustQueueIterator(QueueIterator): """Queue iterator that survives channel reconnection. This iterator handles channel disconnection/reconnection gracefully by waiting for channel restoration instead of raising StopAsyncIteration. """ RETRY_DELAY: float = 0.5 def __init__(self, queue: Queue, **kwargs: Any): super().__init__(queue, **kwargs) # Remove close callback to survive reconnection self._amqp_queue.close_callbacks.discard(self._set_closed) # But listen to connection close to stop iteration when # connection is intentionally closed channel = self._amqp_queue.channel if hasattr(channel, "_connection"): connection = channel._connection connection.closed().add_done_callback(self._on_connection_closed) def _on_connection_closed(self, _: asyncio.Future) -> None: """Handle connection closed - set _closed to stop iteration.""" if not self._closed.done(): self._closed.set_result(True) self._message_or_closed.set() async def consume(self) -> None: """Consume with retry on channel errors. Waits for channel to be ready before consuming, with backoff delay between retries to prevent CPU spinning during reconnection. """ while True: try: # Wait for channel to be fully ready before consuming channel = self._amqp_queue.channel if hasattr(channel, "ready"): await channel.ready() return await super().consume() except ChannelInvalidStateError: log.debug( "Channel invalid state in %r, waiting for restoration", self, ) # Backoff to prevent CPU spinning during reconnection await asyncio.sleep(self.RETRY_DELAY) async def __anext__(self) -> AbstractIncomingMessage: """Get next message, handling reconnection gracefully. During reconnection, the queue may be empty and _message_or_closed may be set due to channel closure. In robust mode, we wait for reconnection instead of raising StopAsyncIteration. """ while True: # Check if explicitly closed if self._closed.done(): raise StopAsyncIteration if not hasattr(self, "_consumer_tag"): await self.consume() timeout: Optional[float] = self._consume_kwargs.get("timeout") if not self._message_or_closed.is_set(): coroutine: Awaitable[Any] = self._message_or_closed.wait() if timeout is not None and timeout > 0: coroutine = asyncio.wait_for(coroutine, timeout=timeout) try: await coroutine except (asyncio.TimeoutError, asyncio.CancelledError): # Handle timeout same as parent class if timeout is not None: timeout_val = ( timeout if timeout > 0 else self.DEFAULT_CLOSE_TIMEOUT ) log.info( "%r closing with timeout %d seconds", self, timeout_val, ) task = asyncio.create_task(self.close()) close_coro: Awaitable[Any] = task if timeout is not None: close_coro = asyncio.wait_for( asyncio.shield(task), timeout=timeout_val, ) try: await close_coro except asyncio.TimeoutError: self._QueueIterator__closing = task raise # Check queue for messages if not self._queue.empty(): msg = self._queue.get_nowait() if ( self._queue.empty() and not self._amqp_queue.channel.is_closed and not self._closed.done() ): self._message_or_closed.clear() return msg # Queue is empty - check if this is a reconnection scenario channel = self._amqp_queue.channel if hasattr(channel, "ready"): # This is a RobustChannel - check if connection is still alive if hasattr(channel, "_connection"): connection = channel._connection # Only wait for reconnection if connection wasn't # intentionally closed if not connection.is_closed and not connection.close_called: # Connection is alive, channel is being restored log.debug( "%r queue empty during channel restoration, " "waiting for reconnection", self, ) # Clear the event and wait for channel restoration self._message_or_closed.clear() try: # Wait for channel to become ready await asyncio.wait_for( channel.ready(), timeout=60.0, ) # Re-establish consumer if needed if not hasattr(self, "_consumer_tag"): await self.consume() # Continue loop to wait for new messages continue except asyncio.TimeoutError: log.error( "%r timeout waiting for channel reconnection", self, ) raise StopAsyncIteration # Truly empty and not reconnecting - stop iteration raise StopAsyncIteration __all__ = ("RobustQueue",)