Source code for nwp500.auth

"""
Authentication module for Navien Smart Control API.

This module provides authentication functionality for the Navien Smart Control
REST API, including sign-in, token management, and token refresh capabilities.

The API uses JWT (JSON Web Tokens) for authentication with the following flow:
1. Sign in with email and password
2. Receive idToken, accessToken, and refreshToken
3. Use accessToken as Bearer token in subsequent requests
4. Refresh tokens when accessToken expires
"""

from __future__ import annotations

import json
import logging
from datetime import UTC, datetime, timedelta
from typing import Any, Self, cast

import aiohttp
from pydantic import (
    Field,
    PrivateAttr,
    field_validator,
    model_validator,
)

from . import __version__
from ._base import NavienBaseModel
from .config import API_BASE_URL, REFRESH_ENDPOINT, SIGN_IN_ENDPOINT
from .exceptions import (
    AuthenticationError,
    InvalidCredentialsError,
    TokenRefreshError,
)
from .unit_system import UnitSystemType

__author__ = "Emmanuel Levijarvi"
__copyright__ = "Emmanuel Levijarvi"
__license__ = "MIT"

_logger = logging.getLogger(__name__)


[docs] class UserInfo(NavienBaseModel): """User information returned from authentication.""" user_type: str = "" user_first_name: str = "" user_last_name: str = "" user_status: str = "" user_seq: int = 0 @property def full_name(self) -> str: """Return the user's full name.""" return f"{self.user_first_name} {self.user_last_name}".strip()
[docs] class AuthTokens(NavienBaseModel): """Authentication tokens and AWS credentials returned from the API.""" id_token: str = "" access_token: str = "" refresh_token: str = "" authentication_expires_in: int = 3600 access_key_id: str | None = None secret_key: str | None = None session_token: str | None = None authorization_expires_in: int | None = None # Calculated fields issued_at: datetime = Field(default_factory=lambda: datetime.now(UTC)) _expires_at: datetime = PrivateAttr() _aws_expires_at: datetime | None = PrivateAttr(default=None) @field_validator("issued_at", mode="before") @classmethod def _normalize_issued_at_tz(cls, v: Any) -> Any: """Assume UTC for timezone-naive datetimes. Handles old stored tokens that may not have timezone info, whether provided as a datetime object or an ISO 8601 string. """ if isinstance(v, str) and not v.endswith("Z"): # Check for a timezone offset (+HH:MM or -HH:MM) in the time # portion only (after the 'T' separator), so that date-part hyphens # like "2026-02-17" are not mistaken for a negative offset. t_pos = v.find("T") time_part = v[t_pos + 1 :] if t_pos >= 0 else v if "+" not in time_part and "-" not in time_part: return v + "+00:00" if isinstance(v, datetime) and v.tzinfo is None: return v.replace(tzinfo=UTC) return v
[docs] @model_validator(mode="before") @classmethod def handle_empty_aliases(cls, data: Any) -> Any: """Handle empty camelCase aliases with snake_case fallbacks.""" if isinstance(data, dict): # Explicitly type data as dict for clarity and type safety d = cast(dict[str, Any], data) # Fields to check for fallback fields_to_check: list[tuple[str, str]] = [ ("accessToken", "access_token"), ("accessKeyId", "access_key_id"), ("secretKey", "secret_key"), ("refreshToken", "refresh_token"), ("sessionToken", "session_token"), ("authenticationExpiresIn", "authentication_expires_in"), ("authorizationExpiresIn", "authorization_expires_in"), ("idToken", "id_token"), ] for camel, snake in fields_to_check: # If camel exists but is empty/None, and snake exists, use snake if camel in d and not d[camel] and snake in d: d[camel] = d[snake] return d return data
[docs] def model_post_init(self, __context: Any) -> None: """Cache the expiration timestamp after initialization.""" # Pre-calculate and cache the expiration time self._expires_at = self.issued_at + timedelta( seconds=self.authentication_expires_in ) # Calculate AWS credentials expiration if available if self.authorization_expires_in: self._aws_expires_at = self.issued_at + timedelta( seconds=self.authorization_expires_in ) else: self._aws_expires_at = None
@property def expires_at(self) -> datetime: """Get the cached expiration timestamp.""" return self._expires_at @property def is_expired(self) -> bool: """Check if the access token has expired (cached calculation).""" # Consider expired if within 5 minutes of expiration return datetime.now(UTC) >= (self._expires_at - timedelta(minutes=5)) @property def are_aws_credentials_expired(self) -> bool: """Check if AWS credentials have expired. AWS credentials have a separate expiration time from JWT tokens. If AWS credentials are expired, a full re-authentication is needed since the token refresh endpoint doesn't provide new AWS credentials. Returns: True if AWS credentials are expired, False if expiration time is unknown or credentials are still valid """ if not self._aws_expires_at: # If we don't know when AWS credentials expire, consider them valid # This handles cases where authorization_expires_in wasn't provided return False # Consider expired if within 5 minutes of expiration return datetime.now(UTC) >= ( self._aws_expires_at - timedelta(minutes=5) ) @property def time_until_expiry(self) -> timedelta: """Get time remaining until token expiration. Uses cached expiration time for efficiency. """ return self._expires_at - datetime.now(UTC) @property def bearer_token(self) -> str: """Get the formatted Bearer token for Authorization header.""" return f"Bearer {self.access_token}"
[docs] def to_dict(self) -> dict[str, Any]: """Convert tokens to a dictionary for serialization. This includes the calculated issued_at timestamp, which is needed to maintain the correct expiration time when restoring tokens. """ data = self.model_dump() # Ensure issued_at is serialized in a format that model_validate can # parse if isinstance(data.get("issued_at"), datetime): data["issued_at"] = ( data["issued_at"].strftime("%Y-%m-%dT%H:%M:%S.%f") + "Z" ) return data
[docs] class AuthenticationResponse(NavienBaseModel): """Complete authentication response including user info and tokens.""" user_info: UserInfo tokens: AuthTokens legal: list[Any] = Field(default_factory=list) code: int = 200 message: str = Field(default="SUCCESS", alias="msg")
[docs] @model_validator(mode="before") @classmethod def wrap_api_response(cls, data: Any) -> Any: """Handle nested 'data' wrapper in API responses.""" if isinstance(data, dict) and "data" in data: # Lift fields from 'data' into the top level for validation # while preserving top-level code/msg response_data = data.get("data", {}) if isinstance(response_data, dict): merged = {**data, **response_data} # Handle 'token' vs 'tokens' inconsistency in API if "token" in response_data and "tokens" not in response_data: merged["tokens"] = response_data["token"] return merged return data
__all__ = [ "UserInfo", "AuthTokens", "AuthenticationResponse", "NavienAuthClient", "authenticate", "refresh_access_token", ] # Convenience functions for one-off authentication
[docs] async def authenticate(user_id: str, password: str) -> AuthenticationResponse: """Authenticate user and obtain tokens. This is a convenience function that creates a temporary auth client, authenticates, and returns the response. Args: user_id: User email address password: User password Returns: AuthenticationResponse with user info and tokens Example: >>> response = await authenticate("user@example.com", "password") >>> print(f"Welcome {response.user.full_name}") >>> # Use the bearer token for API requests >>> # Do not print tokens in production code """ async with NavienAuthClient(user_id, password) as client: auth_response = client.auth_response if auth_response is None: raise AuthenticationError( "Authentication failed: no response received" ) return auth_response
[docs] async def refresh_access_token(refresh_token: str) -> AuthTokens: """Refresh an access token using a refresh token. This is a convenience function that creates a temporary session to perform the token refresh operation without requiring full authentication. Args: refresh_token: The refresh token Returns: New AuthTokens Example: >>> new_tokens = await refresh_access_token(old_tokens.refresh_token) Note: This function creates a temporary client without authentication to perform the token refresh operation. """ url = f"{API_BASE_URL}{REFRESH_ENDPOINT}" payload = {"refreshToken": refresh_token} # Use ThreadedResolver for reliable DNS in containerized environments resolver = aiohttp.ThreadedResolver() connector = aiohttp.TCPConnector(resolver=resolver) async with ( aiohttp.ClientSession(connector=connector) as session, session.post(url, json=payload) as response, ): response_data = await response.json() code = response_data.get("code", response.status) msg = response_data.get("msg", "") if code != 200 or not response.ok: raise TokenRefreshError( f"Failed to refresh token: {msg}", status_code=code, response=response_data, ) data = response_data.get("data", {}) return AuthTokens.model_validate(data)