Source code for nwp500.mqtt.subscriptions

"""
MQTT Subscription Management for Navien devices.

This module handles all subscription-related operations including:
- Low-level subscribe/unsubscribe operations
- Topic pattern matching with MQTT wildcards
- Message routing and handler management
- Typed subscriptions (status, feature, energy)
- State change detection and event emission
"""

from __future__ import annotations

import asyncio
import json
import logging
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, cast

from awscrt import mqtt
from awscrt.exceptions import AwsCrtError
from pydantic import ValidationError

from ..events import EventEmitter
from ..exceptions import MqttNotConnectedError
from ..models import (
    Device,
    DeviceFeature,
    DeviceStatus,
    EnergyUsageResponse,
    RecirculationSchedule,
    ReservationSchedule,
    TOUReservationSchedule,
    WeeklyReservationSchedule,
)
from ..mqtt_events import FeatureReceivedEvent, StatusReceivedEvent
from ..topic_builder import MqttTopicBuilder
from .state_tracker import DeviceStateTracker
from .utils import get_response_data, redact_topic, topic_matches_pattern

if TYPE_CHECKING:
    from ..device_info_cache import MqttDeviceInfoCache

__author__ = "Emmanuel Levijarvi"

_logger = logging.getLogger(__name__)


[docs] class MqttSubscriptionManager: """ Manages MQTT subscriptions, topic matching, and message routing. Handles: - Subscribe/unsubscribe to MQTT topics - Topic pattern matching with wildcards (+ and #) - Message handler registration and invocation - Typed subscriptions with automatic parsing - State change detection and event emission """ def __init__( self, connection: Any, # awsiot.mqtt_connection.Connection client_id: str, event_emitter: EventEmitter, schedule_coroutine: Callable[[Any], None], device_info_cache: MqttDeviceInfoCache | None = None, ): """ Initialize subscription manager. Args: connection: MQTT connection object client_id: Client ID for response topics event_emitter: Event emitter for state changes schedule_coroutine: Function to schedule async tasks device_info_cache: Optional MqttDeviceInfoCache for caching device features """ self._connection = connection self._client_id = client_id self._event_emitter = event_emitter self._schedule_coroutine = schedule_coroutine self._device_info_cache = device_info_cache # Track subscriptions and handlers self._subscriptions: dict[str, mqtt.QoS] = {} self._message_handlers: dict[ str, list[Callable[[str, dict[str, Any]], None]] ] = {} # Per-device state change detection self._state_tracker = DeviceStateTracker(event_emitter) @property def subscriptions(self) -> dict[str, mqtt.QoS]: """Get current subscriptions.""" return self._subscriptions.copy()
[docs] def update_connection(self, connection: Any) -> None: """ Update the MQTT connection reference. This is used when the connection is recreated (e.g., after reconnection) to update the internal reference while preserving subscriptions. Args: connection: New MQTT connection object Note: This does not re-establish subscriptions. Call the appropriate subscribe methods to re-register subscriptions with the new connection if needed. """ self._connection = connection _logger.debug("Updated subscription manager connection reference")
def _on_message_received( self, topic: str, payload: bytes, **kwargs: Any ) -> None: """Handle received MQTT messages. Parses JSON payload and routes to registered handlers. Args: topic: MQTT topic the message was received on payload: Raw message payload (JSON bytes) **kwargs: Additional MQTT metadata """ try: # Parse JSON payload message = json.loads(payload.decode("utf-8")) _logger.debug("Received message on topic: %s", redact_topic(topic)) # Call registered handlers that match this topic # Need to match against subscription patterns with wildcards for ( subscription_pattern, handlers, ) in self._message_handlers.items(): if topic_matches_pattern(topic, subscription_pattern): for handler in handlers: try: handler(topic, message) except (TypeError, AttributeError, KeyError) as e: _logger.error(f"Error in message handler: {e}") except json.JSONDecodeError as e: _logger.error(f"Failed to parse message payload: {e}") except (AttributeError, KeyError, TypeError) as e: _logger.error(f"Error processing message: {e}")
[docs] async def subscribe( self, topic: str, callback: Callable[[str, dict[str, Any]], None], qos: mqtt.QoS = mqtt.QoS.AT_LEAST_ONCE, ) -> int: """ Subscribe to an MQTT topic. Args: topic: MQTT topic to subscribe to (can include wildcards) callback: Function to call when messages arrive (topic, message) qos: Quality of Service level Returns: Subscription packet ID Raises: RuntimeError: If not connected to MQTT broker Exception: If subscription fails """ if not self._connection: raise MqttNotConnectedError("Not connected to MQTT broker") # Track handler first if topic not in self._message_handlers: self._message_handlers[topic] = [] if callback not in self._message_handlers[topic]: self._message_handlers[topic].append(callback) # Check if already subscribed to this topic at the broker level if topic in self._subscriptions: # Already subscribed. If requested QoS is higher than current, # we should upgrade, but standard practice is to just return. # Most brokers handle multiple overlapping subscriptions. # Return a synthetic packet ID (0) as we didn't send a request. return 0 _logger.info(f"Subscribing to topic: {redact_topic(topic)}") try: # Convert concurrent.futures.Future to asyncio.Future and await # Use shield to prevent cancellation from propagating to # underlying future subscribe_future, packet_id = self._connection.subscribe( topic=topic, qos=qos, callback=self._on_message_received ) try: subscribe_result = await asyncio.shield( asyncio.wrap_future(subscribe_future) ) except asyncio.CancelledError: # Shield was cancelled - the underlying subscribe will # complete independently, preventing InvalidStateError # in AWS CRT callbacks _logger.debug( f"Subscribe to '{redact_topic(topic)}' was cancelled " "but will complete in background" ) raise _logger.info( f"Subscription succeeded (topic redacted) with QoS " f"{subscribe_result['qos']}" ) # Store subscription self._subscriptions[topic] = qos return int(packet_id) except (AwsCrtError, RuntimeError) as e: # Clean up handler on failure if this was the first one if (h := self._message_handlers.get(topic)) and callback in h: h.remove(callback) _logger.error( f"Failed to subscribe to '{redact_topic(topic)}': {e}" ) raise
[docs] async def unsubscribe( self, topic: str, callback: Callable[[str, dict[str, Any]], None] | None = None, ) -> int: """ Unsubscribe from an MQTT topic. If a callback is provided, only that specific handler is removed. The underlying MQTT unsubscribe from the broker is only performed if no handlers remain for the topic. If no callback is provided, all handlers are removed and the broker is unsubscribed immediately. Args: topic: MQTT topic to unsubscribe from callback: Optional specific handler to remove Returns: Unsubscribe packet ID (or 0 if no broker call was made) Raises: RuntimeError: If not connected to MQTT broker Exception: If unsubscribe fails """ if not self._connection: raise MqttNotConnectedError("Not connected to MQTT broker") if topic not in self._message_handlers: return 0 if callback is not None: # Remove specific handler if callback in self._message_handlers[topic]: self._message_handlers[topic].remove(callback) # If handlers still exist, don't unsubscribe from broker yet if self._message_handlers[topic]: return 0 # No callback provided or no handlers left: unsubscribe from broker _logger.info("Unsubscribing from topic (redacted)") try: # Convert concurrent.futures.Future to asyncio.Future and await # Use shield to prevent cancellation from propagating to # underlying future unsubscribe_future, packet_id = self._connection.unsubscribe(topic) try: await asyncio.shield(asyncio.wrap_future(unsubscribe_future)) except asyncio.CancelledError: # Shield was cancelled - the underlying unsubscribe will # complete independently, preventing InvalidStateError # in AWS CRT callbacks _logger.debug( "Unsubscribe from topic (redacted) was " "cancelled but will complete in background" ) raise # Remove from tracking self._subscriptions.pop(topic, None) self._message_handlers.pop(topic, None) _logger.info("Unsubscribed from topic (redacted)") return int(packet_id) except (AwsCrtError, RuntimeError) as e: _logger.error(f"Failed to unsubscribe from topic (redacted): {e}") raise
[docs] async def resubscribe_all(self) -> None: """ Re-establish all subscriptions after a connection rebuild. This method is called after a deep reconnection to restore all active subscriptions. It uses the stored subscription information to re-subscribe to all topics with their original QoS settings and handlers. Note: This is typically called automatically during deep reconnection and should not need to be called manually. Raises: RuntimeError: If not connected to MQTT broker Exception: If any subscription fails """ if not self._connection: raise MqttNotConnectedError("Not connected to MQTT broker") if not self._subscriptions: _logger.debug("No subscriptions to restore") return subscription_count = len(self._subscriptions) _logger.info(f"Re-establishing {subscription_count} subscription(s)...") # Store subscriptions to re-establish (avoid modifying dict during # iteration) subscriptions_to_restore = list(self._subscriptions.items()) handlers_to_restore = { topic: handlers.copy() for topic, handlers in self._message_handlers.items() } # Clear current subscriptions (will be re-added by subscribe()) self._subscriptions.clear() self._message_handlers.clear() # Re-establish each subscription — one network call per topic, # regardless of how many handlers are registered for it. failed_subscriptions: set[str] = set() for topic, qos in subscriptions_to_restore: handlers = handlers_to_restore.get(topic, []) if not handlers: continue try: # One network subscribe for the first handler await self.subscribe(topic, handlers[0], qos) except (AwsCrtError, RuntimeError) as e: _logger.error( f"Failed to re-subscribe to '{redact_topic(topic)}': {e}" ) failed_subscriptions.add(topic) continue # Register remaining handlers without extra network calls for handler in handlers[1:]: if handler not in self._message_handlers[topic]: self._message_handlers[topic].append(handler) if failed_subscriptions: # Restore failed subscriptions to internal state so they can be # retried on the next reconnection cycle. qos_map = dict(subscriptions_to_restore) for topic in failed_subscriptions: self._subscriptions[topic] = qos_map.get( topic, mqtt.QoS.AT_LEAST_ONCE ) self._message_handlers[topic] = handlers_to_restore.get( topic, [] ) _logger.warning( f"Failed to restore {len(failed_subscriptions)} " "subscription(s); will retry on next reconnection" ) else: _logger.info("All subscriptions re-established successfully")
[docs] async def subscribe_device( self, device: Device, callback: Callable[[str, dict[str, Any]], None] ) -> int: """ Subscribe to all messages from a specific device. Args: device: Device object callback: Message handler Returns: Subscription packet ID """ # Subscribe to all command responses from device (broader pattern) # Device responses come on cmd/{device_type}/navilink-{device_id}/# device_id = device.device_info.mac_address device_type = str(device.device_info.device_type) response_topic = MqttTopicBuilder.command_topic( device_type, device_id, "#" ) return await self.subscribe(response_topic, callback)
[docs] async def subscribe_device_status( self, device: Device, callback: Callable[[DeviceStatus], None] ) -> int: """Subscribe to device status messages with automatic parsing.""" device_mac = device.device_info.mac_address def post_parse(status: DeviceStatus) -> None: self._schedule_coroutine( self._event_emitter.emit( "status_received", StatusReceivedEvent(device_mac=device_mac, status=status), ) ) self._schedule_coroutine( self._state_tracker.process(device_mac, status) ) handler = self._make_handler( DeviceStatus, callback, "status", post_parse, device_mac=device_mac ) return await self.subscribe_device(device=device, callback=handler)
[docs] async def unsubscribe_device_status( self, device: Device, callback: Callable[[DeviceStatus], None] ) -> None: """Unsubscribe a specific device status callback.""" device_id = device.device_info.mac_address device_type = str(device.device_info.device_type) topic = MqttTopicBuilder.command_topic(device_type, device_id, "#") target_handler = None if topic in self._message_handlers: for h in self._message_handlers[topic]: if getattr(h, "_original_callback", None) == callback: target_handler = h break if target_handler: await self.unsubscribe(topic, target_handler)
def _make_handler( self, model: Any, callback: Callable[[Any], None], key: str | None = None, post_parse: Callable[[Any], None] | None = None, device_mac: str | None = None, ) -> Callable[[str, dict[str, Any]], None]: """Generic factory for MQTT message handlers.""" def handler(topic: str, message: dict[str, Any]) -> None: try: data = get_response_data(message, key) if not data: return parsed = model.model_validate(data) if device_mac and hasattr(parsed, "mac_address"): parsed.mac_address = device_mac if post_parse: post_parse(parsed) callback(parsed) except ( ValidationError, KeyError, ValueError, TypeError, AttributeError, ) as e: _logger.warning( f"Error parsing {model.__name__} on {topic}: {e}" ) cast(Any, handler)._original_callback = callback return handler
[docs] async def subscribe_device_feature( self, device: Device, callback: Callable[[DeviceFeature], None] ) -> int: """Subscribe to device feature/info messages with automatic parsing.""" device_mac = device.device_info.mac_address def post_parse(feature: DeviceFeature) -> None: if self._device_info_cache: self._schedule_coroutine( self._device_info_cache.set(device_mac, feature) ) self._schedule_coroutine( self._event_emitter.emit( "feature_received", FeatureReceivedEvent( device_mac=device_mac, feature=feature ), ) ) handler = self._make_handler( DeviceFeature, callback, "feature", post_parse, device_mac=device_mac, ) return await self.subscribe_device(device=device, callback=handler)
[docs] async def unsubscribe_device_feature( self, device: Device, callback: Callable[[DeviceFeature], None] ) -> None: """Unsubscribe a specific device feature callback.""" device_id = device.device_info.mac_address device_type = str(device.device_info.device_type) topic = MqttTopicBuilder.command_topic(device_type, device_id, "#") if topic not in self._message_handlers: return # Find the specific internal handler that wraps this callback target_handler = None for h in self._message_handlers[topic]: if getattr(h, "_original_callback", None) == callback: target_handler = h break if target_handler: await self.unsubscribe(topic, target_handler)
[docs] async def subscribe_energy_usage( self, device: Device, callback: Callable[[EnergyUsageResponse], None], ) -> int: """Subscribe to energy usage responses with automatic parsing.""" handler = self._make_handler(EnergyUsageResponse, callback) topic = MqttTopicBuilder.response_topic( str(device.device_info.device_type), self._client_id, "energy-usage-daily-query/rd", ) return await self.subscribe(topic, handler)
[docs] async def unsubscribe_energy_usage( self, device: Device, callback: Callable[[EnergyUsageResponse], None], ) -> None: """Unsubscribe a specific energy usage callback.""" topic = MqttTopicBuilder.response_topic( str(device.device_info.device_type), self._client_id, "energy-usage-daily-query/rd", ) target_handler = None if topic in self._message_handlers: for h in self._message_handlers[topic]: if getattr(h, "_original_callback", None) == callback: target_handler = h break if target_handler: await self.unsubscribe(topic, target_handler)
[docs] async def subscribe_reservation_response( self, device: Device, callback: Callable[[ReservationSchedule], None], ) -> int: """Subscribe to reservation read responses with automatic parsing. Subscribes to the ``rsv/rd`` response topic for the given device. The callback receives a fully-parsed :class:`~nwp500.models.ReservationSchedule` whenever the device responds to a reservation read request. Args: device: Device whose reservation responses to receive. callback: Called with the parsed schedule on each response. Returns: Publish packet ID from the MQTT subscribe call. """ handler = self._make_handler(ReservationSchedule, callback) topic = MqttTopicBuilder.response_topic( str(device.device_info.device_type), self._client_id, "rsv/rd", ) return await self.subscribe(topic, handler)
[docs] async def unsubscribe_reservation_response( self, device: Device, callback: Callable[[ReservationSchedule], None], ) -> None: """Unsubscribe a specific reservation response callback.""" topic = MqttTopicBuilder.response_topic( str(device.device_info.device_type), self._client_id, "rsv/rd", ) target_handler = None if topic in self._message_handlers: for h in self._message_handlers[topic]: if getattr(h, "_original_callback", None) == callback: target_handler = h break if target_handler: await self.unsubscribe(topic, target_handler)
[docs] async def subscribe_weekly_reservation_response( self, device: Device, callback: Callable[[WeeklyReservationSchedule], None], ) -> int: """Subscribe to weekly reservation read responses. Subscribes to the ``rsv-weekly/rd`` response topic for the given device. The callback receives a :class:`~nwp500.models.WeeklyReservationSchedule` whenever the device responds to a weekly reservation read request. Args: device: Device whose weekly reservation responses to receive. callback: Called with the parsed schedule on each response. Returns: Publish packet ID from the MQTT subscribe call. """ handler = self._make_handler(WeeklyReservationSchedule, callback) topic = MqttTopicBuilder.response_topic( str(device.device_info.device_type), self._client_id, "rsv-weekly/rd", ) return await self.subscribe(topic, handler)
[docs] async def unsubscribe_weekly_reservation_response( self, device: Device, callback: Callable[[WeeklyReservationSchedule], None], ) -> None: """Unsubscribe a specific weekly reservation callback.""" topic = MqttTopicBuilder.response_topic( str(device.device_info.device_type), self._client_id, "rsv-weekly/rd", ) target_handler = None if topic in self._message_handlers: for h in self._message_handlers[topic]: if getattr(h, "_original_callback", None) == callback: target_handler = h break if target_handler: await self.unsubscribe(topic, target_handler)
[docs] async def subscribe_recirculation_schedule_response( self, device: Device, callback: Callable[[RecirculationSchedule], None], ) -> int: """Subscribe to recirculation schedule read responses. Subscribes to the ``recirc-rsv/rd`` response topic for the given device. The callback receives a :class:`~nwp500.models.RecirculationSchedule` whenever the device responds to a recirculation schedule read request. Args: device: Device whose recirculation schedule responses to receive. callback: Called with the parsed schedule on each response. Returns: Publish packet ID from the MQTT subscribe call. """ handler = self._make_handler(RecirculationSchedule, callback) topic = MqttTopicBuilder.response_topic( str(device.device_info.device_type), self._client_id, "recirc-rsv/rd", ) return await self.subscribe(topic, handler)
[docs] async def unsubscribe_recirculation_schedule_response( self, device: Device, callback: Callable[[RecirculationSchedule], None], ) -> None: """Unsubscribe a specific recirculation schedule callback.""" topic = MqttTopicBuilder.response_topic( str(device.device_info.device_type), self._client_id, "recirc-rsv/rd", ) target_handler = None if topic in self._message_handlers: for h in self._message_handlers[topic]: if getattr(h, "_original_callback", None) == callback: target_handler = h break if target_handler: await self.unsubscribe(topic, target_handler)
[docs] async def subscribe_tou_response( self, device: Device, callback: Callable[[TOUReservationSchedule], None], ) -> int: """Subscribe to Time-of-Use schedule read responses with automatic parsing. Subscribes to the ``tou/rd`` response topic for the given device. The callback receives a fully-parsed :class:`~nwp500.models.TOUReservationSchedule` whenever the device responds to a TOU read or configure request (triggered by :meth:`~nwp500.NavienMqttClient.request_tou_settings` or :meth:`~nwp500.NavienMqttClient.configure_tou_schedule`). Args: device: Device whose TOU responses to receive. callback: Called with the parsed schedule on each response. Returns: Publish packet ID from the MQTT subscribe call. """ handler = self._make_handler(TOUReservationSchedule, callback) topic = MqttTopicBuilder.response_topic( str(device.device_info.device_type), self._client_id, "tou/rd", ) return await self.subscribe(topic, handler)
[docs] async def unsubscribe_tou_response( self, device: Device, callback: Callable[[TOUReservationSchedule], None], ) -> None: """Unsubscribe a specific TOU response callback.""" topic = MqttTopicBuilder.response_topic( str(device.device_info.device_type), self._client_id, "tou/rd", ) target_handler = None if topic in self._message_handlers: for h in self._message_handlers[topic]: if getattr(h, "_original_callback", None) == callback: target_handler = h break if target_handler: await self.unsubscribe(topic, target_handler)
[docs] def clear_subscriptions(self) -> None: """Clear all subscription tracking (called on disconnect).""" self._subscriptions.clear() self._message_handlers.clear() self._state_tracker.clear()