Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable semantic searching during chatting #312

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions backend/app/api/admin_routes/semantic_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ async def add_semantic_cache(
user: CurrentSuperuserDep,
question: str,
answer: str,
chat_id: str,
user_message_id: int,
assistant_message_id: int,
namespace: str = "default",
chat_engine: str = "default",
metadata: Optional[dict] = Body(None),
Expand All @@ -29,6 +32,16 @@ async def add_semantic_cache(
dspy_llm=_dspy_lm,
)

# fill the chat related information into metadata
if metadata is None:
metadata = {}

metadata["chat_detail"] = {
"chat_id": chat_id,
"user_message_id": user_message_id,
"assistant_message_id": assistant_message_id,
}

try:
scm.add_cache(
session,
Expand Down
2 changes: 2 additions & 0 deletions backend/app/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ def server_host(self) -> str:
TIDB_DATABASE: str
TIDB_SSL: bool = True

ENABLE_SEMANTIC_CACHE: bool = False

CELERY_BROKER_URL: str = "redis://redis:6379/0"
CELERY_RESULT_BACKEND: str = "redis://redis:6379/0"

Expand Down
78 changes: 78 additions & 0 deletions backend/app/rag/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
tidb_graph_editor as editor,
)
from app.rag.knowledge_graph import KnowledgeGraphIndex
from app.rag.semantic_cache import SemanticCacheManager
from app.rag.chat_config import ChatEngineConfig, get_default_embedding_model
from app.rag.types import (
MyCBEventType,
Expand Down Expand Up @@ -201,6 +202,83 @@ def _chat(self) -> Generator[ChatEvent, None, None]:
_fast_llm = self.chat_engine_config.get_fast_llama_llm(self.db_session)
_fast_dspy_lm = self.chat_engine_config.get_fast_dspy_lm(self.db_session)

if settings.ENABLE_SEMANTIC_CACHE:
try:
_semantic_cache_manager = SemanticCacheManager(
_fast_dspy_lm,
_embed_model,
)
cached_response = _semantic_cache_manager.search(
self.db_session,
self.user_question,
)
if cached_response['match_type'] == 'exact_match' and len(cached_response['items']) == 1 and 'meta' in cached_response['items'][0]:
# simple cache hit, return the cached response
cached_chat = cached_response['items'][0]['meta'].get('chat_detail', None)
if cached_chat and cached_chat.get('user_message_id', None) and cached_chat.get('assistant_message_id', None):
# get the identical user message from the db
cached_db_user_message = chat_repo.get_message(self.db_session, cached_chat.get('user_message_id', None))
# get the identical assistant message from the db
cached_db_assistant_message = chat_repo.get_message(self.db_session, cached_chat.get('assistant_message_id', None))

yield ChatEvent(
event_type=ChatEventType.MESSAGE_ANNOTATIONS_PART,
payload=ChatStreamMessagePayload(
state=ChatMessageSate.TRACE,
display="Searching from semantic sahce",
context={"langfuse_url": cached_db_assistant_message.trace_url},
),
)

yield ChatEvent(
event_type=ChatEventType.MESSAGE_ANNOTATIONS_PART,
payload=ChatStreamMessagePayload(
state=ChatMessageSate.SOURCE_NODES,
context=cached_db_assistant_message.sources,
),
)

response_text = ""
for word in cached_db_assistant_message.content:
response_text += word
yield ChatEvent(
event_type=ChatEventType.TEXT_PART,
payload=word,
)

yield ChatEvent(
event_type=ChatEventType.MESSAGE_ANNOTATIONS_PART,
payload=ChatStreamMessagePayload(
state=ChatMessageSate.FINISHED,
),
)

db_assistant_message.sources = cached_db_assistant_message.sources
db_assistant_message.graph_data = cached_db_assistant_message.graph_data
db_assistant_message.content = cached_db_assistant_message.content
db_assistant_message.post_verification_result_url = cached_db_assistant_message.post_verification_result_url
db_assistant_message.updated_at = datetime.now(UTC)
db_assistant_message.finished_at = datetime.now(UTC)
self.db_session.add(db_assistant_message)
db_user_message.graph_data = cached_db_user_message.graph_data
db_user_message.updated_at = datetime.now(UTC)
db_user_message.finished_at = datetime.now(UTC)
self.db_session.add(db_user_message)
self.db_session.commit()

yield ChatEvent(
event_type=ChatEventType.DATA_PART,
payload=ChatStreamDataPayload(
chat=self.db_chat_obj,
user_message=db_user_message,
assistant_message=db_assistant_message,
),
)

return
except Exception as e:
logger.error(f"Failed to search from semantic cache: {e}")

def _get_llamaindex_callback_manager():
# Why we don't use high-level decorator `observe()` as \
# `https://langfuse.com/docs/integrations/llama-index/get-started` suggested?
Expand Down
2 changes: 1 addition & 1 deletion backend/app/rag/semantic_cache/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def search(
SemanticCache,
SemanticCache.query_vec.cosine_distance(embedding).label("distance"),
)
.having(SemanticCache.query_vec.cosine_distance(embedding) < 0.5)
.having(SemanticCache.query_vec.cosine_distance(embedding) < 0.8)
.order_by("distance")
.limit(20)
)
Expand Down
Loading