diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 61b6f6c..a941fb1 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -7,7 +7,7 @@ We're very happy about contributions to aiomqtt! ✨ - Clone the aiomqtt repository - Install the Python version noted in `.python-version` via `pyenv` - Install poetry; Then run `./scripts/setup` to install the dependencies and aiomqtt itself -- Run ruff, and mypy with `./scripts/check` +- Run ruff and mypy with `./scripts/check` - Run the tests with `./scripts/test` During development, it's often useful to have a local MQTT broker running. You can spin up a local mosquitto broker with Docker via `./scripts/develop`. You can connect to this broker with `aiomqtt.Client("localhost", port=1883)`. diff --git a/README.md b/README.md index d97fcaa..7e01446 100644 --- a/README.md +++ b/README.md @@ -28,10 +28,9 @@ async with Client("test.mosquitto.org") as client: ```python async with Client("test.mosquitto.org") as client: - async with client.messages() as messages: - await client.subscribe("humidity/#") - async for message in messages: - print(message.payload) + await client.subscribe("humidity/#") + async for message in client.messages: + print(message.payload) ``` aiomqtt combines the stability of the time-proven [paho-mqtt](https://github.com/eclipse/paho.mqtt.python) library with an idiomatic asyncio interface: diff --git a/aiomqtt/__init__.py b/aiomqtt/__init__.py index bc56067..5f1fa98 100644 --- a/aiomqtt/__init__.py +++ b/aiomqtt/__init__.py @@ -1,17 +1,14 @@ # SPDX-License-Identifier: BSD-3-Clause from .client import ( Client, - Message, ProtocolVersion, ProxySettings, TLSParameters, - Topic, - TopicLike, - Wildcard, - WildcardLike, Will, ) -from .error import MqttCodeError, MqttError +from .exceptions import MqttCodeError, MqttError, MqttReentrantError +from .message import Message +from .topic import Topic, TopicLike, Wildcard, WildcardLike __all__ = [ "__version__", @@ -27,5 +24,6 @@ "WildcardLike", "Will", "MqttCodeError", + "MqttReentrantError", "MqttError", ] diff --git a/aiomqtt/client.py b/aiomqtt/client.py index d355007..35f8536 100644 --- a/aiomqtt/client.py +++ b/aiomqtt/client.py @@ -2,15 +2,15 @@ from __future__ import annotations import asyncio +import contextlib +import dataclasses +import enum import functools import logging import math import socket import ssl import sys -from contextlib import asynccontextmanager, contextmanager -from dataclasses import dataclass -from enum import IntEnum from types import TracebackType from typing import ( Any, @@ -26,33 +26,36 @@ ) import paho.mqtt.client as mqtt -from paho.mqtt.properties import Properties -from .error import MqttCodeError, MqttConnectError, MqttError, MqttReentrantError -from .types import PayloadType, T +from .exceptions import MqttCodeError, MqttConnectError, MqttError, MqttReentrantError +from .message import Message +from .types import ( + P, + PayloadType, + SocketOption, + SubscribeTopic, + T, + WebSocketHeaders, + _PahoSocket, +) if sys.version_info >= (3, 11): - from typing import Concatenate, ParamSpec, Self, TypeAlias + from typing import Concatenate, Self elif sys.version_info >= (3, 10): - from typing import Concatenate, ParamSpec, TypeAlias + from typing import Concatenate from typing_extensions import Self else: - from typing_extensions import Concatenate, ParamSpec, Self, TypeAlias + from typing_extensions import Concatenate, Self -MAX_TOPIC_LENGTH = 65535 MQTT_LOGGER = logging.getLogger("mqtt") MQTT_LOGGER.setLevel(logging.WARNING) -_PahoSocket: TypeAlias = "socket.socket | ssl.SSLSocket | mqtt.WebsocketWrapper | Any" - -WebSocketHeaders: TypeAlias = ( - "dict[str, str] | Callable[[dict[str, str]], dict[str, str]]" -) +ClientT = TypeVar("ClientT", bound="Client") -class ProtocolVersion(IntEnum): +class ProtocolVersion(enum.IntEnum): """Map paho-mqtt protocol versions to an Enum for use in type hints.""" V31 = mqtt.MQTTv31 @@ -60,17 +63,7 @@ class ProtocolVersion(IntEnum): V5 = mqtt.MQTTv5 -@dataclass(frozen=True) -class Will: - topic: str - payload: PayloadType | None = None - qos: int = 0 - retain: bool = False - properties: mqtt.Properties | None = None - - -# TLS set parameter class -@dataclass(frozen=True) +@dataclasses.dataclass(frozen=True) class TLSParameters: ca_certs: str | None = None certfile: str | None = None @@ -81,7 +74,6 @@ class TLSParameters: keyfile_password: str | None = None -# Proxy parameters class class ProxySettings: def __init__( # noqa: PLR0913 self, @@ -101,15 +93,6 @@ def __init__( # noqa: PLR0913 } -# See the overloads of `socket.setsockopt` for details. -SocketOption: TypeAlias = "tuple[int, int, int | bytes] | tuple[int, int, None, int]" - -SubscribeTopic: TypeAlias = "str | tuple[str, mqtt.SubscribeOptions] | list[tuple[str, mqtt.SubscribeOptions]] | list[tuple[str, int]]" - -P = ParamSpec("P") -ClientT = TypeVar("ClientT", bound="Client") - - # TODO(frederik): Simplify the logic that surrounds `self._outgoing_calls_sem` with # `nullcontext` when we support Python 3.10 (`nullcontext` becomes async-aware in # 3.10). See: https://docs.python.org/3/library/contextlib.html#contextlib.nullcontext @@ -127,168 +110,13 @@ async def decorated(self: ClientT, /, *args: P.args, **kwargs: P.kwargs) -> T: return decorated -@dataclass(frozen=True) -class Wildcard: - """MQTT wildcard that can be subscribed to, but not published to. - - A wildcard is similar to a topic, but can optionally contain ``+`` and ``#`` - placeholders. - - Args: - value: The wildcard string. - - Attributes: - value: The wildcard string. - """ - - value: str - - def __str__(self) -> str: - return self.value - - def __post_init__(self) -> None: - """Validate the wildcard.""" - if not isinstance(self.value, str): - msg = "Wildcard must be of type str" - raise TypeError(msg) - if ( - len(self.value) == 0 - or len(self.value) > MAX_TOPIC_LENGTH - or "#/" in self.value - or any( - "+" in level or "#" in level - for level in self.value.split("/") - if len(level) > 1 - ) - ): - msg = f"Invalid wildcard: {self.value}" - raise ValueError(msg) - - -WildcardLike: TypeAlias = "str | Wildcard" - - -@dataclass(frozen=True) -class Topic(Wildcard): - """MQTT topic that can be published and subscribed to. - - Args: - value: The topic string. - - Attributes: - value: The topic string. - """ - - def __post_init__(self) -> None: - """Validate the topic.""" - if not isinstance(self.value, str): - msg = "Topic must be of type str" - raise TypeError(msg) - if ( - len(self.value) == 0 - or len(self.value) > MAX_TOPIC_LENGTH - or "+" in self.value - or "#" in self.value - ): - msg = f"Invalid topic: {self.value}" - raise ValueError(msg) - - def matches(self, wildcard: WildcardLike) -> bool: - """Check if the topic matches a given wildcard. - - Args: - wildcard: The wildcard to match against. - - Returns: - True if the topic matches the wildcard, False otherwise. - """ - if not isinstance(wildcard, Wildcard): - wildcard = Wildcard(wildcard) - # Split topics into levels to compare them one by one - topic_levels = self.value.split("/") - wildcard_levels = str(wildcard).split("/") - if wildcard_levels[0] == "$share": - # Shared subscriptions use the topic structure: $share// - wildcard_levels = wildcard_levels[2:] - - def recurse(tl: list[str], wl: list[str]) -> bool: - """Recursively match topic levels with wildcard levels.""" - if not tl: - if not wl or wl[0] == "#": - return True - return False - if not wl: - return False - if wl[0] == "#": - return True - if tl[0] == wl[0] or wl[0] == "+": - return recurse(tl[1:], wl[1:]) - return False - - return recurse(topic_levels, wildcard_levels) - - -TopicLike: TypeAlias = "str | Topic" - - -class Message: - """Wraps the paho-mqtt message class to allow using our own matching logic. - - This class is not meant to be instantiated by the user. Instead, it is yielded by - the async generator returned from ``Client.messages()``. - - Args: - topic: The topic the message was published to. - payload: The message payload. - qos: The quality of service level of the subscription that matched the message. - retain: Whether the message is a retained message. - mid: The message ID. - properties: (MQTT v5.0 only) The properties associated with the message. - - Attributes: - topic (aiomqtt.client.Topic): - The topic the message was published to. - payload (str | bytes | bytearray | int | float | None): - The message payload. - qos (int): - The quality of service level of the subscription that matched the message. - retain (bool): - Whether the message is a retained message. - mid (int): - The message ID. - properties (paho.mqtt.properties.Properties | None): - (MQTT v5.0 only) The properties associated with the message. - """ - - def __init__( # noqa: PLR0913 - self, - topic: TopicLike, - payload: PayloadType, - qos: int, - retain: bool, - mid: int, - properties: Properties | None, - ) -> None: - self.topic = Topic(topic) if not isinstance(topic, Topic) else topic - self.payload = payload - self.qos = qos - self.retain = retain - self.mid = mid - self.properties = properties - - @classmethod - def _from_paho_message(cls, message: mqtt.MQTTMessage) -> Message: - return cls( - topic=message.topic, - payload=message.payload, - qos=message.qos, - retain=message.retain, - mid=message.mid, - properties=message.properties if hasattr(message, "properties") else None, - ) - - def __lt__(self, other: Message) -> bool: - return self.mid < other.mid +@dataclasses.dataclass(frozen=True) +class Will: + topic: str + payload: PayloadType | None = None + qos: int = 0 + retain: bool = False + properties: mqtt.Properties | None = None class Client: @@ -300,12 +128,11 @@ class Client: username: The username to authenticate with. password: The password to authenticate with. logger: Custom logger instance. - client_id: The client ID to use. If ``None``, one will be generated - automatically. - tls_context: The SSL/TLS context. - tls_params: The SSL/TLS configuration to use. - tls_insecure: Enable/disable server hostname verification when using SSL/TLS. - proxy: Configure a proxy for the connection. + identifier: The client identifier. Generated automatically if ``None``. + queue_type: The class to use for the queue. The default is + ``asyncio.Queue``, which stores messages in FIFO order. For LIFO order, + you can use ``asyncio.LifoQueue``; For priority order you can subclass + ``asyncio.PriorityQueue``. protocol: The version of the MQTT protocol. will: The will message to publish if the client disconnects unexpectedly. clean_session: If ``True``, the broker will remove all information about this @@ -320,16 +147,29 @@ class Client: bind_port: The network port to bind this client to. clean_start: (MQTT v5.0 only) Set the clean start flag always, never, or only on the first successful connection to the broker. + max_queued_incoming_messages: Restricts the incoming message queue size. If the + queue is full, further incoming messages are discarded. ``0`` or less means + unlimited (the default). + max_queued_outgoing_messages: Resticts the outgoing message queue size. If the + queue is full, further outgoing messages are discarded. ``0`` means + unlimited (the default). + max_inflight_messages: The maximum number of messages with QoS > ``0`` that can + be part way through their network flow at once. + max_concurrent_outgoing_calls: The maximum number of concurrent outgoing calls. properties: (MQTT v5.0 only) The properties associated with the client. - message_retry_set: Deprecated. + tls_context: The SSL/TLS context. + tls_params: The SSL/TLS configuration to use. + tls_insecure: Enable/disable server hostname verification when using SSL/TLS. + proxy: Configure a proxy for the connection. socket_options: Options to pass to the underlying socket. - max_concurrent_outgoing_calls: The maximum number of concurrent outgoing calls. websocket_path: The path to use for websockets. websocket_headers: The headers to use for websockets. - max_inflight_messages: The maximum number of messages with QoS > ``0`` that can - be part way through their network flow at once. - max_queued_messages: The maximum number of messages in the outgoing message - queue. ``0`` means unlimited. + + Attributes: + messages (typing.AsyncGenerator[aiomqtt.client.Message, None]): + Async generator that yields messages from the underlying message queue. + identifier (str): + The client identifier. """ def __init__( # noqa: C901, PLR0912, PLR0913, PLR0915 @@ -340,11 +180,8 @@ def __init__( # noqa: C901, PLR0912, PLR0913, PLR0915 username: str | None = None, password: str | None = None, logger: logging.Logger | None = None, - client_id: str | None = None, - tls_context: ssl.SSLContext | None = None, - tls_params: TLSParameters | None = None, - tls_insecure: bool | None = None, - proxy: ProxySettings | None = None, + identifier: str | None = None, + queue_type: type[asyncio.Queue[Message]] | None = None, protocol: ProtocolVersion | None = None, will: Will | None = None, clean_session: bool | None = None, @@ -354,14 +191,18 @@ def __init__( # noqa: C901, PLR0912, PLR0913, PLR0915 bind_address: str = "", bind_port: int = 0, clean_start: int = mqtt.MQTT_CLEAN_START_FIRST_ONLY, - properties: Properties | None = None, - message_retry_set: int = 20, - socket_options: Iterable[SocketOption] | None = None, + max_queued_incoming_messages: int | None = None, + max_queued_outgoing_messages: int | None = None, + max_inflight_messages: int | None = None, max_concurrent_outgoing_calls: int | None = None, + properties: mqtt.Properties | None = None, + tls_context: ssl.SSLContext | None = None, + tls_params: TLSParameters | None = None, + tls_insecure: bool | None = None, + proxy: ProxySettings | None = None, + socket_options: Iterable[SocketOption] | None = None, websocket_path: str | None = None, websocket_headers: WebSocketHeaders | None = None, - max_inflight_messages: int | None = None, - max_queued_messages: int | None = None, ) -> None: self._hostname = hostname self._port = port @@ -386,12 +227,15 @@ def __init__( # noqa: C901, PLR0912, PLR0913, PLR0915 self.pending_calls_threshold: int = 10 self._misc_task: asyncio.Task[None] | None = None - # List of all callbacks to call when a message is received - self._on_message_callbacks: list[Callable[[Message], None]] = [] - self._unfiltered_messages_callback: Callable[ - [mqtt.Client, Any, mqtt.MQTTMessage], None - ] | None = None + # Queue that holds incoming messages + if queue_type is None: + queue_type = cast("type[asyncio.Queue[Message]]", asyncio.Queue) + if max_queued_incoming_messages is None: + max_queued_incoming_messages = 0 + self._queue = queue_type(maxsize=max_queued_incoming_messages) + self.messages = self._messages() + # Semaphore to limit the number of concurrent outgoing calls self._outgoing_calls_sem: asyncio.Semaphore | None if max_concurrent_outgoing_calls is not None: self._outgoing_calls_sem = asyncio.Semaphore(max_concurrent_outgoing_calls) @@ -401,8 +245,9 @@ def __init__( # noqa: C901, PLR0912, PLR0913, PLR0915 if protocol is None: protocol = ProtocolVersion.V311 + # Create the underlying paho-mqtt client instance self._client: mqtt.Client = mqtt.Client( - client_id=client_id, + client_id=identifier, protocol=protocol, clean_session=clean_session, transport=transport, @@ -414,6 +259,7 @@ def __init__( # noqa: C901, PLR0912, PLR0913, PLR0915 self._client.on_unsubscribe = self._on_unsubscribe self._client.on_message = self._on_message self._client.on_publish = self._on_publish + # Callbacks for custom event loop self._client.on_socket_open = self._on_socket_open self._client.on_socket_close = self._on_socket_close @@ -422,8 +268,8 @@ def __init__( # noqa: C901, PLR0912, PLR0913, PLR0915 if max_inflight_messages is not None: self._client.max_inflight_messages_set(max_inflight_messages) - if max_queued_messages is not None: - self._client.max_queued_messages_set(max_queued_messages) + if max_queued_outgoing_messages is not None: + self._client.max_queued_messages_set(max_queued_outgoing_messages) if logger is None: logger = MQTT_LOGGER @@ -461,7 +307,6 @@ def __init__( # noqa: C901, PLR0912, PLR0913, PLR0915 will.topic, will.payload, will.qos, will.retain, will.properties ) - self._client.message_retry_set(message_retry_set) if socket_options is None: socket_options = () self._socket_options = tuple(socket_options) @@ -471,14 +316,11 @@ def __init__( # noqa: C901, PLR0912, PLR0913, PLR0915 self.timeout = timeout @property - def id( # noqa: A003 # TODO(jonathan): When doing BREAKING CHANGES rename to avoid shadowing builtin id - self, - ) -> str: - """Return the client ID. + def identifier(self) -> str: + """Return the client identifier. - Note that paho-mqtt stores the client ID as `bytes` internally. - We assume that the client ID is a UTF8-encoded string and decode - it first. + Note that paho-mqtt stores the client ID as `bytes` internally. We assume that + the client ID is a UTF8-encoded string and decode it first. """ return cast(bytes, self._client._client_id).decode() # type: ignore[attr-defined] # noqa: SLF001 @@ -489,89 +331,6 @@ def _pending_calls(self) -> Generator[int, None, None]: yield from self._pending_unsubscribes.keys() yield from self._pending_publishes.keys() - async def connect(self, *, timeout: float | None = None) -> None: - self._logger.warning( - "The manual `connect` and `disconnect` methods are deprecated and will be" - " removed in a future version. The preferred way to connect and disconnect" - " the client is to use the context manager interface via `async with`. In" - " case your use case needs to connect and disconnect manually, you can call" - " the context manager's `__aenter__` and `__aexit__` methods as an escape" - " hatch instead. `__aenter__` is equivalent to `connect`. `__aexit__` is" - " equivalent to `disconnect` except that it forces disconnection instead" - " of throwing an exception in case the client cannot disconnect cleanly." - " `__aexit__` expects three arguments: `exc_type`, `exc`, and `tb`. These" - " arguments describe the exception that caused the context manager to exit," - " if any. You can pass `None` to all of these arguments in a manual call to" - " `__aexit__`." - ) - try: - loop = asyncio.get_running_loop() - - # [3] Run connect() within an executor thread, since it blocks on socket - # connection for up to `keepalive` seconds: https://git.io/Jt5Yc - await loop.run_in_executor( - None, - self._client.connect, - self._hostname, - self._port, - self._keepalive, - self._bind_address, - self._bind_port, - self._clean_start, - self._properties, - ) - client_socket = self._client.socket() - _set_client_socket_defaults(client_socket, self._socket_options) - # paho.mqtt.Client.connect may raise one of several exceptions. - # We convert all of them to the common MqttError for user convenience. - # See: https://github.com/eclipse/paho.mqtt.python/blob/v1.5.0/src/paho/mqtt/client.py#L1770 - except (OSError, mqtt.WebsocketConnectionError) as error: - raise MqttError(str(error)) from None - await self._wait_for(self._connected, timeout=timeout) - # If _disconnected is already completed after connecting, reset it. - if self._disconnected.done(): - self._disconnected = asyncio.Future() - - def _early_out_on_disconnected(self) -> bool: - # Early out if already disconnected... - if self._disconnected.done(): - disc_exc = self._disconnected.exception() - if disc_exc is not None: - # ...by raising the error that caused the disconnect - raise disc_exc - # ...by returning since the disconnect was intentional - return True - return False - - async def disconnect(self, *, timeout: float | None = None) -> None: - """Disconnect from the broker.""" - self._logger.warning( - "The manual `connect` and `disconnect` methods are deprecated and will be" - " removed in a future version. The preferred way to connect and disconnect" - " the client is to use the context manager interface via `async with`. In" - " case your use case needs to connect and disconnect manually, you can call" - " the context manager's `__aenter__` and `__aexit__` methods as an escape" - " hatch instead. `__aenter__` is equivalent to `connect`. `__aexit__` is" - " equivalent to `disconnect` except that it forces disconnection instead" - " of throwing an exception in case the client cannot disconnect cleanly." - " `__aexit__` expects three arguments: `exc_type`, `exc`, and `tb`. These" - " arguments describe the exception that caused the context manager to exit," - " if any. You can pass `None` to all of these arguments in a manual call to" - " `__aexit__`." - ) - if self._early_out_on_disconnected(): - return - # Try to gracefully disconnect from the broker - rc = self._client.disconnect() - # Early out on error - if rc != mqtt.MQTT_ERR_SUCCESS: - raise MqttCodeError(rc, "Could not disconnect") - # Wait for acknowledgement - await self._wait_for(self._disconnected, timeout=timeout) - # If _connected is still in the completed state after disconnection, reset it - if self._connected.done(): - self._connected = asyncio.Future() - @_outgoing_call async def subscribe( # noqa: PLR0913 self, @@ -579,7 +338,7 @@ async def subscribe( # noqa: PLR0913 topic: SubscribeTopic, qos: int = 0, options: mqtt.SubscribeOptions | None = None, - properties: Properties | None = None, + properties: mqtt.Properties | None = None, *args: Any, timeout: float | None = None, **kwargs: Any, @@ -618,7 +377,7 @@ async def unsubscribe( self, /, topic: str | list[str], - properties: Properties | None = None, + properties: mqtt.Properties | None = None, *args: Any, timeout: float | None = None, **kwargs: Any, @@ -653,7 +412,7 @@ async def publish( # noqa: PLR0913 payload: PayloadType = None, qos: int = 0, retain: bool = False, - properties: Properties | None = None, + properties: mqtt.Properties | None = None, *args: Any, timeout: float | None = None, **kwargs: Any, @@ -688,177 +447,31 @@ async def publish( # noqa: PLR0913 # Wait for confirmation await self._wait_for(confirmation.wait(), timeout=timeout) - @asynccontextmanager - async def filtered_messages( - self, topic_filter: str, *, queue_maxsize: int = 0 - ) -> AsyncGenerator[AsyncGenerator[mqtt.MQTTMessage, None], None]: - """Return async generator of messages that match the given filter.""" - self._logger.warning( - "filtered_messages() is deprecated and will be removed in a future version." - " Use messages() together with Topic.matches() instead." - ) - callback, generator = self._deprecated_callback_and_generator( - log_context=f'topic_filter="{topic_filter}"', queue_maxsize=queue_maxsize - ) - try: - self._client.message_callback_add(topic_filter, callback) - # Back to the caller (run whatever is inside the with statement) - yield generator - finally: - # We are exiting the with statement. Remove the topic filter. - self._client.message_callback_remove(topic_filter) - - @asynccontextmanager - async def unfiltered_messages( - self, *, queue_maxsize: int = 0 - ) -> AsyncGenerator[AsyncGenerator[mqtt.MQTTMessage, None], None]: - """Return async generator of all messages that are not caught in filters.""" - self._logger.warning( - "unfiltered_messages() is deprecated and will be removed in a future" - " version. Use messages() instead." - ) - # Early out - if self._unfiltered_messages_callback is not None: - msg = "Only a single unfiltered_messages generator can be used at a time" - raise RuntimeError(msg) - callback, generator = self._deprecated_callback_and_generator( - log_context="unfiltered", queue_maxsize=queue_maxsize - ) - try: - self._unfiltered_messages_callback = callback - # Back to the caller (run whatever is inside the with statement) - yield generator - finally: - # We are exiting the with statement. Unset the callback. - self._unfiltered_messages_callback = None - - @asynccontextmanager - async def messages( - self, - *, - queue_class: type[asyncio.Queue[Message]] = asyncio.Queue, - queue_maxsize: int = 0, - ) -> AsyncGenerator[AsyncGenerator[Message, None], None]: - """Async context manager that creates a queue for incoming messages. - - Args: - queue_class: The class to use for the queue. The default is - ``asyncio.Queue``, which returns messages in FIFO order. For LIFO order, - you can use ``asyncio.LifoQueue``; For priority order you can subclass - ``asyncio.PriorityQueue``. - queue_maxsize: Restricts the queue size. If the queue is full, incoming - messages will be discarded and a warning logged. If set to ``0`` or - less, the queue size is infinite. - - Returns: - An async generator that yields messages from the underlying queue. - """ - callback, generator = self._callback_and_generator( - queue_class=queue_class, queue_maxsize=queue_maxsize - ) - try: - # Add to the list of callbacks to call when a message is received - self._on_message_callbacks.append(callback) - # Back to the caller (run whatever is inside the with statement) - yield generator - finally: - # We are exiting the with statement. Remove the callback from the list. - self._on_message_callbacks.remove(callback) - - def _deprecated_callback_and_generator( - self, *, log_context: str, queue_maxsize: int = 0 - ) -> tuple[ - Callable[[mqtt.Client, Any, mqtt.MQTTMessage], None], - AsyncGenerator[mqtt.MQTTMessage, None], - ]: - # Queue to hold the incoming messages - messages: asyncio.Queue[mqtt.MQTTMessage] = asyncio.Queue(maxsize=queue_maxsize) - - # Callback for the underlying API - def _put_in_queue( - client: mqtt.Client, userdata: Any, message: mqtt.MQTTMessage - ) -> None: + async def _messages(self) -> AsyncGenerator[Message, None]: + """Async generator that yields messages from the underlying message queue.""" + while True: + # Wait until we either: + # 1. Receive a message + # 2. Disconnect from the broker + task = self._loop.create_task(self._queue.get()) try: - messages.put_nowait(message) - except asyncio.QueueFull: - self._logger.warning( - "[%s] Message queue is full. Discarding message.", log_context + done, _ = await asyncio.wait( + (task, self._disconnected), return_when=asyncio.FIRST_COMPLETED ) - - # The generator that we give to the caller - async def _message_generator() -> AsyncGenerator[mqtt.MQTTMessage, None]: - # Forward all messages from the queue - while True: - # Wait until we either: - # 1. Receive a message - # 2. Disconnect from the broker - get: asyncio.Task[mqtt.MQTTMessage] = self._loop.create_task( - messages.get() - ) - try: - done, _ = await asyncio.wait( - (get, self._disconnected), return_when=asyncio.FIRST_COMPLETED - ) - except asyncio.CancelledError: - # If the asyncio.wait is cancelled, we must make sure - # to also cancel the underlying tasks. - get.cancel() - raise - if get in done: - # We received a message. Return the result. - yield get.result() - else: - # We got disconnected from the broker. Cancel the "get" task. - get.cancel() - # Stop the generator with the following exception - msg = "Disconnected during message iteration" - raise MqttError(msg) - - return _put_in_queue, _message_generator() - - def _callback_and_generator( - self, - *, - queue_class: type[asyncio.Queue[Message]] = asyncio.Queue, - queue_maxsize: int = 0, - ) -> tuple[Callable[[Message], None], AsyncGenerator[Message, None]]: - # Queue to hold the incoming messages - messages: asyncio.Queue[Message] = queue_class(maxsize=queue_maxsize) - - def _callback(message: Message) -> None: - """Put the new message in the queue.""" - try: - messages.put_nowait(message) - except asyncio.QueueFull: - self._logger.warning("Message queue is full. Discarding message.") - - async def _generator() -> AsyncGenerator[Message, None]: - """Forward all messages from the message queue.""" - while True: - # Wait until we either: - # 1. Receive a message - # 2. Disconnect from the broker - get: asyncio.Task[Message] = self._loop.create_task(messages.get()) - try: - done, _ = await asyncio.wait( - (get, self._disconnected), return_when=asyncio.FIRST_COMPLETED - ) - except asyncio.CancelledError: - # If the asyncio.wait is cancelled, we must make sure - # to also cancel the underlying tasks. - get.cancel() - raise - if get in done: - # We received a message. Return the result. - yield get.result() - else: - # We got disconnected from the broker. Cancel the "get" task. - get.cancel() - # Stop the generator with the following exception - msg = "Disconnected during message iteration" - raise MqttError(msg) - - return _callback, _generator() + except asyncio.CancelledError: + # If the asyncio.wait is cancelled, we must make sure + # to also cancel the underlying tasks. + task.cancel() + raise + if task in done: + # We received a message. Return the result. + yield task.result() + else: + # We were disconnected from the broker + task.cancel() + # Stop the generator with an exception + msg = "Disconnected during message iteration" + raise MqttError(msg) async def _wait_for( self, fut: Awaitable[T], timeout: float | None, **kwargs: Any @@ -873,7 +486,7 @@ async def _wait_for( msg = "Operation timed out" raise MqttError(msg) from None - @contextmanager + @contextlib.contextmanager def _pending_call( self, mid: int, value: T, pending_dict: dict[int, T] ) -> Iterator[None]: @@ -986,13 +599,13 @@ def _on_unsubscribe( # noqa: PLR0913 def _on_message( self, client: mqtt.Client, userdata: Any, message: mqtt.MQTTMessage ) -> None: - # Call the deprecated unfiltered_messages callback - if self._unfiltered_messages_callback is not None: - self._unfiltered_messages_callback(client, userdata, message) # Convert the paho.mqtt message into our own Message type m = Message._from_paho_message(message) # noqa: SLF001 - for callback in self._on_message_callbacks: - callback(m) + # Put the message in the message queue + try: + self._queue.put_nowait(m) + except asyncio.QueueFull: + self._logger.warning("Message queue is full. Discarding message.") def _on_publish(self, client: mqtt.Client, userdata: Any, mid: int) -> None: try: diff --git a/aiomqtt/error.py b/aiomqtt/exceptions.py similarity index 100% rename from aiomqtt/error.py rename to aiomqtt/exceptions.py diff --git a/aiomqtt/message.py b/aiomqtt/message.py new file mode 100644 index 0000000..133b67e --- /dev/null +++ b/aiomqtt/message.py @@ -0,0 +1,74 @@ +# SPDX-License-Identifier: BSD-3-Clause +from __future__ import annotations + +import sys + +import paho.mqtt.client as mqtt + +if sys.version_info >= (3, 11): + from typing import Self +else: + from typing_extensions import Self + +from .topic import Topic, TopicLike +from .types import PayloadType + + +class Message: + """Wraps the paho-mqtt message class to allow using our own matching logic. + + This class is not meant to be instantiated by the user. Instead, it is yielded by + the async generator ``Client.messages``. + + Args: + topic: The topic the message was published to. + payload: The message payload. + qos: The quality of service level of the subscription that matched the message. + retain: Whether the message is a retained message. + mid: The message ID. + properties: (MQTT v5.0 only) The properties associated with the message. + + Attributes: + topic (aiomqtt.client.Topic): + The topic the message was published to. + payload (str | bytes | bytearray | int | float | None): + The message payload. + qos (int): + The quality of service level of the subscription that matched the message. + retain (bool): + Whether the message is a retained message. + mid (int): + The message ID. + properties (paho.mqtt.properties.Properties | None): + (MQTT v5.0 only) The properties associated with the message. + """ + + def __init__( # noqa: PLR0913 + self, + topic: TopicLike, + payload: PayloadType, + qos: int, + retain: bool, + mid: int, + properties: mqtt.Properties | None, + ) -> None: + self.topic = Topic(topic) if not isinstance(topic, Topic) else topic + self.payload = payload + self.qos = qos + self.retain = retain + self.mid = mid + self.properties = properties + + @classmethod + def _from_paho_message(cls, message: mqtt.MQTTMessage) -> Self: + return cls( + topic=message.topic, + payload=message.payload, + qos=message.qos, + retain=message.retain, + mid=message.mid, + properties=message.properties if hasattr(message, "properties") else None, + ) + + def __lt__(self, other: Self) -> bool: + return self.mid < other.mid diff --git a/aiomqtt/topic.py b/aiomqtt/topic.py new file mode 100644 index 0000000..b111bd3 --- /dev/null +++ b/aiomqtt/topic.py @@ -0,0 +1,118 @@ +# SPDX-License-Identifier: BSD-3-Clause +from __future__ import annotations + +import dataclasses +import sys + +if sys.version_info >= (3, 10): + from typing import TypeAlias +else: + from typing_extensions import TypeAlias + + +MAX_TOPIC_LENGTH = 65535 + + +@dataclasses.dataclass(frozen=True) +class Wildcard: + """MQTT wildcard that can be subscribed to, but not published to. + + A wildcard is similar to a topic, but can optionally contain ``+`` and ``#`` + placeholders. You can access the ``value`` attribute directly to perform ``str`` + operations on a wildcard. + + Args: + value: The wildcard string. + + Attributes: + value: The wildcard string. + """ + + value: str + + def __str__(self) -> str: + return self.value + + def __post_init__(self) -> None: + """Validate the wildcard.""" + if not isinstance(self.value, str): + msg = "Wildcard must be of type str" + raise TypeError(msg) + if ( + len(self.value) == 0 + or len(self.value) > MAX_TOPIC_LENGTH + or "#/" in self.value + or any( + "+" in level or "#" in level + for level in self.value.split("/") + if len(level) > 1 + ) + ): + msg = f"Invalid wildcard: {self.value}" + raise ValueError(msg) + + +WildcardLike: TypeAlias = "str | Wildcard" + + +@dataclasses.dataclass(frozen=True) +class Topic(Wildcard): + """MQTT topic that can be published and subscribed to. + + Args: + value: The topic string. + + Attributes: + value: The topic string. + """ + + def __post_init__(self) -> None: + """Validate the topic.""" + if not isinstance(self.value, str): + msg = "Topic must be of type str" + raise TypeError(msg) + if ( + len(self.value) == 0 + or len(self.value) > MAX_TOPIC_LENGTH + or "+" in self.value + or "#" in self.value + ): + msg = f"Invalid topic: {self.value}" + raise ValueError(msg) + + def matches(self, wildcard: WildcardLike) -> bool: + """Check if the topic matches a given wildcard. + + Args: + wildcard: The wildcard to match against. + + Returns: + True if the topic matches the wildcard, False otherwise. + """ + if not isinstance(wildcard, Wildcard): + wildcard = Wildcard(wildcard) + # Split topics into levels to compare them one by one + topic_levels = self.value.split("/") + wildcard_levels = str(wildcard).split("/") + if wildcard_levels[0] == "$share": + # Shared subscriptions use the topic structure: $share// + wildcard_levels = wildcard_levels[2:] + + def recurse(tl: list[str], wl: list[str]) -> bool: + """Recursively match topic levels with wildcard levels.""" + if not tl: + if not wl or wl[0] == "#": + return True + return False + if not wl: + return False + if wl[0] == "#": + return True + if tl[0] == wl[0] or wl[0] == "+": + return recurse(tl[1:], wl[1:]) + return False + + return recurse(topic_levels, wildcard_levels) + + +TopicLike: TypeAlias = "str | Topic" diff --git a/aiomqtt/types.py b/aiomqtt/types.py index 96a1de6..039a76b 100644 --- a/aiomqtt/types.py +++ b/aiomqtt/types.py @@ -1,11 +1,27 @@ +# SPDX-License-Identifier: BSD-3-Clause +from __future__ import annotations + +import socket +import ssl import sys -from typing import TypeVar +from typing import Any, Callable, TypeVar + +import paho.mqtt.client as mqtt if sys.version_info >= (3, 10): - from typing import TypeAlias + from typing import ParamSpec, TypeAlias else: - from typing_extensions import TypeAlias + from typing_extensions import ParamSpec, TypeAlias + T = TypeVar("T") +P = ParamSpec("P") PayloadType: TypeAlias = "str | bytes | bytearray | int | float | None" +SubscribeTopic: TypeAlias = "str | tuple[str, mqtt.SubscribeOptions] | list[tuple[str, mqtt.SubscribeOptions]] | list[tuple[str, int]]" +WebSocketHeaders: TypeAlias = ( + "dict[str, str] | Callable[[dict[str, str]], dict[str, str]]" +) +_PahoSocket: TypeAlias = "socket.socket | ssl.SSLSocket | mqtt.WebsocketWrapper | Any" +# See the overloads of `socket.setsockopt` for details. +SocketOption: TypeAlias = "tuple[int, int, int | bytes] | tuple[int, int, None, int]" diff --git a/docs/alongside-fastapi-and-co.md b/docs/alongside-fastapi-and-co.md index c390c1a..14841a9 100644 --- a/docs/alongside-fastapi-and-co.md +++ b/docs/alongside-fastapi-and-co.md @@ -12,10 +12,8 @@ import fastapi async def listen(client): - async with client.messages() as messages: - await client.subscribe("humidity/#") - async for message in messages: - print(message.payload) + async for message in client.messages: + print(message.payload) client = None @@ -28,6 +26,7 @@ async def lifespan(app): # Make client globally available client = c # Listen for MQTT messages in (unawaited) asyncio task + await client.subscribe("humidity/#") loop = asyncio.get_event_loop() task = loop.create_task(listen(client)) yield diff --git a/docs/connecting-to-the-broker.md b/docs/connecting-to-the-broker.md index c106b3a..9e2578b 100644 --- a/docs/connecting-to-the-broker.md +++ b/docs/connecting-to-the-broker.md @@ -20,7 +20,7 @@ The connection to the broker is managed by the `Client` context manager. This co Context managers make it easier to manage resources like network connections or files by ensuring that their teardown logic is always executed -- even in case of an exception. ```{tip} -If your use case does not allow you to use a context manager, you can use the client's `__aenter__` and `__aexit__` methods directly as a workaround, similar to how you would use manual `connect` and `disconnect` methods. With this approach you need to make sure that `___aexit___` is also called in case of an exception. Avoid this workaround if you can, it's a bit tricky to get right. +If your use case does not allow you to use a context manager, you can use the client's `__aenter__` and `__aexit__` methods to connect and disconnect as a workaround. With this approach you need to ensure yourself that `___aexit___` is also called in case of an exception. Avoid this workaround if you can, it's a bit tricky to get right. ``` ```{note} diff --git a/docs/index.md b/docs/index.md index 167ef6e..d43600d 100644 --- a/docs/index.md +++ b/docs/index.md @@ -27,10 +27,11 @@ alongside-fastapi-and-co ``` ```{toctree} -:caption: API reference +:caption: reference :hidden: developer-interface +migration-guide-v2 ``` ```{toctree} @@ -39,7 +40,7 @@ developer-interface GitHub Issue tracker -Discussions +Changelog Contributing PyPI ``` diff --git a/docs/migration-guide-v2.md b/docs/migration-guide-v2.md new file mode 100644 index 0000000..f31e3fa --- /dev/null +++ b/docs/migration-guide-v2.md @@ -0,0 +1,154 @@ +# Migration guide: v2.0.0 + +Version 2.0.0 introduces some breaking changes. This page aims to help you migrate to this new major version. The relevant changes are: + +- The deprecated `connect` and `disconnect` methods have been removed +- The deprecated `filtered_messages` and `unfiltered_messages` methods have been removed +- User-managed queues for incoming messages have been replaced with a single client-wide queue +- Some arguments to the `Client` have been renamed or removed + +## Changes to the client lifecycle + +The deprecated `connect` and `disconnect` methods have been removed. The best way to connect and disconnect from the broker is through the client's context manager: + +```python +import asyncio +import aiomqtt + + +async def main(): + async with aiomqtt.Client("test.mosquitto.org") as client: + await client.publish("temperature/outside", payload=28.4) + + +asyncio.run(main()) +``` + +If your use case does not allow you to use a context manager, you can use the client’s `__aenter__` and `__aexit__` methods almost interchangeably in place of the removed `connect` and `disconnect` methods. + +The `__aenter__` and `__aexit__` methods are designed to be called by the `async with` statement when the execution enters and exits the context manager. However, we can also execute them manually: + +```python +import asyncio +import aiomqtt + + +async def main(): + client = aiomqtt.Client("test.mosquitto.org") + await client.__aenter__() + try: + await client.publish("temperature/outside", payload=28.4) + finally: + await client.__aexit__(None, None, None) + + +asyncio.run(main()) +``` + +`__aenter__` is equivalent to `connect`. `__aexit__` is equivalent to `disconnect` except that it forces disconnection instead of throwing an exception in case the client cannot disconnect cleanly. + +```{note} +`__aexit__` expects three arguments: `exc_type`, `exc`, and `tb`. These arguments describe the exception that caused the context manager to exit, if any. You can pass `None` to all of these arguments in a manual call to `__aexit__`. +``` + +## Changes to the message queue + +The `filtered_messages`, `unfiltered_messages`, and `messages` methods have been removed and replaced with a single client-wide message queue. + +A minimal example of printing all messages (unfiltered) looks like this: + +```python +import asyncio +import aiomqtt + + +async def main(): + async with aiomqtt.Client("test.mosquitto.org") as client: + await client.subscribe("temperature/#") + async for message in client.messages: + print(message.payload) + + +asyncio.run(main()) +``` + +To handle messages from different topics differently, we can use `Topic.matches()`: + +```python +import asyncio +import aiomqtt + + +async def main(): + async with aiomqtt.Client("test.mosquitto.org") as client: + await client.subscribe("temperature/#") + await client.subscribe("humidity/#") + async for message in client.messages: + if message.topic.matches("humidity/inside"): + print(f"[humidity/inside] {message.payload}") + if message.topic.matches("+/outside"): + print(f"[+/outside] {message.payload}") + if message.topic.matches("temperature/#"): + print(f"[temperature/#] {message.payload}") + + +asyncio.run(main()) +``` + +```{note} +In our example, messages to `temperature/outside` are handled twice! +``` + +The `filtered_messages`, `unfiltered_messages`, and `messages` methods created isolated message queues underneath, such that you could invoke them multiple times. From Version 2.0.0 on, the client maintains a single queue that holds all incoming messages, accessible via `Client.messages`. + +If you continue to need multiple queues (e.g. because you have special concurrency requirements), you can build a "distributor" on top: + +```python +import asyncio +import aiomqtt + + +async def temperature_consumer(): + while True: + message = await temperature_queue.get() + print(f"[temperature/#] {message.payload}") + + +async def humidity_consumer(): + while True: + message = await humidity_queue.get() + print(f"[humidity/#] {message.payload}") + + +temperature_queue = asyncio.Queue() +humidity_queue = asyncio.Queue() + + +async def distributor(client): + # Sort messages into the appropriate queues + async for message in client.messages: + if message.topic.matches("temperature/#"): + temperature_queue.put_nowait(message) + elif message.topic.matches("humidity/#"): + humidity_queue.put_nowait(message) + + +async def main(): + async with aiomqtt.Client("test.mosquitto.org") as client: + await client.subscribe("temperature/#") + await client.subscribe("humidity/#") + # Use a task group to manage and await all tasks + async with asyncio.TaskGroup() as tg: + tg.create_task(distributor(client)) + tg.create_task(temperature_consumer()) + tg.create_task(humidity_consumer()) + + +asyncio.run(main()) +``` + +## Changes to client arguments + +- The `queue_class` and `queue_maxsize` arguments to `filtered_messages`, `unfiltered_messages`, and `messages` have been moved to the `Client` and have been renamed to `queue_type` and `max_queued_incoming_messages` +- The `max_queued_messages` client argument has been renamed to `max_queued_outgoing_messages` +- The deprecated `message_retry_set` client argument has been removed diff --git a/docs/reconnection.md b/docs/reconnection.md index 9ea92fb..c6b754a 100644 --- a/docs/reconnection.md +++ b/docs/reconnection.md @@ -17,10 +17,9 @@ async def main(): while True: try: async with client: - async with client.messages() as messages: - await client.subscribe("humidity/#") - async for message in messages: - print(message.payload) + await client.subscribe("humidity/#") + async for message in client.messages: + print(message.payload) except aiomqtt.MqttError: print(f"Connection lost; Reconnecting in {interval} seconds ...") await asyncio.sleep(interval) diff --git a/docs/subscribing-to-a-topic.md b/docs/subscribing-to-a-topic.md index bee45c1..74212af 100644 --- a/docs/subscribing-to-a-topic.md +++ b/docs/subscribing-to-a-topic.md @@ -1,6 +1,6 @@ # Subscribing to a topic -To receive messages for a topic, we need to subscribe to it and listen for messages. This is a minimal working example that listens for messages to the `temperature/#` wildcard: +To receive messages for a topic, we need to subscribe to it. Incoming messages are queued internally. You can use the `Client.message` generator to iterate over incoming messages. This is a minimal working example that listens for messages to the `temperature/#` wildcard: ```python import asyncio @@ -9,10 +9,9 @@ import aiomqtt async def main(): async with aiomqtt.Client("test.mosquitto.org") as client: - async with client.messages() as messages: - await client.subscribe("temperature/#") - async for message in messages: - print(message.payload) + await client.subscribe("temperature/#") + async for message in client.messages: + print(message.payload) asyncio.run(main()) @@ -37,16 +36,15 @@ import aiomqtt async def main(): async with aiomqtt.Client("test.mosquitto.org") as client: - async with client.messages() as messages: - await client.subscribe("temperature/#") - await client.subscribe("humidity/#") - async for message in messages: - if message.topic.matches("humidity/inside"): - print(f"[humidity/outside] {message.payload}") - if message.topic.matches("+/outside"): - print(f"[+/inside] {message.payload}") - if message.topic.matches("temperature/#"): - print(f"[temperature/#] {message.payload}") + await client.subscribe("temperature/#") + await client.subscribe("humidity/#") + async for message in client.messages: + if message.topic.matches("humidity/inside"): + print(f"[humidity/inside] {message.payload}") + if message.topic.matches("+/outside"): + print(f"[+/outside] {message.payload}") + if message.topic.matches("temperature/#"): + print(f"[temperature/#] {message.payload}") asyncio.run(main()) @@ -62,9 +60,9 @@ For details on the `+` and `#` wildcards and what topics they match, see the [OA ## The message queue -Messages are queued and returned sequentially from `Client.messages()`. +Messages are queued internally and returned sequentially from `Client.messages`. -The default queue is `asyncio.Queue` which returns messages on a FIFO ("first in first out") basis. You can pass [other types of asyncio queues](https://docs.python.org/3/library/asyncio-queue.html) as `queue_class` to `Client.messages()` to modify the order in which messages are returned, e.g. `asyncio.LifoQueue`. +The default queue is `asyncio.Queue` which returns messages on a FIFO ("first in first out") basis. You can pass [other types of asyncio queues](https://docs.python.org/3/library/asyncio-queue.html) as `queue_class` to the `Client` to modify the order in which messages are returned, e.g. `asyncio.LifoQueue`. You can subclass `asyncio.PriorityQueue` to queue based on priority. Messages are returned ascendingly by their priority values. In the case of ties, messages with lower message identifiers are returned first. @@ -87,12 +85,13 @@ class CustomPriorityQueue(asyncio.PriorityQueue): async def main(): - async with aiomqtt.Client("test.mosquitto.org") as client: - async with client.messages(queue_class=CustomPriorityQueue) as messages: - await client.subscribe("temperature/#") - await client.subscribe("humidity/#") - async for message in messages: - print(message.payload) + async with aiomqtt.Client( + "test.mosquitto.org", queue_class=CustomPriorityQueue + ) as client: + await client.subscribe("temperature/#") + await client.subscribe("humidity/#") + async for message in client.messages: + print(message.payload) asyncio.run(main()) @@ -104,7 +103,7 @@ By default, the size of the queue is unlimited. You can set a limit by passing t ## Processing concurrently -Messages are queued and returned sequentially from `Client.messages()`. If a message takes a long time to handle, it blocks the handling of other messages. +Messages are queued internally and returned sequentially from `Client.messages`. If a message takes a long time to handle, it blocks the handling of other messages. You can handle messages concurrently by using an `asyncio.TaskGroup` like so: @@ -120,11 +119,11 @@ async def handle(message): async def main(): async with aiomqtt.Client("test.mosquitto.org") as client: - async with client.messages() as messages: - await client.subscribe("temperature/#") - async with asyncio.TaskGroup() as tg: - async for message in messages: - tg.create_task(handle(message)) # Spawn new coroutine + await client.subscribe("temperature/#") + # Use a task group to manage and await all tasks + async with asyncio.TaskGroup() as tg: + async for message in client.messages: + tg.create_task(handle(message)) # Spawn new coroutine asyncio.run(main()) @@ -134,6 +133,8 @@ asyncio.run(main()) Coroutines only make sense if your message handling is I/O-bound. If it's CPU-bound, you should spawn multiple processes instead. ``` +## Multiple queues + The code snippet above handles each message in a new coroutine. Sometimes, we want to handle messages from different topics concurrently, but sequentially inside a single topic. The idea here is to implement a "distributor" that sorts incoming messages into multiple asyncio queues. Each queue is then processed by a different coroutine. Let's see how this works for our temperature and humidity messages: @@ -160,19 +161,18 @@ humidity_queue = asyncio.Queue() async def distributor(client): - async with client.messages() as messages: - await client.subscribe("temperature/#") - await client.subscribe("humidity/#") - # Sort messages into the appropriate queues - async for message in messages: - if message.topic.matches("temperature/#"): - temperature_queue.put_nowait(message) - elif message.topic.matches("humidity/#"): - humidity_queue.put_nowait(message) + # Sort messages into the appropriate queues + async for message in client.messages: + if message.topic.matches("temperature/#"): + temperature_queue.put_nowait(message) + elif message.topic.matches("humidity/#"): + humidity_queue.put_nowait(message) async def main(): async with aiomqtt.Client("test.mosquitto.org") as client: + await client.subscribe("temperature/#") + await client.subscribe("humidity/#") # Use a task group to manage and await all tasks async with asyncio.TaskGroup() as tg: tg.create_task(distributor(client)) @@ -208,13 +208,13 @@ async def sleep(seconds): async def listen(): async with aiomqtt.Client("test.mosquitto.org") as client: - async with client.messages() as messages: - await client.subscribe("temperature/#") - async for message in messages: - print(message.payload) + await client.subscribe("temperature/#") + async for message in client.messages: + print(message.payload) async def main(): + # Use a task group to manage and await all tasks async with asyncio.TaskGroup() as tg: tg.create_task(sleep(2)) tg.create_task(listen()) # Start the listener task @@ -240,10 +240,9 @@ import aiomqtt async def listen(): async with aiomqtt.Client("test.mosquitto.org") as client: - async with client.messages() as messages: - await client.subscribe("temperature/#") - async for message in messages: - print(message.payload) + await client.subscribe("temperature/#") + async for message in client.messages: + print(message.payload) background_tasks = set() @@ -280,10 +279,9 @@ import aiomqtt async def listen(): async with aiomqtt.Client("test.mosquitto.org") as client: - async with client.messages() as messages: - await client.subscribe("temperature/#") - async for message in messages: - print(message.payload) + await client.subscribe("temperature/#") + async for message in client.messages: + print(message.payload) async def main(): @@ -315,10 +313,9 @@ import aiomqtt async def listen(): async with aiomqtt.Client("test.mosquitto.org") as client: - async with client.messages() as messages: - await client.subscribe("temperature/#") - async for message in messages: - print(message.payload) + await client.subscribe("temperature/#") + async for message in client.messages: + print(message.payload) async def main(): diff --git a/tests/test_client.py b/tests/test_client.py index 82d4d5c..58c4443 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,9 +1,9 @@ from __future__ import annotations import logging +import pathlib import ssl import sys -from pathlib import Path import anyio import anyio.abc @@ -12,173 +12,50 @@ from anyio import TASK_STATUS_IGNORED from anyio.abc import TaskStatus -from aiomqtt import Client, ProtocolVersion, TLSParameters, Topic, Wildcard, Will -from aiomqtt.error import MqttError, MqttReentrantError +from aiomqtt import ( + Client, + MqttError, + MqttReentrantError, + ProtocolVersion, + TLSParameters, + Will, +) from aiomqtt.types import PayloadType +# This is the same as marking all tests in this file with @pytest.mark.anyio pytestmark = pytest.mark.anyio HOSTNAME = "test.mosquitto.org" OS_PY_VERSION = sys.platform + "_" + ".".join(map(str, sys.version_info[:2])) -TOPIC_HEADER = OS_PY_VERSION + "/tests/aiomqtt/" - - -async def test_topic_validation() -> None: - """Test that Topic raises Exceptions for invalid topics.""" - with pytest.raises(TypeError): - Topic(True) # type: ignore[arg-type] - with pytest.raises(TypeError): - Topic(1.0) # type: ignore[arg-type] - with pytest.raises(TypeError): - Topic(None) # type: ignore[arg-type] - with pytest.raises(TypeError): - Topic([]) # type: ignore[arg-type] - with pytest.raises(ValueError, match="Invalid topic: "): - Topic("a/b/#") - with pytest.raises(ValueError, match="Invalid topic: "): - Topic("a/+/c") - with pytest.raises(ValueError, match="Invalid topic: "): - Topic("#") - with pytest.raises(ValueError, match="Invalid topic: "): - Topic("") - with pytest.raises(ValueError, match="Invalid topic: "): - Topic("a" * 65536) - - -async def test_wildcard_validation() -> None: - """Test that Wildcard raises Exceptions for invalid wildcards.""" - with pytest.raises(TypeError): - Wildcard(True) # type: ignore[arg-type] - with pytest.raises(TypeError): - Wildcard(1.0) # type: ignore[arg-type] - with pytest.raises(TypeError): - Wildcard(None) # type: ignore[arg-type] - with pytest.raises(TypeError): - Wildcard([]) # type: ignore[arg-type] - with pytest.raises(ValueError, match="Invalid wildcard: "): - Wildcard("a/#/c") - with pytest.raises(ValueError, match="Invalid wildcard: "): - Wildcard("a/b+/c") - with pytest.raises(ValueError, match="Invalid wildcard: "): - Wildcard("a/b/#c") - with pytest.raises(ValueError, match="Invalid wildcard: "): - Wildcard("") - with pytest.raises(ValueError, match="Invalid wildcard: "): - Wildcard("a" * 65536) - - -async def test_topic_matches() -> None: - """Test that Topic.matches() does and doesn't match some test wildcards.""" - topic = Topic("a/b/c") - assert topic.matches("a/b/c") - assert topic.matches("a/+/c") - assert topic.matches("+/+/+") - assert topic.matches("+/#") - assert topic.matches("#") - assert topic.matches("a/b/c/#") - assert topic.matches("$share/group/a/b/c") - assert topic.matches("$share/group/a/b/+") - assert not topic.matches("abc") - assert not topic.matches("a/b") - assert not topic.matches("a/b/c/d") - assert not topic.matches("a/b/c/d/#") - assert not topic.matches("a/b/z") - assert not topic.matches("a/b/c/+") - assert not topic.matches("$share/a/b/c") - assert not topic.matches("$test/group/a/b/c") - - -@pytest.mark.network -async def test_multiple_messages_generators() -> None: - """Test that multiple Client.messages() generators can be used at the same time.""" - topic = TOPIC_HEADER + "multiple_messages_generators" - - async def handler(tg: anyio.abc.TaskGroup) -> None: - async with client.messages() as messages: - async for message in messages: - assert str(message.topic) == topic - tg.cancel_scope.cancel() - - async with Client(HOSTNAME) as client, anyio.create_task_group() as tg: - await client.subscribe(topic) - tg.start_soon(handler, tg) - tg.start_soon(handler, tg) - await anyio.wait_all_tasks_blocked() - await client.publish(topic) - - -@pytest.mark.network -async def test_client_filtered_messages() -> None: - topic_header = TOPIC_HEADER + "filtered_messages/" - good_topic = topic_header + "good" - bad_topic = topic_header + "bad" - - async def handle_messages(tg: anyio.abc.TaskGroup) -> None: - async with client.filtered_messages(good_topic) as messages: - async for message in messages: - assert message.topic == good_topic - tg.cancel_scope.cancel() - - async with Client(HOSTNAME) as client, anyio.create_task_group() as tg: - await client.subscribe(topic_header + "#") - tg.start_soon(handle_messages, tg) - await anyio.wait_all_tasks_blocked() - await client.publish(bad_topic, 2) - await client.publish(good_topic, 2) - - -@pytest.mark.network -async def test_client_unfiltered_messages() -> None: - topic_header = TOPIC_HEADER + "unfiltered_messages/" - topic_filtered = topic_header + "filtered" - topic_unfiltered = topic_header + "unfiltered" - - async def handle_unfiltered_messages(tg: anyio.abc.TaskGroup) -> None: - async with client.unfiltered_messages() as messages: - async for message in messages: - assert message.topic == topic_unfiltered - tg.cancel_scope.cancel() - - async def handle_filtered_messages() -> None: - async with client.filtered_messages(topic_filtered) as messages: - async for message in messages: - assert message.topic == topic_filtered - - async with Client(HOSTNAME) as client, anyio.create_task_group() as tg: - await client.subscribe(topic_header + "#") - tg.start_soon(handle_filtered_messages) - tg.start_soon(handle_unfiltered_messages, tg) - await anyio.wait_all_tasks_blocked() - await client.publish(topic_filtered, 2) - await client.publish(topic_unfiltered, 2) +TOPIC_PREFIX = OS_PY_VERSION + "/tests/aiomqtt/" @pytest.mark.network async def test_client_unsubscribe() -> None: - topic_header = TOPIC_HEADER + "unsubscribe/" - topic1 = topic_header + "1" - topic2 = topic_header + "2" - - async def handle_messages(tg: anyio.abc.TaskGroup) -> None: - async with client.unfiltered_messages() as messages: - is_first_message = True - async for message in messages: - if is_first_message: - assert message.topic == topic1 - is_first_message = False - else: - assert message.topic == topic2 - tg.cancel_scope.cancel() + """Test that messages are no longer received after unsubscribing from a topic.""" + topic_1 = TOPIC_PREFIX + "test_client_unsubscribe/1" + topic_2 = TOPIC_PREFIX + "test_client_unsubscribe/2" + + async def handle(tg: anyio.abc.TaskGroup) -> None: + is_first_message = True + async for message in client.messages: + if is_first_message: + assert message.topic.value == topic_1 + is_first_message = False + else: + assert message.topic.value == topic_2 + tg.cancel_scope.cancel() async with Client(HOSTNAME) as client, anyio.create_task_group() as tg: - await client.subscribe(topic1) - await client.subscribe(topic2) - tg.start_soon(handle_messages, tg) + await client.subscribe(topic_1) + await client.subscribe(topic_2) + tg.start_soon(handle, tg) await anyio.wait_all_tasks_blocked() - await client.publish(topic1, 2) - await client.unsubscribe(topic1) - await client.publish(topic1, 2) - await client.publish(topic2, 2) + await client.publish(topic_1, None) + await client.unsubscribe(topic_1) + await client.publish(topic_1, None) + # Test that other subscriptions still receive messages + await client.publish(topic_2, None) @pytest.mark.parametrize( @@ -187,12 +64,12 @@ async def handle_messages(tg: anyio.abc.TaskGroup) -> None: ) async def test_client_id(protocol: ProtocolVersion, length: int) -> None: client = Client(HOSTNAME, protocol=protocol) - assert len(client.id) == length + assert len(client.identifier) == length @pytest.mark.network async def test_client_will() -> None: - topic = TOPIC_HEADER + "will" + topic = TOPIC_PREFIX + "test_client_will" event = anyio.Event() async def launch_client() -> None: @@ -200,10 +77,9 @@ async def launch_client() -> None: async with Client(HOSTNAME) as client: await client.subscribe(topic) event.set() - async with client.filtered_messages(topic) as messages: - async for message in messages: - assert message.topic == topic - cs.cancel() + async for message in client.messages: + assert message.topic.value == topic + cs.cancel() async with anyio.create_task_group() as tg: tg.start_soon(launch_client) @@ -214,13 +90,12 @@ async def launch_client() -> None: @pytest.mark.network async def test_client_tls_context() -> None: - topic = TOPIC_HEADER + "tls_context" + topic = TOPIC_PREFIX + "test_client_tls_context" - async def handle_messages(tg: anyio.abc.TaskGroup) -> None: - async with client.filtered_messages(topic) as messages: - async for message in messages: - assert message.topic == topic - tg.cancel_scope.cancel() + async def handle(tg: anyio.abc.TaskGroup) -> None: + async for message in client.messages: + assert message.topic.value == topic + tg.cancel_scope.cancel() async with Client( HOSTNAME, @@ -228,49 +103,47 @@ async def handle_messages(tg: anyio.abc.TaskGroup) -> None: tls_context=ssl.SSLContext(protocol=ssl.PROTOCOL_TLS), ) as client, anyio.create_task_group() as tg: await client.subscribe(topic) - tg.start_soon(handle_messages, tg) + tg.start_soon(handle, tg) await anyio.wait_all_tasks_blocked() await client.publish(topic) @pytest.mark.network async def test_client_tls_params() -> None: - topic = TOPIC_HEADER + "tls_params" + topic = TOPIC_PREFIX + "tls_params" - async def handle_messages(tg: anyio.abc.TaskGroup) -> None: - async with client.filtered_messages(topic) as messages: - async for message in messages: - assert message.topic == topic - tg.cancel_scope.cancel() + async def handle(tg: anyio.abc.TaskGroup) -> None: + async for message in client.messages: + assert message.topic.value == topic + tg.cancel_scope.cancel() async with Client( HOSTNAME, 8883, tls_params=TLSParameters( - ca_certs=str(Path.cwd() / "tests" / "mosquitto.org.crt") + ca_certs=str(pathlib.Path.cwd() / "tests" / "mosquitto.org.crt") ), ) as client, anyio.create_task_group() as tg: await client.subscribe(topic) - tg.start_soon(handle_messages, tg) + tg.start_soon(handle, tg) await anyio.wait_all_tasks_blocked() await client.publish(topic) @pytest.mark.network async def test_client_username_password() -> None: - topic = TOPIC_HEADER + "username_password" + topic = TOPIC_PREFIX + "username_password" - async def handle_messages(tg: anyio.abc.TaskGroup) -> None: - async with client.filtered_messages(topic) as messages: - async for message in messages: - assert message.topic == topic - tg.cancel_scope.cancel() + async def handle(tg: anyio.abc.TaskGroup) -> None: + async for message in client.messages: + assert message.topic.value == topic + tg.cancel_scope.cancel() async with Client( HOSTNAME, username="", password="" ) as client, anyio.create_task_group() as tg: await client.subscribe(topic) - tg.start_soon(handle_messages, tg) + tg.start_soon(handle, tg) await anyio.wait_all_tasks_blocked() await client.publish(topic) @@ -286,7 +159,7 @@ async def test_client_logger() -> None: async def test_client_max_concurrent_outgoing_calls( monkeypatch: pytest.MonkeyPatch, ) -> None: - topic = TOPIC_HEADER + "max_concurrent_outgoing_calls" + topic = TOPIC_PREFIX + "max_concurrent_outgoing_calls" class MockPahoClient(mqtt.Client): def subscribe( @@ -332,13 +205,12 @@ def publish( # noqa: PLR0913 @pytest.mark.network async def test_client_websockets() -> None: - topic = TOPIC_HEADER + "websockets" + topic = TOPIC_PREFIX + "websockets" - async def handle_messages(tg: anyio.abc.TaskGroup) -> None: - async with client.filtered_messages(topic) as messages: - async for message in messages: - assert message.topic == topic - tg.cancel_scope.cancel() + async def handle(tg: anyio.abc.TaskGroup) -> None: + async for message in client.messages: + assert message.topic.value == topic + tg.cancel_scope.cancel() async with Client( HOSTNAME, @@ -346,11 +218,12 @@ async def handle_messages(tg: anyio.abc.TaskGroup) -> None: transport="websockets", websocket_path="/", websocket_headers={"foo": "bar"}, - ) as client, anyio.create_task_group() as tg: + ) as client: await client.subscribe(topic) - tg.start_soon(handle_messages, tg) - await anyio.wait_all_tasks_blocked() - await client.publish(topic) + async with anyio.create_task_group() as tg: + tg.start_soon(handle, tg) + await anyio.wait_all_tasks_blocked() + await client.publish(topic) @pytest.mark.network @@ -358,7 +231,7 @@ async def handle_messages(tg: anyio.abc.TaskGroup) -> None: async def test_client_pending_calls_threshold( pending_calls_threshold: int, caplog: pytest.LogCaptureFixture ) -> None: - topic = TOPIC_HEADER + "pending_calls_threshold" + topic = TOPIC_PREFIX + "pending_calls_threshold" async with Client(HOSTNAME) as client: client.pending_calls_threshold = pending_calls_threshold @@ -384,7 +257,7 @@ async def test_client_no_pending_calls_warnings_with_max_concurrent_outgoing_cal caplog: pytest.LogCaptureFixture, ) -> None: topic = ( - TOPIC_HEADER + "no_pending_calls_warnings_with_max_concurrent_outgoing_calls" + TOPIC_PREFIX + "no_pending_calls_warnings_with_max_concurrent_outgoing_calls" ) async with Client(HOSTNAME, max_concurrent_outgoing_calls=1) as client: @@ -399,32 +272,24 @@ async def test_client_no_pending_calls_warnings_with_max_concurrent_outgoing_cal @pytest.mark.network -async def test_client_not_reentrant() -> None: - """Test that the client raises an error when we try to reenter.""" - client = Client(HOSTNAME) - with pytest.raises(MqttReentrantError): # noqa: PT012 - async with client: - async with client: - pass - - -@pytest.mark.network -async def test_client_reusable() -> None: - """Test that an instance of the client context manager can be reused.""" +async def test_client_context_is_reusable() -> None: + """Test that a client context manager instance is reusable.""" + topic = TOPIC_PREFIX + "test_client_is_reusable" client = Client(HOSTNAME) async with client: - await client.publish("task/a", "task_a") + await client.publish(topic, "foo") async with client: - await client.publish("task/b", "task_b") + await client.publish(topic, "bar") @pytest.mark.network -async def test_client_connect_disconnect() -> None: +async def test_client_context_is_not_reentrant() -> None: + """Test that a client context manager instance is not reentrant.""" client = Client(HOSTNAME) - - await client.connect() - await client.publish("connect", "connect") - await client.disconnect() + async with client: + with pytest.raises(MqttReentrantError): + async with client: + pass @pytest.mark.network @@ -435,10 +300,10 @@ async def test_client_reusable_message() -> None: async def task_a_customer( task_status: TaskStatus[None] = TASK_STATUS_IGNORED, ) -> None: - async with custom_client, custom_client.messages() as messages: + async with custom_client: await custom_client.subscribe("task/a") task_status.started() - async for message in messages: + async for message in custom_client.messages: assert message.payload == b"task_a" return @@ -463,81 +328,9 @@ async def task_a_publisher() -> None: @pytest.mark.network -async def test_client_use_connect_disconnect_multiple_message() -> None: - custom_client = Client(HOSTNAME) - publish_client = Client(HOSTNAME) - - topic_a = TOPIC_HEADER + "task/a" - topic_b = TOPIC_HEADER + "task/b" - - await custom_client.connect() - await publish_client.connect() - - async def task_a_customer( - task_status: TaskStatus[None] = TASK_STATUS_IGNORED, - ) -> None: - await custom_client.subscribe(topic_a) - async with custom_client.messages() as messages: - task_status.started() - async for message in messages: - assert message.payload == b"task_a" - return - - async def task_b_customer( - task_status: TaskStatus[None] = TASK_STATUS_IGNORED, - ) -> None: - num = 0 - await custom_client.subscribe(topic_b) - async with custom_client.messages() as messages: - task_status.started() - async for message in messages: - assert message.payload in {b"task_a", b"task_b"} - num += 1 - if num == 2: # noqa: PLR2004 - return - - async def task_publisher(topic: str, payload: PayloadType) -> None: - await publish_client.publish(topic, payload) - - async with anyio.create_task_group() as tg: - await tg.start(task_a_customer) - await tg.start(task_b_customer) - tg.start_soon(task_publisher, topic_a, "task_a") - tg.start_soon(task_publisher, topic_b, "task_b") - - await custom_client.disconnect() - await publish_client.disconnect() - - -@pytest.mark.network -async def test_client_disconnected_exception() -> None: - client = Client(HOSTNAME) - await client.connect() - client._disconnected.set_exception(RuntimeError) - with pytest.raises(RuntimeError): - await client.disconnect() - - -@pytest.mark.network -async def test_client_disconnected_done() -> None: - client = Client(HOSTNAME) - await client.connect() - client._disconnected.set_result(None) - await client.disconnect() - - -@pytest.mark.network -async def test_client_connecting_disconnected_done() -> None: - client = Client(HOSTNAME) - client._disconnected.set_result(None) - await client.connect() - await client.disconnect() - - -@pytest.mark.network -async def test_client_aenter_error_lock_release() -> None: - """Test that the client's reusability lock is released on error in __aenter__.""" - client = Client(hostname="aenter_connect_error_lock_release") +async def test_aenter_error_lock_release() -> None: + """Test that the client's reusability lock is released on error in ``aenter``.""" + client = Client(hostname="invalid") with pytest.raises(MqttError): await client.__aenter__() assert not client._lock.locked() @@ -545,7 +338,7 @@ async def test_client_aenter_error_lock_release() -> None: @pytest.mark.network async def test_aexit_without_prior_aenter() -> None: - """Test that __aexit__ without prior (or unsuccessful) __aenter__ runs cleanly.""" + """Test that ``aexit`` without prior (or unsuccessful) ``aenter`` runs cleanly.""" client = Client(HOSTNAME) await client.__aexit__(None, None, None) @@ -559,16 +352,32 @@ async def test_aexit_consecutive_calls() -> None: @pytest.mark.network async def test_aexit_client_is_already_disconnected_success() -> None: - """Test that __aexit__ exits cleanly if client is already cleanly disconnected.""" + """Test that ``aexit`` runs cleanly if client is already cleanly disconnected.""" async with Client(HOSTNAME) as client: client._disconnected.set_result(None) @pytest.mark.network async def test_aexit_client_is_already_disconnected_failure() -> None: - """Test that __aexit__ reraises if client is already disconnected with an error.""" + """Test that ``aexit`` reraises if client is already disconnected with an error.""" client = Client(HOSTNAME) await client.__aenter__() client._disconnected.set_exception(RuntimeError) with pytest.raises(RuntimeError): await client.__aexit__(None, None, None) + + +@pytest.mark.network +async def test_messages_generator_is_reusable() -> None: + """Test that the messages generator is reusable and returns no duplicates.""" + topic = TOPIC_PREFIX + "test_messages_generator_is_reusable" + async with Client(HOSTNAME) as client: + await client.subscribe(topic) + await client.publish(topic, "foo") + await client.publish(topic, "bar") + async for message in client.messages: + assert message.payload == b"foo" + break + async for message in client.messages: + assert message.payload == b"bar" + break diff --git a/tests/test_error.py b/tests/test_exceptions.py similarity index 96% rename from tests/test_error.py rename to tests/test_exceptions.py index a4c90b8..4a24928 100644 --- a/tests/test_error.py +++ b/tests/test_exceptions.py @@ -2,7 +2,7 @@ import pytest from paho.mqtt.packettypes import PacketTypes -from aiomqtt.error import _CONNECT_RC_STRINGS, MqttCodeError, MqttConnectError +from aiomqtt.exceptions import _CONNECT_RC_STRINGS, MqttCodeError, MqttConnectError @pytest.mark.parametrize( diff --git a/tests/test_topic.py b/tests/test_topic.py new file mode 100644 index 0000000..023077e --- /dev/null +++ b/tests/test_topic.py @@ -0,0 +1,70 @@ +from __future__ import annotations + +import pytest + +from aiomqtt import Topic, Wildcard + + +def test_topic_validation() -> None: + """Test that Topic raises Exceptions for invalid topics.""" + with pytest.raises(TypeError): + Topic(True) # type: ignore[arg-type] + with pytest.raises(TypeError): + Topic(1.0) # type: ignore[arg-type] + with pytest.raises(TypeError): + Topic(None) # type: ignore[arg-type] + with pytest.raises(TypeError): + Topic([]) # type: ignore[arg-type] + with pytest.raises(ValueError, match="Invalid topic: "): + Topic("a/b/#") + with pytest.raises(ValueError, match="Invalid topic: "): + Topic("a/+/c") + with pytest.raises(ValueError, match="Invalid topic: "): + Topic("#") + with pytest.raises(ValueError, match="Invalid topic: "): + Topic("") + with pytest.raises(ValueError, match="Invalid topic: "): + Topic("a" * 65536) + + +def test_wildcard_validation() -> None: + """Test that Wildcard raises Exceptions for invalid wildcards.""" + with pytest.raises(TypeError): + Wildcard(True) # type: ignore[arg-type] + with pytest.raises(TypeError): + Wildcard(1.0) # type: ignore[arg-type] + with pytest.raises(TypeError): + Wildcard(None) # type: ignore[arg-type] + with pytest.raises(TypeError): + Wildcard([]) # type: ignore[arg-type] + with pytest.raises(ValueError, match="Invalid wildcard: "): + Wildcard("a/#/c") + with pytest.raises(ValueError, match="Invalid wildcard: "): + Wildcard("a/b+/c") + with pytest.raises(ValueError, match="Invalid wildcard: "): + Wildcard("a/b/#c") + with pytest.raises(ValueError, match="Invalid wildcard: "): + Wildcard("") + with pytest.raises(ValueError, match="Invalid wildcard: "): + Wildcard("a" * 65536) + + +def test_topic_matches() -> None: + """Test that Topic.matches() does and doesn't match some test wildcards.""" + topic = Topic("a/b/c") + assert topic.matches("a/b/c") + assert topic.matches("a/+/c") + assert topic.matches("+/+/+") + assert topic.matches("+/#") + assert topic.matches("#") + assert topic.matches("a/b/c/#") + assert topic.matches("$share/group/a/b/c") + assert topic.matches("$share/group/a/b/+") + assert not topic.matches("abc") + assert not topic.matches("a/b") + assert not topic.matches("a/b/c/d") + assert not topic.matches("a/b/c/d/#") + assert not topic.matches("a/b/z") + assert not topic.matches("a/b/c/+") + assert not topic.matches("$share/a/b/c") + assert not topic.matches("$test/group/a/b/c")