Skip to content

Commit

Permalink
Merge pull request #234 from MOV-AI/bugfix/bp-1242
Browse files Browse the repository at this point in the history
bring changes from 2.3
  • Loading branch information
duartecoelhomovai authored Aug 5, 2024
2 parents ab3b554 + 0905c67 commit 661d1dc
Showing 1 changed file with 160 additions and 103 deletions.
263 changes: 160 additions & 103 deletions dal/classes/protocols/wsredissub.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(self, app: web.Application, _node_name: str, **_ignore):
self.tasks = {}
self.movaidb = MovaiDB()
self.loop = asyncio.get_event_loop()
self.recovery_mode = False

# API available actions
self.actions = {
Expand All @@ -68,110 +69,152 @@ async def connect(self):
self.databases = await RedisClient.get_client()

async def acquire(self, retries: int = 3):
_conn = None
"""Acquire a Redis connection with retry logic, handle cancellations, health checks, and remove unhealthy connections"""
LOGGER.debug(f"Acquiring connection. Retries left: {retries}")
if retries == 0:
return _conn
return None
try:
await self.connect()
_conn = await self.databases.slave_pubsub.acquire()
conn = await asyncio.wait_for(self.databases.slave_pubsub.acquire(), timeout=30)

# Check if the connection is healthy
if await conn.execute('PING') != b'PONG':
raise Exception("Failed to receive PONG response")
return conn
except asyncio.CancelledError:
LOGGER.debug("Acquire operation was cancelled")
return None
except asyncio.TimeoutError:
LOGGER.error("Acquire operation timed out")
return None
except Exception as e:
LOGGER.error(e)
await self.connect()
_conn = await self.acquire(retries - 1)
return _conn
LOGGER.error(f"Error acquiring connection: {e}. Retrying...")
# Close the connection if it's already acquired but unhealthy
if 'conn' in locals() and conn:
conn.close()
await conn.wait_closed()
await asyncio.sleep(2 ** (3 - retries)) # Exponential backoff
return await asyncio.wait_for(self.acquire(retries - 1), timeout=60)

async def release(self, conn_id):
conn = self.connections[conn_id]["subs"]
asyncio.create_task(conn.wait_closed())
if conn:
self.loop.create_task(conn.wait_closed())
del self.connections[conn_id]
await self.databases.shutdown()
if self.recovery_mode:
await self.databases.shutdown()
self.recovery_mode = False

async def close_and_release(self, ws: web.WebSocketResponse, conn_id: str):
"""Closes the socket, cancels active tasks and release db connections.
Args:
ws (web.WebSocketResponse): The websocket to close
conn_id (str): the connection id.
"""
await ws.close()
async def close(self, ws: web.WebSocketResponse, conn_id: str):
"""Close the WebSocket and clean up tasks and connections"""

for task in self.tasks[conn_id]:
if not task.done():
task.cancel()
try:
await asyncio.wait_for(task, timeout=2)
except asyncio.CancelledError:
LOGGER.debug("[WSRedisSub.close] Task was properly cancelled "+ str(task))
except asyncio.TimeoutError:
LOGGER.error("[WSRedisSub.close] Waited 2 seconds for it. Force cancelling! "+ str(task))

if conn_id in self.tasks:
self.tasks.pop(conn_id)

async def handler(self, request: web.Request) -> web.WebSocketResponse:
"""handle websocket connections"""
self.tasks.pop(conn_id, None)
await self.release(conn_id)
await ws.close()

ws_resp = web.WebSocketResponse()
await ws_resp.prepare(request)
async def print_current_task_count(self):
tasks = asyncio.Task.all_tasks(loop=asyncio.get_event_loop())
LOGGER.warning(f"Number of current tasks: {len(tasks)}")

# acquire db connection
conn = None
async def handler(self, request: web.Request) -> web.WebSocketResponse:
"""Handle WebSocket connections"""
connection_queue = asyncio.Queue()
lock = asyncio.Lock()
try:
_conn = await self.acquire()
conn = aioredis.Redis(_conn)
except Exception as error:
LOGGER.error(str(error))
await self.push_to_queue(
connection_queue, {"event": "", "patterns": None, "error": str(error)}
)
conn = None

ws_resp = web.WebSocketResponse()
await ws_resp.prepare(request)

conn_id = uuid.uuid4().hex
self.connections[conn_id] = {"conn": connection_queue, "subs": conn, "patterns": []}

# add connection
self.connections.update({conn_id: {"conn": connection_queue, "subs": conn, "patterns": []}})

# wait for messages
write_task = asyncio.create_task(self.write_websocket_loop(ws_resp, connection_queue, lock))
# Ensure the loop is running and get it
loop = asyncio.get_event_loop()
write_task = loop.create_task(self.write_websocket_loop(ws_resp, connection_queue, lock))
self.tasks[conn_id] = [write_task]
async for ws_msg in ws_resp:
# check if redis connection is active
if not conn or conn.closed:
print("redis connection not available")
if ws_msg.type == WSMsgType.TEXT:
# message should be json
try:
if ws_msg.data == "close":
break
data = ws_msg.json()
if "event" in data:
if data.get("event") == "execute":
_config = {
"conn_id": conn_id,
"conn": conn,
"callback": data.get("callback", None),
"func": data.get("func", None),
"data": data.get("data", None),
}
else:
_config = {
"conn_id": conn_id,
"conn": conn,
"_pattern": data.get("pattern", None),
}
await self.actions[data["event"]](**_config)
else:
raise KeyError("Not all required keys found")

except Exception as e:
LOGGER.error(e)
output = {"event": None, "patterns": None, "error": str(e)}

await self.push_to_queue(connection_queue, output)

elif ws_msg.type == WSMsgType.ERROR:
LOGGER.error("ws connection closed with exception %s" % ws_resp.exception())
async with lock:
await self.close_and_release(ws_resp, conn_id)
await self.release(conn_id)
try:
try:
#_conn = await self.acquire()
_conn = await asyncio.wait_for(self.acquire(), timeout=160)
if _conn:
conn = aioredis.Redis(_conn)
self.connections[conn_id]["subs"] = conn

else:
LOGGER.error(f"Error acquiring Redis connection:")
self.recovery_mode = True
raise Exception("Unable to acquire Redis connection")
except Exception as error:
LOGGER.error(f"Exception: {error}")
output = {"event": None, "patterns": None, "error": "{error}"}
await self.push_to_queue(connection_queue, output)
raise

async for ws_msg in ws_resp:
if ws_msg.type == WSMsgType.TEXT:
try:
if ws_msg.data == "close":
break
data = ws_msg.json()
if "event" in data:
if data["event"] == "execute":
_config = {
"conn_id": conn_id,
"conn": conn,
"callback": data.get("callback"),
"func": data.get("func"),
"data": data.get("data"),
}
else:
_config = {
"conn_id": conn_id,
"conn": conn,
"_pattern": data.get("pattern"),
}
await self.actions[data["event"]](**_config)
else:
raise KeyError("Missing 'event' key in WebSocket message")
except aioredis.PoolClosedError:
LOGGER.debug("Connection closed while executing.")
output = {"event": None, "patterns": None, "error": "Connection was closed"}
await self.push_to_queue(connection_queue, output)
except Exception as e:
LOGGER.error(f"Error handling WebSocket message: {e}")
output = {"event": None, "patterns": None, "error": str(e)}
await self.push_to_queue(connection_queue, output)
elif ws_msg.type == WSMsgType.ERROR:
LOGGER.error(f"WebSocket connection closed with exception: {ws_resp.exception()}")
else:
LOGGER.error(f"Unexpected WebSocket message type: {ws_msg.type}")
finally:
# Drain the messaging queue before closing the connection
LOGGER.debug(f"Closing everything and cancelling ongoing async tasks!{conn_id}")
attempts = 0
while connection_queue.qsize() > 0 and attempts < 5:
LOGGER.debug(f"Waiting for websocket writter to drain queued responses.Attempt [{attempts}]. Messaging queue size is :{connection_queue.qsize()}")
await asyncio.sleep(2)
attempts += 1

async with lock:
await self.close(ws_resp, conn_id)
return ws_resp

async def write_websocket_loop(
self, ws_resp: web.WebSocketResponse, connection_queue: asyncio.Queue, lock: asyncio.Lock
self,
ws_resp: web.WebSocketResponse,
connection_queue: asyncio.Queue,
lock: asyncio.Lock,
):
"""Write messages to websocket.
args:
Expand All @@ -181,19 +224,24 @@ async def write_websocket_loop(
try:
while True:
msg = await connection_queue.get()
#used for debugging
#await self.print_current_task_count()
async with lock:
if ws_resp is not None and not ws_resp.closed and not ws_resp._closing:
if (
ws_resp is not None
and not ws_resp.closed
and not ws_resp._closing
):
await ws_resp.send_json(msg)
else:
break
except asyncio.CancelledError:
LOGGER.debug("Write task is canceled, socket is closing")

LOGGER.debug("Stopping websocket writter.")
except Exception as err:
LOGGER.error(str(err))
LOGGER.error("Writing back to websocket failed! " + str(err))


def convert_pattern(self, _pattern: dict):
"""Convert pattern to redis pattern"""
def convert_pattern(self, _pattern: dict) -> str:
try:
pattern = _pattern.copy()
scope = pattern.pop("Scope")
Expand All @@ -212,6 +260,7 @@ async def add_pattern(self, conn_id, conn, _pattern, **ignore):
"""Add pattern to subscriber"""
LOGGER.info(f"add_pattern{_pattern}")


self.connections[conn_id]["patterns"].append(_pattern)
key_patterns = []
if isinstance(_pattern, list):
Expand All @@ -224,7 +273,8 @@ async def add_pattern(self, conn_id, conn, _pattern, **ignore):
for key_pattern in key_patterns:
pattern = "__keyspace@*__:%s" % (key_pattern)
channel = await conn.psubscribe(pattern)
read_task = asyncio.create_task(self.wait_message(conn_id, channel[0]))
loop = asyncio.get_event_loop()
read_task = loop.create_task(self.wait_message(conn_id, channel[0]))
self.tasks[conn_id].append(read_task)

# add a new get_keys task
Expand Down Expand Up @@ -298,7 +348,6 @@ async def wait_message(self, conn_id, channel):
await self.push_to_queue(ws, output)
except asyncio.CancelledError:
LOGGER.debug("Wait task was cancelled, socket is closing!")

except Exception as err:
LOGGER.error(str(err))

Expand Down Expand Up @@ -335,18 +384,20 @@ async def get_value(self, keys):
return output

async def fetch_value(self, _conn, key):
# DEPRECATED
type_ = await _conn.type(key)
type_ = type_.decode("utf-8")
if type_ == "string":
value = await _conn.get(key)
value = self.movaidb.decode_value(value)
if type_ == "list":
value = await _conn.lrange(key, 0, -1)
value = self.movaidb.decode_list(value)
if type_ == "hash":
value = await _conn.hgetall(key)
value = self.movaidb.decode_hash(value)
try:
type_ = await _conn.type(key)
type_ = type_.decode("utf-8")
if type_ == "string":
value = await _conn.get(key)
value = self.movaidb.decode_value(value)
if type_ == "list":
value = await _conn.lrange(key, 0, -1)
value = self.movaidb.decode_list(value)
if type_ == "hash":
value = await _conn.hgetall(key)
value = self.movaidb.decode_hash(value)
except (aioredis.PoolClosedError, aioredis.ConnectionForcedCloseError):
LOGGER.debug("Connection closed while fetching value")
try: # Json cannot dump ROS Messages
json.dumps(value)
except:
Expand Down Expand Up @@ -438,7 +489,7 @@ async def execute(self, conn_id, conn, callback, data=None, **ignore):

try:
# get callback
callback = gdnode_modules["GD_Callback"](
callback = GD_Callback(
callback, self.node_name, "cloud", _update=False
)

Expand All @@ -462,12 +513,18 @@ async def execute(self, conn_id, conn, callback, data=None, **ignore):
except Exception:
error = f"{str(sys.exc_info()[1])} {sys.exc_info()}"
await self.push_to_queue(
ws, {"event": "execute", "callback": callback, "result": None, "error": error}
ws,
{
"event": "execute",
"callback": callback,
"result": None,
"error": error,
},
)

async def push_to_queue(self, conn: asyncio.Queue, data):
"""send json data"""
try:
await conn.put(data)
except Exception as e:
LOGGER.error(str(e))
LOGGER.error(str(e))

0 comments on commit 661d1dc

Please sign in to comment.