diff --git a/aioesphomeapi/__init__.py b/aioesphomeapi/__init__.py index 92d12b4d..1c649445 100644 --- a/aioesphomeapi/__init__.py +++ b/aioesphomeapi/__init__.py @@ -1,6 +1,16 @@ # flake8: noqa from .client import APIClient from .connection import APIConnection, ConnectionParams -from .core import MESSAGE_TYPE_TO_PROTO, APIConnectionError +from .core import ( + MESSAGE_TYPE_TO_PROTO, + APIConnectionError, + HandshakeAPIError, + InvalidAuthAPIError, + InvalidEncryptionKeyAPIError, + ProtocolAPIError, + RequiresEncryptionAPIError, + ResolveAPIError, + SocketAPIError, +) from .model import * from .reconnect_logic import ReconnectLogic diff --git a/aioesphomeapi/client.py b/aioesphomeapi/client.py index 09a24dd9..08d13c3f 100644 --- a/aioesphomeapi/client.py +++ b/aioesphomeapi/client.py @@ -139,7 +139,8 @@ def __init__( client_info=client_info, keepalive=keepalive, zeroconf_instance=zeroconf_instance, - noise_psk=noise_psk, + # treat empty psk string as missing (like password) + noise_psk=noise_psk or None, ) self._connection: Optional[APIConnection] = None self._cached_name: Optional[str] = None diff --git a/aioesphomeapi/connection.py b/aioesphomeapi/connection.py index d3553240..41887104 100644 --- a/aioesphomeapi/connection.py +++ b/aioesphomeapi/connection.py @@ -3,6 +3,7 @@ import logging import socket import time +from contextlib import suppress from dataclasses import astuple, dataclass from typing import Any, Awaitable, Callable, List, Optional @@ -23,7 +24,17 @@ PingRequest, PingResponse, ) -from .core import MESSAGE_TYPE_TO_PROTO, APIConnectionError +from .core import ( + MESSAGE_TYPE_TO_PROTO, + APIConnectionError, + HandshakeAPIError, + InvalidAuthAPIError, + InvalidEncryptionKeyAPIError, + ProtocolAPIError, + RequiresEncryptionAPIError, + ResolveAPIError, + SocketAPIError, +) from .model import APIVersion from .util import bytes_to_varuint, varuint_to_bytes @@ -41,12 +52,6 @@ class ConnectionParams: zeroconf_instance: hr.ZeroconfInstanceType noise_psk: Optional[str] - @property - def noise_psk_bytes(self) -> Optional[bytes]: - if self.noise_psk is None: - return None - return base64.b64decode(self.noise_psk) - @dataclass class Packet: @@ -87,18 +92,18 @@ async def _write_frame_noise(self, frame: bytes) -> None: self._writer.write(header + frame) await self._writer.drain() except OSError as err: - raise APIConnectionError(f"Error while writing data: {err}") from err + raise SocketAPIError(f"Error while writing data: {err}") from err async def _read_frame_noise(self) -> bytes: try: async with self._read_lock: header = await self._reader.readexactly(3) if header[0] != 0x01: - raise APIConnectionError(f"Marker byte invalid: {header[0]}") + raise ProtocolAPIError(f"Marker byte invalid: {header[0]}") msg_size = (header[1] << 8) | header[2] frame = await self._reader.readexactly(msg_size) except (asyncio.IncompleteReadError, OSError, TimeoutError) as err: - raise APIConnectionError(f"Error while reading data: {err}") from err + raise SocketAPIError(f"Error while reading data: {err}") from err _LOGGER.debug("Received frame %s", frame.hex()) return frame @@ -110,16 +115,28 @@ async def perform_handshake(self) -> None: prologue = b"NoiseAPIInit" + b"\x00\x00" server_hello = await self._read_frame_noise() # ServerHello if not server_hello: - raise APIConnectionError("ServerHello is empty") + raise HandshakeAPIError("ServerHello is empty") chosen_proto = server_hello[0] if chosen_proto != 0x01: - raise APIConnectionError( + raise HandshakeAPIError( f"Unknown protocol selected by client {chosen_proto}" ) self._proto = NoiseConnection.from_name(b"Noise_NNpsk0_25519_ChaChaPoly_SHA256") self._proto.set_as_initiator() - self._proto.set_psks(self._params.noise_psk_bytes) + + try: + noise_psk_bytes = base64.b64decode(self._params.noise_psk) + except ValueError: + raise InvalidEncryptionKeyAPIError( + f"Malformed PSK {self._params.noise_psk}, expected base64-encoded value" + ) + if len(noise_psk_bytes) != 32: + raise InvalidEncryptionKeyAPIError( + f"Malformed PSK {self._params.noise_psk}, expected 32-bytes of base64 data" + ) + + self._proto.set_psks(noise_psk_bytes) self._proto.set_prologue(prologue) self._proto.start_handshake() @@ -131,8 +148,13 @@ async def perform_handshake(self) -> None: await self._write_frame_noise(b"\x00" + msg) else: msg = await self._read_frame_noise() - if not msg or msg[0] != 0: - raise APIConnectionError(f"Handshake failure: {msg[1:].decode()}") + if not msg: + raise HandshakeAPIError("Handshake message too short") + if msg[0] != 0: + explanation = msg[1:].decode() + if explanation == "Handshake MAC failure": + raise InvalidEncryptionKeyAPIError("Invalid encryption key") + raise HandshakeAPIError(f"Handshake failure: {explanation}") self._proto.read_message(msg[1:]) do_write = not do_write @@ -170,7 +192,7 @@ async def _write_packet_plaintext(self, packet: Packet) -> None: self._writer.write(data) await self._writer.drain() except OSError as err: - raise APIConnectionError(f"Error while writing data: {err}") from err + raise SocketAPIError(f"Error while writing data: {err}") from err async def write_packet(self, packet: Packet) -> None: if self._params.noise_psk is None: @@ -184,11 +206,11 @@ async def _read_packet_noise(self) -> Packet: assert self._proto is not None msg = self._proto.decrypt(frame) if len(msg) < 4: - raise APIConnectionError(f"Bad packet frame: {msg}") + raise ProtocolAPIError(f"Bad packet frame: {msg}") pkt_type = (msg[0] << 8) | msg[1] data_len = (msg[2] << 8) | msg[3] if data_len + 4 > len(msg): - raise APIConnectionError(f"Bad data len: {data_len} vs {len(msg)}") + raise ProtocolAPIError(f"Bad data len: {data_len} vs {len(msg)}") data = msg[4 : 4 + data_len] return Packet(type=pkt_type, data=data) @@ -196,7 +218,9 @@ async def _read_packet_plaintext(self) -> Packet: async with self._read_lock: preamble = await self._reader.readexactly(1) if preamble[0] != 0x00: - raise APIConnectionError("Invalid preamble") + if preamble[0] == 0x01: + raise RequiresEncryptionAPIError("Connection requires encryption") + raise ProtocolAPIError(f"Invalid preamble {preamble[0]:02x}") length = b"" while not length or (length[-1] & 0x80) == 0x80: @@ -238,6 +262,7 @@ def __init__( self._message_handlers: List[Callable[[message.Message], None]] = [] self.log_name = params.address self._ping_task: Optional[asyncio.Task[None]] = None + self._read_exception_handlers: List[Callable[[Exception], None]] = [] def _start_ping(self) -> None: async def func() -> None: @@ -305,7 +330,7 @@ async def connect(self) -> None: raise err except asyncio.TimeoutError: await self._on_error() - raise APIConnectionError( + raise ResolveAPIError( f"Timeout while resolving IP address for {self.log_name}" ) @@ -328,10 +353,10 @@ async def connect(self) -> None: await asyncio.wait_for(coro2, 30.0) except OSError as err: await self._on_error() - raise APIConnectionError(f"Error connecting to {sockaddr}: {err}") + raise SocketAPIError(f"Error connecting to {sockaddr}: {err}") except asyncio.TimeoutError: await self._on_error() - raise APIConnectionError(f"Timeout while connecting to {sockaddr}") + raise SocketAPIError(f"Timeout while connecting to {sockaddr}") _LOGGER.debug("%s: Opened socket for", self._params.address) reader, writer = await asyncio.open_connection(sock=self._socket) @@ -383,7 +408,7 @@ async def login(self) -> None: connect.password = self._params.password resp = await self.send_message_await_response(connect, ConnectResponse) if resp.invalid_password: - raise APIConnectionError("Invalid password!") + raise InvalidAuthAPIError("Invalid password!") self._authenticated = True @@ -444,20 +469,25 @@ def on_message(resp: message.Message) -> None: if do_stop(resp): fut.set_result(responses) + def on_read_exception(exc: Exception) -> None: + if not fut.done(): + fut.set_exception(exc) + self._message_handlers.append(on_message) + self._read_exception_handlers.append(on_read_exception) await self.send_message(send_msg) try: await asyncio.wait_for(fut, timeout) except asyncio.TimeoutError: if self._stopped: - raise APIConnectionError("Disconnected while waiting for API response!") - raise APIConnectionError("Timeout while waiting for API response!") - - try: - self._message_handlers.remove(on_message) - except ValueError: - pass + raise SocketAPIError("Disconnected while waiting for API response!") + raise SocketAPIError("Timeout while waiting for API response!") + finally: + with suppress(ValueError): + self._message_handlers.remove(on_message) + with suppress(ValueError): + self._read_exception_handlers.remove(on_read_exception) return responses @@ -491,7 +521,7 @@ async def _run_once(self) -> None: try: msg.ParseFromString(raw_msg) except Exception as e: - raise APIConnectionError("Invalid protobuf message: {}".format(e)) + raise ProtocolAPIError(f"Invalid protobuf message: {e}") from e _LOGGER.debug( "%s: Got message of type %s: %s", self._params.address, type(msg), msg ) @@ -509,15 +539,19 @@ async def run_forever(self) -> None: self.log_name, err, ) + for handler in self._read_exception_handlers[:]: + handler(err) await self._on_error() break except Exception as err: # pylint: disable=broad-except - _LOGGER.info( + _LOGGER.warning( "%s: Unexpected error while reading incoming messages: %s", self.log_name, err, exc_info=True, ) + for handler in self._read_exception_handlers[:]: + handler(err) await self._on_error() break diff --git a/aioesphomeapi/core.py b/aioesphomeapi/core.py index 83dc90ae..1d0b9503 100644 --- a/aioesphomeapi/core.py +++ b/aioesphomeapi/core.py @@ -63,6 +63,34 @@ class APIConnectionError(Exception): pass +class InvalidAuthAPIError(APIConnectionError): + pass + + +class ResolveAPIError(APIConnectionError): + pass + + +class ProtocolAPIError(APIConnectionError): + pass + + +class RequiresEncryptionAPIError(ProtocolAPIError): + pass + + +class SocketAPIError(APIConnectionError): + pass + + +class HandshakeAPIError(APIConnectionError): + pass + + +class InvalidEncryptionKeyAPIError(HandshakeAPIError): + pass + + MESSAGE_TYPE_TO_PROTO = { 1: HelloRequest, 2: HelloResponse, diff --git a/aioesphomeapi/host_resolver.py b/aioesphomeapi/host_resolver.py index d8a1e022..e2ea2f2b 100644 --- a/aioesphomeapi/host_resolver.py +++ b/aioesphomeapi/host_resolver.py @@ -13,7 +13,7 @@ except ImportError: ZC_ASYNCIO = False -from .core import APIConnectionError +from .core import APIConnectionError, ResolveAPIError ZeroconfInstanceType = Union[zeroconf.Zeroconf, "zeroconf.asyncio.AsyncZeroconf", None] @@ -56,7 +56,7 @@ def _sync_zeroconf_get_service_info( try: zc = zeroconf.Zeroconf() except Exception: - raise APIConnectionError( + raise ResolveAPIError( "Cannot start mDNS sockets, is this a docker container without " "host network mode?" ) @@ -72,7 +72,7 @@ def _sync_zeroconf_get_service_info( try: info = zc.get_service_info(service_type, service_name, int(timeout * 1000)) except Exception as exc: - raise APIConnectionError( + raise ResolveAPIError( f"Error resolving mDNS {service_name} via mDNS: {exc}" ) from exc finally: @@ -105,7 +105,7 @@ async def _async_zeroconf_get_service_info( try: zc = zeroconf.asyncio.AsyncZeroconf() except Exception: - raise APIConnectionError( + raise ResolveAPIError( "Cannot start mDNS sockets, is this a docker container without " "host network mode?" ) @@ -126,7 +126,7 @@ async def _async_zeroconf_get_service_info( service_type, service_name, int(timeout * 1000) ) except Exception as exc: - raise APIConnectionError( + raise ResolveAPIError( f"Error resolving mDNS {service_name} via mDNS: {exc}" ) from exc finally: @@ -240,9 +240,7 @@ async def async_resolve_host( if zc_error: # Only show ZC error if getaddrinfo also didn't work raise zc_error - raise APIConnectionError( - f"Could not resolve host {host} - got no results from OS" - ) + raise ResolveAPIError(f"Could not resolve host {host} - got no results from OS") # Use first matching result # Future: return all matches and use first working one