Source code for aio_pika.tools

from __future__ import annotations

import asyncio
import inspect
import warnings
from functools import wraps
from itertools import chain
from threading import Lock
from typing import (
    AbstractSet,
    Any,
    Awaitable,
    Callable,
    Coroutine,
    Generator,
    Generic,
    Iterator,
    List,
    MutableSet,
    Optional,
    ParamSpec,
    Protocol,
    TypeVar,
    Union,
)
from weakref import ReferenceType, WeakSet, ref

from aio_pika.log import get_logger


log = get_logger(__name__)
T = TypeVar("T")


def iscoroutinepartial(fn: Callable[..., Any]) -> bool:
    """
    Use Python 3.8's inspect.iscoroutinefunction() instead
    """
    warnings.warn(
        "Use inspect.iscoroutinefunction() instead.", DeprecationWarning
    )
    return asyncio.iscoroutinefunction(fn)


def _task_done(future: asyncio.Future) -> None:
    try:
        exc = future.exception()
        if exc is not None:
            raise exc
    except asyncio.CancelledError:
        pass


def create_task(
    func: Callable[..., Union[Coroutine[Any, Any, T], Awaitable[T]]],
    *args: Any,
    loop: Optional[asyncio.AbstractEventLoop] = None,
    **kwargs: Any,
) -> Awaitable[T]:
    loop = loop or asyncio.get_event_loop()

    if inspect.iscoroutinefunction(func):
        task = loop.create_task(func(*args, **kwargs))
        task.add_done_callback(_task_done)
        return task

    def run(future: asyncio.Future) -> Optional[asyncio.Future]:
        if future.done():
            return None

        try:
            future.set_result(func(*args, **kwargs))
        except Exception as e:
            future.set_exception(e)

        return future

    future = loop.create_future()
    future.add_done_callback(_task_done)
    loop.call_soon(run, future)
    return future


_Sender = TypeVar("_Sender", contravariant=True)
_Params = ParamSpec("_Params")
_Return = TypeVar("_Return", covariant=True)


