Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: integrate weave with evaluation pipeline and experiments with finetuning #75

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3,111 changes: 660 additions & 2,451 deletions poetry.lock

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ include = ["src/**/*", "LICENSE", "README.md"]
[tool.poetry.dependencies]
python = ">=3.10.0,<3.12"
numpy = "^1.26.1"
wandb = "<=0.16.1"
wandb = "<=0.17.0"
tiktoken = "^0.5.1"
pandas = "^2.1.2"
unstructured = "^0.12.3"
Expand All @@ -31,9 +31,9 @@ markdownify = "^0.11.6"
uvicorn = "^0.24.0"
zenpy = "^2.0.46"
openai = "^1.3.2"
weave = "^0.31.0"
weave = "^0.50.0"
colorlog = "^6.8.0"
litellm = "^1.15.1"
litellm = "^1.31.6"
google-cloud-bigquery = "^3.14.1"
db-dtypes = "^1.2.0"
python-frontmatter = "^1.1.0"
Expand Down
1 change: 0 additions & 1 deletion src/wandbot/api/routers/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
logger = get_logger(__name__)

chat_config = ChatConfig()
logger.info(f"Chat config: {chat_config}")
chat: Chat | None = None

router = APIRouter(
Expand Down
16 changes: 8 additions & 8 deletions src/wandbot/chat/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
"""
from typing import List

from weave.monitoring import StreamTable
# from weave.monitoring import StreamTable

import wandb
from wandbot.chat.config import ChatConfig
Expand Down Expand Up @@ -66,11 +66,11 @@ def __init__(
job_type="chat",
)
self.run._label(repo="wandbot")
self.stream_table = StreamTable(
table_name="chat_logs",
project_name=self.config.wandb_project,
entity_name=self.config.wandb_entity,
)
# self.stream_table = StreamTable(
# table_name="chat_logs",
# project_name=self.config.wandb_project,
# entity_name=self.config.wandb_entity,
# )

self.rag_pipeline = RAGPipeline(vector_store=vector_store)

Expand Down Expand Up @@ -109,7 +109,7 @@ def __call__(self, chat_request: ChatRequest) -> ChatResponse:
}
result_dict.update({"application": chat_request.application})
self.run.log(usage_stats)
self.stream_table.log(result_dict)
# self.stream_table.log(result_dict)
return ChatResponse(**result_dict)
except Exception as e:
with Timer() as timer:
Expand All @@ -131,5 +131,5 @@ def __call__(self, chat_request: ChatRequest) -> ChatResponse:
"end_time": timer.stop,
}
)
self.stream_table.log(result)
# self.stream_table.log(result)
return ChatResponse(**result)
2 changes: 1 addition & 1 deletion src/wandbot/evaluation/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,4 @@ class EvalConfig(BaseSettings):
validation_alias="eval_judge_model",
)
wandb_entity: str = Field("wandbot", env="WANDB_ENTITY")
wandb_project: str = Field("wandbot-eval", env="WANDB_PROJECT")
wandb_project: str = Field("wandbot-eval")
41 changes: 41 additions & 0 deletions src/wandbot/evaluation/weave_eval/log_eval_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import os
os.environ["WANDB_ENTITY"] = "wandbot"

import wandb
import weave
import pandas as pd
from weave import Dataset

from wandbot.evaluation.config import EvalConfig

config = EvalConfig()

wandb_project = config.wandb_project
wandb_entity = config.wandb_entity

eval_artifact = wandb.Api().artifact(config.eval_artifact)
eval_artifact_dir = eval_artifact.download(root=config.eval_artifact_root)

df = pd.read_json(
f"{eval_artifact_dir}/{config.eval_annotations_file}",
lines=True,
orient="records",
)
df.insert(0, "id", df.index)

correct_df = df[
(df["is_wandb_query"] == "YES") & (df["correctness"] == "correct")
]

data_rows = correct_df.to_dict('records')

weave.init(wandb_project)

# Create a dataset
dataset = Dataset(
name='wandbot_eval_data',
rows=data_rows,
)

# Publish the dataset
weave.publish(dataset)
113 changes: 113 additions & 0 deletions src/wandbot/evaluation/weave_eval/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import os
os.environ["WANDB_ENTITY"] = "wandbot"

import json
import httpx
import weave
import asyncio
from weave import Evaluation
from weave import Model
from llama_index.llms.openai import OpenAI

from wandbot.evaluation.config import EvalConfig
from wandbot.utils import get_logger

from wandbot.evaluation.eval.correctness import (
CORRECTNESS_EVAL_TEMPLATE,
WandbCorrectnessEvaluator,
)

logger = get_logger(__name__)
config = EvalConfig()

correctness_evaluator = WandbCorrectnessEvaluator(
llm=OpenAI(config.eval_judge_model),
eval_template=CORRECTNESS_EVAL_TEMPLATE,
)

wandb_project = config.wandb_project
wandb_entity = config.wandb_entity

weave.init(f"{wandb_entity}/{wandb_project}")


@weave.op()
async def get_answer(question: str, application: str = "api-eval") -> str:
url = "http://0.0.0.0:8000/chat/query"
payload = {
"question": question,
"application": application,
"language": "en",
}
async with httpx.AsyncClient(timeout=200.0) as client:
response = await client.post(url, json=payload)
response_json = response.json()
return json.dumps(response_json)


@weave.op()
async def get_eval_record(
question: str,
) -> dict:
response = await get_answer(question)
response = json.loads(response)
return {
"system_prompt": response["system_prompt"],
"generated_answer": response["answer"],
"retrieved_contexts": response["source_documents"],
"model": response["model"],
"total_tokens": response["total_tokens"],
"prompt_tokens": response["prompt_tokens"],
"completion_tokens": response["completion_tokens"],
"time_taken": response["time_taken"],
}


class EvaluatorModel(Model):
eval_judge_model: str = config.eval_judge_model

@weave.op()
async def predict(self, question: str) -> dict:
# Model logic goes here
prediction = await get_eval_record(question)
return prediction


@weave.op()
async def get_answer_correctness(
question: str,
ground_truth: str,
notes: str,
model_output: dict
) -> dict:
result = await correctness_evaluator.aevaluate(
query=question,
response=model_output["generated_answer"],
reference=ground_truth,
contexts=model_output["retrieved_contexts"],
reference_notes=notes,
)
return {
"answer_correctness": result.dict()["passing"]
}


dataset_ref = weave.ref(
"weave:///wandbot/wandbot-eval/object/wandbot_eval_data:eCQQ0GjM077wi4ykTWYhLPRpuGIaXbMwUGEB7IyHlFU"
).get()
question_rows = dataset_ref.rows
question_rows = [
{
"question": row["question"],
"ground_truth": row["answer"],
"notes": row["notes"],
} for row in question_rows
]
logger.info("Number of evaluation samples: %s", len(question_rows))

evaluation = Evaluation(
dataset=question_rows, scorers=[get_answer_correctness]
)

if __name__ == "__main__":
asyncio.run(evaluation.evaluate(EvaluatorModel()))
127 changes: 127 additions & 0 deletions src/wandbot/evaluation/weave_eval/weave_correctness.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import asyncio
from typing import Any, Optional, Sequence

import regex as re
from llama_index.core.evaluation import CorrectnessEvaluator, EvaluationResult

from wandbot.evaluation.eval.utils import (
make_eval_template,
safe_parse_eval_response,
)

import wandb
import weave

SYSTEM_TEMPLATE = """You are a Weight & Biases support expert tasked with evaluating the correctness of answers to questions asked by users to a a technical support chatbot.

You are given the following information:
- a user query,
- the documentation used to generate the answer
- a reference answer
- the reason why the reference answer is correct, and
- a generated answer.


Your job is to judge the relevance and correctness of the generated answer.
- Consider whether the answer addresses all aspects of the question.
- The generated answer must provide only correct information according to the documentation.
- Compare the generated answer to the reference answer for completeness and correctness.
- Output a score and a decision that represents a holistic evaluation of the generated answer.
- You must return your response only in the below mentioned format. Do not return answers in any other format.

Follow these guidelines for scoring:
- Your score has to be between 1 and 3, where 1 is the worst and 3 is the best.
- If the generated answer is not correct in comparison to the reference, you should give a score of 1.
- If the generated answer is correct in comparison to the reference but contains mistakes, you should give a score of 2.
- If the generated answer is correct in comparision to the reference and completely answer's the user's query, you should give a score of 3.

Output your final verdict by strictly following JSON format:
{{
"reason": <<Provide a brief explanation for your decision here>>,
"score": <<Provide a score as per the above guidelines>>,
"decision": <<Provide your final decision here, either 'correct', or 'incorrect'>>

}}

Example Response 1:
{{
"reason": "The generated answer has the exact details as the reference answer and completely answer's the user's query.",
"score": 3,
"decision": "correct"
}}

Example Response 2:
{{
"reason": "The generated answer doesn't match the reference answer, and deviates from the documentation provided",
"score": 1,
"decision": "incorrect"
}}

Example Response 3:
{{
"reason": "The generated answer follows the same steps as the reference answer. However, it includes assumptions about methods that are not mentioned in the documentation.",
"score": 2,
"decision": "incorrect"
}}
"""


USER_TEMPLATE = """
## User Query
{query}

## Documentation
{context_str}

## Reference Answer
{reference_answer}

## Reference Correctness Reason
{reference_notes}

## Generated Answer
{generated_answer}
"""

CORRECTNESS_EVAL_TEMPLATE = make_eval_template(SYSTEM_TEMPLATE, USER_TEMPLATE)


class WandbCorrectnessEvaluator(CorrectnessEvaluator):
@weave.op()
async def aevaluate(
self,
query: Optional[str] = None,
response: Optional[str] = None,
contexts: Optional[Sequence[str]] = None,
reference: Optional[str] = None,
sleep_time_in_seconds: int = 0,
**kwargs: Any,
) -> EvaluationResult:
await asyncio.sleep(sleep_time_in_seconds)

if query is None or response is None or reference is None:
print(query, response, reference, flush=True)
raise ValueError("query, response, and reference must be provided")

eval_response = await self._llm.apredict(
prompt=self._eval_template,
query=query,
generated_answer=response,
reference_answer=reference,
context_str=re.sub(
"\n+", "\n", "\n---\n".join(contexts) if contexts else ""
),
reference_notes=kwargs.get("reference_notes", ""),
)

passing, reasoning, score = await safe_parse_eval_response(
eval_response, "correct"
)

return EvaluationResult(
query=query,
response=response,
passing=passing,
score=score,
feedback=reasoning,
)
10 changes: 6 additions & 4 deletions src/wandbot/ingestion/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@ def main():
project = os.environ.get("WANDB_PROJECT", "wandbot-dev")
entity = os.environ.get("WANDB_ENTITY", "wandbot")

raw_artifact = prepare_data.load(project, entity)
preprocessed_artifact = preprocess_data.load(project, entity, raw_artifact)
# raw_artifact = prepare_data.load(project, entity)
raw_artifact = "wandbot/wandbot-dev/raw_dataset:v56"
# preprocessed_artifact = preprocess_data.load(project, entity, raw_artifact)
preprocessed_artifact = "wandbot/wandbot-dev/transformed_data:v23"
vectorstore_artifact = vectorstores.load(
project, entity, preprocessed_artifact
project, entity, preprocessed_artifact, "chroma_index"
)

create_ingestion_report(project, entity, raw_artifact, vectorstore_artifact)
# create_ingestion_report(project, entity, raw_artifact, vectorstore_artifact)
print(vectorstore_artifact)


Expand Down
Loading