From 663caac1a97afa65a6d75e5065c69613e6eeb458 Mon Sep 17 00:00:00 2001 From: Cara Haas Date: Fri, 22 Mar 2024 15:42:50 -0400 Subject: [PATCH] example --- config.py | 2 + event_generator.py | 94 ++++++++++++++++++++++++++++++++++++++++++++-- main.py | 5 +-- 3 files changed, 94 insertions(+), 7 deletions(-) diff --git a/config.py b/config.py index 8255485..507af89 100644 --- a/config.py +++ b/config.py @@ -17,6 +17,8 @@ if config["MZ_SCHEMA"]: config["options"] += f' -c search_path={config["MZ_SCHEMA"]}' + +DSN = f'''postgresql://{config["MZ_USER"]}:{config["MZ_PASSWORD"]}@{config["MZ_HOST"]}:{config["MZ_PORT"]}/{config["MZ_DB"]}?options=--cluster%3D{config["MZ_CLUSTER"]}%20-csearch_path%3D{config["MZ_SCHEMA"]}''' if __name__=="__main__": print(config) \ No newline at end of file diff --git a/event_generator.py b/event_generator.py index 9ac671b..206be1a 100644 --- a/event_generator.py +++ b/event_generator.py @@ -3,6 +3,7 @@ Materialize will push events whenever someone's bid has won an auction. ''' import logging +from typing import List, Optional import psycopg from psycopg_pool import AsyncConnectionPool @@ -11,6 +12,8 @@ from fastapi import Request from pydantic import BaseModel +from config import DSN + _logger = logging.getLogger('uvicorn.error') def log_db_diagnosis_callback(diagnosis: psycopg.Error.diag): '''Include database diagnostic messages in server logs''' @@ -18,10 +21,15 @@ def log_db_diagnosis_callback(diagnosis: psycopg.Error.diag): class WinningBid(BaseModel): '''Bid for an item at an auction''' - auction_id: int - bid_id: int - item: str - amount: int + mz_timestamp: int + mz_progressed: Optional[bool] = None + auction_id: Optional[int] = None + bid_id: Optional[int] = None + item: Optional[str] = None + amount: Optional[int] = None + +class SubscribeProgress(BaseModel): + last_progress_mz_timestamp: int async def event_generator( request: Request, @@ -60,3 +68,81 @@ async def event_generator( break except Exception as err: _logger.error(err) + +async def notify_generator( + request: Request, + pool: AsyncConnectionPool, + amount: list[int] | None = None) -> WinningBid: + ''' + Generate events that will send a notification. + Materialize will push events whenever someone's bid has won an auction. + ''' + try: + while True: + if await request.is_disconnected(): + break + as_of_ts: int = None + async with pool.connection() as conn, conn.cursor(row_factory=class_row(SubscribeProgress)) as cur: + conn.add_notice_handler(log_db_diagnosis_callback) + # In this example we used a table in materialize to store the last_progress_mz_timestamp + # But the user could store this in their own internal system. + rows = await cur.execute(""" + SELECT last_progress_mz_timestamp + FROM subscribe_progress + WHERE subscribe_name = 'notify_winners' + ORDER BY last_progress_mz_timestamp desc + LIMIT 1 + """) + async for row in rows: + as_of_ts = row.last_progress_mz_timestamp + print("as of ts: ", as_of_ts) + # Asycronously get real-time updates from Materialize + async with pool.connection() as conn, conn.cursor(row_factory=class_row(WinningBid)) as cur: + # Format query + base_query = SQL("SELECT auction_id, bid_id, item, amount FROM winning_bids") + if as_of_ts: + query = SQL("SUBSCRIBE ({}) WITH (PROGRESS, SNAPSHOT false) AS OF {}").format(base_query, as_of_ts) + else: + query = SQL("SUBSCRIBE ({}) WITH (PROGRESS, SNAPSHOT true)").format(base_query) + # Subscribe to an endless stream of updates + # Todo: handle error where AS OF timestamp is past the retain history interval + rows = cur.stream(query, amount) + print("got rows, query: ", query) + staged_data: List[WinningBid] = [] + async for row in rows: + print("row: ", row) + if row.mz_progressed: + print("in mz_progressed") + for staged_row in staged_data: + yield staged_row + staged_data.clear() + last_progress_mz_timestamp = row.mz_timestamp + + # TODO: make recording `last_progress_mz_timestamp` an async task that + # happens periodically. + # Ideally we'd be able to do `INSERT ... ON CONFLICT UPDATE ...`, or + # the server stores `last_progress_mz_timestamp` somewhere in their own + # durable infrastructure (not within mz). + print("writing last_progress_mz_timestamp ") + insert_conn = await psycopg.AsyncConnection.connect(DSN, autocommit=True) + async with insert_conn: + insert_conn.add_notice_handler(log_db_diagnosis_callback) + async with insert_conn.cursor() as insert_cursor: + if as_of_ts: + await insert_cursor.execute( + SQL("UPDATE subscribe_progress SET last_progress_mz_timestamp = {} WHERE subscribe_name = 'notify_winners'").format(last_progress_mz_timestamp) + ) + else: + await insert_cursor.execute( + "INSERT INTO subscribe_progress (subscribe_name, last_progress_mz_timestamp) VALUES (%s, %s)", + ('notify_winners',last_progress_mz_timestamp) ) + print("wrote last_progress_mz_timestamp ") + else: + staged_data.append(row) + # Detect when user disconnects and exit the event loop + if await request.is_disconnected(): + await conn.close() + break + + except Exception as err: + _logger.error(err) diff --git a/main.py b/main.py index a9c46f5..9e33a5f 100644 --- a/main.py +++ b/main.py @@ -13,7 +13,7 @@ import uvicorn from config import config -from event_generator import event_generator, WinningBid +from event_generator import event_generator, notify_generator, WinningBid # Logging stuff _logger = logging.getLogger('uvicorn.error') @@ -48,7 +48,6 @@ def open_pool(): host = config["MZ_HOST"], password = config["MZ_PASSWORD"], port = 6875, - sslmode = 'require', application_name = 'FastAPI', options = config["options"], ), @@ -70,7 +69,7 @@ async def root(): @app.get("/subscribe/", response_model=WinningBid) async def message_stream(request: Request, amount: list[int] | None = Query(default=None)): '''Retrieve events from the event generator for SSE''' - return (EventSourceResponse(event_generator(request, app.state.pool, amount))) + return (EventSourceResponse(notify_generator(request, app.state.pool, amount))) if __name__ == "__main__": logger.setLevel(_logger.level)