diff --git a/dal/classes/protocols/wsredissub.py b/dal/classes/protocols/wsredissub.py index 77f1b44c..ee5c91e3 100644 --- a/dal/classes/protocols/wsredissub.py +++ b/dal/classes/protocols/wsredissub.py @@ -48,6 +48,7 @@ def __init__(self, app: web.Application, _node_name: str, **_ignore): self.tasks = {} self.movaidb = MovaiDB() self.loop = asyncio.get_event_loop() + self.recovery_mode = False # API available actions self.actions = { @@ -68,110 +69,152 @@ async def connect(self): self.databases = await RedisClient.get_client() async def acquire(self, retries: int = 3): - _conn = None + """Acquire a Redis connection with retry logic, handle cancellations, health checks, and remove unhealthy connections""" + LOGGER.debug(f"Acquiring connection. Retries left: {retries}") if retries == 0: - return _conn + return None try: await self.connect() - _conn = await self.databases.slave_pubsub.acquire() + conn = await asyncio.wait_for(self.databases.slave_pubsub.acquire(), timeout=30) + + # Check if the connection is healthy + if await conn.execute('PING') != b'PONG': + raise Exception("Failed to receive PONG response") + return conn + except asyncio.CancelledError: + LOGGER.debug("Acquire operation was cancelled") + return None + except asyncio.TimeoutError: + LOGGER.error("Acquire operation timed out") + return None except Exception as e: - LOGGER.error(e) - await self.connect() - _conn = await self.acquire(retries - 1) - return _conn + LOGGER.error(f"Error acquiring connection: {e}. Retrying...") + # Close the connection if it's already acquired but unhealthy + if 'conn' in locals() and conn: + conn.close() + await conn.wait_closed() + await asyncio.sleep(2 ** (3 - retries)) # Exponential backoff + return await asyncio.wait_for(self.acquire(retries - 1), timeout=60) async def release(self, conn_id): conn = self.connections[conn_id]["subs"] - asyncio.create_task(conn.wait_closed()) + if conn: + self.loop.create_task(conn.wait_closed()) del self.connections[conn_id] - await self.databases.shutdown() + if self.recovery_mode: + await self.databases.shutdown() + self.recovery_mode = False - async def close_and_release(self, ws: web.WebSocketResponse, conn_id: str): - """Closes the socket, cancels active tasks and release db connections. - - Args: - ws (web.WebSocketResponse): The websocket to close - conn_id (str): the connection id. - """ - await ws.close() + async def close(self, ws: web.WebSocketResponse, conn_id: str): + """Close the WebSocket and clean up tasks and connections""" + for task in self.tasks[conn_id]: if not task.done(): task.cancel() + try: + await asyncio.wait_for(task, timeout=2) + except asyncio.CancelledError: + LOGGER.debug("[WSRedisSub.close] Task was properly cancelled "+ str(task)) + except asyncio.TimeoutError: + LOGGER.error("[WSRedisSub.close] Waited 2 seconds for it. Force cancelling! "+ str(task)) - if conn_id in self.tasks: - self.tasks.pop(conn_id) - - async def handler(self, request: web.Request) -> web.WebSocketResponse: - """handle websocket connections""" + self.tasks.pop(conn_id, None) + await self.release(conn_id) + await ws.close() - ws_resp = web.WebSocketResponse() - await ws_resp.prepare(request) + async def print_current_task_count(self): + tasks = asyncio.Task.all_tasks(loop=asyncio.get_event_loop()) + LOGGER.warning(f"Number of current tasks: {len(tasks)}") - # acquire db connection - conn = None + async def handler(self, request: web.Request) -> web.WebSocketResponse: + """Handle WebSocket connections""" connection_queue = asyncio.Queue() lock = asyncio.Lock() - try: - _conn = await self.acquire() - conn = aioredis.Redis(_conn) - except Exception as error: - LOGGER.error(str(error)) - await self.push_to_queue( - connection_queue, {"event": "", "patterns": None, "error": str(error)} - ) + conn = None + ws_resp = web.WebSocketResponse() + await ws_resp.prepare(request) + conn_id = uuid.uuid4().hex + self.connections[conn_id] = {"conn": connection_queue, "subs": conn, "patterns": []} - # add connection - self.connections.update({conn_id: {"conn": connection_queue, "subs": conn, "patterns": []}}) - - # wait for messages - write_task = asyncio.create_task(self.write_websocket_loop(ws_resp, connection_queue, lock)) + # Ensure the loop is running and get it + loop = asyncio.get_event_loop() + write_task = loop.create_task(self.write_websocket_loop(ws_resp, connection_queue, lock)) self.tasks[conn_id] = [write_task] - async for ws_msg in ws_resp: - # check if redis connection is active - if not conn or conn.closed: - print("redis connection not available") - if ws_msg.type == WSMsgType.TEXT: - # message should be json - try: - if ws_msg.data == "close": - break - data = ws_msg.json() - if "event" in data: - if data.get("event") == "execute": - _config = { - "conn_id": conn_id, - "conn": conn, - "callback": data.get("callback", None), - "func": data.get("func", None), - "data": data.get("data", None), - } - else: - _config = { - "conn_id": conn_id, - "conn": conn, - "_pattern": data.get("pattern", None), - } - await self.actions[data["event"]](**_config) - else: - raise KeyError("Not all required keys found") - except Exception as e: - LOGGER.error(e) - output = {"event": None, "patterns": None, "error": str(e)} - - await self.push_to_queue(connection_queue, output) - - elif ws_msg.type == WSMsgType.ERROR: - LOGGER.error("ws connection closed with exception %s" % ws_resp.exception()) - async with lock: - await self.close_and_release(ws_resp, conn_id) - await self.release(conn_id) + try: + try: + #_conn = await self.acquire() + _conn = await asyncio.wait_for(self.acquire(), timeout=160) + if _conn: + conn = aioredis.Redis(_conn) + self.connections[conn_id]["subs"] = conn + + else: + LOGGER.error(f"Error acquiring Redis connection:") + self.recovery_mode = True + raise Exception("Unable to acquire Redis connection") + except Exception as error: + LOGGER.error(f"Exception: {error}") + output = {"event": None, "patterns": None, "error": "{error}"} + await self.push_to_queue(connection_queue, output) + raise + + async for ws_msg in ws_resp: + if ws_msg.type == WSMsgType.TEXT: + try: + if ws_msg.data == "close": + break + data = ws_msg.json() + if "event" in data: + if data["event"] == "execute": + _config = { + "conn_id": conn_id, + "conn": conn, + "callback": data.get("callback"), + "func": data.get("func"), + "data": data.get("data"), + } + else: + _config = { + "conn_id": conn_id, + "conn": conn, + "_pattern": data.get("pattern"), + } + await self.actions[data["event"]](**_config) + else: + raise KeyError("Missing 'event' key in WebSocket message") + except aioredis.PoolClosedError: + LOGGER.debug("Connection closed while executing.") + output = {"event": None, "patterns": None, "error": "Connection was closed"} + await self.push_to_queue(connection_queue, output) + except Exception as e: + LOGGER.error(f"Error handling WebSocket message: {e}") + output = {"event": None, "patterns": None, "error": str(e)} + await self.push_to_queue(connection_queue, output) + elif ws_msg.type == WSMsgType.ERROR: + LOGGER.error(f"WebSocket connection closed with exception: {ws_resp.exception()}") + else: + LOGGER.error(f"Unexpected WebSocket message type: {ws_msg.type}") + finally: + # Drain the messaging queue before closing the connection + LOGGER.debug(f"Closing everything and cancelling ongoing async tasks!{conn_id}") + attempts = 0 + while connection_queue.qsize() > 0 and attempts < 5: + LOGGER.debug(f"Waiting for websocket writter to drain queued responses.Attempt [{attempts}]. Messaging queue size is :{connection_queue.qsize()}") + await asyncio.sleep(2) + attempts += 1 + + async with lock: + await self.close(ws_resp, conn_id) return ws_resp - + async def write_websocket_loop( - self, ws_resp: web.WebSocketResponse, connection_queue: asyncio.Queue, lock: asyncio.Lock + self, + ws_resp: web.WebSocketResponse, + connection_queue: asyncio.Queue, + lock: asyncio.Lock, ): """Write messages to websocket. args: @@ -181,19 +224,24 @@ async def write_websocket_loop( try: while True: msg = await connection_queue.get() + #used for debugging + #await self.print_current_task_count() async with lock: - if ws_resp is not None and not ws_resp.closed and not ws_resp._closing: + if ( + ws_resp is not None + and not ws_resp.closed + and not ws_resp._closing + ): await ws_resp.send_json(msg) else: break except asyncio.CancelledError: - LOGGER.debug("Write task is canceled, socket is closing") - + LOGGER.debug("Stopping websocket writter.") except Exception as err: - LOGGER.error(str(err)) + LOGGER.error("Writing back to websocket failed! " + str(err)) + - def convert_pattern(self, _pattern: dict): - """Convert pattern to redis pattern""" + def convert_pattern(self, _pattern: dict) -> str: try: pattern = _pattern.copy() scope = pattern.pop("Scope") @@ -212,6 +260,7 @@ async def add_pattern(self, conn_id, conn, _pattern, **ignore): """Add pattern to subscriber""" LOGGER.info(f"add_pattern{_pattern}") + self.connections[conn_id]["patterns"].append(_pattern) key_patterns = [] if isinstance(_pattern, list): @@ -224,7 +273,8 @@ async def add_pattern(self, conn_id, conn, _pattern, **ignore): for key_pattern in key_patterns: pattern = "__keyspace@*__:%s" % (key_pattern) channel = await conn.psubscribe(pattern) - read_task = asyncio.create_task(self.wait_message(conn_id, channel[0])) + loop = asyncio.get_event_loop() + read_task = loop.create_task(self.wait_message(conn_id, channel[0])) self.tasks[conn_id].append(read_task) # add a new get_keys task @@ -298,7 +348,6 @@ async def wait_message(self, conn_id, channel): await self.push_to_queue(ws, output) except asyncio.CancelledError: LOGGER.debug("Wait task was cancelled, socket is closing!") - except Exception as err: LOGGER.error(str(err)) @@ -335,18 +384,20 @@ async def get_value(self, keys): return output async def fetch_value(self, _conn, key): - # DEPRECATED - type_ = await _conn.type(key) - type_ = type_.decode("utf-8") - if type_ == "string": - value = await _conn.get(key) - value = self.movaidb.decode_value(value) - if type_ == "list": - value = await _conn.lrange(key, 0, -1) - value = self.movaidb.decode_list(value) - if type_ == "hash": - value = await _conn.hgetall(key) - value = self.movaidb.decode_hash(value) + try: + type_ = await _conn.type(key) + type_ = type_.decode("utf-8") + if type_ == "string": + value = await _conn.get(key) + value = self.movaidb.decode_value(value) + if type_ == "list": + value = await _conn.lrange(key, 0, -1) + value = self.movaidb.decode_list(value) + if type_ == "hash": + value = await _conn.hgetall(key) + value = self.movaidb.decode_hash(value) + except (aioredis.PoolClosedError, aioredis.ConnectionForcedCloseError): + LOGGER.debug("Connection closed while fetching value") try: # Json cannot dump ROS Messages json.dumps(value) except: @@ -438,7 +489,7 @@ async def execute(self, conn_id, conn, callback, data=None, **ignore): try: # get callback - callback = gdnode_modules["GD_Callback"]( + callback = GD_Callback( callback, self.node_name, "cloud", _update=False ) @@ -462,7 +513,13 @@ async def execute(self, conn_id, conn, callback, data=None, **ignore): except Exception: error = f"{str(sys.exc_info()[1])} {sys.exc_info()}" await self.push_to_queue( - ws, {"event": "execute", "callback": callback, "result": None, "error": error} + ws, + { + "event": "execute", + "callback": callback, + "result": None, + "error": error, + }, ) async def push_to_queue(self, conn: asyncio.Queue, data): @@ -470,4 +527,4 @@ async def push_to_queue(self, conn: asyncio.Queue, data): try: await conn.put(data) except Exception as e: - LOGGER.error(str(e)) + LOGGER.error(str(e)) \ No newline at end of file