From e96348849f7fe4504f11f004ba75d3de48985105 Mon Sep 17 00:00:00 2001 From: Bharat Ramanathan Date: Wed, 19 Jun 2024 16:00:02 +0530 Subject: [PATCH 1/3] fix: move sql database creation to start-up --- src/wandbot/api/app.py | 36 +++++++++++++++++++++++--------- src/wandbot/database/database.py | 3 --- 2 files changed, 26 insertions(+), 13 deletions(-) diff --git a/src/wandbot/api/app.py b/src/wandbot/api/app.py index cdc6e54..7679ff9 100644 --- a/src/wandbot/api/app.py +++ b/src/wandbot/api/app.py @@ -34,13 +34,14 @@ from datetime import datetime, timezone import pandas as pd -import weave -from fastapi import FastAPI - import wandb +import weave +from fastapi import BackgroundTasks, FastAPI from wandbot.api.routers import chat as chat_router from wandbot.api.routers import database as database_router from wandbot.api.routers import retrieve as retrieve_router +from wandbot.database.database import engine +from wandbot.database.models import Base from wandbot.ingestion.config import VectorStoreConfig from wandbot.retriever import VectorStore from wandbot.utils import get_logger @@ -50,6 +51,20 @@ weave.init(f"{os.environ['WANDB_ENTITY']}/{os.environ['WANDB_PROJECT']}") +is_initialized = False + + +async def initialize(): + global is_initialized + if not is_initialized: + vector_store = VectorStore.from_config(VectorStoreConfig()) + chat_router.chat = chat_router.Chat(vector_store=vector_store) + database_router.db_client = database_router.DatabaseClient() + retrieve_router.retriever = retrieve_router.SimpleRetrievalEngine( + vector_store=vector_store + ) + is_initialized = True + @asynccontextmanager async def lifespan(app: FastAPI): @@ -61,12 +76,8 @@ async def lifespan(app: FastAPI): Returns: None """ - vector_store = VectorStore.from_config(VectorStoreConfig()) - chat_router.chat = chat_router.Chat(vector_store=vector_store) - database_router.db_client = database_router.DatabaseClient() - retrieve_router.retriever = retrieve_router.SimpleRetrievalEngine( - vector_store=vector_store - ) + + Base.metadata.create_all(bind=engine) async def backup_db(): """Periodically backs up the database to a table. @@ -107,11 +118,16 @@ async def backup_db(): ) +@app.get("/") +async def root(background_tasks: BackgroundTasks): + background_tasks.add_task(initialize) + return {"message": "Initialization started in the background"} + + app.include_router(chat_router.router) app.include_router(database_router.router) app.include_router(retrieve_router.router) - if __name__ == "__main__": import uvicorn diff --git a/src/wandbot/database/database.py b/src/wandbot/database/database.py index 2647e68..f4ea15a 100644 --- a/src/wandbot/database/database.py +++ b/src/wandbot/database/database.py @@ -12,9 +12,7 @@ from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker - from wandbot.database.config import DataBaseConfig -from wandbot.database.models import Base db_config = DataBaseConfig() @@ -22,4 +20,3 @@ db_config.SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False} ) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) -Base.metadata.create_all(bind=engine) From 3133cb2dfb6bbfa144d158a581bad61fe05be45f Mon Sep 17 00:00:00 2001 From: Bharat Ramanathan Date: Wed, 19 Jun 2024 16:00:35 +0530 Subject: [PATCH 2/3] fix: references processing in slack and discord formatter --- src/wandbot/apps/utils.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/wandbot/apps/utils.py b/src/wandbot/apps/utils.py index d75f249..b8d5ae8 100644 --- a/src/wandbot/apps/utils.py +++ b/src/wandbot/apps/utils.py @@ -16,7 +16,6 @@ from typing import Any, List from pydantic_settings import BaseSettings - from wandbot.api.routers.chat import APIQueryResponse @@ -62,7 +61,7 @@ def format_response( sources_list = deduplicate( [ item - for item in response.sources.split(",") + for item in response.sources.split("\n") if item.strip().startswith("http") ] ) @@ -72,13 +71,13 @@ def format_response( result = ( f"{result}\n\n*参考文献*\n\n>" + "\n> ".join(sources_list[:items]) - + "\n\n" + + "\n" ) else: result = ( f"{result}\n\n*References*\n\n>" + "\n> ".join(sources_list[:items]) - + "\n\n" + + "\n" ) if outro_message: result = f"{result}\n\n{outro_message}" From 55ccd118a75237f64655469155feccbf5bea3ecc Mon Sep 17 00:00:00 2001 From: Bharat Ramanathan Date: Wed, 19 Jun 2024 16:01:17 +0530 Subject: [PATCH 3/3] chore: fun formatters and linters --- src/wandbot/api/app.py | 3 ++- src/wandbot/apps/utils.py | 1 + src/wandbot/database/database.py | 1 + 3 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/wandbot/api/app.py b/src/wandbot/api/app.py index 7679ff9..d815372 100644 --- a/src/wandbot/api/app.py +++ b/src/wandbot/api/app.py @@ -34,9 +34,10 @@ from datetime import datetime, timezone import pandas as pd -import wandb import weave from fastapi import BackgroundTasks, FastAPI + +import wandb from wandbot.api.routers import chat as chat_router from wandbot.api.routers import database as database_router from wandbot.api.routers import retrieve as retrieve_router diff --git a/src/wandbot/apps/utils.py b/src/wandbot/apps/utils.py index b8d5ae8..59d99f4 100644 --- a/src/wandbot/apps/utils.py +++ b/src/wandbot/apps/utils.py @@ -16,6 +16,7 @@ from typing import Any, List from pydantic_settings import BaseSettings + from wandbot.api.routers.chat import APIQueryResponse diff --git a/src/wandbot/database/database.py b/src/wandbot/database/database.py index f4ea15a..0ed0a93 100644 --- a/src/wandbot/database/database.py +++ b/src/wandbot/database/database.py @@ -12,6 +12,7 @@ from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker + from wandbot.database.config import DataBaseConfig db_config = DataBaseConfig()