Skip to content

Commit

Permalink
Emit different Exception types to differentiate between connection er…
Browse files Browse the repository at this point in the history
…rors (#102)

* Emit different Exception types to differentiate between connection errors

* Import in init
  • Loading branch information
OttoWinter authored Sep 14, 2021
1 parent 0660f1c commit 5c9e7ac
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 42 deletions.
12 changes: 11 additions & 1 deletion aioesphomeapi/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
# flake8: noqa
from .client import APIClient
from .connection import APIConnection, ConnectionParams
from .core import MESSAGE_TYPE_TO_PROTO, APIConnectionError
from .core import (
MESSAGE_TYPE_TO_PROTO,
APIConnectionError,
HandshakeAPIError,
InvalidAuthAPIError,
InvalidEncryptionKeyAPIError,
ProtocolAPIError,
RequiresEncryptionAPIError,
ResolveAPIError,
SocketAPIError,
)
from .model import *
from .reconnect_logic import ReconnectLogic
3 changes: 2 additions & 1 deletion aioesphomeapi/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,8 @@ def __init__(
client_info=client_info,
keepalive=keepalive,
zeroconf_instance=zeroconf_instance,
noise_psk=noise_psk,
# treat empty psk string as missing (like password)
noise_psk=noise_psk or None,
)
self._connection: Optional[APIConnection] = None
self._cached_name: Optional[str] = None
Expand Down
98 changes: 66 additions & 32 deletions aioesphomeapi/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
import socket
import time
from contextlib import suppress
from dataclasses import astuple, dataclass
from typing import Any, Awaitable, Callable, List, Optional

Expand All @@ -23,7 +24,17 @@
PingRequest,
PingResponse,
)
from .core import MESSAGE_TYPE_TO_PROTO, APIConnectionError
from .core import (
MESSAGE_TYPE_TO_PROTO,
APIConnectionError,
HandshakeAPIError,
InvalidAuthAPIError,
InvalidEncryptionKeyAPIError,
ProtocolAPIError,
RequiresEncryptionAPIError,
ResolveAPIError,
SocketAPIError,
)
from .model import APIVersion
from .util import bytes_to_varuint, varuint_to_bytes

Expand All @@ -41,12 +52,6 @@ class ConnectionParams:
zeroconf_instance: hr.ZeroconfInstanceType
noise_psk: Optional[str]

@property
def noise_psk_bytes(self) -> Optional[bytes]:
if self.noise_psk is None:
return None
return base64.b64decode(self.noise_psk)


@dataclass
class Packet:
Expand Down Expand Up @@ -87,18 +92,18 @@ async def _write_frame_noise(self, frame: bytes) -> None:
self._writer.write(header + frame)
await self._writer.drain()
except OSError as err:
raise APIConnectionError(f"Error while writing data: {err}") from err
raise SocketAPIError(f"Error while writing data: {err}") from err

async def _read_frame_noise(self) -> bytes:
try:
async with self._read_lock:
header = await self._reader.readexactly(3)
if header[0] != 0x01:
raise APIConnectionError(f"Marker byte invalid: {header[0]}")
raise ProtocolAPIError(f"Marker byte invalid: {header[0]}")
msg_size = (header[1] << 8) | header[2]
frame = await self._reader.readexactly(msg_size)
except (asyncio.IncompleteReadError, OSError, TimeoutError) as err:
raise APIConnectionError(f"Error while reading data: {err}") from err
raise SocketAPIError(f"Error while reading data: {err}") from err

