Source code for aio_pika.patterns.master

import asyncio
import gzip
import json
import logging
from functools import partial
from types import MappingProxyType
from typing import Any, Awaitable, Mapping, Optional

import aiormq

from aio_pika.abc import (
    AbstractChannel, AbstractExchange, AbstractIncomingMessage, AbstractQueue,
    ConsumerTag, DeliveryMode,
)
from aio_pika.message import Message

from ..tools import create_task, ensure_awaitable
from .base import Base, CallbackType, Proxy, T


log = logging.getLogger(__name__)


class MessageProcessingError(Exception):
    pass


class NackMessage(MessageProcessingError):
    def __init__(self, requeue: bool = False):
        self.requeue = requeue


class RejectMessage(MessageProcessingError):
    def __init__(self, requeue: bool = False):
        self.requeue = requeue


[docs] class Worker: __slots__ = ( "queue", "consumer_tag", "loop", ) def __init__( self, queue: AbstractQueue, consumer_tag: ConsumerTag, loop: asyncio.AbstractEventLoop, ): self.queue = queue self.consumer_tag = consumer_tag self.loop = loop
[docs] def close(self) -> Awaitable[None]: """ Cancel subscription to the channel :return: :class:`asyncio.Task` """ async def closer() -> None: await self.queue.cancel(self.consumer_tag) return create_task(closer)
[docs] class Master(Base): __slots__ = ( "channel", "loop", "proxy", "_requeue", "_reject_on_redelivered", ) DELIVERY_MODE = DeliveryMode.PERSISTENT __doc__ = """ Implements Master/Worker pattern. Usage example: `worker.py` :: master = Master(channel) worker = await master.create_worker('test_worker', lambda x: print(x)) `master.py` :: master = Master(channel) await master.proxy.test_worker('foo') """ def __init__( self, channel: AbstractChannel, requeue: bool = True, reject_on_redelivered: bool = False, ): """ Creates a new :class:`Master` instance. :param channel: Initialized instance of :class:`aio_pika.Channel` """ self.channel: AbstractChannel = channel self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() self.proxy = Proxy(self.create_task) self.channel.return_callbacks.add(self.on_message_returned) self._requeue = requeue self._reject_on_redelivered = reject_on_redelivered @property def exchange(self) -> AbstractExchange: return self.channel.default_exchange @staticmethod def on_message_returned( channel: Optional[AbstractChannel], message: AbstractIncomingMessage, ) -> None: log.warning( "Message returned. Probably destination queue does not exists: %r", message, )
[docs] def serialize(self, data: Any) -> bytes: """ Serialize data to the bytes. Uses `pickle` by default. You should overlap this method when you want to change serializer :param data: Data which will be serialized :returns: bytes """ return super().serialize(data)
[docs] def deserialize(self, data: bytes) -> Any: """ Deserialize data from bytes. Uses `pickle` by default. You should overlap this method when you want to change serializer :param data: Data which will be deserialized :returns: :class:`Any` """ return super().deserialize(data)
@classmethod async def execute( cls, func: CallbackType, kwargs: Any, ) -> T: kwargs = kwargs or {} if not isinstance(kwargs, dict): logging.error("Bad kwargs %r received for the %r", kwargs, func) raise RejectMessage(requeue=False) return await func(**kwargs) async def on_message( self, func: CallbackType, message: AbstractIncomingMessage, ) -> None: async with message.process( requeue=self._requeue, reject_on_redelivered=self._reject_on_redelivered, ignore_processed=True, ): try: await self.execute(func, self.deserialize(message.body)) except RejectMessage as e: await message.reject(requeue=e.requeue) except NackMessage as e: await message.nack(requeue=e.requeue) async def create_queue( self, queue_name: str, **kwargs: Any, ) -> AbstractQueue: return await self.channel.declare_queue(queue_name, **kwargs)
[docs] async def create_worker( self, queue_name: str, func: CallbackType, **kwargs: Any, ) -> Worker: """ Creates a new :class:`Worker` instance. """ queue = await self.create_queue(queue_name, **kwargs) consumer_tag = await queue.consume( partial(self.on_message, ensure_awaitable(func)), ) return Worker(queue, consumer_tag, self.loop)
[docs] async def create_task( self, channel_name: str, kwargs: Mapping[str, Any] = MappingProxyType({}), **message_kwargs: Any, ) -> Optional[aiormq.abc.ConfirmationFrameType]: """ Creates a new task for the worker """ message = Message( body=self.serialize(kwargs), content_type=self.CONTENT_TYPE, delivery_mode=self.DELIVERY_MODE, **message_kwargs, ) return await self.exchange.publish( message, channel_name, mandatory=True, )
class JsonMaster(Master): SERIALIZER = json CONTENT_TYPE = "application/json" def serialize(self, data: Any) -> bytes: return self.SERIALIZER.dumps(data, ensure_ascii=False).encode() class CompressedJsonMaster(Master): SERIALIZER = json CONTENT_TYPE = "application/json;compression=gzip" COMPRESS_LEVEL = 6 def serialize(self, data: Any) -> bytes: return gzip.compress( self.SERIALIZER.dumps(data, ensure_ascii=False).encode(), compresslevel=self.COMPRESS_LEVEL, ) def deserialize(self, data: bytes) -> Any: return self.SERIALIZER.loads(gzip.decompress(data))