[docs] class CallbackType(Protocol[_Sender, _Params, _Return]): def __call__( self, __sender: Optional[_Sender], /, *args: _Params.args, **kwargs: _Params.kwargs, ) -> Union[_Return, Awaitable[_Return]]: ...
class StubAwaitable: __slots__ = () def __await__(self) -> Generator[Any, Any, None]: yield STUB_AWAITABLE = StubAwaitable() class CallbackCollection( MutableSet[ Union[ CallbackType[_Sender, _Params, Any], "CallbackCollection[Any, _Params]", ], ], Generic[_Sender, _Params], ): __slots__ = ( "__weakref__", "__sender", "__callbacks", "__weak_callbacks", "__lock", ) def __init__(self, sender: Union[_Sender, ReferenceType[_Sender]]): self.__sender: ReferenceType if isinstance(sender, ReferenceType): self.__sender = sender else: self.__sender = ref(sender) self.__callbacks: CallbackSetType = set() self.__weak_callbacks: MutableSet[ Union[ CallbackType[_Sender, _Params, Any], CallbackCollection[Any, _Params], ], ] = WeakSet() self.__lock: Lock = Lock() def add( self, callback: Union[ CallbackType[_Sender, _Params, Any], CallbackCollection[Any, _Params], ], weak: bool = False, ) -> None: if self.is_frozen: raise RuntimeError("Collection frozen") if not callable(callback): raise ValueError("Callback is not callable") with self.__lock: if weak or isinstance(callback, CallbackCollection): self.__weak_callbacks.add(callback) else: self.__callbacks.add(callback) # type: ignore def remove( self, callback: Union[ CallbackType[_Sender, _Params, Any], CallbackCollection[Any, _Params], ], ) -> None: if self.is_frozen: raise RuntimeError("Collection frozen") with self.__lock: try: self.__callbacks.remove(callback) # type: ignore except KeyError: self.__weak_callbacks.remove(callback) def discard( self, callback: Union[ CallbackType[_Sender, _Params, Any], CallbackCollection[Any, _Params], ], ) -> None: if self.is_frozen: raise RuntimeError("Collection frozen") with self.__lock: if callback in self.__callbacks: self.__callbacks.remove(callback) # type: ignore elif callback in self.__weak_callbacks: self.__weak_callbacks.remove(callback) def clear(self) -> None: if self.is_frozen: raise RuntimeError("Collection frozen") with self.__lock: self.__callbacks.clear() # type: ignore self.__weak_callbacks.clear() @property def is_frozen(self) -> bool: return isinstance(self.__callbacks, frozenset) def freeze(self) -> None: if self.is_frozen: raise RuntimeError("Collection already frozen") with self.__lock: self.__callbacks = frozenset(self.__callbacks) self.__weak_callbacks = WeakSet(self.__weak_callbacks) def unfreeze(self) -> None: if not self.is_frozen: raise RuntimeError("Collection is not frozen") with self.__lock: self.__callbacks = set(self.__callbacks) self.__weak_callbacks = WeakSet(self.__weak_callbacks) def __contains__(self, x: object) -> bool: return x in self.__callbacks or x in self.__weak_callbacks def __len__(self) -> int: return len(self.__callbacks) + len(self.__weak_callbacks) def __iter__( self, ) -> Iterator[ Union[ CallbackType[_Sender, _Params, Any], CallbackCollection[_Sender, _Params], ], ]: return iter(chain(self.__callbacks, self.__weak_callbacks)) def __bool__(self) -> bool: return bool(self.__callbacks) or bool(self.__weak_callbacks) def __copy__(self) -> CallbackCollection[_Sender, _Params]: instance = self.__class__(self.__sender) with self.__lock: for cb in self.__callbacks: instance.add(cb, weak=False) for cb in self.__weak_callbacks: instance.add(cb, weak=True) if self.is_frozen: instance.freeze() return instance def __call__( self, *args: _Params.args, **kwargs: _Params.kwargs, ) -> Awaitable[Any]: futures: List[asyncio.Future] = [] with self.__lock: sender = self.__sender() for cb in self: try: if isinstance(cb, CallbackCollection): result = cb(*args, **kwargs) else: result = cb(sender, *args, **kwargs) if inspect.isawaitable(result): futures.append(asyncio.ensure_future(result)) except Exception as exc: log.error( "Callback %r error: %s: %s", cb, type(exc).__name__, exc, ) log.debug( "Full traceback for callback %r error", cb, exc_info=True, ) if not futures: return STUB_AWAITABLE return asyncio.gather(*futures, return_exceptions=True) def __hash__(self) -> int: return id(self) class OneShotCallback: __slots__ = ("loop", "finished", "__lock", "callback", "__task") def __init__(self, callback: Callable[..., Awaitable[T]]): self.callback: Callable[..., Awaitable[T]] = callback self.loop = asyncio.get_event_loop() self.finished: asyncio.Event = asyncio.Event() self.__lock: asyncio.Lock = asyncio.Lock() self.__task: Optional[asyncio.Future] = None def __repr__(self) -> str: return f"<{self.__class__.__name__}: cb={self.callback!r}>" def wait(self) -> Awaitable[Any]: try: return self.finished.wait() except asyncio.CancelledError: if self.__task is not None: self.__task.cancel() raise async def __task_inner(self, *args: Any, **kwargs: Any) -> None: async with self.__lock: if self.finished.is_set(): return try: await self.callback(*args, **kwargs) except Exception as exc: log.error( "Callback %r error: %s: %s", self, type(exc).__name__, exc, ) log.debug( "Full traceback for callback %r error", self, exc_info=True, ) finally: self.loop.call_soon(self.finished.set) del self.callback def __call__(self, *args: Any, **kwargs: Any) -> Awaitable[Any]: if self.finished.is_set() or self.__task is not None: return STUB_AWAITABLE self.__task = self.loop.create_task( self.__task_inner(*args, **kwargs), ) return self.__task def ensure_awaitable( func: Callable[_Params, Union[T, Awaitable[T]]], ) -> Callable[_Params, Awaitable[T]]: if inspect.iscoroutinefunction(func): return func if inspect.isfunction(func): warnings.warn( f"You probably registering the non-coroutine function {func!r}. " "This is deprecated and will be removed in future releases. " "Moreover, it can block the event loop", DeprecationWarning, ) @wraps(func) async def wrapper(*args: _Params.args, **kwargs: _Params.kwargs) -> T: nonlocal func # noqa result = func(*args, **kwargs) if not hasattr(result, "__await__"): warnings.warn( f"Function {func!r} returned a non awaitable result." "This may be bad for performance or may blocks the " "event loop, you should pay attention to this. This " "warning is here in an attempt to maintain backwards " "compatibility and will simply be removed in " "future releases.", DeprecationWarning, ) return result return await result return wrapper CallbackSetType = AbstractSet[ Union[ CallbackType[_Sender, _Params, None], CallbackCollection[_Sender, _Params], ], ] __all__ = ( "CallbackCollection", "CallbackSetType", "CallbackType", "OneShotCallback", "create_task", "ensure_awaitable", "iscoroutinepartial", )