Source code for nwp500.mqtt.reconnection
"""
MQTT reconnection handler for Navien Smart Control.
This module handles automatic reconnection with exponential backoff when
the MQTT connection is interrupted.
"""
from __future__ import annotations
import asyncio
import contextlib
import logging
from collections.abc import Awaitable, Callable
from typing import TYPE_CHECKING, Any
from awscrt.exceptions import AwsCrtError
if TYPE_CHECKING:
from .utils import MqttConnectionConfig
__author__ = "Emmanuel Levijarvi"
__copyright__ = "Emmanuel Levijarvi"
__license__ = "MIT"
_logger = logging.getLogger(__name__)
[docs]
class MqttReconnectionHandler:
"""
Handles automatic reconnection logic with exponential backoff.
This class manages reconnection attempts when the MQTT connection is
interrupted, implementing exponential backoff and configurable retry limits.
"""
def __init__(
self,
config: MqttConnectionConfig,
is_connected_func: Callable[[], bool],
schedule_coroutine_func: Callable[[Any], None],
reconnect_func: Callable[[], Awaitable[None]],
deep_reconnect_func: Callable[[], Awaitable[None]] | None = None,
emit_event_func: Callable[..., Awaitable[Any]] | None = None,
):
"""
Initialize reconnection handler.
Args:
config: MQTT connection configuration
is_connected_func: Function to check if currently connected
schedule_coroutine_func: Function to schedule coroutines from any
thread
reconnect_func: Async function to trigger active reconnection
deep_reconnect_func: Optional async function to trigger deep
reconnection (full rebuild)
emit_event_func: Optional async function to emit events
(e.g., EventEmitter.emit)
"""
self.config = config
self._is_connected_func = is_connected_func
self._schedule_coroutine = schedule_coroutine_func
self._reconnect_func = reconnect_func
self._deep_reconnect_func = deep_reconnect_func
self._emit_event = emit_event_func
self._reconnect_attempts = 0
self._reconnect_task: asyncio.Task[None] | None = None
self._manual_disconnect = False
self._enabled = False
[docs]
def enable(self) -> None:
"""Enable automatic reconnection."""
self._enabled = True
self._manual_disconnect = False
_logger.debug("Automatic reconnection enabled")
[docs]
def disable(self) -> None:
"""Disable automatic reconnection (e.g., for manual disconnect)."""
self._enabled = False
self._manual_disconnect = True
_logger.debug("Automatic reconnection disabled")
# Cancel any pending reconnection task
if self._reconnect_task and not self._reconnect_task.done():
self._reconnect_task.cancel()
self._reconnect_task = None
[docs]
def on_connection_interrupted(self, error: Exception) -> None:
"""
Handle connection interruption.
Args:
error: Error that caused the interruption
"""
_logger.warning(f"Connection interrupted: {error}")
# Start automatic reconnection if enabled.
# Also guard against stale interruption events that arrive after the
# connection has already been restored: these can be queued via
# run_coroutine_threadsafe and fire after on_connection_resumed has
# cancelled _reconnect_task (setting it to None), which would
# otherwise bypass the task-existence check and spawn a new backoff
# loop while the client is perfectly healthy.
if (
self.config.auto_reconnect
and self._enabled
and not self._manual_disconnect
and not self._is_connected_func()
and (not self._reconnect_task or self._reconnect_task.done())
):
_logger.info("Starting automatic reconnection...")
self._schedule_coroutine(self._start_reconnect_task())
[docs]
def on_connection_resumed(
self, return_code: Any, session_present: Any
) -> None:
"""
Handle connection resumption.
Args:
return_code: MQTT return code
session_present: Whether session was present
"""
_logger.info(
f"Connection resumed: return_code={return_code}, "
f"session_present={session_present}"
)
# Reset reconnection attempts on successful connection
self._reconnect_attempts = 0
# Cancel any pending reconnection task
if self._reconnect_task and not self._reconnect_task.done():
self._reconnect_task.cancel()
self._reconnect_task = None
async def _start_reconnect_task(self) -> None:
"""
Start the reconnect task within the event loop.
This is a helper method to create the reconnect task from within
a coroutine that's scheduled via _schedule_coroutine.
The is_connected guard is re-checked here because this coroutine may
be queued via run_coroutine_threadsafe and run after the connection
has already been restored (e.g. by on_connection_resumed cancelling
_reconnect_task), in which case starting a new backoff loop would
incorrectly tear down a healthy connection.
"""
if not self._is_connected_func() and (
not self._reconnect_task or self._reconnect_task.done()
):
self._reconnect_task = asyncio.create_task(
self._reconnect_with_backoff()
)
async def _reconnect_with_backoff(self) -> None:
"""
Attempt to reconnect with exponential backoff.
This method is called automatically when connection is interrupted
if auto_reconnect is enabled. Supports unlimited retries when
max_reconnect_attempts is -1.
Uses a two-tier strategy:
- Quick reconnects (attempts 1-N): Fast reconnection with existing setup
- Deep reconnects (attempts N+): Full rebuild including token refresh
"""
unlimited_retries = self.config.max_reconnect_attempts < 0
while (
not self._is_connected_func()
and not self._manual_disconnect
and (
unlimited_retries
or self._reconnect_attempts < self.config.max_reconnect_attempts
)
):
self._reconnect_attempts += 1
# Determine if we should do a deep reconnection
has_deep_reconnect = self._deep_reconnect_func is not None
is_at_threshold = (
self._reconnect_attempts >= self.config.deep_reconnect_threshold
)
is_threshold_multiple = (
self._reconnect_attempts % self.config.deep_reconnect_threshold
== 0
)
use_deep_reconnect = (
has_deep_reconnect and is_at_threshold and is_threshold_multiple
)
# Calculate delay with exponential backoff
delay = min(
self.config.initial_reconnect_delay
* (
self.config.reconnect_backoff_multiplier
** (self._reconnect_attempts - 1)
),
self.config.max_reconnect_delay,
)
if unlimited_retries:
reconnect_type = "deep" if use_deep_reconnect else "quick"
_logger.info(
"Reconnection attempt %d (%s) in %.1f seconds...",
self._reconnect_attempts,
reconnect_type,
delay,
)
else:
_logger.info(
"Reconnection attempt %d/%d in %.1f seconds...",
self._reconnect_attempts,
self.config.max_reconnect_attempts,
delay,
)
try:
await asyncio.sleep(delay)
# Check if we're already connected (AWS SDK auto-reconnected)
if self._is_connected_func():
_logger.info(
"AWS IoT SDK automatically reconnected during delay"
)
break
# Trigger appropriate reconnection type
if use_deep_reconnect and self._deep_reconnect_func is not None:
_logger.info(
"Triggering deep reconnection "
"(full rebuild with token refresh)..."
)
try:
await self._deep_reconnect_func()
if self._is_connected_func():
_logger.info(
"Successfully reconnected via deep reconnection"
)
break
except (AwsCrtError, RuntimeError, ValueError) as e:
_logger.warning(
f"Deep reconnection failed: {e}. Will retry..."
)
else:
_logger.info("Triggering quick reconnection...")
try:
await self._reconnect_func()
if self._is_connected_func():
_logger.info(
"Successfully reconnected via "
"quick reconnection"
)
break
except (AwsCrtError, RuntimeError) as e:
_logger.warning(
f"Quick reconnection failed: {e}. Will retry..."
)
except asyncio.CancelledError:
_logger.info("Reconnection task cancelled")
break
except (AwsCrtError, RuntimeError) as e:
_logger.error(
f"Error during reconnection attempt: {e}", exc_info=True
)
# Check final state (only if not unlimited retries)
if (
not unlimited_retries
and self._reconnect_attempts >= self.config.max_reconnect_attempts
and not self._is_connected_func()
):
_logger.error(
f"Failed to reconnect after "
f"{self.config.max_reconnect_attempts} attempts. "
"Manual reconnection required."
)
# Emit reconnection_failed event if event emitter is available
if self._emit_event:
try:
await self._emit_event(
"reconnection_failed", self._reconnect_attempts
)
except (TypeError, RuntimeError) as e:
_logger.error(
f"Error emitting reconnection_failed event: {e}"
)
[docs]
async def cancel(self) -> None:
"""Cancel any pending reconnection task."""
if self._reconnect_task and not self._reconnect_task.done():
self._reconnect_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self._reconnect_task
self._reconnect_task = None
@property
def is_reconnecting(self) -> bool:
"""Check if currently attempting to reconnect."""
return (
self._reconnect_task is not None and not self._reconnect_task.done()
)
@property
def attempt_count(self) -> int:
"""Get the number of reconnection attempts made."""
return self._reconnect_attempts
[docs]
def reset_attempts(self) -> None:
"""Reset the reconnection attempt counter."""
self._reconnect_attempts = 0
[docs]
def reset(self) -> None:
"""Reset reconnection state and enable reconnection."""
self._reconnect_attempts = 0
self.enable()