diff --git a/src/wandbot/api/app.py b/src/wandbot/api/app.py index cdc6e54..d815372 100644 --- a/src/wandbot/api/app.py +++ b/src/wandbot/api/app.py @@ -35,12 +35,14 @@ import pandas as pd import weave -from fastapi import FastAPI +from fastapi import BackgroundTasks, FastAPI import wandb from wandbot.api.routers import chat as chat_router from wandbot.api.routers import database as database_router from wandbot.api.routers import retrieve as retrieve_router +from wandbot.database.database import engine +from wandbot.database.models import Base from wandbot.ingestion.config import VectorStoreConfig from wandbot.retriever import VectorStore from wandbot.utils import get_logger @@ -50,6 +52,20 @@ weave.init(f"{os.environ['WANDB_ENTITY']}/{os.environ['WANDB_PROJECT']}") +is_initialized = False + + +async def initialize(): + global is_initialized + if not is_initialized: + vector_store = VectorStore.from_config(VectorStoreConfig()) + chat_router.chat = chat_router.Chat(vector_store=vector_store) + database_router.db_client = database_router.DatabaseClient() + retrieve_router.retriever = retrieve_router.SimpleRetrievalEngine( + vector_store=vector_store + ) + is_initialized = True + @asynccontextmanager async def lifespan(app: FastAPI): @@ -61,12 +77,8 @@ async def lifespan(app: FastAPI): Returns: None """ - vector_store = VectorStore.from_config(VectorStoreConfig()) - chat_router.chat = chat_router.Chat(vector_store=vector_store) - database_router.db_client = database_router.DatabaseClient() - retrieve_router.retriever = retrieve_router.SimpleRetrievalEngine( - vector_store=vector_store - ) + + Base.metadata.create_all(bind=engine) async def backup_db(): """Periodically backs up the database to a table. @@ -107,11 +119,16 @@ async def backup_db(): ) +@app.get("/") +async def root(background_tasks: BackgroundTasks): + background_tasks.add_task(initialize) + return {"message": "Initialization started in the background"} + + app.include_router(chat_router.router) app.include_router(database_router.router) app.include_router(retrieve_router.router) - if __name__ == "__main__": import uvicorn diff --git a/src/wandbot/apps/utils.py b/src/wandbot/apps/utils.py index d75f249..59d99f4 100644 --- a/src/wandbot/apps/utils.py +++ b/src/wandbot/apps/utils.py @@ -62,7 +62,7 @@ def format_response( sources_list = deduplicate( [ item - for item in response.sources.split(",") + for item in response.sources.split("\n") if item.strip().startswith("http") ] ) @@ -72,13 +72,13 @@ def format_response( result = ( f"{result}\n\n*参考文献*\n\n>" + "\n> ".join(sources_list[:items]) - + "\n\n" + + "\n" ) else: result = ( f"{result}\n\n*References*\n\n>" + "\n> ".join(sources_list[:items]) - + "\n\n" + + "\n" ) if outro_message: result = f"{result}\n\n{outro_message}" diff --git a/src/wandbot/database/database.py b/src/wandbot/database/database.py index 2647e68..0ed0a93 100644 --- a/src/wandbot/database/database.py +++ b/src/wandbot/database/database.py @@ -14,7 +14,6 @@ from sqlalchemy.orm import sessionmaker from wandbot.database.config import DataBaseConfig -from wandbot.database.models import Base db_config = DataBaseConfig() @@ -22,4 +21,3 @@ db_config.SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False} ) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) -Base.metadata.create_all(bind=engine)