_LOGGER.debug("Received frame %s", frame.hex())
return frame
Expand All @@ -110,16 +115,28 @@ async def perform_handshake(self) -> None:
prologue = b"NoiseAPIInit" + b"\x00\x00"
server_hello = await self._read_frame_noise() # ServerHello
if not server_hello:
raise APIConnectionError("ServerHello is empty")
raise HandshakeAPIError("ServerHello is empty")
chosen_proto = server_hello[0]
if chosen_proto != 0x01:
raise APIConnectionError(
raise HandshakeAPIError(
f"Unknown protocol selected by client {chosen_proto}"
)

self._proto = NoiseConnection.from_name(b"Noise_NNpsk0_25519_ChaChaPoly_SHA256")
self._proto.set_as_initiator()
self._proto.set_psks(self._params.noise_psk_bytes)

try:
noise_psk_bytes = base64.b64decode(self._params.noise_psk)
except ValueError:
raise InvalidEncryptionKeyAPIError(
f"Malformed PSK {self._params.noise_psk}, expected base64-encoded value"
)
if len(noise_psk_bytes) != 32:
raise InvalidEncryptionKeyAPIError(
f"Malformed PSK {self._params.noise_psk}, expected 32-bytes of base64 data"
)

self._proto.set_psks(noise_psk_bytes)
self._proto.set_prologue(prologue)
self._proto.start_handshake()

Expand All @@ -131,8 +148,13 @@ async def perform_handshake(self) -> None:
await self._write_frame_noise(b"\x00" + msg)
else:
msg = await self._read_frame_noise()
if not msg or msg[0] != 0:
raise APIConnectionError(f"Handshake failure: {msg[1:].decode()}")
if not msg:
raise HandshakeAPIError("Handshake message too short")
if msg[0] != 0:
explanation = msg[1:].decode()
if explanation == "Handshake MAC failure":
raise InvalidEncryptionKeyAPIError("Invalid encryption key")
raise HandshakeAPIError(f"Handshake failure: {explanation}")
self._proto.read_message(msg[1:])

do_write = not do_write
Expand Down Expand Up @@ -170,7 +192,7 @@ async def _write_packet_plaintext(self, packet: Packet) -> None:
self._writer.write(data)
await self._writer.drain()
except OSError as err:
raise APIConnectionError(f"Error while writing data: {err}") from err
raise SocketAPIError(f"Error while writing data: {err}") from err

async def write_packet(self, packet: Packet) -> None:
if self._params.noise_psk is None:
Expand All @@ -184,19 +206,21 @@ async def _read_packet_noise(self) -> Packet:
assert self._proto is not None
msg = self._proto.decrypt(frame)
if len(msg) < 4:
raise APIConnectionError(f"Bad packet frame: {msg}")
raise ProtocolAPIError(f"Bad packet frame: {msg}")
pkt_type = (msg[0] << 8) | msg[1]
data_len = (msg[2] << 8) | msg[3]
if data_len + 4 > len(msg):
raise APIConnectionError(f"Bad data len: {data_len} vs {len(msg)}")
raise ProtocolAPIError(f"Bad data len: {data_len} vs {len(msg)}")
data = msg[4 : 4 + data_len]
return Packet(type=pkt_type, data=data)

async def _read_packet_plaintext(self) -> Packet:
async with self._read_lock:
preamble = await self._reader.readexactly(1)
if preamble[0] != 0x00:
raise APIConnectionError("Invalid preamble")
if preamble[0] == 0x01:
raise RequiresEncryptionAPIError("Connection requires encryption")
raise ProtocolAPIError(f"Invalid preamble {preamble[0]:02x}")

length = b""
while not length or (length[-1] & 0x80) == 0x80:
Expand Down Expand Up @@ -238,6 +262,7 @@ def __init__(
self._message_handlers: List[Callable[[message.Message], None]] = []
self.log_name = params.address
self._ping_task: Optional[asyncio.Task[None]] = None
self._read_exception_handlers: List[Callable[[Exception], None]] = []

def _start_ping(self) -> None:
async def func() -> None:
Expand Down Expand Up @@ -305,7 +330,7 @@ async def connect(self) -> None:
raise err
except asyncio.TimeoutError:
await self._on_error()
raise APIConnectionError(
raise ResolveAPIError(
f"Timeout while resolving IP address for {self.log_name}"
)

Expand All @@ -328,10 +353,10 @@ async def connect(self) -> None:
await asyncio.wait_for(coro2, 30.0)
except OSError as err:
await self._on_error()
raise APIConnectionError(f"Error connecting to {sockaddr}: {err}")
raise SocketAPIError(f"Error connecting to {sockaddr}: {err}")
except asyncio.TimeoutError:
await self._on_error()
raise APIConnectionError(f"Timeout while connecting to {sockaddr}")
raise SocketAPIError(f"Timeout while connecting to {sockaddr}")

_LOGGER.debug("%s: Opened socket for", self._params.address)
reader, writer = await asyncio.open_connection(sock=self._socket)
Expand Down Expand Up @@ -383,7 +408,7 @@ async def login(self) -> None:
connect.password = self._params.password
resp = await self.send_message_await_response(connect, ConnectResponse)
if resp.invalid_password:
raise APIConnectionError("Invalid password!")
raise InvalidAuthAPIError("Invalid password!")

self._authenticated = True

Expand Down Expand Up @@ -444,20 +469,25 @@ def on_message(resp: message.Message) -> None:
if do_stop(resp):
fut.set_result(responses)

def on_read_exception(exc: Exception) -> None:
if not fut.done():
fut.set_exception(exc)

self._message_handlers.append(on_message)
self._read_exception_handlers.append(on_read_exception)
await self.send_message(send_msg)

try:
await asyncio.wait_for(fut, timeout)
except asyncio.TimeoutError:
if self._stopped:
raise APIConnectionError("Disconnected while waiting for API response!")
raise APIConnectionError("Timeout while waiting for API response!")

try:
self._message_handlers.remove(on_message)
except ValueError:
pass
raise SocketAPIError("Disconnected while waiting for API response!")
raise SocketAPIError("Timeout while waiting for API response!")
finally:
with suppress(ValueError):
self._message_handlers.remove(on_message)
with suppress(ValueError):
self._read_exception_handlers.remove(on_read_exception)

return responses

Expand Down Expand Up @@ -491,7 +521,7 @@ async def _run_once(self) -> None:
try:
msg.ParseFromString(raw_msg)
except Exception as e:
raise APIConnectionError("Invalid protobuf message: {}".format(e))
raise ProtocolAPIError(f"Invalid protobuf message: {e}") from e
_LOGGER.debug(
"%s: Got message of type %s: %s", self._params.address, type(msg), msg
)
Expand All @@ -509,15 +539,19 @@ async def run_forever(self) -> None:
self.log_name,
err,
)
for handler in self._read_exception_handlers[:]:
handler(err)
await self._on_error()
break
except Exception as err: # pylint: disable=broad-except
_LOGGER.info(
_LOGGER.warning(
"%s: Unexpected error while reading incoming messages: %s",
self.log_name,
err,
exc_info=True,
)
for handler in self._read_exception_handlers[:]:
handler(err)
await self._on_error()
break

Expand Down
28 changes: 28 additions & 0 deletions aioesphomeapi/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,34 @@ class APIConnectionError(Exception):
pass


class InvalidAuthAPIError(APIConnectionError):
pass


class ResolveAPIError(APIConnectionError):
pass


class ProtocolAPIError(APIConnectionError):
pass


class RequiresEncryptionAPIError(ProtocolAPIError):
pass


class SocketAPIError(APIConnectionError):
pass


class HandshakeAPIError(APIConnectionError):
pass


class InvalidEncryptionKeyAPIError(HandshakeAPIError):
pass


MESSAGE_TYPE_TO_PROTO = {
1: HelloRequest,
2: HelloResponse,
Expand Down
14 changes: 6 additions & 8 deletions aioesphomeapi/host_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
except ImportError:
ZC_ASYNCIO = False

from .core import APIConnectionError
from .core import APIConnectionError, ResolveAPIError

ZeroconfInstanceType = Union[zeroconf.Zeroconf, "zeroconf.asyncio.AsyncZeroconf", None]

Expand Down Expand Up @@ -56,7 +56,7 @@ def _sync_zeroconf_get_service_info(
try:
zc = zeroconf.Zeroconf()
except Exception:
raise APIConnectionError(
raise ResolveAPIError(
"Cannot start mDNS sockets, is this a docker container without "
"host network mode?"
)
Expand All @@ -72,7 +72,7 @@ def _sync_zeroconf_get_service_info(
try:
info = zc.get_service_info(service_type, service_name, int(timeout * 1000))
except Exception as exc:
raise APIConnectionError(
raise ResolveAPIError(
f"Error resolving mDNS {service_name} via mDNS: {exc}"
) from exc
finally:
Expand Down Expand Up @@ -105,7 +105,7 @@ async def _async_zeroconf_get_service_info(
try:
zc = zeroconf.asyncio.AsyncZeroconf()
except Exception:
raise APIConnectionError(
raise ResolveAPIError(
"Cannot start mDNS sockets, is this a docker container without "
"host network mode?"
)
Expand All @@ -126,7 +126,7 @@ async def _async_zeroconf_get_service_info(
service_type, service_name, int(timeout * 1000)
)
except Exception as exc:
raise APIConnectionError(
raise ResolveAPIError(
f"Error resolving mDNS {service_name} via mDNS: {exc}"
) from exc
finally:
Expand Down Expand Up @@ -240,9 +240,7 @@ async def async_resolve_host(
if zc_error:
# Only show ZC error if getaddrinfo also didn't work
raise zc_error
raise APIConnectionError(
f"Could not resolve host {host} - got no results from OS"
)
raise ResolveAPIError(f"Could not resolve host {host} - got no results from OS")

# Use first matching result
# Future: return all matches and use first working one
Expand Down

0 comments on commit 5c9e7ac

Please sign in to comment.