From 015e9c8d5e46e65077a5cab40b0703ed9e1db187 Mon Sep 17 00:00:00 2001 From: Otto Winter Date: Wed, 8 Sep 2021 23:12:07 +0200 Subject: [PATCH] Add noise API transport support (#100) --- aioesphomeapi/client.py | 5 + aioesphomeapi/connection.py | 283 ++++++++++++++++++++++++++++-------- aioesphomeapi/log_reader.py | 81 +++++++++++ requirements.txt | 1 + tests/test_connection.py | 1 + 5 files changed, 308 insertions(+), 63 deletions(-) create mode 100644 aioesphomeapi/log_reader.py diff --git a/aioesphomeapi/client.py b/aioesphomeapi/client.py index fb64ccbe..e1119d43 100644 --- a/aioesphomeapi/client.py +++ b/aioesphomeapi/client.py @@ -124,6 +124,7 @@ def __init__( client_info: str = "aioesphomeapi", keepalive: float = 15.0, zeroconf_instance: ZeroconfInstanceType = None, + noise_psk: Optional[str] = None, ): self._params = ConnectionParams( eventloop=eventloop, @@ -133,6 +134,7 @@ def __init__( client_info=client_info, keepalive=keepalive, zeroconf_instance=zeroconf_instance, + noise_psk=noise_psk, ) self._connection: Optional[APIConnection] = None self._cached_name: Optional[str] = None @@ -305,6 +307,7 @@ async def subscribe_logs( self, on_log: Callable[[SubscribeLogsResponse], None], log_level: Optional[LogLevel] = None, + dump_config: Optional[bool] = None, ) -> None: self._check_authenticated() @@ -315,6 +318,8 @@ def on_msg(msg: message.Message) -> None: req = SubscribeLogsRequest() if log_level is not None: req.level = log_level + if dump_config is not None: + req.dump_config = dump_config assert self._connection is not None await self._connection.send_message_callback_response(req, on_msg) diff --git a/aioesphomeapi/connection.py b/aioesphomeapi/connection.py index 60e6384e..d3553240 100644 --- a/aioesphomeapi/connection.py +++ b/aioesphomeapi/connection.py @@ -1,11 +1,13 @@ import asyncio +import base64 import logging import socket import time from dataclasses import astuple, dataclass -from typing import Any, Awaitable, Callable, List, Optional, cast +from typing import Any, Awaitable, Callable, List, Optional from google.protobuf import message +from noise.connection import NoiseConnection # type: ignore import aioesphomeapi.host_resolver as hr @@ -37,6 +39,185 @@ class ConnectionParams: client_info: str keepalive: float zeroconf_instance: hr.ZeroconfInstanceType + noise_psk: Optional[str] + + @property + def noise_psk_bytes(self) -> Optional[bytes]: + if self.noise_psk is None: + return None + return base64.b64decode(self.noise_psk) + + +@dataclass +class Packet: + type: int + data: bytes + + +class APIFrameHelper: + def __init__( + self, + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter, + params: ConnectionParams, + ): + self._reader = reader + self._writer = writer + self._params = params + self._write_lock = asyncio.Lock() + self._read_lock = asyncio.Lock() + self._ready_event = asyncio.Event() + self._proto: Optional[NoiseConnection] = None + + async def close(self) -> None: + async with self._write_lock: + self._writer.close() + + async def _write_frame_noise(self, frame: bytes) -> None: + try: + async with self._write_lock: + _LOGGER.debug("Sending frame %s", frame.hex()) + header = bytes( + [ + 0x01, + (len(frame) >> 8) & 0xFF, + len(frame) & 0xFF, + ] + ) + self._writer.write(header + frame) + await self._writer.drain() + except OSError as err: + raise APIConnectionError(f"Error while writing data: {err}") from err + + async def _read_frame_noise(self) -> bytes: + try: + async with self._read_lock: + header = await self._reader.readexactly(3) + if header[0] != 0x01: + raise APIConnectionError(f"Marker byte invalid: {header[0]}") + msg_size = (header[1] << 8) | header[2] + frame = await self._reader.readexactly(msg_size) + except (asyncio.IncompleteReadError, OSError, TimeoutError) as err: + raise APIConnectionError(f"Error while reading data: {err}") from err + + _LOGGER.debug("Received frame %s", frame.hex()) + return frame + + async def perform_handshake(self) -> None: + if self._params.noise_psk is None: + return + await self._write_frame_noise(b"") # ClientHello + prologue = b"NoiseAPIInit" + b"\x00\x00" + server_hello = await self._read_frame_noise() # ServerHello + if not server_hello: + raise APIConnectionError("ServerHello is empty") + chosen_proto = server_hello[0] + if chosen_proto != 0x01: + raise APIConnectionError( + f"Unknown protocol selected by client {chosen_proto}" + ) + + self._proto = NoiseConnection.from_name(b"Noise_NNpsk0_25519_ChaChaPoly_SHA256") + self._proto.set_as_initiator() + self._proto.set_psks(self._params.noise_psk_bytes) + self._proto.set_prologue(prologue) + self._proto.start_handshake() + + _LOGGER.debug("Starting handshake...") + do_write = True + while not self._proto.handshake_finished: + if do_write: + msg = self._proto.write_message() + await self._write_frame_noise(b"\x00" + msg) + else: + msg = await self._read_frame_noise() + if not msg or msg[0] != 0: + raise APIConnectionError(f"Handshake failure: {msg[1:].decode()}") + self._proto.read_message(msg[1:]) + + do_write = not do_write + + _LOGGER.debug("Handshake complete!") + self._ready_event.set() + + async def _write_packet_noise(self, packet: Packet) -> None: + await self._ready_event.wait() + padding = 0 + data = ( + bytes( + [ + (packet.type >> 8) & 0xFF, + (packet.type >> 0) & 0xFF, + (len(packet.data) >> 8) & 0xFF, + (len(packet.data) >> 0) & 0xFF, + ] + ) + + packet.data + + b"\x00" * padding + ) + assert self._proto is not None + frame = self._proto.encrypt(data) + await self._write_frame_noise(frame) + + async def _write_packet_plaintext(self, packet: Packet) -> None: + data = b"\0" + data += varuint_to_bytes(len(packet.data)) + data += varuint_to_bytes(packet.type) + data += packet.data + try: + async with self._write_lock: + _LOGGER.debug("Sending frame %s", data.hex()) + self._writer.write(data) + await self._writer.drain() + except OSError as err: + raise APIConnectionError(f"Error while writing data: {err}") from err + + async def write_packet(self, packet: Packet) -> None: + if self._params.noise_psk is None: + await self._write_packet_plaintext(packet) + else: + await self._write_packet_noise(packet) + + async def _read_packet_noise(self) -> Packet: + await self._ready_event.wait() + frame = await self._read_frame_noise() + assert self._proto is not None + msg = self._proto.decrypt(frame) + if len(msg) < 4: + raise APIConnectionError(f"Bad packet frame: {msg}") + pkt_type = (msg[0] << 8) | msg[1] + data_len = (msg[2] << 8) | msg[3] + if data_len + 4 > len(msg): + raise APIConnectionError(f"Bad data len: {data_len} vs {len(msg)}") + data = msg[4 : 4 + data_len] + return Packet(type=pkt_type, data=data) + + async def _read_packet_plaintext(self) -> Packet: + async with self._read_lock: + preamble = await self._reader.readexactly(1) + if preamble[0] != 0x00: + raise APIConnectionError("Invalid preamble") + + length = b"" + while not length or (length[-1] & 0x80) == 0x80: + length += await self._reader.readexactly(1) + length_int = bytes_to_varuint(length) + assert length_int is not None + msg_type = b"" + while not msg_type or (msg_type[-1] & 0x80) == 0x80: + msg_type += await self._reader.readexactly(1) + msg_type_int = bytes_to_varuint(msg_type) + assert msg_type_int is not None + + raw_msg = b"" + if length_int != 0: + raw_msg = await self._reader.readexactly(length_int) + return Packet(type=msg_type_int, data=raw_msg) + + async def read_packet(self) -> Packet: + if self._params.noise_psk is None: + return await self._read_packet_plaintext() + return await self._read_packet_noise() class APIConnection: @@ -47,9 +228,7 @@ def __init__( self.on_stop = on_stop self._stopped = False self._socket: Optional[socket.socket] = None - self._socket_reader: Optional[asyncio.StreamReader] = None - self._socket_writer: Optional[asyncio.StreamWriter] = None - self._write_lock = asyncio.Lock() + self._frame_helper: Optional[APIFrameHelper] = None self._connected = False self._authenticated = False self._socket_connected = False @@ -58,15 +237,13 @@ def __init__( self._message_handlers: List[Callable[[message.Message], None]] = [] self.log_name = params.address + self._ping_task: Optional[asyncio.Task[None]] = None def _start_ping(self) -> None: async def func() -> None: - while self._connected: + while True: await asyncio.sleep(self._params.keepalive) - if not self._connected: - return - try: await self.ping() except APIConnectionError: @@ -74,18 +251,20 @@ async def func() -> None: await self._on_error() return - self._params.eventloop.create_task(func()) + self._ping_task = asyncio.create_task(func()) async def _close_socket(self) -> None: if not self._socket_connected: return - async with self._write_lock: - if self._socket_writer is not None: - self._socket_writer.close() - self._socket_writer = None - self._socket_reader = None + if self._frame_helper is not None: + await self._frame_helper.close() + self._frame_helper = None if self._socket is not None: self._socket.close() + self._socket = None + if self._ping_task is not None: + self._ping_task.cancel() + self._ping_task = None self._socket_connected = False self._connected = False self._authenticated = False @@ -106,6 +285,7 @@ async def stop(self, force: bool = False) -> None: async def _on_error(self) -> None: await self.stop(force=True) + # pylint: disable=too-many-statements async def connect(self) -> None: if self._stopped: raise APIConnectionError(f"Connection is closed for {self.log_name}!") @@ -154,19 +334,25 @@ async def connect(self) -> None: raise APIConnectionError(f"Timeout while connecting to {sockaddr}") _LOGGER.debug("%s: Opened socket for", self._params.address) - self._socket_reader, self._socket_writer = await asyncio.open_connection( - sock=self._socket - ) + reader, writer = await asyncio.open_connection(sock=self._socket) + self._frame_helper = APIFrameHelper(reader, writer, self._params) self._socket_connected = True + + try: + await self._frame_helper.perform_handshake() + except APIConnectionError: + await self._on_error() + raise + self._params.eventloop.create_task(self.run_forever()) hello = HelloRequest() hello.client_info = self._params.client_info try: resp = await self.send_message_await_response(hello, HelloResponse) - except APIConnectionError as err: + except APIConnectionError: await self._on_error() - raise err + raise _LOGGER.debug( "%s: Successfully connected ('%s' API=%s.%s)", self.log_name, @@ -213,21 +399,10 @@ def is_connected(self) -> bool: def is_authenticated(self) -> bool: return self._authenticated - async def _write(self, data: bytes) -> None: - # _LOGGER.debug("%s: Write: %s", self._params.address, - # ' '.join('{:02X}'.format(x) for x in data)) + async def send_message(self, msg: message.Message) -> None: if not self._socket_connected: raise APIConnectionError("Socket is not connected") - try: - async with self._write_lock: - if self._socket_writer is not None: - self._socket_writer.write(data) - await self._socket_writer.drain() - except OSError as err: - await self._on_error() - raise APIConnectionError("Error while writing data: {}".format(err)) - async def send_message(self, msg: message.Message) -> None: for message_type, klass in MESSAGE_TYPE_TO_PROTO.items(): if isinstance(msg, klass): break @@ -236,12 +411,14 @@ async def send_message(self, msg: message.Message) -> None: encoded = msg.SerializeToString() _LOGGER.debug("%s: Sending %s: %s", self._params.address, type(msg), str(msg)) - req = bytes([0]) - req += varuint_to_bytes(len(encoded)) # pylint: disable=undefined-loop-variable - req += varuint_to_bytes(message_type) - req += encoded - await self._write(req) + assert self._frame_helper is not None + await self._frame_helper.write_packet( + Packet( + type=message_type, + data=encoded, + ) + ) async def send_message_callback_response( self, send_msg: message.Message, on_message: Callable[[Any], None] @@ -254,7 +431,7 @@ async def send_message_await_response_complex( send_msg: message.Message, do_append: Callable[[Any], bool], do_stop: Callable[[Any], bool], - timeout: float = 5.0, + timeout: float = 10.0, ) -> List[Any]: fut = self._params.eventloop.create_future() responses = [] @@ -285,7 +462,7 @@ def on_message(resp: message.Message) -> None: return responses async def send_message_await_response( - self, send_msg: message.Message, response_type: Any, timeout: float = 5.0 + self, send_msg: message.Message, response_type: Any, timeout: float = 10.0 ) -> Any: def is_response(msg: message.Message) -> bool: return isinstance(msg, response_type) @@ -298,33 +475,12 @@ def is_response(msg: message.Message) -> bool: return res[0] - async def _recv(self, amount: int) -> bytes: - if amount == 0: - return bytes() - - try: - assert self._socket_reader is not None - ret = await self._socket_reader.readexactly(amount) - except (asyncio.IncompleteReadError, OSError, TimeoutError) as err: - raise APIConnectionError("Error while receiving data: {}".format(err)) - - return ret - - async def _recv_varint(self) -> int: - raw = bytes() - while not raw or raw[-1] & 0x80: - raw += await self._recv(1) - return cast(int, bytes_to_varuint(raw)) - async def _run_once(self) -> None: - preamble = await self._recv(1) - if preamble[0] != 0x00: - raise APIConnectionError("Invalid preamble") - - length = await self._recv_varint() - msg_type = await self._recv_varint() + assert self._frame_helper is not None + pkt = await self._frame_helper.read_packet() - raw_msg = await self._recv(length) + msg_type = pkt.type + raw_msg = pkt.data if msg_type not in MESSAGE_TYPE_TO_PROTO: _LOGGER.debug( "%s: Skipping message type %s", self._params.address, msg_type @@ -360,6 +516,7 @@ async def run_forever(self) -> None: "%s: Unexpected error while reading incoming messages: %s", self.log_name, err, + exc_info=True, ) await self._on_error() break diff --git a/aioesphomeapi/log_reader.py b/aioesphomeapi/log_reader.py new file mode 100644 index 00000000..ac854c2f --- /dev/null +++ b/aioesphomeapi/log_reader.py @@ -0,0 +1,81 @@ +# Helper script and aioesphomeapi to view logs from an esphome device +import argparse +import asyncio +import logging +import sys +from datetime import datetime +from typing import List + +import zeroconf + +from aioesphomeapi.api_pb2 import SubscribeLogsResponse # type: ignore +from aioesphomeapi.client import APIClient +from aioesphomeapi.core import APIConnectionError +from aioesphomeapi.model import LogLevel +from aioesphomeapi.reconnect_logic import ReconnectLogic + +_LOGGER = logging.getLogger(__name__) + + +async def main(argv: List[str]) -> None: + parser = argparse.ArgumentParser("aioesphomeapi-logs") + parser.add_argument("--port", type=int, default=6053) + parser.add_argument("--password", type=str) + parser.add_argument("--noise-psk", type=str) + parser.add_argument("-v", "--verbose", action="store_true") + parser.add_argument("address") + args = parser.parse_args(argv[1:]) + + logging.basicConfig( + format="%(asctime)s %(levelname)-8s %(message)s", + level=logging.DEBUG if args.verbose else logging.INFO, + datefmt="%Y-%m-%d %H:%M:%S", + ) + + cli = APIClient( + asyncio.get_event_loop(), + args.address, + args.port, + args.password or "", + noise_psk=args.noise_psk, + keepalive=10, + ) + + def on_log(msg: SubscribeLogsResponse) -> None: + time_ = datetime.now().time().strftime("[%H:%M:%S]") + text = msg.message + print(time_ + text.decode("utf8", "backslashreplace")) + + has_connects = False + + async def on_connect() -> None: + nonlocal has_connects + try: + await cli.subscribe_logs( + on_log, + log_level=LogLevel.LOG_LEVEL_VERY_VERBOSE, + dump_config=not has_connects, + ) + has_connects = True + except APIConnectionError: + cli.disconnect() + + async def on_disconnect() -> None: + _LOGGER.warning("Disconnected from API") + + logic = ReconnectLogic( + client=cli, + on_connect=on_connect, + on_disconnect=on_disconnect, + zeroconf_instance=zeroconf.Zeroconf(), + ) + await logic.start() + try: + while True: + await asyncio.sleep(60) + except KeyboardInterrupt: + await logic.stop() + + +if __name__ == "__main__": + sys.exit(asyncio.run(main(sys.argv)) or 0) diff --git a/requirements.txt b/requirements.txt index 1916b642..0c081b6e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,3 @@ protobuf>=3.12.2,<4.0 zeroconf>=0.28.0,<1.0 +noiseprotocol>=0.3.1,<1.0 \ No newline at end of file diff --git a/tests/test_connection.py b/tests/test_connection.py index 1939887f..28a0c2fd 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -20,6 +20,7 @@ def connection_params() -> ConnectionParams: client_info="Tests client", keepalive=15.0, zeroconf_instance=None, + noise_psk=None, )