"""
MQTT connection management for Navien Smart Control.
This module handles establishing and maintaining the MQTT connection to AWS IoT
Core,
including credential management and connection state tracking.
"""
from __future__ import annotations
import asyncio
import json
import logging
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, cast
from awscrt import mqtt
from awscrt.exceptions import AwsCrtError
from awsiot import mqtt_connection_builder
from ..exceptions import (
MqttCredentialsError,
MqttNotConnectedError,
)
if TYPE_CHECKING:
from ..auth import NavienAuthClient
from .utils import MqttConnectionConfig
__author__ = "Emmanuel Levijarvi"
__copyright__ = "Emmanuel Levijarvi"
__license__ = "MIT"
_logger = logging.getLogger(__name__)
[docs]
class MqttConnection:
"""
Manages MQTT connection lifecycle to AWS IoT Core.
Handles:
- Connection establishment with AWS credentials
- Disconnection with cleanup
- Connection state tracking
- AWS credentials provider creation
"""
def __init__(
self,
config: MqttConnectionConfig,
auth_client: NavienAuthClient,
on_connection_interrupted: (
Callable[[mqtt.Connection, AwsCrtError], None] | None
) = None,
on_connection_resumed: (
Callable[[mqtt.Connection, Any, Any | None], None] | None
) = None,
):
"""
Initialize connection manager.
Args:
config: MQTT connection configuration
auth_client: Authenticated Navien auth client with AWS credentials
on_connection_interrupted: Callback for connection interruption
on_connection_resumed: Callback for connection resumption
Raises:
ValueError: If auth client not authenticated or missing AWS
credentials
"""
if not auth_client.is_authenticated:
raise ValueError(
"Authentication client must be authenticated before "
"creating connection manager."
)
if not auth_client.current_tokens:
raise MqttCredentialsError("No tokens available from auth client")
auth_tokens = auth_client.current_tokens
if not auth_tokens.access_key_id or not auth_tokens.secret_key:
raise ValueError(
"AWS credentials not available in auth tokens. "
"Ensure authentication provides AWS IoT credentials."
)
self.config = config
self._auth_client = auth_client
self._connection: mqtt.Connection | None = None
self._connected = False
self._on_connection_interrupted = on_connection_interrupted
self._on_connection_resumed = on_connection_resumed
_logger.info(
f"Initialized connection manager with client ID: {config.client_id}"
)
[docs]
async def connect(self) -> bool:
"""
Establish connection to AWS IoT Core.
Ensures tokens are valid before connecting and refreshes if necessary.
Returns:
True if connection successful
Raises:
Exception: If connection fails
"""
if self._connected:
_logger.warning("Already connected")
return True
# Ensure we have valid tokens before connecting
await self._auth_client.ensure_valid_token()
_logger.info(f"Connecting to AWS IoT endpoint: {self.config.endpoint}")
_logger.debug(f"Client ID: {self.config.client_id}")
_logger.debug(f"Region: {self.config.region}")
try:
# Build WebSocket MQTT connection with AWS credentials
# Run blocking operations in a thread to avoid blocking the event
# loop
# The AWS IoT SDK performs synchronous file I/O operations during
# connection setup
credentials_provider = await asyncio.to_thread(
self._create_credentials_provider
)
self._connection = await asyncio.to_thread(
mqtt_connection_builder.websockets_with_default_aws_signing,
endpoint=self.config.endpoint,
region=self.config.region,
credentials_provider=credentials_provider,
client_id=self.config.client_id,
clean_session=self.config.clean_session,
keep_alive_secs=self.config.keep_alive_secs,
on_connection_interrupted=self._on_connection_interrupted,
on_connection_resumed=self._on_connection_resumed,
)
# Connect
_logger.info("Establishing MQTT connection...")
# Convert concurrent.futures.Future to asyncio.Future and await
# Use shield to prevent cancellation from propagating to
# underlying future
if not self._connection:
raise RuntimeError("Connection not initialized")
connect_future = cast(
asyncio.Future[Any], self._connection.connect()
)
try:
connect_result = await asyncio.shield(
asyncio.wrap_future(connect_future)
)
except asyncio.CancelledError:
# Shield was cancelled - the underlying connect will
# complete independently, preventing InvalidStateError
# in AWS CRT callbacks
_logger.debug(
"Connect operation was cancelled but will complete "
"in background"
)
raise
self._connected = True
_logger.info(
f"Connected successfully: "
f"session_present={connect_result['session_present']}"
)
return True
except (AwsCrtError, RuntimeError, ValueError) as e:
_logger.error(f"Failed to connect: {e}")
raise
def _create_credentials_provider(self) -> Any:
"""
Create AWS credentials provider from auth tokens.
Returns:
AWS credentials provider for MQTT connection
Raises:
ValueError: If tokens are not available
"""
from awscrt.auth import (
AwsCredentialsProvider,
)
# Get current tokens from auth client
auth_tokens = self._auth_client.current_tokens
if (
not auth_tokens
or not auth_tokens.access_key_id
or not auth_tokens.secret_key
):
raise MqttCredentialsError("AWS credentials not available")
return AwsCredentialsProvider.new_static(
access_key_id=auth_tokens.access_key_id,
secret_access_key=auth_tokens.secret_key,
session_token=auth_tokens.session_token,
)
[docs]
async def disconnect(self) -> None:
"""
Disconnect from AWS IoT Core.
Raises:
Exception: If disconnect fails
"""
if not self._connected or not self._connection:
_logger.warning("Not connected")
return
_logger.info("Disconnecting from AWS IoT...")
try:
# Convert concurrent.futures.Future to asyncio.Future and await
# Use shield to prevent cancellation from propagating to
# underlying future
disconnect_future = cast(
asyncio.Future[Any], self._connection.disconnect()
)
try:
await asyncio.shield(asyncio.wrap_future(disconnect_future))
except asyncio.CancelledError:
# Shield was cancelled - the underlying disconnect will
# complete independently, preventing InvalidStateError
# in AWS CRT callbacks
_logger.debug(
"Disconnect operation was cancelled but will complete "
"in background"
)
raise
self._connected = False
self._connection = None
_logger.info("Disconnected successfully")
except (AwsCrtError, RuntimeError) as e:
_logger.error(f"Error during disconnect: {e}")
raise
[docs]
async def close(self) -> None:
"""Unconditionally close the underlying SDK connection.
Unlike :meth:`disconnect`, this method closes the connection
regardless of the ``_connected`` flag. After a connection
interruption, ``_connected`` is ``False`` but the SDK connection
object is still alive and its built-in auto-reconnect can still
fire. Calling ``close()`` ensures the SDK connection is fully
torn down so its callbacks and auto-reconnect cannot interfere
with a replacement connection.
This method is safe to call multiple times or on already-closed
connections.
"""
connection = self._connection
self._connection = None
self._connected = False
if connection is None:
return
_logger.debug("Closing underlying SDK connection...")
try:
disconnect_future = cast(
asyncio.Future[Any], connection.disconnect()
)
await asyncio.shield(asyncio.wrap_future(disconnect_future))
_logger.debug("SDK connection closed")
except (AwsCrtError, RuntimeError) as e:
# Expected when connection is already dead or in bad state
_logger.debug(f"SDK connection close (benign): {e}")
except asyncio.CancelledError:
_logger.debug(
"Close operation cancelled but SDK disconnect "
"will complete in background"
)
raise
[docs]
async def subscribe(
self,
topic: str,
qos: mqtt.QoS,
callback: Callable[..., None] | None = None,
) -> tuple[Any, int]:
"""
Subscribe to an MQTT topic.
Args:
topic: Topic pattern to subscribe to (supports wildcards)
qos: Quality of Service level
callback: Optional callback for received messages
Returns:
Tuple of (subscribe_future, packet_id)
Raises:
RuntimeError: If not connected
"""
if not self._connected or not self._connection:
raise MqttNotConnectedError("Not connected to MQTT broker")
_logger.debug(f"Subscribing to topic: {topic}")
# Convert concurrent.futures.Future to asyncio.Future and await
# Use shield to prevent cancellation from propagating to
# underlying future
subscribe_future_raw, packet_id_raw = self._connection.subscribe(
topic=topic, qos=qos, callback=callback
)
subscribe_future = cast(asyncio.Future[Any], subscribe_future_raw)
packet_id = packet_id_raw
try:
await asyncio.shield(asyncio.wrap_future(subscribe_future))
except asyncio.CancelledError:
# Shield was cancelled - the underlying subscribe will
# complete independently, preventing InvalidStateError
# in AWS CRT callbacks
_logger.debug(
f"Subscribe to '{topic}' was cancelled but will complete "
"in background"
)
raise
_logger.info(f"Subscribed to '{topic}' with packet_id {packet_id}")
return (subscribe_future, packet_id)
[docs]
async def unsubscribe(self, topic: str) -> int:
"""
Unsubscribe from an MQTT topic.
Args:
topic: Topic to unsubscribe from
Returns:
Packet ID
Raises:
RuntimeError: If not connected
"""
if not self._connected or not self._connection:
raise MqttNotConnectedError("Not connected to MQTT broker")
_logger.debug(f"Unsubscribing from topic: {topic}")
# Convert concurrent.futures.Future to asyncio.Future and await
# Use shield to prevent cancellation from propagating to
# underlying future
unsubscribe_future_raw, packet_id_raw = self._connection.unsubscribe(
topic=topic
)
unsubscribe_future = cast(asyncio.Future[Any], unsubscribe_future_raw)
packet_id = int(packet_id_raw)
try:
await asyncio.shield(asyncio.wrap_future(unsubscribe_future))
except asyncio.CancelledError:
# Shield was cancelled - the underlying unsubscribe will
# complete independently, preventing InvalidStateError
# in AWS CRT callbacks
_logger.debug(
f"Unsubscribe from '{topic}' was cancelled but will "
"complete in background"
)
raise
_logger.info(f"Unsubscribed from '{topic}' with packet_id {packet_id}")
return packet_id
[docs]
async def publish(
self,
topic: str,
payload: str | dict[str, Any],
qos: mqtt.QoS = mqtt.QoS.AT_LEAST_ONCE,
) -> int:
"""
Publish a message to an MQTT topic.
Args:
topic: MQTT topic to publish to
payload: Message payload (dict, JSON string, or bytes)
qos: Quality of Service level
Returns:
Publish packet ID
Raises:
RuntimeError: If not connected
asyncio.CancelledError: If operation cancelled during disconnect
"""
if not self._connected or not self._connection:
raise MqttNotConnectedError("Not connected to MQTT broker")
_logger.debug(f"Publishing to topic: {topic}")
# Convert payload to bytes if needed
if isinstance(payload, dict):
payload_bytes = json.dumps(payload).encode("utf-8")
else:
# payload is str
payload_bytes = payload.encode("utf-8")
# Publish and get the concurrent.futures.Future
publish_future_raw, packet_id_raw = self._connection.publish(
topic=topic, payload=payload_bytes, qos=qos
)
publish_future = cast(asyncio.Future[Any], publish_future_raw)
packet_id = int(packet_id_raw)
# Shield the operation to prevent cancellation from propagating to
# the underlying concurrent.futures.Future. This avoids
# InvalidStateError when AWS CRT tries to set exception on a
# cancelled future.
try:
await asyncio.shield(asyncio.wrap_future(publish_future))
except asyncio.CancelledError:
# Shield was cancelled - the underlying publish will complete
# independently, preventing InvalidStateError in AWS CRT
# callbacks
_logger.debug(
f"Publish to '{topic}' was cancelled but will complete "
"in background"
)
raise
except AwsCrtError as e:
# Handle connection destruction during publish
# This can happen when AWS IoT Core disconnects (e.g., 24-hour
# timeout)
error_name = getattr(e, "name", None)
if error_name == "AWS_ERROR_MQTT_CONNECTION_DESTROYED":
_logger.warning(
f"MQTT connection destroyed during publish to '{topic}'. "
"This can occur during AWS-initiated disconnections. "
"Reconnection will be attempted automatically."
)
# Mark as disconnected so reconnection handler can take over
self._connected = False
raise
_logger.debug(f"Published to '{topic}' with packet_id {packet_id}")
return packet_id
@property
def is_connected(self) -> bool:
"""Check if currently connected."""
return self._connected
@property
def connection(self) -> mqtt.Connection | None:
"""Get the underlying MQTT connection.
Returns:
The MQTT connection object, or None if not connected
Note:
This property is provided for advanced usage. Most operations
should use the higher-level methods provided by this class.
"""
return self._connection