Source code for aio_pika.pool

import abc
import asyncio
from types import TracebackType
from typing import (
    Any,
    AsyncContextManager,
    Awaitable,
    Callable,
    Generic,
    Optional,
    Set,
    Tuple,
    Type,
    TypeVar,
)

from aio_pika.log import get_logger
from aio_pika.tools import create_task


log = get_logger(__name__)


[docs] class PoolInstance(abc.ABC): @abc.abstractmethod def close(self) -> Awaitable[None]: raise NotImplementedError
T = TypeVar("T") ConstructorType = Callable[ ..., Awaitable[PoolInstance], ]
[docs] class PoolInvalidStateError(RuntimeError): pass
[docs] class Pool(Generic[T]): __slots__ = ( "loop", "__max_size", "__items", "__constructor", "__created", "__lock", "__constructor_args", "__item_set", "__closed", ) def __init__( self, constructor: ConstructorType, *args: Any, max_size: Optional[int] = None, loop: Optional[asyncio.AbstractEventLoop] = None, ): self.loop = loop or asyncio.get_event_loop() self.__closed = False self.__constructor: Callable[..., Awaitable[Any]] = constructor self.__constructor_args: Tuple[Any, ...] = args or () self.__created: int = 0 self.__item_set: Set[PoolInstance] = set() self.__items: asyncio.Queue = asyncio.Queue() self.__lock: asyncio.Lock = asyncio.Lock() self.__max_size: Optional[int] = max_size @property def is_closed(self) -> bool: return self.__closed def acquire(self) -> "PoolItemContextManager[T]": if self.__closed: raise PoolInvalidStateError("acquire operation on closed pool") return PoolItemContextManager[T](self) @property def _has_released(self) -> bool: return self.__items.qsize() > 0 @property def _is_overflow(self) -> bool: if self.__max_size: return self.__created >= self.__max_size or self._has_released return self._has_released async def _create_item(self) -> T: if self.__closed: raise PoolInvalidStateError("create item operation on closed pool") async with self.__lock: if self._is_overflow: return await self.__items.get() log.debug("Creating a new instance of %r", self.__constructor) item = await self.__constructor(*self.__constructor_args) self.__created += 1 self.__item_set.add(item) return item async def _get(self) -> T: if self.__closed: raise PoolInvalidStateError("get operation on closed pool") if self._is_overflow: return await self.__items.get() return await self._create_item() def put(self, item: T) -> None: if self.__closed: raise PoolInvalidStateError("put operation on closed pool") self.__items.put_nowait(item) async def close(self) -> None: async with self.__lock: self.__closed = True tasks = [] for item in self.__item_set: tasks.append(create_task(item.close)) if tasks: await asyncio.gather(*tasks, return_exceptions=True) async def __aenter__(self) -> "Pool": return self async def __aexit__( self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType], ) -> None: if self.__closed: return await asyncio.ensure_future(self.close())
[docs] class PoolItemContextManager(Generic[T], AsyncContextManager): __slots__ = "pool", "item" def __init__(self, pool: Pool): self.pool = pool self.item: T async def __aenter__(self) -> T: # noinspection PyProtectedMember self.item = await self.pool._get() return self.item async def __aexit__( self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType], ) -> None: if self.item is not None: self.pool.put(self.item)