Source code for aio_pika.robust_connection

import asyncio
from ssl import SSLContext
from typing import Any, Optional, Tuple, Type, Union
from weakref import WeakSet

import aiormq.abc
from aiormq.connection import parse_bool, parse_timeout
from pamqp.common import FieldTable
from yarl import URL

from .abc import (
    AbstractRobustChannel, AbstractRobustConnection, ConnectionParameter,
    SSLOptions, TimeoutType,
)
from .connection import Connection, make_url
from .exceptions import CONNECTION_EXCEPTIONS
from .log import get_logger
from .robust_channel import RobustChannel
from .tools import CallbackCollection


log = get_logger(__name__)


[docs] class RobustConnection(Connection, AbstractRobustConnection): """Robust connection""" CHANNEL_REOPEN_PAUSE = 1 CHANNEL_CLASS: Type[RobustChannel] = RobustChannel PARAMETERS: Tuple[ConnectionParameter, ...] = Connection.PARAMETERS + ( ConnectionParameter( name="reconnect_interval", parser=parse_timeout, default="5", ), ConnectionParameter( name="fail_fast", parser=parse_bool, default="1", ), ) def __init__( self, url: URL, loop: Optional[asyncio.AbstractEventLoop] = None, **kwargs: Any, ): super().__init__(url=url, loop=loop, **kwargs) self.reconnect_interval = self.kwargs.pop("reconnect_interval") self.connection_attempt: int = 0 self.__fail_fast_future = self.loop.create_future() self.fail_fast = self.kwargs.pop("fail_fast", True) if not self.fail_fast: self.__fail_fast_future.set_result(None) self.__channels: WeakSet[AbstractRobustChannel] = WeakSet() self.__connection_close_event = asyncio.Event() self.__connect_timeout: Optional[TimeoutType] = None self.__reconnection_task: Optional[asyncio.Task] = None self._reconnect_lock = asyncio.Lock() self.reconnect_callbacks = CallbackCollection(self) self.__connection_close_event.set() @property def reconnecting(self) -> bool: return self._reconnect_lock.locked() def __repr__(self) -> str: return ( f'<{self.__class__.__name__}: "{self}" ' f"{len(self.__channels)} channels>" ) async def _on_connection_close(self, closing: asyncio.Future) -> None: await super()._on_connection_close(closing) if self._close_called or self.is_closed: return log.info( "Connection to %s closed. Reconnecting after %r seconds.", self, self.reconnect_interval, ) self.__connection_close_event.set() async def _on_connected(self) -> None: await super()._on_connected() transport = self.transport if transport is None: raise RuntimeError("No active transport for connection %r", self) try: # Make a copy of the channels to iterate on, to guard from # concurrent updates to the set. for channel in tuple(self.__channels): try: await channel.restore() except Exception: log.exception("Failed to reopen channel") raise except Exception as e: await self.close_callbacks(e) await asyncio.gather( transport.connection.close(e), return_exceptions=True, ) raise if self.connection_attempt: await self.reconnect_callbacks() self.connection_attempt += 1 self.__connection_close_event.clear() async def __connection_factory(self) -> None: log.debug("Starting connection factory for %r", self) while not self.is_closed and not self._close_called: log.debug("Waiting for connection close event for %r", self) await self.__connection_close_event.wait() if self.is_closed or self._close_called: return # noinspection PyBroadException try: self.transport = None self.connected.clear() log.debug("Connection attempt for %r", self) await Connection.connect(self, self.__connect_timeout) if not self.__fail_fast_future.done(): self.__fail_fast_future.set_result(None) log.debug("Connection made on %r", self) except CONNECTION_EXCEPTIONS as e: if not self.__fail_fast_future.done(): self.__fail_fast_future.set_exception(e) return log.warning( 'Connection attempt to "%s" failed: %s. ' "Reconnecting after %r seconds.", self, e, self.reconnect_interval, ) except Exception: log.exception( "Reconnect attempt failed %s. " "Retrying after %r seconds.", self, self.reconnect_interval, ) await asyncio.sleep(self.reconnect_interval)
[docs] async def connect(self, timeout: TimeoutType = None) -> None: self.__connect_timeout = timeout if self.is_closed: raise RuntimeError(f"{self!r} connection closed") if self.reconnecting: raise RuntimeError( ( "Connect method called but connection " f"{self!r} is reconnecting right now." ), self, ) if not self.__reconnection_task: self.__reconnection_task = self.loop.create_task( self.__connection_factory(), ) await self.__fail_fast_future await self.connected.wait()
async def reconnect(self) -> None: if self.transport: await self.transport.connection.close() await self.connect() await self.reconnect_callbacks()
[docs] def channel( self, channel_number: Optional[int] = None, publisher_confirms: bool = True, on_return_raises: bool = False, ) -> AbstractRobustChannel: channel: AbstractRobustChannel = super().channel( channel_number=channel_number, publisher_confirms=publisher_confirms, on_return_raises=on_return_raises, ) # type: ignore self.__channels.add(channel) return channel
async def close( self, exc: Optional[aiormq.abc.ExceptionType] = asyncio.CancelledError, ) -> None: if self.__reconnection_task is not None: self.__reconnection_task.cancel() await asyncio.gather( self.__reconnection_task, return_exceptions=True, ) self.__reconnection_task = None return await super().close(exc)
[docs] async def connect_robust( url: Union[str, URL, None] = None, *, host: str = "localhost", port: int = 5672, login: str = "guest", password: str = "guest", virtualhost: str = "/", ssl: bool = False, loop: Optional[asyncio.AbstractEventLoop] = None, ssl_options: Optional[SSLOptions] = None, ssl_context: Optional[SSLContext] = None, timeout: TimeoutType = None, client_properties: Optional[FieldTable] = None, connection_class: Type[AbstractRobustConnection] = RobustConnection, **kwargs: Any, ) -> AbstractRobustConnection: """Make connection to the broker. Example: .. code-block:: python import aio_pika async def main(): connection = await aio_pika.connect( "amqp://guest:[email protected]/" ) Connect to localhost with default credentials: .. code-block:: python import aio_pika async def main(): connection = await aio_pika.connect() .. note:: The available keys for ssl_options parameter are: * cert_reqs * certfile * keyfile * ssl_version For an information on what the ssl_options can be set to reference the `official Python documentation`_ . Set connection name for RabbitMQ admin panel: .. code-block:: python # As URL parameter method read_connection = await connect( "amqp://guest:guest@localhost/?name=Read%20connection" ) # keyword method write_connection = await connect( client_properties={ 'connection_name': 'Write connection' } ) .. note: ``client_properties`` argument requires ``aiormq>=2.9`` URL string might contain ssl parameters e.g. `amqps://user:pass@host//?ca_certs=ca.pem&certfile=crt.pem&keyfile=key.pem` :param client_properties: add custom client capability. :param url: RFC3986_ formatted broker address. When :class:`None` will be used keyword arguments. :param host: hostname of the broker :param port: broker port 5672 by default :param login: username string. `'guest'` by default. :param password: password string. `'guest'` by default. :param virtualhost: virtualhost parameter. `'/'` by default :param ssl: use SSL for connection. Should be used with addition kwargs. :param ssl_options: A dict of values for the SSL connection. :param timeout: connection timeout in seconds :param loop: Event loop (:func:`asyncio.get_event_loop()` when :class:`None`) :param ssl_context: ssl.SSLContext instance :param connection_class: Factory of a new connection :param kwargs: addition parameters which will be passed to the connection. :return: :class:`aio_pika.connection.Connection` .. _RFC3986: https://goo.gl/MzgYAs .. _official Python documentation: https://goo.gl/pty9xA """ connection: AbstractRobustConnection = connection_class( make_url( url, host=host, port=port, login=login, password=password, virtualhost=virtualhost, ssl=ssl, ssl_options=ssl_options, client_properties=client_properties, **kwargs, ), loop=loop, ssl_context=ssl_context, **kwargs, ) await connection.connect(timeout=timeout) return connection
__all__ = ( "RobustConnection", "connect_robust", )