diff --git a/.clang-format b/.clang-format index 5c891058..91f7bc2d 100644 --- a/.clang-format +++ b/.clang-format @@ -167,7 +167,6 @@ SpaceBeforeParens: Custom SpaceBeforeParensOptions: AfterControlStatements: true AfterForeachMacros: false - AfterFunctionDeclationName: false AfterIfMacros: false AfterOverloadedOperator: false AfterRequiresInClause: false diff --git a/.github/workflows/ci-build.yml b/.github/workflows/ci-build.yml index 511ae819..59a52853 100644 --- a/.github/workflows/ci-build.yml +++ b/.github/workflows/ci-build.yml @@ -185,5 +185,5 @@ jobs: with: name: OpenShock_${{ matrix.board }} path: OpenShock.${{ matrix.board }}.bin - retention-days: 1 + retention-days: 7 if-no-files-found: error diff --git a/WebUI/src/lib/WebSocketClient.ts b/WebUI/src/lib/WebSocketClient.ts index b997a620..61ab8113 100644 --- a/WebUI/src/lib/WebSocketClient.ts +++ b/WebUI/src/lib/WebSocketClient.ts @@ -1,7 +1,5 @@ import { browser } from "$app/environment"; -import { getToastStore } from "@skeletonlabs/skeleton"; -import { WiFiStateStore } from "./stores"; -import type { WiFiNetwork } from "./types/WiFiNetwork"; +import { WebSocketMessageHandler } from "./WebSocketMessageHandler"; export enum ConnectionState { DISCONNECTED = 0, @@ -134,22 +132,8 @@ export class WebSocketClient { return; } - if (message.networks !== undefined) { - WiFiStateStore.setNetworks(message.networks as WiFiNetwork[]); - return; - } - - if (message.scanning !== undefined) { - const toastStore = getToastStore(); - if (message.scanning) { - toastStore.trigger({ message: 'Scanning for WiFi networks...', background: 'bg-blue-500' }); - } else { - toastStore.trigger({ message: 'Scanning for WiFi networks finished', background: 'bg-green-500' }); - } - - WiFiStateStore.setScanning(message.scanning as boolean); - return; - } + // Handle message + WebSocketMessageHandler(message); } private AbortWebSocket() { if (this._socket) { diff --git a/WebUI/src/lib/WebSocketMessageHandler.ts b/WebUI/src/lib/WebSocketMessageHandler.ts new file mode 100644 index 00000000..9cdb5839 --- /dev/null +++ b/WebUI/src/lib/WebSocketMessageHandler.ts @@ -0,0 +1,112 @@ +import { WebSocketClient } from './WebSocketClient'; +import { WiFiStateStore } from './stores'; +import type { WiFiNetwork } from './types/WiFiNetwork'; + +interface InvalidMessage { + type: undefined | null; +} + +interface PoggiesMessage { + type: 'poggies'; +} + +interface WiFiScanStartedMessage { + type: 'wifi'; + subject: 'scan'; + status: 'started'; +} + +interface WiFiScanDiscoveryMessage { + type: 'wifi'; + subject: 'scan'; + status: 'discovery'; + data: WiFiNetwork; +} + +interface WiFiScanCompletedMessage { + type: 'wifi'; + subject: 'scan'; + status: 'completed'; +} + +interface WiFiScanErrorMessage { + type: 'wifi'; + subject: 'scan'; + status: 'error'; +} + +export type WiFiScanMessage = WiFiScanStartedMessage | WiFiScanDiscoveryMessage | WiFiScanCompletedMessage | WiFiScanErrorMessage; +export type WiFiMessage = WiFiScanMessage; +export type WebSocketMessage = InvalidMessage | PoggiesMessage | WiFiMessage; + +export function WebSocketMessageHandler(message: WebSocketMessage) { + const type = message.type; + if (!type) { + console.warn('[WS] Received invalid message: ', message); + return; + } + + switch (type) { + case 'poggies': + handlePoggiesMessage(); + break; + case 'wifi': + handleWiFiMessage(message); + break; + default: + console.warn('[WS] Received invalid message: ', message); + return; + } +} + +function handlePoggiesMessage() { + WebSocketClient.Instance.Send('{ "type": "wifi", "action": "scan", "run": true }'); +} + +function handleWiFiMessage(message: WiFiMessage) { + switch (message.subject) { + case 'scan': + handleWiFiScanMessage(message); + break; + default: + console.warn('[WS] Received invalid wifi message: ', message); + return; + } +} + +function handleWiFiScanMessage(message: WiFiScanMessage) { + switch (message.status) { + case 'started': + handleWiFiScanStartedMessage(); + break; + case 'discovery': + handleWiFiScanDiscoveryMessage(message); + break; + case 'completed': + handleWiFiScanCompletedMessage(); + break; + case 'error': + handleWiFiScanErrorMessage(message); + break; + default: + console.warn('[WS] Received invalid scan message: ', message); + return; + } +} + +function handleWiFiScanStartedMessage() { + WiFiStateStore.setScanning(true); +} + +function handleWiFiScanDiscoveryMessage(message: WiFiScanDiscoveryMessage) { + WiFiStateStore.addNetwork(message.data); +} + +function handleWiFiScanCompletedMessage() { + WiFiStateStore.setScanning(false); +} + +function handleWiFiScanErrorMessage(message: WiFiScanErrorMessage) { + console.error('[WS] Received WiFi scan error message: ', message); + WiFiStateStore.setScanning(false); +} diff --git a/WebUI/src/lib/components/WiFiList.svelte b/WebUI/src/lib/components/WiFiList.svelte index 3b077abe..3eab0c9d 100644 --- a/WebUI/src/lib/components/WiFiList.svelte +++ b/WebUI/src/lib/components/WiFiList.svelte @@ -1,40 +1,42 @@
@@ -32,7 +34,7 @@
{#if item.saved} - + {:else} diff --git a/WebUI/src/lib/stores/WiFiStateStore.ts b/WebUI/src/lib/stores/WiFiStateStore.ts index d98043d3..b3b5c69e 100644 --- a/WebUI/src/lib/stores/WiFiStateStore.ts +++ b/WebUI/src/lib/stores/WiFiStateStore.ts @@ -5,7 +5,7 @@ import { writable } from 'svelte/store'; const { subscribe, update } = writable({ initialized: false, scanning: false, - networks: [], + networks: {}, }); export const WiFiStateStore = { @@ -22,16 +22,15 @@ export const WiFiStateStore = { return store; }); }, - setNetworks(networks: WiFiNetwork[]) { + addNetwork(network: WiFiNetwork) { update((store) => { - store.scanning = false; - store.networks = networks; + store.networks[network.bssid] = network; return store; }); }, clearNetworks() { update((store) => { - store.networks = []; + store.networks = {}; return store; }); }, diff --git a/WebUI/src/lib/types/WiFiNetwork.ts b/WebUI/src/lib/types/WiFiNetwork.ts index 68578859..33fb6d86 100644 --- a/WebUI/src/lib/types/WiFiNetwork.ts +++ b/WebUI/src/lib/types/WiFiNetwork.ts @@ -1,9 +1,8 @@ export type WiFiNetwork = { - index: number; ssid: string; bssid: string; rssi: number; channel: number; - secure: boolean; + security: 'Open' | 'WEP' | 'WPA PSK' | 'WPA2 PSK' | 'WPA/WPA2 PSK' | 'WPA2 Enterprise' | 'WPA3 PSK' | 'WPA2/WPA3 PSK' | 'WAPI PSK' | null; saved: boolean; }; diff --git a/WebUI/src/lib/types/WiFiState.ts b/WebUI/src/lib/types/WiFiState.ts index a78d900a..909be83d 100644 --- a/WebUI/src/lib/types/WiFiState.ts +++ b/WebUI/src/lib/types/WiFiState.ts @@ -1,7 +1,7 @@ -import type { WiFiNetwork } from "./WiFiNetwork"; +import type { WiFiNetwork } from './WiFiNetwork'; export type WiFiState = { initialized: boolean; scanning: boolean; - networks: WiFiNetwork[]; + networks: { [bssid: string]: WiFiNetwork }; }; diff --git a/include/AuthenticationManager.h b/include/AuthenticationManager.h index c704cd2d..069e5a28 100644 --- a/include/AuthenticationManager.h +++ b/include/AuthenticationManager.h @@ -5,9 +5,10 @@ #include namespace OpenShock::AuthenticationManager { - bool Authenticate(unsigned int pairCode); + bool IsPaired(); + bool Pair(unsigned int pairCode); + void UnPair(); - bool IsAuthenticated(); String GetAuthToken(); void ClearAuthToken(); } // namespace OpenShock::AuthenticationManager diff --git a/include/CaptivePortal.h b/include/CaptivePortal.h index 7e179ca4..b6f175dd 100644 --- a/include/CaptivePortal.h +++ b/include/CaptivePortal.h @@ -11,6 +11,17 @@ namespace OpenShock::CaptivePortal { bool IsRunning(); void Update(); + bool SendMessageTXT(std::uint8_t socketId, const char* data, std::size_t len); + bool SendMessageBIN(std::uint8_t socketId, const std::uint8_t* data, std::size_t len); + inline bool SendMessageTXT(std::uint8_t socketId, const String& message) { + return SendMessageTXT(socketId, message.c_str(), message.length()); + } + inline bool SendMessageJSON(std::uint8_t socketId, const DynamicJsonDocument& doc) { + String message; + serializeJson(doc, message); + return SendMessageTXT(socketId, message); + } + bool BroadcastMessageTXT(const char* data, std::size_t len); bool BroadcastMessageBIN(const std::uint8_t* data, std::size_t len); inline bool BroadcastMessageTXT(const String& message) { diff --git a/include/Mappers/EspWiFiTypesMapper.h b/include/Mappers/EspWiFiTypesMapper.h new file mode 100644 index 00000000..1c4cb344 --- /dev/null +++ b/include/Mappers/EspWiFiTypesMapper.h @@ -0,0 +1,12 @@ +#pragma once + +#include + +namespace OpenShock::Mappers { + /* + * @brief Maps a WiFi auth mode to a human-readable string. + * @param authMode The auth mode to map. + * @return A human-readable string representing the auth mode, or nullptr if the auth mode is invalid. + */ + const char* GetWiFiAuthModeName(wifi_auth_mode_t authMode); +} diff --git a/include/Utils/HexUtils.h b/include/Utils/HexUtils.h new file mode 100644 index 00000000..3ff14dbe --- /dev/null +++ b/include/Utils/HexUtils.h @@ -0,0 +1,71 @@ +#pragma once + +#include + +#include +#include + +namespace OpenShock::HexUtils { + /// @brief Converts a single byte to a hex pair, and writes it to the output buffer. + /// @param data The byte to convert. + /// @param output The output buffer to write to. + /// @param upper Whether to use uppercase hex characters. + constexpr void ToHex(std::uint8_t data, char* output, bool upper = true) noexcept { + const char* hex = upper ? "0123456789ABCDEF" : "0123456789abcdef"; + output[0] = hex[data >> 4]; + output[1] = hex[data & 0x0F]; + } + + /// @brief Converts a byte array to a hex string. + /// @param data The byte array to convert. + /// @param output The output buffer to write to. + /// @param upper Whether to use uppercase hex characters. + /// @remark To use this you must specify the size of the array in the template parameter. (e.g. ToHexMac<6>(...)) + template + constexpr void ToHex(nonstd::span data, nonstd::span output, bool upper = true) noexcept { + for (std::size_t i = 0; i < data.size(); ++i) { + ToHex(data[i], &output[i * 2], upper); + } + } + + /// @brief Converts a byte array to a hex string. + /// @param data The byte array to convert. + /// @param upper Whether to use uppercase hex characters. + /// @return The hex string. + /// @remark To use this you must specify the size of the array in the template parameter. (e.g. ToHexMac<6>(...)) + template + constexpr std::array ToHex(nonstd::span data, bool upper = true) noexcept { + std::array output {}; + ToHex(data, output, upper); + output[N * 2] = '\0'; + return output; + } + + /// @brief Converts a byte array to a MAC address string. (hex pairs separated by colons) + /// @param data The byte array to convert. + /// @param output The output buffer to write to. + /// @param upper Whether to use uppercase hex characters. + /// @remark To use this you must specify the size of the array in the template parameter. (e.g. ToHexMac<6>(...)) + template + constexpr void ToHexMac(nonstd::span data, nonstd::span output, bool upper = true) noexcept { + const std::size_t Last = N - 1; + for (std::size_t i = 0; i < Last; ++i) { + ToHex(data[i], &output[i * 3], upper); + output[i * 3 + 2] = ':'; + } + ToHex(data[Last], &output[Last * 3], upper); + } + + /// @brief Converts a byte array to a MAC address string. (hex pairs separated by colons) + /// @param data The byte array to convert. + /// @param upper Whether to use uppercase hex characters. + /// @return The hex string in a MAC address format. + /// @remark To use this you must specify the size of the array in the template parameter. (e.g. ToHexMac<6>(...)) + template + constexpr std::array ToHexMac(nonstd::span data, bool upper = true) noexcept { + std::array output {}; + ToHexMac(data, nonstd::span(output.data(), output.size() - 1), upper); + output[(N * 3) - 1] = '\0'; + return output; + } +} // namespace OpenShock::HexUtils diff --git a/include/WiFiCredentials.h b/include/WiFiCredentials.h new file mode 100644 index 00000000..2952dd26 --- /dev/null +++ b/include/WiFiCredentials.h @@ -0,0 +1,50 @@ +#pragma once + +#include + +#include + +#include +#include +#include + +namespace fs { + class File; +} + +namespace OpenShock { + class WiFiCredentials { + WiFiCredentials() = default; + + public: + static bool Load(std::vector& credentials); + + WiFiCredentials(std::uint8_t id, const char* ssid, std::uint8_t ssidLength, const char* password, std::uint8_t passwordLength); + + constexpr std::uint8_t id() const noexcept { return _id; } + + constexpr nonstd::span bssid() const noexcept { return nonstd::span(_bssid); } + void setBSSID(const std::uint8_t* bssid); + + constexpr bool hasSSID() const noexcept { return _ssidLength > 0; } + constexpr nonstd::span ssid() const noexcept { return nonstd::span(_ssid, _ssidLength); } + void setSSID(const char* ssid, std::uint8_t ssidLength); + + constexpr bool hasPassword() const noexcept { return _passwordLength > 0; } + constexpr nonstd::span password() const noexcept { return nonstd::span(_password, _passwordLength); } + void setPassword(const char* password, std::uint8_t passwordLength); + + bool save() const; + bool erase() const; + + private: + bool _load(fs::File& file); + + std::uint8_t _id; + std::uint8_t _bssid[6]; + char _ssid[33]; + std::uint8_t _ssidLength; + char _password[64]; + std::uint8_t _passwordLength; + }; +} // namespace OpenShock diff --git a/include/WiFiManager.h b/include/WiFiManager.h index 3225a3c3..40b9ee5f 100644 --- a/include/WiFiManager.h +++ b/include/WiFiManager.h @@ -7,11 +7,9 @@ namespace OpenShock::WiFiManager { bool Init(); - WiFiState GetWiFiState(); - - void AddOrUpdateNetwork(const char* ssid, const char* password); - void RemoveNetwork(const char* ssid); - - bool StartScan(); + bool Authenticate(std::uint8_t (&bssid)[6], const char* password, std::uint8_t passwordLength); + void Forget(std::uint8_t wifiId); + void Connect(std::uint8_t wifiId); + void Disconnect(); } // namespace OpenShock::WiFiManager diff --git a/include/WiFiScanManager.h b/include/WiFiScanManager.h new file mode 100644 index 00000000..93a7e78f --- /dev/null +++ b/include/WiFiScanManager.h @@ -0,0 +1,36 @@ +#pragma once + +#include + +#include +#include + +namespace OpenShock::WiFiScanManager { + bool Init(); + + bool IsScanning(); + + bool StartScan(); + void CancelScan(); + + typedef std::uint64_t CallbackHandle; + + typedef std::function ScanStartedHandler; + CallbackHandle RegisterScanStartedHandler(const ScanStartedHandler& handler); + void UnregisterScanStartedHandler(CallbackHandle id); + + enum class ScanCompletedStatus { + Success, + Cancelled, + Error, + }; + typedef std::function ScanCompletedHandler; + CallbackHandle RegisterScanCompletedHandler(const ScanCompletedHandler& handler); + void UnregisterScanCompletedHandler(CallbackHandle id); + + typedef std::function ScanDiscoveryHandler; + CallbackHandle RegisterScanDiscoveryHandler(const ScanDiscoveryHandler& handler); + void UnregisterScanDiscoveryHandler(CallbackHandle id); + + void Update(); +} // namespace OpenShock::WiFiScanManager diff --git a/include/WiFiState.h b/include/WiFiState.h index a60a2f4b..a1b5c054 100644 --- a/include/WiFiState.h +++ b/include/WiFiState.h @@ -7,4 +7,7 @@ namespace OpenShock { Connecting, Connected }; + + WiFiState GetWiFiState() noexcept; + void SetWiFiState(WiFiState state) noexcept; } // namespace OpenShock diff --git a/src/AuthenticationManager.cpp b/src/AuthenticationManager.cpp index fdcfa3c3..32ecbcde 100644 --- a/src/AuthenticationManager.cpp +++ b/src/AuthenticationManager.cpp @@ -8,77 +8,84 @@ static const char* const TAG = "AuthenticationManager"; static const char* const AUTH_TOKEN_FILE = "/authToken"; -static bool _isAuthenticated = false; +static bool _isPaired = false; static String _authToken; using namespace OpenShock; -bool AuthenticationManager::Authenticate(unsigned int pairCode) { +bool AuthenticationManager::IsPaired() { + if (_isPaired) { + return true; + } + + if (!FileUtils::TryReadFile(AUTH_TOKEN_FILE, _authToken)) { + return false; + } + HTTPClient http; - String uri = OPENSHOCK_API_URL("/1/device/pair/") + String(pairCode); + const char* const uri = OPENSHOCK_API_URL("/1/device/self"); - ESP_LOGD(TAG, "Contacting pair code url: %s", uri.c_str()); + ESP_LOGD(TAG, "Contacting self url: %s", uri); http.begin(uri); int responseCode = http.GET(); if (responseCode != 200) { - ESP_LOGE(TAG, "Error while getting auth token: [%d] %s", responseCode, http.getString().c_str()); - - _isAuthenticated = false; - return false; - } - - _authToken = http.getString(); - - if (!FileUtils::TryWriteFile(AUTH_TOKEN_FILE, _authToken)) { - ESP_LOGE(TAG, "Error while writing auth token to file"); - - _isAuthenticated = false; + ESP_LOGE(TAG, "Error while verifying auth token: [%d] %s", responseCode, http.getString().c_str()); + FileUtils::DeleteFile(AUTH_TOKEN_FILE); return false; } http.end(); - _isAuthenticated = true; + _isPaired = true; + + ESP_LOGD(TAG, "Successfully verified auth token"); return true; } -bool AuthenticationManager::IsAuthenticated() { - if (_isAuthenticated) { - return true; - } - - if (!FileUtils::TryReadFile(AUTH_TOKEN_FILE, _authToken)) { - return false; - } - +bool AuthenticationManager::Pair(unsigned int pairCode) { HTTPClient http; - const char* const uri = OPENSHOCK_API_URL("/1/device/self"); + String uri = OPENSHOCK_API_URL("/1/device/pair/") + String(pairCode); - ESP_LOGD(TAG, "Contacting self url: %s", uri); + ESP_LOGD(TAG, "Contacting pair code url: %s", uri.c_str()); http.begin(uri); int responseCode = http.GET(); if (responseCode != 200) { - ESP_LOGE(TAG, "Error while verifying auth token: [%d] %s", responseCode, http.getString().c_str()); - FileUtils::DeleteFile(AUTH_TOKEN_FILE); + ESP_LOGE(TAG, "Error while getting auth token: [%d] %s", responseCode, http.getString().c_str()); + + _isPaired = false; return false; } - http.end(); + _authToken = http.getString(); - _isAuthenticated = true; + if (!FileUtils::TryWriteFile(AUTH_TOKEN_FILE, _authToken)) { + ESP_LOGE(TAG, "Error while writing auth token to file"); - ESP_LOGD(TAG, "Successfully verified auth token"); + _isPaired = false; + return false; + } + + http.end(); + + _isPaired = true; return true; } +void AuthenticationManager::UnPair() { + _isPaired = false; + _authToken = ""; + + FileUtils::DeleteFile(AUTH_TOKEN_FILE); +} + String AuthenticationManager::GetAuthToken() { - if (_isAuthenticated) { + if (_isPaired) { return _authToken; } @@ -86,14 +93,7 @@ String AuthenticationManager::GetAuthToken() { return ""; } - _isAuthenticated = true; + _isPaired = true; return _authToken; } - -void AuthenticationManager::ClearAuthToken() { - _isAuthenticated = false; - _authToken = ""; - - FileUtils::DeleteFile(AUTH_TOKEN_FILE); -} diff --git a/src/CaptivePortal.cpp b/src/CaptivePortal.cpp index 776fff60..1c21eb21 100644 --- a/src/CaptivePortal.cpp +++ b/src/CaptivePortal.cpp @@ -1,7 +1,10 @@ #include "CaptivePortal.h" #include "AuthenticationManager.h" +#include "Mappers/EspWiFiTypesMapper.h" +#include "Utils/HexUtils.h" #include "WiFiManager.h" +#include "WiFiScanManager.h" #include #include @@ -24,8 +27,11 @@ struct CaptivePortalInstance { AsyncWebServer webServer; WebSocketsServer socketServer; + OpenShock::WiFiScanManager::CallbackHandle wifiScanStartedHandlerId; + OpenShock::WiFiScanManager::CallbackHandle wifiScanCompletedHandlerId; + OpenShock::WiFiScanManager::CallbackHandle wifiScanDiscoveryHandlerId; }; -std::unique_ptr s_webServices = nullptr; +static std::unique_ptr s_webServices = nullptr; void handleWebSocketEvent(std::uint8_t socketId, WStype_t type, std::uint8_t* data, std::size_t len); void handleHttpNotFound(AsyncWebServerRequest* request); @@ -68,6 +74,49 @@ bool CaptivePortal::Start() { s_webServices->webServer.onNotFound([](AsyncWebServerRequest* request) { request->send(404, "text/plain", "Not found"); }); s_webServices->webServer.begin(); + s_webServices->wifiScanStartedHandlerId = WiFiScanManager::RegisterScanStartedHandler([]() { + StaticJsonDocument<256> doc; + doc["type"] = "wifi"; + doc["subject"] = "scan"; + doc["status"] = "started"; + CaptivePortal::BroadcastMessageJSON(doc); + }); + s_webServices->wifiScanCompletedHandlerId = WiFiScanManager::RegisterScanCompletedHandler([](WiFiScanManager::ScanCompletedStatus status) { + StaticJsonDocument<256> doc; + doc["type"] = "wifi"; + doc["subject"] = "scan"; + switch (status) { + case WiFiScanManager::ScanCompletedStatus::Success: + doc["status"] = "success"; + break; + case WiFiScanManager::ScanCompletedStatus::Cancelled: + doc["status"] = "cancelled"; + break; + case WiFiScanManager::ScanCompletedStatus::Error: + doc["status"] = "error"; + break; + default: + doc["status"] = "unknown"; + break; + } + CaptivePortal::BroadcastMessageJSON(doc); + }); + s_webServices->wifiScanDiscoveryHandlerId = WiFiScanManager::RegisterScanDiscoveryHandler([](const wifi_ap_record_t* record) { + StaticJsonDocument<256> doc; + doc["type"] = "wifi"; + doc["subject"] = "scan"; + doc["status"] = "discovery"; + + auto data = doc.createNestedObject("data"); + data["ssid"] = reinterpret_cast(record->ssid); + data["bssid"] = HexUtils::ToHexMac<6>(record->bssid).data(); + data["rssi"] = record->rssi; + data["channel"] = record->primary; + data["security"] = Mappers::GetWiFiAuthModeName(record->authmode); + + CaptivePortal::BroadcastMessageJSON(doc); + }); + ESP_LOGD(TAG, "Started"); return true; @@ -83,6 +132,10 @@ void CaptivePortal::Stop() { s_webServices->webServer.end(); s_webServices->socketServer.close(); + WiFiScanManager::UnregisterScanStartedHandler(s_webServices->wifiScanStartedHandlerId); + WiFiScanManager::UnregisterScanCompletedHandler(s_webServices->wifiScanCompletedHandlerId); + WiFiScanManager::UnregisterScanDiscoveryHandler(s_webServices->wifiScanDiscoveryHandlerId); + s_webServices = nullptr; WiFi.softAPdisconnect(true); @@ -97,6 +150,24 @@ void CaptivePortal::Update() { s_webServices->socketServer.loop(); } +bool CaptivePortal::SendMessageTXT(std::uint8_t socketId, const char* data, std::size_t len) { + if (s_webServices == nullptr) { + return false; + } + + s_webServices->socketServer.sendTXT(socketId, data, len); + + return true; +} +bool CaptivePortal::SendMessageBIN(std::uint8_t socketId, const std::uint8_t* data, std::size_t len) { + if (s_webServices == nullptr) { + return false; + } + + s_webServices->socketServer.sendBIN(socketId, data, len); + + return true; +} bool CaptivePortal::BroadcastMessageTXT(const char* data, std::size_t len) { if (s_webServices == nullptr) { return false; @@ -117,12 +188,85 @@ bool CaptivePortal::BroadcastMessageBIN(const std::uint8_t* data, std::size_t le } void handleWebSocketClientConnected(std::uint8_t socketId) { - ESP_LOGD( - TAG, "WebSocket client #%u connected from %s", socketId, s_webServices->socketServer.remoteIP(socketId).toString().c_str()); + ESP_LOGD(TAG, "WebSocket client #%u connected from %s", socketId, s_webServices->socketServer.remoteIP(socketId).toString().c_str()); + + StaticJsonDocument<24> doc; + doc["type"] = "poggies"; + CaptivePortal::SendMessageJSON(socketId, doc); } void handleWebSocketClientDisconnected(std::uint8_t socketId) { ESP_LOGD(TAG, "WebSocket client #%u disconnected", socketId); } +void handleWebSocketClientWiFiScanMessage(const StaticJsonDocument<256>& doc) { + bool run = doc["run"]; + if (run) { + WiFiScanManager::StartScan(); + } else { + WiFiScanManager::CancelScan(); + } +} +void handleWebSocketClientWiFiAuthenticateMessage(const StaticJsonDocument<256>& doc) { + String bssidStr = doc["bssid"]; + if (bssidStr.isEmpty()) { + ESP_LOGE(TAG, "WiFi BSSID is missing"); + return; + } + if (bssidStr.length() != 17) { + ESP_LOGE(TAG, "WiFi BSSID is invalid"); + return; + } + + String password = doc["password"]; + + // Convert BSSID to byte array + // Uses sscanf to parse the max-style hex format, e.g. "AA:BB:CC:DD:EE:FF" where each pair is a byte, and %02X means to parse 2 characters as a hex byte + // We check the return value to ensure that we parsed all 6 arguments (6 pairs of hex bytes, or 6 bytes) + std::uint8_t bssid[6]; + if (sscanf(bssidStr.c_str(), "%02X:%02X:%02X:%02X:%02X:%02X", bssid + 0, bssid + 1, bssid + 2, bssid + 3, bssid + 4, bssid + 5) != 6) { + ESP_LOGE(TAG, "WiFi BSSID is invalid"); + return; + } + + std::size_t passwordLength = password.length(); + if (passwordLength > UINT8_MAX) { + ESP_LOGE(TAG, "WiFi password is too long"); + return; + } + + WiFiManager::Authenticate(bssid, password.c_str(), static_cast(passwordLength)); +} +void handleWebSocketClientWiFiConnectMessage(const StaticJsonDocument<256>& doc) { + std::uint16_t wifiId = doc["id"]; + + WiFiManager::Connect(wifiId); +} +void handleWebSocketClientWiFiDisconnectMessage(const StaticJsonDocument<256>& doc) { + WiFiManager::Disconnect(); +} +void handleWebSocketClientWiFiForgetMessage(const StaticJsonDocument<256>& doc) { + WiFiManager::Forget(doc["bssid"]); +} +void handleWebSocketClientWiFiMessage(StaticJsonDocument<256> doc) { + String actionStr = doc["action"]; + if (actionStr.isEmpty()) { + ESP_LOGE(TAG, "Received WiFi message with \"action\" property missing"); + return; + } + + if (actionStr == "scan") { + handleWebSocketClientWiFiScanMessage(doc); + } else if (actionStr == "authenticate") { + handleWebSocketClientWiFiAuthenticateMessage(doc); + } else if (actionStr == "connect") { + handleWebSocketClientWiFiConnectMessage(doc); + } else if (actionStr == "disconnect") { + handleWebSocketClientWiFiDisconnectMessage(doc); + } else if (actionStr == "forget") { + handleWebSocketClientWiFiForgetMessage(doc); + } else { + ESP_LOGE(TAG, "Received WiFi message with unknown action \"%s\"", actionStr.c_str()); + } +} void handleWebSocketClientMessage(std::uint8_t socketId, WStype_t type, std::uint8_t* data, std::size_t len) { if (type != WStype_t::WStype_TEXT) { ESP_LOGE(TAG, "Message type is not supported"); @@ -137,26 +281,24 @@ void handleWebSocketClientMessage(std::uint8_t socketId, WStype_t type, std::uin } String typeStr = doc["type"]; - if (typeStr.length() == 0) { + if (typeStr.isEmpty()) { ESP_LOGE(TAG, "Message type is missing"); return; } - if (typeStr == "startScan") { - WiFiManager::StartScan(); - } /* else if (typeStr == "connect") { - WiFiManager::Connect(doc["ssid"], doc["password"]); - } else if (typeStr == "disconnect") { - WiFiManager::Disconnect(); - } else if (typeStr == "authenticate") { - AuthenticationManager::Authenticate(doc["code"]); - } else if (typeStr == "pair") { - AuthenticationManager::Pair(doc["code"]); - } else if (typeStr == "unpair") { - AuthenticationManager::Unpair(); - } else if (typeStr == "setRmtPin") { - AuthenticationManager::SetRmtPin(doc["pin"]); - }*/ + if (typeStr == "wifi") { + handleWebSocketClientWiFiMessage(doc); + } else if (typeStr == "pair") { + if (!doc.containsKey("code")) { + ESP_LOGE(TAG, "Pair message is missing \"code\" property"); + return; + } + AuthenticationManager::Pair(doc["code"]); + } else if (typeStr == "unpair") { + AuthenticationManager::UnPair(); + } else if (typeStr == "tx_pin") { + //AuthenticationManager::SetRmtPin(doc["pin"]); + } } void handleWebSocketClientError(std::uint8_t socketId, std::uint16_t code, const char* message) { ESP_LOGE(TAG, "WebSocket client #%u error %u: %s", socketId, code, message); diff --git a/src/Mappers/EspWiFiTypesMapper.cpp b/src/Mappers/EspWiFiTypesMapper.cpp new file mode 100644 index 00000000..755b05fa --- /dev/null +++ b/src/Mappers/EspWiFiTypesMapper.cpp @@ -0,0 +1,28 @@ +#include "Mappers/EspWiFiTypesMapper.h" + +using namespace OpenShock; + +const char* Mappers::GetWiFiAuthModeName(wifi_auth_mode_t authMode) { + switch (authMode) { + case WIFI_AUTH_OPEN: + return "Open"; + case WIFI_AUTH_WEP: + return "WEP"; + case WIFI_AUTH_WPA_PSK: + return "WPA PSK"; + case WIFI_AUTH_WPA2_PSK: + return "WPA2 PSK"; + case WIFI_AUTH_WPA_WPA2_PSK: + return "WPA/WPA2 PSK"; + case WIFI_AUTH_WPA2_ENTERPRISE: + return "WPA2 Enterprise"; + case WIFI_AUTH_WPA3_PSK: + return "WPA3 PSK"; + case WIFI_AUTH_WPA2_WPA3_PSK: + return "WPA2/WPA3 PSK"; + case WIFI_AUTH_WAPI_PSK: + return "WAPI PSK"; + default: + return nullptr; + } +} diff --git a/src/VisualStateManager.cpp b/src/VisualStateManager.cpp index fcf216e8..86b5f02a 100644 --- a/src/VisualStateManager.cpp +++ b/src/VisualStateManager.cpp @@ -95,11 +95,6 @@ void VisualStateManager::SetCriticalError() { } void VisualStateManager::SetWiFiState(WiFiState state) { - static WiFiState _state = (WiFiState)-1; - if (_state == state) { - return; - } - ESP_LOGD(TAG, "SetWiFiStateState: %d", state); switch (state) { case WiFiState::Disconnected: @@ -111,8 +106,6 @@ void VisualStateManager::SetWiFiState(WiFiState state) { default: return; } - - _state = state; } #endif // OPENSHOCK_LED_GPIO diff --git a/src/WiFiCredentials.cpp b/src/WiFiCredentials.cpp new file mode 100644 index 00000000..fc2feca1 --- /dev/null +++ b/src/WiFiCredentials.cpp @@ -0,0 +1,183 @@ +#include "WiFiCredentials.h" + +#include "Utils/HexUtils.h" + +#include + +#include + +#include + +const char* const TAG = "WiFiCredentials"; + +using namespace OpenShock; + +const char* const WiFiDir = "/wifi/"; +const char* const WiFiCredsDir = "/wifi/creds/"; + +inline void GetWiFiCredsFilename(char (&filename)[15], std::uint8_t id) { + memcpy(filename, WiFiCredsDir, 12); + HexUtils::ToHex(id, filename + 12); +} + +template +std::uint8_t CopyString(const char* src, std::uint8_t srcLength, char (&dest)[N]) { + if (src == nullptr || srcLength == 0) { + ESP_LOGW(TAG, "String is null/empty, clearing"); + memset(dest, 0, N); + return 0; + } + + if (srcLength > N - 1) { + ESP_LOGW(TAG, "String is too long, truncating"); + srcLength = N - 1; + } + + memcpy(dest, src, srcLength); + dest[srcLength] = 0; + + return srcLength; +} + +template +bool WriteString(fs::File& file, const char (&str)[N], std::uint8_t length) { + if (length > N - 1) return false; + + file.write(&length, sizeof(length)); + file.write(reinterpret_cast(str), length); + + return true; +} +template +bool ReadString(fs::File& file, char (&str)[N], std::uint8_t& length) { + file.read(&length, sizeof(length)); + + if (length > N - 1) return false; + + file.read(reinterpret_cast(str), length); + + str[length] = 0; // Ensure null-terminated + + return true; +} + +bool WiFiCredentials::Load(std::vector& credentials) { + credentials.clear(); + + // Ensure the credentials directory exists + if (!LittleFS.exists(WiFiCredsDir)) { + if (!LittleFS.exists(WiFiDir)) { + if (!LittleFS.mkdir(WiFiDir)) { + ESP_LOGE(TAG, "Failed to create WiFi directory"); + return false; + } + } + if (!LittleFS.mkdir(WiFiCredsDir)) { + ESP_LOGE(TAG, "Failed to create WiFi credentials directory"); + return false; + } + ESP_LOGI(TAG, "No credentials directory found, created one"); + return true; + } + + File credsDir = LittleFS.open(WiFiCredsDir); + if (!credsDir) { + ESP_LOGE(TAG, "Failed to open WiFi credentials directory"); + return false; + } + + while (true) { + File file = credsDir.openNextFile(); + if (!file) { + break; + } + + WiFiCredentials creds; + if (!creds._load(file)) { + ESP_LOGE(TAG, "Failed to load credentials from %s", file.name()); + continue; + } + + credentials.push_back(creds); + } + + return true; +} + +WiFiCredentials::WiFiCredentials(std::uint8_t id, const char* ssid, std::uint8_t ssidLength, const char* password, std::uint8_t passwordLength) { + _id = id; + + setSSID(ssid, ssidLength); + setPassword(password, passwordLength); +} + +void WiFiCredentials::setSSID(const char* ssid, std::uint8_t ssidLength) { + _ssidLength = CopyString(ssid, ssidLength, _ssid); +} + +void WiFiCredentials::setPassword(const char* password, std::uint8_t passwordLength) { + _passwordLength = CopyString(password, passwordLength, _password); +} + +bool WiFiCredentials::save() const { + char filename[15]; + GetWiFiCredsFilename(filename, _id); + File file = LittleFS.open(filename, "wb"); + if (!file) { + ESP_LOGE(TAG, "Failed to open file %s for writing", filename); + return false; + } + + file.write(_id); + if (!WriteString(file, _ssid, _ssidLength)) return false; + if (!WriteString(file, _password, _passwordLength)) return false; + + file.close(); + + return true; +} + +bool WiFiCredentials::erase() const { + char filename[15]; + GetWiFiCredsFilename(filename, _id); + if (!LittleFS.remove(filename)) { + ESP_LOGE(TAG, "Failed to remove file %s", filename); + return false; + } + + return true; +} + +bool WiFiCredentials::_load(fs::File& file) { + if (!file) { + ESP_LOGE(TAG, "File is not open"); + return false; + } + + file.read(&_id, sizeof(_id)); + if (_id > 31) { + const char* filename = file.name(); + ESP_LOGE(TAG, "Loading credentials for %s failed: ID is too large (needs to fit into a uint32 by bitshifting)", filename); // Look in WiFiManager.cpp for the bitshifting + ESP_LOGW(TAG, "Deleting credentials for %s", filename); + LittleFS.remove(filename); + return false; + } + + if (!ReadString(file, _ssid, _ssidLength)) { + const char* filename = file.name(); + ESP_LOGE(TAG, "Loading credentials for %s failed: SSID length is too long", filename); + ESP_LOGW(TAG, "Deleting credentials for %s", filename); + LittleFS.remove(filename); + return false; + } + + if (!ReadString(file, _password, _passwordLength)) { + const char* filename = file.name(); + ESP_LOGE(TAG, "Loading credentials for %s failed: password length is too long", filename); + ESP_LOGW(TAG, "Deleting credentials for %s", filename); + LittleFS.remove(filename); + return false; + } + + return true; +} diff --git a/src/WiFiManager.cpp b/src/WiFiManager.cpp index a08b9691..e4691a5e 100644 --- a/src/WiFiManager.cpp +++ b/src/WiFiManager.cpp @@ -1,228 +1,207 @@ #include "WiFiManager.h" #include "CaptivePortal.h" +#include "Mappers/EspWiFiTypesMapper.h" #include "VisualStateManager.h" +#include "WiFiCredentials.h" +#include "WiFiScanManager.h" #include #include #include +#include + #include +#include + const char* const TAG = "WiFiManager"; using namespace OpenShock; -struct WifiCredentials { - String ssid; - String password; - std::uint16_t wifiIndex; - std::uint8_t attempts; -}; -static std::vector s_wifiCredentials; -static WiFiState s_wifiState; - -void SetWiFiState(WiFiState state) { - s_wifiState = state; - VisualStateManager::SetWiFiState(state); +void _broadcastWifiAddNetworkSuccess(const char* ssid) { + DynamicJsonDocument doc(64); + doc["type"] = "wifi"; + doc["subject"] = "add_network"; + doc["status"] = "success"; + doc["ssid"] = ssid; + CaptivePortal::BroadcastMessageJSON(doc); +} +void _broadcastWifiAddNetworkError(const char* error) { + DynamicJsonDocument doc(64); + doc["type"] = "wifi"; + doc["subject"] = "add_network"; + doc["status"] = "error"; + doc["error"] = error; + CaptivePortal::BroadcastMessageJSON(doc); } -bool SaveCredentials() { - File file = LittleFS.open("/networks", FILE_WRITE); - if (!file) { - ESP_LOGE(TAG, "Failed to open networks file for writing"); - return false; - } +struct WiFiNetwork { + char ssid[33]; + std::uint8_t bssid[6]; + std::uint8_t channel; + std::int8_t rssi; + wifi_auth_mode_t authMode; + std::uint16_t reconnectionCount; + std::uint8_t credentialsId; +}; - DynamicJsonDocument doc(1024); - JsonArray networks = doc.createNestedArray("networks"); +static std::vector s_wifiNetworks; +static std::vector s_wifiCredentials; +bool _addNetwork(const char* ssid, std::uint8_t ssidLength, const char* password, std::uint8_t passwordLength) { + // Bitmask representing available credential IDs (0-31) + std::uint32_t bits = 0; for (auto& cred : s_wifiCredentials) { - JsonObject network = networks.createNestedObject(); - network["ssid"] = cred.ssid; - network["password"] = cred.password; - } + if (strcmp(cred.ssid().data(), ssid) == 0) { + ESP_LOGE(TAG, "Failed to add WiFi credentials: credentials for %s already exist", ssid); + cred.setPassword(password, passwordLength); + cred.save(); + return true; + } - if (serializeJson(doc, file) == 0) { - ESP_LOGE(TAG, "Failed to serialize networks file"); - file.close(); - return false; + // Mark the credential ID as used + bits |= 1u << cred.id(); } - file.close(); - return true; -} -bool ReadCredentials() { - File file = LittleFS.open("/networks", FILE_READ); - if (!file) { - ESP_LOGE(TAG, "Failed to open networks file for reading"); + // If we have 31 credentials, we can't add any more + if (s_wifiCredentials.size() == 31) { + ESP_LOGE(TAG, "Cannot add WiFi credentials: too many credentials"); return false; } - DynamicJsonDocument doc(1024); - if (deserializeJson(doc, file) != DeserializationError::Ok) { - file.close(); - - ESP_LOGE(TAG, "Failed to deserialize networks file, overwriting"); - SaveCredentials(); - - return false; + std::uint8_t id = 0; + while (bits & (1u << id)) { + id++; } - s_wifiCredentials.clear(); + WiFiCredentials credentials(id, ssid, ssidLength, password, passwordLength); + credentials.save(); - JsonArray networks = doc["networks"]; - for (int i = 0; i < networks.size(); i++) { - JsonObject network = networks[i]; - String ssid = network["ssid"]; - String password = network["password"]; - - s_wifiCredentials.push_back({ssid, password, UINT16_MAX, 0}); - ESP_LOGD(TAG, "Read credentials for %s", ssid.c_str()); - } - - file.close(); + s_wifiCredentials.push_back(std::move(credentials)); return true; } -void _evScanCompleted(arduino_event_id_t event, arduino_event_info_t info) { - ESP_LOGD(TAG, "Scan completed"); - - std::uint16_t numNetworks = WiFi.scanComplete(); - if (numNetworks < 0) { - if (numNetworks != WIFI_SCAN_RUNNING) { - ESP_LOGE(TAG, "Scan failed"); - SetWiFiState(WiFiState::Disconnected); - CaptivePortal::Start(); - } else { - ESP_LOGE(TAG, "Scan is still running"); - } - return; - } - - DynamicJsonDocument doc(64 + numNetworks * 128); - JsonArray networks = doc.createNestedArray("networks"); - - if (numNetworks == 0) { - ESP_LOGD(TAG, "No networks found"); - SetWiFiState(WiFiState::Disconnected); - CaptivePortal::Start(); - CaptivePortal::BroadcastMessageJSON(doc); - return; - } - - for (auto& cred : s_wifiCredentials) { - cred.wifiIndex = UINT16_MAX; - } - - std::uint16_t recognizedNetworks = 0; - for (std::uint16_t i = 0; i < numNetworks; i++) { - String ssid = WiFi.SSID(i); - bool saved = false; - for (auto& cred : s_wifiCredentials) { - if (cred.ssid == ssid) { - cred.wifiIndex = i; - recognizedNetworks++; - saved = true; - break; - } - } - JsonObject network = networks.createNestedObject(); - network["index"] = i; - network["ssid"] = ssid; - network["bssid"] = WiFi.BSSIDstr(i); - network["rssi"] = WiFi.RSSI(i); - network["channel"] = WiFi.channel(i); - network["saved"] = saved; - } - - CaptivePortal::BroadcastMessageJSON(doc); - - if (recognizedNetworks == 0) { - ESP_LOGD(TAG, "No recognized networks found"); - SetWiFiState(WiFiState::Disconnected); - CaptivePortal::Start(); - return; - } - - // Attempt to connect to the first recognized network - for (auto& cred : s_wifiCredentials) { - if (cred.wifiIndex != UINT16_MAX) { - ESP_LOGD(TAG, "Attempting to connect to %s", cred.ssid.c_str()); - WiFi.begin(cred.ssid.c_str(), cred.password.c_str()); - SetWiFiState(WiFiState::Connecting); - return; - } - } -} void _evWiFiConnected(arduino_event_id_t event, arduino_event_info_t info) { ESP_LOGD(TAG, "WiFi connected"); - SetWiFiState(WiFiState::Connected); + OpenShock::SetWiFiState(WiFiState::Connected); CaptivePortal::Stop(); } void _evWiFiDisconnected(arduino_event_id_t event, arduino_event_info_t info) { ESP_LOGD(TAG, "WiFi disconnected"); - SetWiFiState(WiFiState::Disconnected); + OpenShock::SetWiFiState(WiFiState::Disconnected); CaptivePortal::Start(); } +void _evWiFiNetworkDiscovered(const wifi_ap_record_t* record) { + WiFiNetwork network { + .ssid = {0}, + .bssid = {0}, + .channel = record->primary, + .rssi = record->rssi, + .authMode = record->authmode, + .reconnectionCount = 0, + .credentialsId = UINT8_MAX, + }; + + static_assert(sizeof(network.ssid) == sizeof(record->ssid), "SSID size mismatch"); + memcpy(network.ssid, record->ssid, sizeof(network.ssid)); + + static_assert(sizeof(network.bssid) == sizeof(record->bssid), "BSSID size mismatch"); + memcpy(network.bssid, record->bssid, sizeof(network.bssid)); + + s_wifiNetworks.push_back(network); +} bool WiFiManager::Init() { - ReadCredentials(); + WiFiCredentials::Load(s_wifiCredentials); - WiFi.onEvent(_evScanCompleted, ARDUINO_EVENT_WIFI_SCAN_DONE); WiFi.onEvent(_evWiFiConnected, ARDUINO_EVENT_WIFI_STA_CONNECTED); WiFi.onEvent(_evWiFiDisconnected, ARDUINO_EVENT_WIFI_STA_DISCONNECTED); + WiFiScanManager::RegisterScanDiscoveryHandler(_evWiFiNetworkDiscovered); + + if (!WiFiScanManager::Init()) { + ESP_LOGE(TAG, "Failed to initialize WiFiScanManager"); + return false; + } WiFi.mode(WIFI_STA); WiFi.setHostname("OpenShock"); // TODO: Add the device name to the hostname (retrieve from API and store in LittleFS) if (s_wifiCredentials.size() > 0) { WiFi.scanNetworks(true); - SetWiFiState(WiFiState::Scanning); + OpenShock::SetWiFiState(WiFiState::Scanning); } else { CaptivePortal::Start(); - SetWiFiState(WiFiState::Disconnected); + OpenShock::SetWiFiState(WiFiState::Disconnected); } return true; } -WiFiState WiFiManager::GetWiFiState() { - return s_wifiState; -} - -void WiFiManager::AddOrUpdateNetwork(const char* ssid, const char* password) { - for (auto& cred : s_wifiCredentials) { - if (cred.ssid == ssid) { - cred.password = password; - cred.attempts = 0; - SaveCredentials(); - return; +bool WiFiManager::Authenticate(std::uint8_t (&bssid)[6], const char* password, std::uint8_t passwordLength) { + bool found = false; + char ssid[33]; + for (std::uint16_t i = 0; i < s_wifiNetworks.size(); i++) { + static_assert(sizeof(s_wifiNetworks[i].bssid) == sizeof(bssid), "BSSID size mismatch"); + if (memcmp(s_wifiNetworks[i].bssid, bssid, sizeof(bssid)) == 0) { + memcpy(ssid, s_wifiNetworks[i].ssid, sizeof(ssid)); + found = true; + break; } } - s_wifiCredentials.push_back({ssid, password, UINT16_MAX, 0}); - SaveCredentials(); + if (!found) { + ESP_LOGE(TAG, "Failed to find network with BSSID %02X:%02X:%02X:%02X:%02X:%02X", bssid[0], bssid[1], bssid[2], bssid[3], bssid[4], bssid[5]); + _broadcastWifiAddNetworkError("network_not_found"); + return false; + } + + if (!_addNetwork(ssid, strlen(ssid), password, passwordLength)) { + _broadcastWifiAddNetworkError("too_many_credentials"); + return false; + } + + _broadcastWifiAddNetworkSuccess(ssid); + + wl_status_t stat = WiFi.begin(ssid, password, 0, bssid, true); + if (stat != WL_CONNECTED) { + ESP_LOGE(TAG, "Failed to connect to network %s, error code %d", ssid, stat); + return false; + } + + return true; } -void WiFiManager::RemoveNetwork(const char* ssid) { +void WiFiManager::Forget(std::uint8_t wifiId) { for (auto it = s_wifiCredentials.begin(); it != s_wifiCredentials.end(); it++) { - if (it->ssid == ssid) { + if (it->id() == wifiId) { s_wifiCredentials.erase(it); - SaveCredentials(); + it->erase(); return; } } } -bool WiFiManager::StartScan() { - if (s_wifiState != WiFiState::Disconnected) return false; +void WiFiManager::Connect(std::uint8_t wifiId) { + if (OpenShock::GetWiFiState() != WiFiState::Disconnected) return; + + for (auto& creds : s_wifiCredentials) { + if (creds.id() == wifiId) { + WiFi.begin(creds.ssid().data(), creds.password().data()); + OpenShock::SetWiFiState(WiFiState::Connecting); + return; + } + } - CaptivePortal::BroadcastMessageTXT("{\"scanning\":true}"); + ESP_LOGE(TAG, "Failed to find credentials with ID %u", wifiId); +} - WiFi.scanNetworks(true); - SetWiFiState(WiFiState::Scanning); +void WiFiManager::Disconnect() { + if (OpenShock::GetWiFiState() != WiFiState::Connected) return; - return true; + WiFi.disconnect(true); + OpenShock::SetWiFiState(WiFiState::Disconnected); + CaptivePortal::Start(); } diff --git a/src/WiFiScanManager.cpp b/src/WiFiScanManager.cpp new file mode 100644 index 00000000..15912600 --- /dev/null +++ b/src/WiFiScanManager.cpp @@ -0,0 +1,198 @@ +#include "WiFiScanManager.h" + +#include + +#include + +#include + +const char* const TAG = "WiFiScanManager"; + +constexpr const std::uint8_t OPENSHOCK_WIFI_SCAN_MAX_CHANNEL = 13; +constexpr const std::uint32_t OPENSHOCK_WIFI_SCAN_MAX_MS_PER_CHANNEL = 300; // Adjusting this value will affect the scan rate, but may also affect the scan results + +using namespace OpenShock; + +static bool s_initialized = false; +static bool s_scanInProgress = false; +static bool s_channelScanDone = false; +static std::uint8_t s_currentChannel = 0; +static std::unordered_map s_scanStartedHandlers; +static std::unordered_map s_scanCompletedHandlers; +static std::unordered_map s_scanDiscoveryHandlers; + +void _setScanInProgress(bool inProgress) { + if (s_scanInProgress != inProgress) { + s_scanInProgress = inProgress; + if (inProgress) { + ESP_LOGD(TAG, "Scan started"); + for (auto& it : s_scanStartedHandlers) { + it.second(); + } + WiFi.scanDelete(); + } else { + ESP_LOGD(TAG, "Scan completed"); + for (auto& it : s_scanCompletedHandlers) { + it.second(WiFiScanManager::ScanCompletedStatus::Success); + } + } + } + + if (!inProgress) { + s_currentChannel = 0; + s_channelScanDone = false; + } +} + +void _handleScanError(std::int16_t retval) { + s_channelScanDone = true; + + if (retval == WIFI_SCAN_FAILED) { + ESP_LOGE(TAG, "Failed to start scan on channel %u", s_currentChannel); + for (auto& it : s_scanCompletedHandlers) { + it.second(WiFiScanManager::ScanCompletedStatus::Error); + } + return; + } + + ESP_LOGE(TAG, "Scan returned an unknown error"); +} + +void _iterateChannel() { + if (s_currentChannel-- <= 1) { + s_currentChannel = 0; + _setScanInProgress(false); + return; + } + + s_channelScanDone = false; + + std::int16_t retval = WiFi.scanNetworks(true, true, false, OPENSHOCK_WIFI_SCAN_MAX_MS_PER_CHANNEL, s_currentChannel); + + if (retval == WIFI_SCAN_RUNNING) { + _setScanInProgress(true); + return; + } + + _handleScanError(retval); +} + +void _evScanCompleted(arduino_event_id_t event, arduino_event_info_t info); +void _evSTAStopped(arduino_event_id_t event, arduino_event_info_t info); + +bool WiFiScanManager::Init() { + if (s_initialized) { + ESP_LOGE(TAG, "WiFiScanManager::Init() called twice"); + return false; + } + + WiFi.onEvent(_evScanCompleted, ARDUINO_EVENT_WIFI_SCAN_DONE); + WiFi.onEvent(_evSTAStopped, ARDUINO_EVENT_WIFI_STA_STOP); + + s_initialized = true; + + return true; +} + +bool WiFiScanManager::StartScan() { + if (s_scanInProgress) { + ESP_LOGE(TAG, "Cannot start scan: scan is already in progress"); + return false; + } + + WiFi.enableSTA(true); + s_currentChannel = OPENSHOCK_WIFI_SCAN_MAX_CHANNEL; + _iterateChannel(); + + return true; +} +void WiFiScanManager::CancelScan() { + if (!s_scanInProgress) { + ESP_LOGE(TAG, "Cannot cancel scan: no scan is in progress"); + return; + } + + s_currentChannel = 0; +} + +WiFiScanManager::CallbackHandle WiFiScanManager::RegisterScanStartedHandler(const WiFiScanManager::ScanStartedHandler& handler) { + static WiFiScanManager::CallbackHandle nextId = 0; + WiFiScanManager::CallbackHandle CallbackHandle = nextId++; + s_scanStartedHandlers[CallbackHandle] = handler; + return CallbackHandle; +} +void WiFiScanManager::UnregisterScanStartedHandler(WiFiScanManager::CallbackHandle id) { + auto it = s_scanStartedHandlers.find(id); + if (it == s_scanStartedHandlers.end()) { + ESP_LOGE(TAG, "Cannot unregister scan handler: no handler with ID %u", id); + return; + } + + s_scanStartedHandlers.erase(it); +} + +WiFiScanManager::CallbackHandle WiFiScanManager::RegisterScanCompletedHandler(const WiFiScanManager::ScanCompletedHandler& handler) { + static WiFiScanManager::CallbackHandle nextId = 0; + WiFiScanManager::CallbackHandle CallbackHandle = nextId++; + s_scanCompletedHandlers[CallbackHandle] = handler; + return CallbackHandle; +} +void WiFiScanManager::UnregisterScanCompletedHandler(WiFiScanManager::CallbackHandle id) { + auto it = s_scanCompletedHandlers.find(id); + if (it == s_scanCompletedHandlers.end()) { + ESP_LOGE(TAG, "Cannot unregister scan handler: no handler with ID %u", id); + return; + } + + s_scanCompletedHandlers.erase(it); +} + +WiFiScanManager::CallbackHandle WiFiScanManager::RegisterScanDiscoveryHandler(const WiFiScanManager::ScanDiscoveryHandler& handler) { + static WiFiScanManager::CallbackHandle nextId = 0; + WiFiScanManager::CallbackHandle CallbackHandle = nextId++; + s_scanDiscoveryHandlers[CallbackHandle] = handler; + return CallbackHandle; +} +void WiFiScanManager::UnregisterScanDiscoveryHandler(WiFiScanManager::CallbackHandle id) { + auto it = s_scanDiscoveryHandlers.find(id); + if (it == s_scanDiscoveryHandlers.end()) { + ESP_LOGE(TAG, "Cannot unregister scan handler: no handler with ID %u", id); + return; + } + + s_scanDiscoveryHandlers.erase(it); +} + +void WiFiScanManager::Update() { + if (!s_initialized) return; + + if (s_scanInProgress && s_channelScanDone) { + _iterateChannel(); + } +} + +void _evScanCompleted(arduino_event_id_t event, arduino_event_info_t info) { + std::uint16_t numNetworks = WiFi.scanComplete(); + if (numNetworks < 0) { + _handleScanError(numNetworks); + return; + } + + for (std::uint16_t i = 0; i < numNetworks; i++) { + wifi_ap_record_t* record = reinterpret_cast(WiFi.getScanInfoByIndex(i)); + if (record == nullptr) { + ESP_LOGE(TAG, "Failed to get scan info for network #%u", i); + return; + } + + for (auto& it : s_scanDiscoveryHandlers) { + it.second(record); + } + } + + s_channelScanDone = true; +} +void _evSTAStopped(arduino_event_id_t event, arduino_event_info_t info) { + ESP_LOGD(TAG, "STA stopped"); + _setScanInProgress(false); +} diff --git a/src/WiFiState.cpp b/src/WiFiState.cpp new file mode 100644 index 00000000..40bd912e --- /dev/null +++ b/src/WiFiState.cpp @@ -0,0 +1,15 @@ +#include "WiFiState.h" + +#include "VisualStateManager.h" + +static OpenShock::WiFiState s_wifiState = OpenShock::WiFiState::Disconnected; + +OpenShock::WiFiState OpenShock::GetWiFiState() noexcept { + return s_wifiState; +} +void OpenShock::SetWiFiState(WiFiState state) noexcept { + if (s_wifiState == state) return; + + s_wifiState = state; + VisualStateManager::SetWiFiState(state); +} diff --git a/src/main.cpp b/src/main.cpp index 0d058714..dd271fbe 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -5,6 +5,7 @@ #include "FileUtils.h" #include "SerialInputHandler.h" #include "WiFiManager.h" +#include "WiFiScanManager.h" #include #include @@ -43,4 +44,6 @@ void loop() { if (s_apiConnection != nullptr) { s_apiConnection->Update(); } + + OpenShock::WiFiScanManager::Update(); }