Skip to content

Commit

Permalink
Merge branch 'main' into weaveeval
Browse files Browse the repository at this point in the history
  • Loading branch information
ayulockin authored Sep 6, 2024
2 parents 9950b9f + 36f315e commit f2bba5a
Show file tree
Hide file tree
Showing 21 changed files with 2,330 additions and 1,874 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,11 @@ Once these environment variables are set, you can start the Q&A bot application
(poetry run python -m wandbot.apps.discord > discord_app.log 2>&1)
```

You might need to then call the endpoint to trigger the final wandbot app initialisation:
```bash
curl http://localhost:8000/
```

For more detailed instructions on installing and running the bot, please refer to the [run.sh](./run.sh) file located in the root of the repository.

Executing these commands will launch the API, Slackbot, and Discord bot applications, enabling you to interact with the bot and ask questions related to the Weights & Biases documentation.
Expand Down
1 change: 1 addition & 0 deletions build-dev.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
pip install fasttext && \
poetry install --all-extras && \
pip install protobuf==3.19.6 && \
poetry build && \
mkdir -p ./data/cache
3,323 changes: 1,749 additions & 1,574 deletions poetry.lock

Large diffs are not rendered by default.

13 changes: 8 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,9 @@ repository = "https://github.com/wandb/wandbot"
include = ["src/**/*", "LICENSE", "README.md"]

[tool.poetry.dependencies]
python = ">=3.10.0,<3.12"
python = ">=3.10.0,<=3.12.4"
numpy = "^1.26.1"
pandas = "^2.1.2"
unstructured = "^0.12.3"
pydantic-settings = "^2.0.3"
gitpython = "^3.1.40"
giturlparse = "^0.12.0"
Expand All @@ -28,7 +27,7 @@ tree-sitter-languages = "^1.7.1"
markdownify = "^0.11.6"
uvicorn = "^0.24.0"
openai = "^1.3.2"
weave = "^0.50.3"
weave = "^0.50.12"
colorlog = "^6.8.0"
litellm = "^1.15.1"
google-cloud-bigquery = "^3.14.1"
Expand All @@ -37,11 +36,15 @@ python-frontmatter = "^1.1.0"
pymdown-extensions = "^10.5"
langchain = "^0.2.2"
langchain-openai = "^0.1.8"
chromadb = "^0.4.22"
langchain-experimental = "^0.0.60"
simsimd = "3.7.7"
langchain-core = "^0.2.2"
langchain-cohere = "^0.1.3"
langchain-chroma = "^0.1.2"
simsimd = "3.7.7"
nbformat = "^5.10.4"
nbconvert = "^7.16.4"
wandb = {extras = ["workspaces"], version = "^0.17.5"}
tree-sitter = "0.21.3"

[tool.poetry.dev-dependencies]

Expand Down
28 changes: 24 additions & 4 deletions src/wandbot/api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from contextlib import asynccontextmanager
from datetime import datetime, timezone

import dotenv
import pandas as pd
import weave
from fastapi import BackgroundTasks, FastAPI
Expand All @@ -41,6 +42,7 @@
from wandbot.api.routers import chat as chat_router
from wandbot.api.routers import database as database_router
from wandbot.api.routers import retrieve as retrieve_router
from wandbot.chat.chat import ChatConfig
from wandbot.database.database import engine
from wandbot.database.models import Base
from wandbot.ingestion.config import VectorStoreConfig
Expand All @@ -50,20 +52,39 @@
logger = get_logger(__name__)
last_backup = datetime.now().astimezone(timezone.utc)

dotenv_path = os.path.join(os.path.dirname(__file__), "../../../.env")
dotenv.load_dotenv(dotenv_path)

# turn off chromadb telemetry
os.environ["ANONYMIZED_TELEMETRY"] = "false"

weave.init(f"{os.environ['WANDB_ENTITY']}/{os.environ['WANDB_PROJECT']}")

is_initialized = False


async def initialize():
logger.info(f"Initializing wandbot")
global is_initialized
if not is_initialized:
vector_store = VectorStore.from_config(VectorStoreConfig())
chat_router.chat = chat_router.Chat(vector_store=vector_store)
chat_config = ChatConfig()
chat_router.chat = chat_router.Chat(
vector_store=vector_store, config=chat_config
)
logger.info(f"Initialized chat router")
database_router.db_client = database_router.DatabaseClient()
logger.info(f"Initialized database client")

retrieve_router.retriever = retrieve_router.SimpleRetrievalEngine(
vector_store=vector_store
vector_store=vector_store,
rerank_models={
"english_reranker_model": chat_config.english_reranker_model,
"multilingual_reranker_model": chat_config.multilingual_reranker_model,
},
)
logger.info(f"Initialized retrieve router")
logger.info(f"wandbot initialization complete")
is_initialized = True


Expand Down Expand Up @@ -123,8 +144,7 @@ async def backup_db():

@app.get("/")
async def root(background_tasks: BackgroundTasks):
return {"message": "Initialization happened background"}

return {"message": "Initialization happened in the background"}

app.include_router(chat_router.router)
app.include_router(database_router.router)
Expand Down
33 changes: 17 additions & 16 deletions src/wandbot/chat/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from typing import List

import weave
# from weave.monitoring import StreamTable

import wandb
from wandbot.chat.config import ChatConfig
Expand All @@ -48,32 +47,31 @@ class Chat:
run: An instance of wandb.Run for logging experiment information.
"""

config: ChatConfig = ChatConfig()

def __init__(
self, vector_store: VectorStore, config: ChatConfig | None = None
):
def __init__(self, vector_store: VectorStore, config: ChatConfig):
"""Initializes the Chat instance.
Args:
config: An instance of ChatConfig containing configuration settings.
"""
self.vector_store = vector_store
if config is not None:
self.config = config
self.config = config
self.run = wandb.init(
project=self.config.wandb_project,
entity=self.config.wandb_entity,
job_type="chat",
)
self.run._label(repo="wandbot")
# self.stream_table = StreamTable(
# table_name="chat_logs",
# project_name=self.config.wandb_project,
# entity_name=self.config.wandb_entity,
# )

self.rag_pipeline = RAGPipeline(vector_store=vector_store)
self.rag_pipeline = RAGPipeline(
vector_store=vector_store,
top_k=self.config.top_k,
english_reranker_model=self.config.english_reranker_model,
multilingual_reranker_model=self.config.multilingual_reranker_model,
response_synthesizer_model=self.config.response_synthesizer_model,
response_synthesizer_temperature=self.config.response_synthesizer_temperature,
response_synthesizer_fallback_model=self.config.response_synthesizer_fallback_model,
response_synthesizer_fallback_temperature=self.config.response_synthesizer_fallback_temperature,
)

def _get_answer(
self, question: str, chat_history: List[QuestionAnswer]
Expand Down Expand Up @@ -108,10 +106,13 @@ def __call__(self, chat_request: ChatRequest) -> ChatResponse:
"total_tokens": result.total_tokens,
"prompt_tokens": result.prompt_tokens,
"completion_tokens": result.completion_tokens,
"web_search_success": result.api_call_statuses[
"web_search_success"
],
}
result_dict.update({"application": chat_request.application})
self.run.log(usage_stats)
# self.stream_table.log(result_dict)

return ChatResponse(**result_dict)
except Exception as e:
with Timer() as timer:
Expand All @@ -133,5 +134,5 @@ def __call__(self, chat_request: ChatRequest) -> ChatResponse:
"end_time": timer.stop,
}
)
# self.stream_table.log(result)

return ChatResponse(**result)
11 changes: 11 additions & 0 deletions src/wandbot/chat/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,14 @@ class ChatConfig(BaseSettings):
)
wandb_project: str | None = Field("wandbot_public", env="WANDB_PROJECT")
wandb_entity: str | None = Field("wandbot", env="WANDB_ENTITY")
# Retrieval settings
top_k: int = 15
search_type: str = "mmr"
# Cohere reranker models
english_reranker_model: str = "rerank-english-v2.0"
multilingual_reranker_model: str = "rerank-multilingual-v2.0"
# Response synthesis settings
response_synthesizer_model: str = "gpt-4-0125-preview"
response_synthesizer_temperature: float = 0.1
response_synthesizer_fallback_model: str = "gpt-4-0125-preview"
response_synthesizer_fallback_temperature: float = 0.1
24 changes: 22 additions & 2 deletions src/wandbot/chat/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class RAGPipelineOutput(BaseModel):
time_taken: float
start_time: datetime.datetime
end_time: datetime.datetime
api_call_statuses: dict = {}


class RAGPipeline:
Expand All @@ -52,13 +53,28 @@ def __init__(
vector_store: VectorStore,
top_k: int = 15,
search_type: str = "mmr",
english_reranker_model: str = "rerank-english-v2.0",
multilingual_reranker_model: str = "rerank-multilingual-v2.0",
response_synthesizer_model: str = "gpt-4-0125-preview",
response_synthesizer_temperature: float = 0.1,
response_synthesizer_fallback_model: str = "gpt-4-0125-preview",
response_synthesizer_fallback_temperature: float = 0.1,
):
self.vector_store = vector_store
self.query_enhancer = QueryEnhancer()
self.retrieval = FusionRetrieval(
vector_store=vector_store, top_k=top_k, search_type=search_type
vector_store=vector_store,
top_k=top_k,
search_type=search_type,
english_reranker_model=english_reranker_model,
multilingual_reranker_model=multilingual_reranker_model,
)
self.response_synthesizer = ResponseSynthesizer(
model=response_synthesizer_model,
temperature=response_synthesizer_temperature,
fallback_model=response_synthesizer_fallback_model,
fallback_temperature=response_synthesizer_fallback_temperature,
)
self.response_synthesizer = ResponseSynthesizer()

@weave.op()
def __call__(
Expand All @@ -74,6 +90,7 @@ def __call__(

with Timer() as retrieval_tb:
retrieval_results = self.retrieval(enhanced_query)
logger.debug(f"Retrieval results: {retrieval_results}")

with get_openai_callback() as response_cb, Timer() as response_tb:
response = self.response_synthesizer(retrieval_results)
Expand Down Expand Up @@ -101,6 +118,9 @@ def __call__(
+ response_tb.elapsed,
start_time=query_enhancer_tb.start,
end_time=response_tb.stop,
api_call_statuses={
"web_search_success": retrieval_results["web_search_success"],
},
)

return output
1 change: 1 addition & 0 deletions src/wandbot/chat/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,4 @@ class ChatResponse(BaseModel):
time_taken: float
start_time: datetime
end_time: datetime
api_call_statuses: dict = {}
Loading

0 comments on commit f2bba5a

Please sign in to comment.