From f0cf15b4dfa74479e131636a81bac5370786177f Mon Sep 17 00:00:00 2001 From: "J.C. Zhong" Date: Mon, 14 Aug 2023 18:47:23 +0000 Subject: [PATCH 1/5] feat: use websocket for ai assistant --- querybook/server/const/ai_assistant.py | 14 ++ querybook/server/datasources/__init__.py | 2 - querybook/server/datasources/ai_assistant.py | 48 ---- .../server/datasources_socketio/__init__.py | 2 + .../datasources_socketio/ai_assistant.py | 31 +++ .../server/datasources_socketio/connect.py | 6 + querybook/server/lib/ai_assistant/__init__.py | 11 +- .../server/lib/ai_assistant/ai_assistant.py | 52 ----- .../assistants/openai_assistant.py | 176 +------------- .../lib/ai_assistant/base_ai_assistant.py | 220 +++++++++--------- .../ai_assistant/prompts/sql_fix_prompt.py | 50 ++++ .../ai_assistant/prompts/sql_title_prompt.py | 29 +++ .../ai_assistant/prompts/text2sql_prompt.py | 51 ++++ .../redis_chat_history_storage.py | 17 ++ .../web_socket_callback_handler.py | 63 +++++ .../components/AIAssistant/AutoFixButton.tsx | 3 +- .../AIAssistant/QueryGenerationModal.tsx | 10 +- .../QueryCellTitle/QueryCellTitle.tsx | 3 +- querybook/webapp/const/aiAssistant.ts | 10 + querybook/webapp/hooks/useStream.ts | 9 +- .../lib/ai-assistant/ai-assistant-socketio.ts | 62 +++++ querybook/webapp/lib/datasource.ts | 58 ++--- 22 files changed, 497 insertions(+), 430 deletions(-) create mode 100644 querybook/server/const/ai_assistant.py delete mode 100644 querybook/server/datasources/ai_assistant.py create mode 100644 querybook/server/datasources_socketio/ai_assistant.py delete mode 100644 querybook/server/lib/ai_assistant/ai_assistant.py create mode 100644 querybook/server/lib/ai_assistant/prompts/sql_fix_prompt.py create mode 100644 querybook/server/lib/ai_assistant/prompts/sql_title_prompt.py create mode 100644 querybook/server/lib/ai_assistant/prompts/text2sql_prompt.py create mode 100644 querybook/server/lib/ai_assistant/redis_chat_history_storage.py create mode 100644 querybook/server/lib/ai_assistant/web_socket_callback_handler.py create mode 100644 querybook/webapp/const/aiAssistant.ts create mode 100644 querybook/webapp/lib/ai-assistant/ai-assistant-socketio.ts diff --git a/querybook/server/const/ai_assistant.py b/querybook/server/const/ai_assistant.py new file mode 100644 index 000000000..2c1b33070 --- /dev/null +++ b/querybook/server/const/ai_assistant.py @@ -0,0 +1,14 @@ +from enum import Enum + + +# KEEP IT CONSISTENT AS webapp/const/aiAssistant.ts +class AICommandType(Enum): + SQL_FIX = "SQL_FIX" + SQL_TITLE = "SQL_TITLE" + TEXT_TO_SQL = "TEXT_TO_SQL" + RESET_MEMORY = "RESET_MEMORY" + + +AI_ASSISTANT_NAMESPACE = "/ai_assistant" +AI_ASSISTANT_REQUEST_EVENT = "ai_assistant_request" +AI_ASSISTANT_RESPONSE_EVENT = "ai_assistant_response" diff --git a/querybook/server/datasources/__init__.py b/querybook/server/datasources/__init__.py index 65f990498..8701b5054 100644 --- a/querybook/server/datasources/__init__.py +++ b/querybook/server/datasources/__init__.py @@ -16,7 +16,6 @@ from . import event_log from . import data_element from . import comment -from . import ai_assistant # Keep this at the end of imports to make sure the plugin APIs override the default ones try: @@ -43,5 +42,4 @@ event_log data_element comment -ai_assistant api_plugin diff --git a/querybook/server/datasources/ai_assistant.py b/querybook/server/datasources/ai_assistant.py deleted file mode 100644 index d5ede985e..000000000 --- a/querybook/server/datasources/ai_assistant.py +++ /dev/null @@ -1,48 +0,0 @@ -from flask import Response -from flask_login import current_user -from app.datasource import register - -from lib.ai_assistant import ai_assistant -from logic import datadoc as datadoc_logic - - -@register("/ai/query_title/", custom_response=True) -def generate_query_title(data_cell_id: int): - data_cell = datadoc_logic.get_data_cell_by_id(data_cell_id) - query = data_cell.context if data_cell else None - - if not query: - return Response(None) - - title_stream = ai_assistant.generate_title_from_query( - query=query, user_id=current_user.id - ) - - return Response(title_stream, mimetype="text/event-stream") - - -@register("/ai/query_auto_fix/", custom_response=True) -def query_auto_fix(query_execution_id: int): - res_stream = ai_assistant.query_auto_fix( - query_execution_id=query_execution_id, - user_id=current_user.id, - ) - - return Response(res_stream, mimetype="text/event-stream") - - -@register("/ai/generate_query/", custom_response=True) -def generate_sql_query( - query_engine_id: int, tables: list[str], question: str, data_cell_id: int = None -): - data_cell = datadoc_logic.get_data_cell_by_id(data_cell_id) - original_query = data_cell.context if data_cell else None - res_stream = ai_assistant.generate_sql_query( - query_engine_id=query_engine_id, - tables=tables, - question=question, - original_query=original_query, - user_id=current_user.id, - ) - - return Response(res_stream, mimetype="text/event-stream") diff --git a/querybook/server/datasources_socketio/__init__.py b/querybook/server/datasources_socketio/__init__.py index 3d6425a41..6c06442b2 100644 --- a/querybook/server/datasources_socketio/__init__.py +++ b/querybook/server/datasources_socketio/__init__.py @@ -1,7 +1,9 @@ from . import query_execution from . import datadoc from . import connect +from . import ai_assistant connect query_execution datadoc +ai_assistant diff --git a/querybook/server/datasources_socketio/ai_assistant.py b/querybook/server/datasources_socketio/ai_assistant.py new file mode 100644 index 000000000..37d47955c --- /dev/null +++ b/querybook/server/datasources_socketio/ai_assistant.py @@ -0,0 +1,31 @@ +from flask import request +from flask_socketio import join_room, leave_room + +from const.ai_assistant import ( + AI_ASSISTANT_NAMESPACE, + AI_ASSISTANT_REQUEST_EVENT, +) + +from .helper import register_socket + + +@register_socket("subscribe", namespace=AI_ASSISTANT_NAMESPACE) +def on_join_room(): + join_room(request.sid) + + +@register_socket("unsubscribe", namespace=AI_ASSISTANT_NAMESPACE) +def on_leave_room(): + leave_room(request.sid) + + +@register_socket("disconnect", namespace=AI_ASSISTANT_NAMESPACE) +def disconnect(): + leave_room(request.sid) + + +@register_socket(AI_ASSISTANT_REQUEST_EVENT, namespace=AI_ASSISTANT_NAMESPACE) +def ai_assistant_request(command_type: str, payload={}): + from lib.ai_assistant import ai_assistant + + ai_assistant.handle_ai_command(command_type, payload) diff --git a/querybook/server/datasources_socketio/connect.py b/querybook/server/datasources_socketio/connect.py index c1057a390..69f6df2b6 100644 --- a/querybook/server/datasources_socketio/connect.py +++ b/querybook/server/datasources_socketio/connect.py @@ -1,6 +1,7 @@ from flask_login import current_user from flask_socketio import ConnectionRefusedError +from const.ai_assistant import AI_ASSISTANT_NAMESPACE from const.data_doc import DATA_DOC_NAMESPACE from const.query_execution import QUERY_EXECUTION_NAMESPACE @@ -20,3 +21,8 @@ def connect_query_execution(auth): @register_socket("connect", namespace=DATA_DOC_NAMESPACE) def connect_datadoc(auth): on_connect() + + +@register_socket("connect", namespace=AI_ASSISTANT_NAMESPACE) +def connect_ai_assistant(auth): + on_connect() diff --git a/querybook/server/lib/ai_assistant/__init__.py b/querybook/server/lib/ai_assistant/__init__.py index 3ecedab7c..4b0f320f5 100644 --- a/querybook/server/lib/ai_assistant/__init__.py +++ b/querybook/server/lib/ai_assistant/__init__.py @@ -1,11 +1,14 @@ from env import QuerybookSettings +from .all_ai_assistants import get_ai_assistant_class + if QuerybookSettings.AI_ASSISTANT_PROVIDER: - from .ai_assistant import AIAssistant + ai_assistant = get_ai_assistant_class(QuerybookSettings.AI_ASSISTANT_PROVIDER) + ai_assistant.set_config(QuerybookSettings.AI_ASSISTANT_CONFIG) - ai_assistant = AIAssistant( - QuerybookSettings.AI_ASSISTANT_PROVIDER, QuerybookSettings.AI_ASSISTANT_CONFIG - ) else: ai_assistant = None + + +__all__ = ["ai_assistant"] diff --git a/querybook/server/lib/ai_assistant/ai_assistant.py b/querybook/server/lib/ai_assistant/ai_assistant.py deleted file mode 100644 index 6dc633161..000000000 --- a/querybook/server/lib/ai_assistant/ai_assistant.py +++ /dev/null @@ -1,52 +0,0 @@ -import threading - -from .all_ai_assistants import get_ai_assistant_class -from .base_ai_assistant import ChainStreamHandler, EventStream - - -class AIAssistant: - def __init__(self, provider: str, config: dict = {}): - self._assisant = get_ai_assistant_class(provider) - self._assisant.set_config(config) - - def _get_streaming_result(self, fn, kwargs): - event_stream = EventStream() - callback_handler = ChainStreamHandler(event_stream) - kwargs["callback_handler"] = callback_handler - thread = threading.Thread(target=fn, kwargs=kwargs) - thread.start() - return event_stream - - def generate_title_from_query(self, query, user_id=None): - return self._get_streaming_result( - self._assisant.generate_title_from_query, - {"query": query, "user_id": user_id}, - ) - - def query_auto_fix(self, query_execution_id, user_id=None): - return self._get_streaming_result( - self._assisant.query_auto_fix, - { - "query_execution_id": query_execution_id, - "user_id": user_id, - }, - ) - - def generate_sql_query( - self, - query_engine_id: int, - tables: list[str], - question: str, - original_query: str = None, - user_id=None, - ): - return self._get_streaming_result( - self._assisant.generate_sql_query, - { - "query_engine_id": query_engine_id, - "tables": tables, - "question": question, - "original_query": original_query, - "user_id": user_id, - }, - ) diff --git a/querybook/server/lib/ai_assistant/assistants/openai_assistant.py b/querybook/server/lib/ai_assistant/assistants/openai_assistant.py index 94db5f76a..50b09d23c 100644 --- a/querybook/server/lib/ai_assistant/assistants/openai_assistant.py +++ b/querybook/server/lib/ai_assistant/assistants/openai_assistant.py @@ -3,11 +3,6 @@ from langchain.chat_models import ChatOpenAI from langchain.callbacks.manager import CallbackManager -from langchain.prompts.chat import ( - ChatPromptTemplate, - SystemMessage, - HumanMessagePromptTemplate, -) import openai @@ -29,174 +24,9 @@ def _get_error_msg(self, error) -> str: return super()._get_error_msg(error) - @property - def title_generation_prompt_template(self) -> ChatPromptTemplate: - system_message_prompt = SystemMessage( - content="You are a helpful assistant that can summerize SQL queries." - ) - human_template = ( - "Generate a brief 10-word-maximum title for the SQL query below. " - "===Query\n" - "{query}\n\n" - "===Response Guidelines\n" - "1. Only respond with the title without any explanation\n" - "2. Dont use double quotes to enclose the title\n" - "3. Dont add a final period to the title\n\n" - "===Example response\n" - "This is a title\n" - ) - human_message_prompt = HumanMessagePromptTemplate.from_template(human_template) - return ChatPromptTemplate.from_messages( - [system_message_prompt, human_message_prompt] - ) - - @property - def query_auto_fix_prompt_template(self) -> ChatPromptTemplate: - system_message_prompt = SystemMessage( - content=( - "You are a SQL expert that can help fix SQL query errors.\n\n" - "Please follow the format below for your response:\n" - "<@key-1@>\n" - "value-1\n\n" - "<@key-2@>\n" - "value-2\n\n" - ) - ) - human_template = ( - "Please help fix the query below based on the given error message and table schemas. \n\n" - "===SQL dialect\n" - "{dialect}\n\n" - "===Query\n" - "{query}\n\n" - "===Error\n" - "{error}\n\n" - "===Table Schemas\n" - "{table_schemas}\n\n" - "===Response Format\n" - "<@key-1@>\n" - "value-1\n\n" - "<@key-2@>\n" - "value-2\n\n" - "===Example response:\n" - "<@explanation@>\n" - "This is an explanation about the error\n\n" - "<@fix_suggestion@>\n" - "This is a recommended fix for the error\n\n" - "<@fixed_query@>\n" - "The fixed SQL query\n\n" - "===Response Guidelines\n" - "1. For the <@fixed_query@> section, it can only be a valid SQL query without any explanation.\n" - "2. If there is insufficient context to address the query error, you may leave the fixed_query section blank and provide a general suggestion instead.\n" - "3. Maintain the original query format and case in the fixed_query section, including comments, except when correcting the erroneous part.\n" - ) - human_message_prompt = HumanMessagePromptTemplate.from_template(human_template) - return ChatPromptTemplate.from_messages( - [system_message_prompt, human_message_prompt] - ) - - @property - def generate_sql_query_prompt_template(self) -> ChatPromptTemplate: - system_message_prompt = SystemMessage( - content=( - "You are a SQL expert that can help generating SQL query.\n\n" - "Please follow the key/value pair format below for your response:\n" - "<@key-1@>\n" - "value-1\n\n" - "<@key-2@>\n" - "value-2\n\n" - ) - ) - human_template = ( - "Please help to generate a new SQL query or modify the original query to answer the following question. Your response should ONLY be based on the given context.\n\n" - "===SQL Dialect\n" - "{dialect}\n\n" - "===Tables\n" - "{table_schemas}\n\n" - "===Original Query\n" - "{original_query}\n\n" - "===Question\n" - "{question}\n\n" - "===Response Format\n" - "<@key-1@>\n" - "value-1\n\n" - "<@key-2@>\n" - "value-2\n\n" - "===Example Response:\n" - "Example 1: Sufficient Context\n" - "<@query@>\n" - "A generated SQL query based on the provided context with the asked question at the beginning is provided here.\n\n" - "Example 2: Insufficient Context\n" - "<@explanation@>\n" - "An explanation of the missing context is provided here.\n\n" - "===Response Guidelines\n" - "1. If the provided context is sufficient, please respond only with a valid SQL query without any explanations in the <@query@> section. The query should start with a comment containing the question being asked.\n" - "2. If the provided context is insufficient, please explain what information is missing.\n" - "3. If the original query is provided, please modify the original query to answer the question. The original query may start with a comment containing a previously asked question. If you find such a comment, please use both the original question and the new question to generate the new query.\n" - "4. The <@key_name@> in the response can only be <@explanation@> or <@query@>.\n\n" - ) - human_message_prompt = HumanMessagePromptTemplate.from_template(human_template) - return ChatPromptTemplate.from_messages( - [system_message_prompt, human_message_prompt] - ) - - def _generate_title_from_query( - self, query, stream=True, callback_handler=None, user_id=None - ): - """Generate title from SQL query using OpenAI's chat model.""" - messages = self.title_generation_prompt_template.format_prompt( - query=query - ).to_messages() - chat = ChatOpenAI( - **self._config, - streaming=stream, - callback_manager=CallbackManager([callback_handler]), - ) - ai_message = chat(messages) - return ai_message.content - - def _query_auto_fix( - self, - language, - query, - error, - table_schemas, - stream, - callback_handler, - user_id=None, - ): - """Query auto fix using OpenAI's chat model.""" - messages = self.query_auto_fix_prompt_template.format_prompt( - dialect=language, query=query, error=error, table_schemas=table_schemas - ).to_messages() - chat = ChatOpenAI( - **self._config, - streaming=stream, - callback_manager=CallbackManager([callback_handler]), - ) - ai_message = chat(messages) - return ai_message.content - - def _generate_sql_query( - self, - language: str, - table_schemas: str, - question: str, - original_query: str, - stream, - callback_handler, - user_id=None, - ): - """Generate SQL query using OpenAI's chat model.""" - messages = self.generate_sql_query_prompt_template.format_prompt( - dialect=language, - question=question, - table_schemas=table_schemas, - original_query=original_query, - ).to_messages() - chat = ChatOpenAI( + def _get_llm(self, callback_handler): + return ChatOpenAI( **self._config, - streaming=stream, + streaming=True, callback_manager=CallbackManager([callback_handler]), ) - ai_message = chat(messages) - return ai_message.content diff --git a/querybook/server/lib/ai_assistant/base_ai_assistant.py b/querybook/server/lib/ai_assistant/base_ai_assistant.py index 635e19f72..6311ac99e 100644 --- a/querybook/server/lib/ai_assistant/base_ai_assistant.py +++ b/querybook/server/lib/ai_assistant/base_ai_assistant.py @@ -1,66 +1,31 @@ -from abc import ABC, abstractmethod import functools -import json -import queue +from abc import ABC, abstractmethod -from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler +from flask_login import current_user +from langchain.chains import LLMChain +from langchain.memory import ConversationBufferMemory from pydantic.error_wrappers import ValidationError from app.db import with_session +from app.flask_app import socketio +from const.ai_assistant import AICommandType +from clients.redis_client import with_redis from lib.logger import get_logger -from logic import query_execution as qe_logic from lib.query_analysis.lineage import process_query from logic import admin as admin_logic +from logic import datadoc as datadoc_logic from logic import metastore as m_logic -from models.query_execution import QueryExecution +from logic import query_execution as qe_logic from models.metastore import DataTableColumn +from models.query_execution import QueryExecution -LOG = get_logger(__file__) - - -class EventStream: - """Generator to facilitate streaming result from Langchain. - The stream format is based on Server-Sent Events (SSE).""" - - def __init__(self): - self.queue = queue.Queue() - - def __iter__(self): - return self - - def __next__(self): - item = self.queue.get() - if item is StopIteration: - raise item - return item - - def send(self, data: str): - self.queue.put("data: " + json.dumps({"data": data}) + "\n\n") - - def close(self): - # the empty data is to make sure the client receives the close event - self.queue.put("event: close\ndata: \n\n") - self.queue.put(StopIteration) - - def send_error(self, error: str): - self.queue.put("event: error\n") - data = json.dumps({"data": error}) - self.queue.put(f"data: {data}\n\n") - self.close() - - -class ChainStreamHandler(StreamingStdOutCallbackHandler): - """Callback handlder to stream the result to a generator.""" - - def __init__(self, stream: EventStream): - super().__init__() - self.stream = stream - - def on_llm_new_token(self, token: str, **kwargs): - self.stream.send(token) +from .prompts.sql_fix_prompt import PROMPT as SQL_FIX_PROMPT +from .prompts.sql_title_prompt import PROMPT as SQL_TITLE_PROMPT +from .prompts.text2sql_prompt import PROMPT as TEXT2SQL_PROMPT +from .redis_chat_history_storage import RedisChatHistoryStorage +from .web_socket_callback_handler import WebSocketStream, WebSocketCallbackHandler - def on_llm_end(self, response, **kwargs): - self.stream.close() +LOG = get_logger(__file__) class BaseAIAssistant(ABC): @@ -87,6 +52,48 @@ def wrapper(self, *args, **kwargs): return wrapper + @abstractmethod + def _get_llm(self, callback_handler: WebSocketCallbackHandler): + """return the language model to use""" + + @with_redis + def _get_chat_memory( + self, + session_id, + memory_key="chat_history", + input_key="question", + ttl=600, + redis_conn=None, + ): + message_history_storage = RedisChatHistoryStorage( + redis_client=redis_conn, ttl=ttl, session_id=session_id + ) + + return ConversationBufferMemory( + memory_key=memory_key, + chat_memory=message_history_storage, + input_key=input_key, + return_messages=True, + ) + + def _get_sql_title_prompt(self): + """Override this method to return specific prompt for your own assistant.""" + return SQL_TITLE_PROMPT + + def _get_text2sql_prompt(self): + """Override this method to return specific prompt for your own assistant.""" + return TEXT2SQL_PROMPT + + def _get_sql_fix_prompt(self): + """Override this method to return specific prompt for your own assistant.""" + return SQL_FIX_PROMPT + + def _get_llm_chain(self, command_type, prompt, memory=None): + ws_stream = WebSocketStream(socketio, command_type) + callback_handler = WebSocketCallbackHandler(ws_stream) + llm = self._get_llm(callback_handler=callback_handler) + return LLMChain(llm=llm, prompt=prompt, memory=memory) + def _get_error_msg(self, error) -> str: """Override this method to return specific error messages for your own assistant.""" if isinstance(error, ValidationError): @@ -167,6 +174,30 @@ def _get_query_execution_error(self, query_execution: QueryExecution) -> str: return error[:1000] + def handle_ai_command(self, command_type: str, payload: dict = {}): + data_cell_id = payload.get("data_cell_id") + data_cell = datadoc_logic.get_data_cell_by_id(data_cell_id) + query = data_cell.context if data_cell else None + + if command_type == AICommandType.SQL_TITLE.value: + self.generate_title_from_query(query=query) + elif command_type == AICommandType.TEXT_TO_SQL.value: + query_engine_id = payload.get("query_engine_id") + tables = payload.get("tables") + question = payload.get("question") + self.generate_sql_query( + query_engine_id=query_engine_id, + tables=tables, + question=question, + original_query=query, + memory_session_id=f"{current_user.id}_{data_cell_id}", + ) + elif command_type == AICommandType.SQL_FIX.value: + query_execution_id = payload.get("query_execution_id") + self.query_auto_fix( + query_execution_id=query_execution_id, + ) + @catch_error @with_session def generate_sql_query( @@ -175,9 +206,7 @@ def generate_sql_query( tables: list[str], question: str, original_query: str = None, - stream=True, - callback_handler: ChainStreamHandler = None, - user_id=None, + memory_session_id=None, session=None, ): query_engine = admin_logic.get_query_engine_by_id( @@ -186,36 +215,25 @@ def generate_sql_query( table_schemas = self._generate_table_schema_prompt( metastore_id=query_engine.metastore_id, table_names=tables, session=session ) - return self._generate_sql_query( - language=query_engine.language, - table_schemas=table_schemas, + + prompt = self._get_text2sql_prompt() + memory = self._get_chat_memory(session_id=memory_session_id) + chain = self._get_llm_chain( + command_type=AICommandType.TEXT_TO_SQL.value, + prompt=prompt, + memory=memory, + ) + return chain.run( + dialect=query_engine.language, question=question, + table_schemas=table_schemas, original_query=original_query, - stream=stream, - callback_handler=callback_handler, - user_id=user_id, ) - @abstractmethod - def _generate_sql_query( - self, - language: str, - table_schemas: str, - question: str, - original_query: str = None, - stream=True, - callback_handler: ChainStreamHandler = None, - user_id=None, - ): - raise NotImplementedError() - @catch_error def generate_title_from_query( self, query, - stream=True, - callback_handler: ChainStreamHandler = None, - user_id=None, ): """Generate title from SQL query. @@ -224,39 +242,24 @@ def generate_title_from_query( stream (bool, optional): Whether to stream the result. Defaults to True. callback_handler (CallbackHandler, optional): Callback handler to handle the straming result. Required if stream is True. """ - return self._generate_title_from_query( - query=query, - stream=stream, - callback_handler=callback_handler, - user_id=user_id, + prompt = self._get_sql_title_prompt() + chain = self._get_llm_chain( + command_type=AICommandType.SQL_TITLE.value, + prompt=prompt, ) - - @abstractmethod - def _generate_title_from_query( - self, - query, - stream, - callback_handler, - user_id=None, - ): - raise NotImplementedError() + return chain.run(query=query) @catch_error @with_session def query_auto_fix( self, query_execution_id: int, - stream: bool = True, - callback_handler: ChainStreamHandler = None, - user_id: int = None, session=None, ): """Generate title from SQL query. Args: query_execution_id (int): The failed query execution id - stream (bool, optional): Whether to stream the result. Defaults to True. - callback_handler (CallbackHandler, optional): Callback handler to handle the straming result. Required if stream is True. """ query_execution = qe_logic.get_query_execution_by_id( query_execution_id, session=session @@ -276,25 +279,14 @@ def query_auto_fix( session=session, ) - return self._query_auto_fix( - language=language, + prompt = self._get_sql_fix_prompt() + chain = self._get_llm_chain( + command_type=AICommandType.SQL_FIX.value, + prompt=prompt, + ) + return chain.run( + dialect=language, query=query_execution.query, error=self._get_query_execution_error(query_execution), table_schemas=table_schemas, - stream=stream, - callback_handler=callback_handler, - user_id=user_id, ) - - @abstractmethod - def _query_auto_fix( - self, - language: str, - query: str, - error: str, - table_schemas: str, - stream: bool, - callback_handler: ChainStreamHandler, - user_id=None, - ): - raise NotImplementedError() diff --git a/querybook/server/lib/ai_assistant/prompts/sql_fix_prompt.py b/querybook/server/lib/ai_assistant/prompts/sql_fix_prompt.py new file mode 100644 index 000000000..59bd9d021 --- /dev/null +++ b/querybook/server/lib/ai_assistant/prompts/sql_fix_prompt.py @@ -0,0 +1,50 @@ +from langchain.prompts import ( + ChatPromptTemplate, + SystemMessagePromptTemplate, + HumanMessagePromptTemplate, +) + + +system_message_template = ( + "You are a SQL expert that can help fix SQL query errors.\n\n" + "Please follow the format below for your response:\n" + "<@key-1@>\n" + "value-1\n\n" + "<@key-2@>\n" + "value-2\n\n" +) + +human_message_template = ( + "Please help fix the query below based on the given error message and table schemas. \n\n" + "===SQL dialect\n" + "{dialect}\n\n" + "===Query\n" + "{query}\n\n" + "===Error\n" + "{error}\n\n" + "===Table Schemas\n" + "{table_schemas}\n\n" + "===Response Format\n" + "<@key-1@>\n" + "value-1\n\n" + "<@key-2@>\n" + "value-2\n\n" + "===Example response:\n" + "<@explanation@>\n" + "This is an explanation about the error\n\n" + "<@fix_suggestion@>\n" + "This is a recommended fix for the error\n\n" + "<@fixed_query@>\n" + "The fixed SQL query\n\n" + "===Response Guidelines\n" + "1. For the <@fixed_query@> section, it can only be a valid SQL query without any explanation.\n" + "2. If there is insufficient context to address the query error, you may leave the fixed_query section blank and provide a general suggestion instead.\n" + "3. Maintain the original query format and case in the fixed_query section, including comments, except when correcting the erroneous part.\n" +) + +PROMPT = ChatPromptTemplate.from_messages( + [ + SystemMessagePromptTemplate.from_template(system_message_template), + HumanMessagePromptTemplate.from_template(human_message_template), + ] +) diff --git a/querybook/server/lib/ai_assistant/prompts/sql_title_prompt.py b/querybook/server/lib/ai_assistant/prompts/sql_title_prompt.py new file mode 100644 index 000000000..9ea12e32f --- /dev/null +++ b/querybook/server/lib/ai_assistant/prompts/sql_title_prompt.py @@ -0,0 +1,29 @@ +from langchain.prompts import ( + ChatPromptTemplate, + SystemMessagePromptTemplate, + HumanMessagePromptTemplate, +) + + +system_message_template = ( + """You are a helpful assistant that can summerize SQL queries.""" +) + +human_message_template = ( + "Generate a brief 10-word-maximum title for the SQL query below. " + "===Query\n" + "{query}\n\n" + "===Response Guidelines\n" + "1. Only respond with the title without any explanation\n" + "2. Dont use double quotes to enclose the title\n" + "3. Dont add a final period to the title\n\n" + "===Example response\n" + "This is a title\n" +) + +PROMPT = ChatPromptTemplate.from_messages( + [ + SystemMessagePromptTemplate.from_template(system_message_template), + HumanMessagePromptTemplate.from_template(human_message_template), + ] +) diff --git a/querybook/server/lib/ai_assistant/prompts/text2sql_prompt.py b/querybook/server/lib/ai_assistant/prompts/text2sql_prompt.py new file mode 100644 index 000000000..3f0bc233c --- /dev/null +++ b/querybook/server/lib/ai_assistant/prompts/text2sql_prompt.py @@ -0,0 +1,51 @@ +from langchain.prompts import ( + ChatPromptTemplate, + MessagesPlaceholder, + SystemMessagePromptTemplate, + HumanMessagePromptTemplate, +) + + +system_message_template = ( + "You are a SQL expert that can help generating SQL query.\n\n" + "Please help to generate a new SQL query or modify the original query to answer the following question. Your response should ONLY be based on the given context.\n\n" + "Please always follow the key/value pair format below for your response:\n" + "===Response Format\n" + "<@query@>\n" + "query\n\n" + "or\n\n" + "<@explanation@>\n" + "explanation\n\n" + "===Example Response:\n" + "Example 1: Sufficient Context\n" + "<@query@>\n" + "A generated SQL query based on the provided context with the asked question at the beginning is provided here.\n\n" + "Example 2: Insufficient Context\n" + "<@explanation@>\n" + "An explanation of the missing context is provided here.\n\n" + "===Response Guidelines\n" + "1. If the provided context is sufficient, please respond only with a valid SQL query without any explanations in the <@query@> section. The query should start with a comment containing the question being asked.\n" + "2. If the provided context is insufficient, please explain what information is missing.\n" + "3. If the original query is provided, please modify the original query to answer the question. The original query may start with a comment containing a previously asked question. If you find such a comment, please use both the original question and the new question to generate the new query.\n" + "4. Please always honor the table schmeas for the query generation\n\n" +) + +human_message_template = ( + "===SQL Dialect\n" + "{dialect}\n\n" + "===Tables\n" + "{table_schemas}\n\n" + "===Original Query\n" + "{original_query}\n\n" +) + +PROMPT = ChatPromptTemplate.from_messages( + [ + SystemMessagePromptTemplate.from_template(system_message_template), + HumanMessagePromptTemplate.from_template(human_message_template), + MessagesPlaceholder(variable_name="chat_history"), + HumanMessagePromptTemplate.from_template( + "{question}\nPlease remember always start your response with <@query@> or <@explanation@>.\n" + ), + ] +) diff --git a/querybook/server/lib/ai_assistant/redis_chat_history_storage.py b/querybook/server/lib/ai_assistant/redis_chat_history_storage.py new file mode 100644 index 000000000..bccdf0680 --- /dev/null +++ b/querybook/server/lib/ai_assistant/redis_chat_history_storage.py @@ -0,0 +1,17 @@ +from langchain.memory.chat_message_histories import RedisChatMessageHistory + + +class RedisChatHistoryStorage(RedisChatMessageHistory): + """Chat message history stored in a Redis database.""" + + def __init__( + self, + redis_client, + session_id: str, + key_prefix: str = "message_store:", + ttl=600, + ): + self.redis_client = redis_client + self.session_id = session_id + self.key_prefix = key_prefix + self.ttl = ttl diff --git a/querybook/server/lib/ai_assistant/web_socket_callback_handler.py b/querybook/server/lib/ai_assistant/web_socket_callback_handler.py new file mode 100644 index 000000000..5a8919af5 --- /dev/null +++ b/querybook/server/lib/ai_assistant/web_socket_callback_handler.py @@ -0,0 +1,63 @@ +from flask import request +from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler + +from const.ai_assistant import ( + AI_ASSISTANT_NAMESPACE, + AI_ASSISTANT_RESPONSE_EVENT, +) + + +class WebSocketStream: + def __init__(self, socketio, command_type: str): + self.socketio = socketio + self.command_type = command_type + self.room = request.sid + + def _send(self, payload: dict): + self.socketio.emit( + AI_ASSISTANT_RESPONSE_EVENT, + ( + self.command_type, + payload, + ), + namespace=AI_ASSISTANT_NAMESPACE, + room=self.room, + ) + + def send(self, data: str, end=False): + self._send( + { + "event": "data", + "data": data, + } + ) + + def send_error(self, error: str): + self._send( + { + "event": "error", + "data": error, + } + ) + self.close() + + def close(self): + self._send( + { + "event": "close", + } + ) + + +class WebSocketCallbackHandler(StreamingStdOutCallbackHandler): + """Callback handlder to stream the result through web socket.""" + + def __init__(self, stream: WebSocketStream): + super().__init__() + self.stream = stream + + def on_llm_new_token(self, token: str, **kwargs): + self.stream.send(token) + + def on_llm_end(self, response, **kwargs): + self.stream.close() diff --git a/querybook/webapp/components/AIAssistant/AutoFixButton.tsx b/querybook/webapp/components/AIAssistant/AutoFixButton.tsx index 042472ae7..5c993c127 100644 --- a/querybook/webapp/components/AIAssistant/AutoFixButton.tsx +++ b/querybook/webapp/components/AIAssistant/AutoFixButton.tsx @@ -1,6 +1,7 @@ import React, { useState } from 'react'; import { QueryComparison } from 'components/TranspileQueryModal/QueryComparison'; +import { AICommandType } from 'const/aiAssistant'; import { ComponentType, ElementType } from 'const/analytics'; import { StreamStatus, useStream } from 'hooks/useStream'; import { trackClick } from 'lib/analytics'; @@ -26,7 +27,7 @@ export const AutoFixButton = ({ const [show, setShow] = useState(false); const { streamStatus, startStream, streamData, cancelStream } = useStream( - '/ds/ai/query_auto_fix/', + AICommandType.SQL_FIX, { query_execution_id: queryExecutionId, } diff --git a/querybook/webapp/components/AIAssistant/QueryGenerationModal.tsx b/querybook/webapp/components/AIAssistant/QueryGenerationModal.tsx index 4156c6b2d..7b383663c 100644 --- a/querybook/webapp/components/AIAssistant/QueryGenerationModal.tsx +++ b/querybook/webapp/components/AIAssistant/QueryGenerationModal.tsx @@ -3,6 +3,7 @@ import React, { useCallback, useEffect, useState } from 'react'; import { QueryEngineSelector } from 'components/QueryRunButton/QueryRunButton'; import { QueryComparison } from 'components/TranspileQueryModal/QueryComparison'; +import { AICommandType } from 'const/aiAssistant'; import { ComponentType, ElementType } from 'const/analytics'; import { IQueryEngine } from 'const/queryEngine'; import { StreamStatus, useStream } from 'hooks/useStream'; @@ -12,7 +13,6 @@ import { trimSQLQuery } from 'lib/stream'; import { matchKeyPress } from 'lib/utils/keyboard'; import { analyzeCode } from 'lib/web-worker'; import { Button } from 'ui/Button/Button'; -import { DebouncedInput } from 'ui/DebouncedInput/DebouncedInput'; import { Icon } from 'ui/Icon/Icon'; import { Message } from 'ui/Message/Message'; import { Modal } from 'ui/Modal/Modal'; @@ -78,7 +78,7 @@ export const QueryGenerationModal = ({ }, [tablesInQuery]); const { streamStatus, startStream, streamData, cancelStream } = useStream( - '/ds/ai/generate_query/', + AICommandType.TEXT_TO_SQL, { query_engine_id: engineId, tables: tables, @@ -88,7 +88,7 @@ export const QueryGenerationModal = ({ } ); - const { explanation, query: rawNewQuery } = streamData; + const { explanation, query: rawNewQuery, data } = streamData; const newQuery = trimSQLQuery(rawNewQuery); @@ -258,8 +258,8 @@ export const QueryGenerationModal = ({ {tables.length > 0 && ( <> {questionBarDOM} - {explanation && ( -
{explanation}
+ {(explanation || data) && ( +
{explanation || data}
)} {(query || newQuery) && ( diff --git a/querybook/webapp/components/QueryCellTitle/QueryCellTitle.tsx b/querybook/webapp/components/QueryCellTitle/QueryCellTitle.tsx index 531cf5caa..6da319fbc 100644 --- a/querybook/webapp/components/QueryCellTitle/QueryCellTitle.tsx +++ b/querybook/webapp/components/QueryCellTitle/QueryCellTitle.tsx @@ -2,6 +2,7 @@ import React, { useCallback, useEffect } from 'react'; import { useDispatch } from 'react-redux'; import PublicConfig from 'config/querybook_public_config.yaml'; +import { AICommandType } from 'const/aiAssistant'; import { ComponentType, ElementType } from 'const/analytics'; import { StreamStatus, useStream } from 'hooks/useStream'; import { trackClick } from 'lib/analytics'; @@ -38,7 +39,7 @@ export const QueryCellTitle: React.FC = ({ query; const { streamStatus, startStream, streamData } = useStream( - '/ds/ai/query_title/', + AICommandType.SQL_TITLE, { data_cell_id: cellId, } diff --git a/querybook/webapp/const/aiAssistant.ts b/querybook/webapp/const/aiAssistant.ts new file mode 100644 index 000000000..c67b300b5 --- /dev/null +++ b/querybook/webapp/const/aiAssistant.ts @@ -0,0 +1,10 @@ +// Keep it in sync with AICommandType in server/const/ai_assistant.py +export enum AICommandType { + SQL_FIX = 'SQL_FIX', + SQL_TITLE = 'SQL_TITLE', + TEXT_TO_SQL = 'TEXT_TO_SQL', +} + +export const AI_ASSISTANT_NAMESPACE = '/ai_assistant'; +export const AI_ASSISTANT_REQUEST_EVENT = 'ai_assistant_request'; +export const AI_ASSISTANT_RESPONSE_EVENT = 'ai_assistant_response'; diff --git a/querybook/webapp/hooks/useStream.ts b/querybook/webapp/hooks/useStream.ts index 4533320e4..3a9dbce7d 100644 --- a/querybook/webapp/hooks/useStream.ts +++ b/querybook/webapp/hooks/useStream.ts @@ -1,5 +1,6 @@ import { useCallback, useRef, useState } from 'react'; +import { AICommandType } from 'const/aiAssistant'; import ds from 'lib/datasource'; export enum StreamStatus { @@ -10,7 +11,7 @@ export enum StreamStatus { } export function useStream( - url: string, + commandType: AICommandType, params: Record = {} ): { streamStatus: StreamStatus; @@ -21,17 +22,17 @@ export function useStream( } { const [streamStatus, setSteamStatus] = useState(StreamStatus.NOT_STARTED); const [data, setData] = useState<{ [key: string]: string }>({}); - const streamRef = useRef(null); + const streamRef = useRef<{ close: () => void } | null>(null); const startStream = useCallback(() => { setSteamStatus(StreamStatus.STREAMING); setData({}); - streamRef.current = ds.stream(url, params, setData, (data) => { + streamRef.current = ds.stream(commandType, params, setData, (data) => { setData(data); setSteamStatus(StreamStatus.FINISHED); }); - }, [url, params]); + }, [commandType, params]); const resetStream = useCallback(() => { setSteamStatus(StreamStatus.NOT_STARTED); diff --git a/querybook/webapp/lib/ai-assistant/ai-assistant-socketio.ts b/querybook/webapp/lib/ai-assistant/ai-assistant-socketio.ts new file mode 100644 index 000000000..fc5449886 --- /dev/null +++ b/querybook/webapp/lib/ai-assistant/ai-assistant-socketio.ts @@ -0,0 +1,62 @@ +import type { Socket } from 'socket.io-client'; + +import { + AI_ASSISTANT_REQUEST_EVENT, + AI_ASSISTANT_RESPONSE_EVENT, + AICommandType, +} from 'const/aiAssistant'; +import SocketIOManager from 'lib/socketio-manager'; + +class AIAssistantSocket { + private static NAME_SPACE = '/ai_assistant'; + + private socket: Socket = null; + private socketPromise: Promise = null; + + constructor() { + this.setupSocket(); + } + + public onSocketConnect(socket: Socket) { + socket.emit('subscribe'); + } + + private setupSocket = async () => { + if (this.socket) { + return this.socket; + } + if (this.socketPromise) { + this.socket = await this.socketPromise; + } else { + // We need to setup our socket + this.socketPromise = SocketIOManager.getSocket( + AIAssistantSocket.NAME_SPACE, + this.onSocketConnect.bind(this) + ); + + // Setup socket's connection functions + this.socket = await this.socketPromise; + } + + this.socket.on('error', (e) => { + console.error('Socket error', e); + }); + + return this.socket; + }; + + public requestAIAssistant = (command: AICommandType, payload: object) => { + this.socket.emit(AI_ASSISTANT_REQUEST_EVENT, command, payload); + }; + + public addAIListener = (listener) => { + this.socket.on(AI_ASSISTANT_RESPONSE_EVENT, listener); + }; + + public removeAIListener = (listener) => { + this.socket.off(AI_ASSISTANT_RESPONSE_EVENT, listener); + }; +} + +const defaultSocket = new AIAssistantSocket(); +export default defaultSocket; diff --git a/querybook/webapp/lib/datasource.ts b/querybook/webapp/lib/datasource.ts index f39d0696c..f2a379f0a 100644 --- a/querybook/webapp/lib/datasource.ts +++ b/querybook/webapp/lib/datasource.ts @@ -1,6 +1,8 @@ +import aiAssistantSocket from './ai-assistant/ai-assistant-socketio'; import axios, { AxiosRequestConfig, Canceler, Method } from 'axios'; import toast from 'react-hot-toast'; +import { AICommandType } from 'const/aiAssistant'; import { setSessionExpired } from 'lib/querybookUI'; import { formatError } from 'lib/utils/error'; @@ -159,7 +161,7 @@ export function uploadDatasource( } /** - * Stream data from a datasource using EventSource + * Stream data from WebSocket * * The data is streamed in the form of deltas. Each delta is a JSON object * ``` @@ -168,41 +170,45 @@ export function uploadDatasource( * } * ``` * - * @param url The url to stream from - * @param params The data to send to the url + * @param commandType The ai command type + * @param params The data to send * @param onStraming Callback when data is received. The data is the accumulated data. * @param onStramingEnd Callback when the stream ends */ function streamDatasource( - url: string, + commandType: AICommandType, params?: Record, onStreaming?: (data: { [key: string]: string }) => void, onStreamingEnd?: (data: { [key: string]: string }) => void ) { - const eventSource = new EventSource( - `${url}?params=${JSON.stringify(params)}` - ); const parser = new DeltaStreamParser(); - eventSource.addEventListener('message', (e) => { - const newToken = JSON.parse(e.data).data; - parser.parse(newToken); - onStreaming?.(parser.result); - }); - eventSource.addEventListener('error', (e) => { - console.error(e); - eventSource.close(); - onStreamingEnd?.(parser.result); - if (e instanceof MessageEvent) { - toast.error(JSON.parse(e.data).data); + + const onData = (command, payload) => { + if (command !== commandType) { + return; } - }); - eventSource.addEventListener('close', (e) => { - eventSource.close(); - parser.close(); - onStreamingEnd?.(parser.result); - }); - - return eventSource; + + if (payload.event === 'close') { + aiAssistantSocket.removeAIListener(onData); + onStreamingEnd?.(parser.result); + return; + } else if (payload.event === 'error') { + aiAssistantSocket.removeAIListener(onData); + toast.error(payload.data); + } + + parser.parse(payload.data); + onStreaming?.(parser.result); + }; + + aiAssistantSocket.addAIListener(onData); + + aiAssistantSocket.requestAIAssistant(commandType, params); + return { + close: () => { + aiAssistantSocket.removeAIListener(onData); + }, + }; } export default { From 997c8eeb0a1d2231b42383bfe4263fcf69786ace Mon Sep 17 00:00:00 2001 From: "J.C. Zhong" Date: Mon, 14 Aug 2023 19:01:23 +0000 Subject: [PATCH 2/5] fix node test --- .../lib/ai-assistant/ai-assistant-socketio.ts | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/querybook/webapp/lib/ai-assistant/ai-assistant-socketio.ts b/querybook/webapp/lib/ai-assistant/ai-assistant-socketio.ts index fc5449886..0b402e69f 100644 --- a/querybook/webapp/lib/ai-assistant/ai-assistant-socketio.ts +++ b/querybook/webapp/lib/ai-assistant/ai-assistant-socketio.ts @@ -13,7 +13,7 @@ class AIAssistantSocket { private socket: Socket = null; private socketPromise: Promise = null; - constructor() { + public constructor() { this.setupSocket(); } @@ -21,6 +21,18 @@ class AIAssistantSocket { socket.emit('subscribe'); } + public requestAIAssistant = (command: AICommandType, payload: object) => { + this.socket.emit(AI_ASSISTANT_REQUEST_EVENT, command, payload); + }; + + public addAIListener = (listener) => { + this.socket.on(AI_ASSISTANT_RESPONSE_EVENT, listener); + }; + + public removeAIListener = (listener) => { + this.socket.off(AI_ASSISTANT_RESPONSE_EVENT, listener); + }; + private setupSocket = async () => { if (this.socket) { return this.socket; @@ -44,18 +56,6 @@ class AIAssistantSocket { return this.socket; }; - - public requestAIAssistant = (command: AICommandType, payload: object) => { - this.socket.emit(AI_ASSISTANT_REQUEST_EVENT, command, payload); - }; - - public addAIListener = (listener) => { - this.socket.on(AI_ASSISTANT_RESPONSE_EVENT, listener); - }; - - public removeAIListener = (listener) => { - this.socket.off(AI_ASSISTANT_RESPONSE_EVENT, listener); - }; } const defaultSocket = new AIAssistantSocket(); From 7a551ec84ba7b61df2c1274dd5d178fd1ce4fd81 Mon Sep 17 00:00:00 2001 From: "J.C. Zhong" Date: Tue, 15 Aug 2023 23:50:31 +0000 Subject: [PATCH 3/5] comments --- .../datasources_socketio/ai_assistant.py | 18 ----- .../lib/ai_assistant/base_ai_assistant.py | 71 +++++++++++-------- .../ai_assistant/prompts/sql_fix_prompt.py | 2 +- .../ai_assistant/prompts/sql_title_prompt.py | 2 +- .../ai_assistant/prompts/text2sql_prompt.py | 2 +- ... streaming_web_socket_callback_handler.py} | 4 +- .../QueryCellTitle/QueryCellTitle.tsx | 2 +- .../lib/ai-assistant/ai-assistant-socketio.ts | 8 ++- querybook/webapp/lib/datasource.ts | 3 + 9 files changed, 57 insertions(+), 55 deletions(-) rename querybook/server/lib/ai_assistant/{web_socket_callback_handler.py => streaming_web_socket_callback_handler.py} (93%) diff --git a/querybook/server/datasources_socketio/ai_assistant.py b/querybook/server/datasources_socketio/ai_assistant.py index 37d47955c..c0a7d2e1f 100644 --- a/querybook/server/datasources_socketio/ai_assistant.py +++ b/querybook/server/datasources_socketio/ai_assistant.py @@ -1,6 +1,3 @@ -from flask import request -from flask_socketio import join_room, leave_room - from const.ai_assistant import ( AI_ASSISTANT_NAMESPACE, AI_ASSISTANT_REQUEST_EVENT, @@ -9,21 +6,6 @@ from .helper import register_socket -@register_socket("subscribe", namespace=AI_ASSISTANT_NAMESPACE) -def on_join_room(): - join_room(request.sid) - - -@register_socket("unsubscribe", namespace=AI_ASSISTANT_NAMESPACE) -def on_leave_room(): - leave_room(request.sid) - - -@register_socket("disconnect", namespace=AI_ASSISTANT_NAMESPACE) -def disconnect(): - leave_room(request.sid) - - @register_socket(AI_ASSISTANT_REQUEST_EVENT, namespace=AI_ASSISTANT_NAMESPACE) def ai_assistant_request(command_type: str, payload={}): from lib.ai_assistant import ai_assistant diff --git a/querybook/server/lib/ai_assistant/base_ai_assistant.py b/querybook/server/lib/ai_assistant/base_ai_assistant.py index 6311ac99e..e07010c7e 100644 --- a/querybook/server/lib/ai_assistant/base_ai_assistant.py +++ b/querybook/server/lib/ai_assistant/base_ai_assistant.py @@ -19,11 +19,14 @@ from models.metastore import DataTableColumn from models.query_execution import QueryExecution -from .prompts.sql_fix_prompt import PROMPT as SQL_FIX_PROMPT -from .prompts.sql_title_prompt import PROMPT as SQL_TITLE_PROMPT -from .prompts.text2sql_prompt import PROMPT as TEXT2SQL_PROMPT +from .prompts.sql_fix_prompt import SQL_FIX_PROMPT +from .prompts.sql_title_prompt import SQL_TITLE_PROMPT +from .prompts.text2sql_prompt import TEXT2SQL_PROMPT from .redis_chat_history_storage import RedisChatHistoryStorage -from .web_socket_callback_handler import WebSocketStream, WebSocketCallbackHandler +from .streaming_web_socket_callback_handler import ( + WebSocketStream, + StreamingWebsocketCallbackHandler, +) LOG = get_logger(__file__) @@ -53,7 +56,7 @@ def wrapper(self, *args, **kwargs): return wrapper @abstractmethod - def _get_llm(self, callback_handler: WebSocketCallbackHandler): + def _get_llm(self, callback_handler: StreamingWebsocketCallbackHandler): """return the language model to use""" @with_redis @@ -88,9 +91,12 @@ def _get_sql_fix_prompt(self): """Override this method to return specific prompt for your own assistant.""" return SQL_FIX_PROMPT + def _get_ws_stream(self, command_type: str): + return WebSocketStream(socketio, command_type) + def _get_llm_chain(self, command_type, prompt, memory=None): - ws_stream = WebSocketStream(socketio, command_type) - callback_handler = WebSocketCallbackHandler(ws_stream) + ws_stream = self._get_ws_stream(command_type=command_type) + callback_handler = StreamingWebsocketCallbackHandler(ws_stream) llm = self._get_llm(callback_handler=callback_handler) return LLMChain(llm=llm, prompt=prompt, memory=memory) @@ -175,28 +181,35 @@ def _get_query_execution_error(self, query_execution: QueryExecution) -> str: return error[:1000] def handle_ai_command(self, command_type: str, payload: dict = {}): - data_cell_id = payload.get("data_cell_id") - data_cell = datadoc_logic.get_data_cell_by_id(data_cell_id) - query = data_cell.context if data_cell else None - - if command_type == AICommandType.SQL_TITLE.value: - self.generate_title_from_query(query=query) - elif command_type == AICommandType.TEXT_TO_SQL.value: - query_engine_id = payload.get("query_engine_id") - tables = payload.get("tables") - question = payload.get("question") - self.generate_sql_query( - query_engine_id=query_engine_id, - tables=tables, - question=question, - original_query=query, - memory_session_id=f"{current_user.id}_{data_cell_id}", - ) - elif command_type == AICommandType.SQL_FIX.value: - query_execution_id = payload.get("query_execution_id") - self.query_auto_fix( - query_execution_id=query_execution_id, - ) + try: + if command_type == AICommandType.SQL_TITLE.value: + query = payload["query"] + self.generate_title_from_query(query=query) + elif command_type == AICommandType.TEXT_TO_SQL.value: + data_cell_id = payload["data_cell_id"] + data_cell = datadoc_logic.get_data_cell_by_id(data_cell_id) + query = data_cell.context if data_cell else None + query_engine_id = payload["query_engine_id"] + tables = payload.get("tables") + question = payload["question"] + self.generate_sql_query( + query_engine_id=query_engine_id, + tables=tables, + question=question, + original_query=query, + memory_session_id=f"{current_user.id}_{data_cell_id}", + ) + elif command_type == AICommandType.SQL_FIX.value: + query_execution_id = payload["query_execution_id"] + self.query_auto_fix( + query_execution_id=query_execution_id, + ) + else: + self._get_ws_stream(command_type=command_type).send_error( + "Unsupported command" + ) + except Exception as e: + self._get_ws_stream(command_type=command_type).send_error(str(e)) @catch_error @with_session diff --git a/querybook/server/lib/ai_assistant/prompts/sql_fix_prompt.py b/querybook/server/lib/ai_assistant/prompts/sql_fix_prompt.py index 59bd9d021..82294c528 100644 --- a/querybook/server/lib/ai_assistant/prompts/sql_fix_prompt.py +++ b/querybook/server/lib/ai_assistant/prompts/sql_fix_prompt.py @@ -42,7 +42,7 @@ "3. Maintain the original query format and case in the fixed_query section, including comments, except when correcting the erroneous part.\n" ) -PROMPT = ChatPromptTemplate.from_messages( +SQL_FIX_PROMPT = ChatPromptTemplate.from_messages( [ SystemMessagePromptTemplate.from_template(system_message_template), HumanMessagePromptTemplate.from_template(human_message_template), diff --git a/querybook/server/lib/ai_assistant/prompts/sql_title_prompt.py b/querybook/server/lib/ai_assistant/prompts/sql_title_prompt.py index 9ea12e32f..1c8d7bbb0 100644 --- a/querybook/server/lib/ai_assistant/prompts/sql_title_prompt.py +++ b/querybook/server/lib/ai_assistant/prompts/sql_title_prompt.py @@ -21,7 +21,7 @@ "This is a title\n" ) -PROMPT = ChatPromptTemplate.from_messages( +SQL_TITLE_PROMPT = ChatPromptTemplate.from_messages( [ SystemMessagePromptTemplate.from_template(system_message_template), HumanMessagePromptTemplate.from_template(human_message_template), diff --git a/querybook/server/lib/ai_assistant/prompts/text2sql_prompt.py b/querybook/server/lib/ai_assistant/prompts/text2sql_prompt.py index 3f0bc233c..0807d33e6 100644 --- a/querybook/server/lib/ai_assistant/prompts/text2sql_prompt.py +++ b/querybook/server/lib/ai_assistant/prompts/text2sql_prompt.py @@ -39,7 +39,7 @@ "{original_query}\n\n" ) -PROMPT = ChatPromptTemplate.from_messages( +TEXT2SQL_PROMPT = ChatPromptTemplate.from_messages( [ SystemMessagePromptTemplate.from_template(system_message_template), HumanMessagePromptTemplate.from_template(human_message_template), diff --git a/querybook/server/lib/ai_assistant/web_socket_callback_handler.py b/querybook/server/lib/ai_assistant/streaming_web_socket_callback_handler.py similarity index 93% rename from querybook/server/lib/ai_assistant/web_socket_callback_handler.py rename to querybook/server/lib/ai_assistant/streaming_web_socket_callback_handler.py index 5a8919af5..e9610cb5a 100644 --- a/querybook/server/lib/ai_assistant/web_socket_callback_handler.py +++ b/querybook/server/lib/ai_assistant/streaming_web_socket_callback_handler.py @@ -24,7 +24,7 @@ def _send(self, payload: dict): room=self.room, ) - def send(self, data: str, end=False): + def send(self, data: str): self._send( { "event": "data", @@ -49,7 +49,7 @@ def close(self): ) -class WebSocketCallbackHandler(StreamingStdOutCallbackHandler): +class StreamingWebsocketCallbackHandler(StreamingStdOutCallbackHandler): """Callback handlder to stream the result through web socket.""" def __init__(self, stream: WebSocketStream): diff --git a/querybook/webapp/components/QueryCellTitle/QueryCellTitle.tsx b/querybook/webapp/components/QueryCellTitle/QueryCellTitle.tsx index 6da319fbc..f8ed81bf4 100644 --- a/querybook/webapp/components/QueryCellTitle/QueryCellTitle.tsx +++ b/querybook/webapp/components/QueryCellTitle/QueryCellTitle.tsx @@ -41,7 +41,7 @@ export const QueryCellTitle: React.FC = ({ const { streamStatus, startStream, streamData } = useStream( AICommandType.SQL_TITLE, { - data_cell_id: cellId, + query, } ); const { data: title } = streamData; diff --git a/querybook/webapp/lib/ai-assistant/ai-assistant-socketio.ts b/querybook/webapp/lib/ai-assistant/ai-assistant-socketio.ts index 0b402e69f..0ce6281a6 100644 --- a/querybook/webapp/lib/ai-assistant/ai-assistant-socketio.ts +++ b/querybook/webapp/lib/ai-assistant/ai-assistant-socketio.ts @@ -25,11 +25,15 @@ class AIAssistantSocket { this.socket.emit(AI_ASSISTANT_REQUEST_EVENT, command, payload); }; - public addAIListener = (listener) => { + public addAIListener = ( + listener: (command: string, payload: object) => void + ) => { this.socket.on(AI_ASSISTANT_RESPONSE_EVENT, listener); }; - public removeAIListener = (listener) => { + public removeAIListener = ( + listener: (command: string, payload: object) => void + ) => { this.socket.off(AI_ASSISTANT_RESPONSE_EVENT, listener); }; diff --git a/querybook/webapp/lib/datasource.ts b/querybook/webapp/lib/datasource.ts index f2a379f0a..baca1f972 100644 --- a/querybook/webapp/lib/datasource.ts +++ b/querybook/webapp/lib/datasource.ts @@ -190,11 +190,14 @@ function streamDatasource( if (payload.event === 'close') { aiAssistantSocket.removeAIListener(onData); + parser.close(); onStreamingEnd?.(parser.result); return; } else if (payload.event === 'error') { aiAssistantSocket.removeAIListener(onData); toast.error(payload.data); + onStreamingEnd?.(parser.result); + return; } parser.parse(payload.data); From 771fd35f86831ea3458fbef9f8e42bf13d2fdde5 Mon Sep 17 00:00:00 2001 From: "J.C. Zhong" Date: Wed, 16 Aug 2023 02:03:28 +0000 Subject: [PATCH 4/5] remove memory and add keep button --- .../lib/ai_assistant/base_ai_assistant.py | 34 +-------------- .../ai_assistant/prompts/text2sql_prompt.py | 14 +++--- .../redis_chat_history_storage.py | 17 -------- .../AIAssistant/QueryGenerationModal.tsx | 43 +++++++++++++++++-- .../TranspileQueryModal/QueryComparison.scss | 3 -- .../TranspileQueryModal/QueryComparison.tsx | 24 +++++++++-- querybook/webapp/const/analytics.ts | 1 + 7 files changed, 67 insertions(+), 69 deletions(-) delete mode 100644 querybook/server/lib/ai_assistant/redis_chat_history_storage.py diff --git a/querybook/server/lib/ai_assistant/base_ai_assistant.py b/querybook/server/lib/ai_assistant/base_ai_assistant.py index e07010c7e..aed9ecc4d 100644 --- a/querybook/server/lib/ai_assistant/base_ai_assistant.py +++ b/querybook/server/lib/ai_assistant/base_ai_assistant.py @@ -3,17 +3,14 @@ from flask_login import current_user from langchain.chains import LLMChain -from langchain.memory import ConversationBufferMemory from pydantic.error_wrappers import ValidationError from app.db import with_session from app.flask_app import socketio from const.ai_assistant import AICommandType -from clients.redis_client import with_redis from lib.logger import get_logger from lib.query_analysis.lineage import process_query from logic import admin as admin_logic -from logic import datadoc as datadoc_logic from logic import metastore as m_logic from logic import query_execution as qe_logic from models.metastore import DataTableColumn @@ -22,7 +19,6 @@ from .prompts.sql_fix_prompt import SQL_FIX_PROMPT from .prompts.sql_title_prompt import SQL_TITLE_PROMPT from .prompts.text2sql_prompt import TEXT2SQL_PROMPT -from .redis_chat_history_storage import RedisChatHistoryStorage from .streaming_web_socket_callback_handler import ( WebSocketStream, StreamingWebsocketCallbackHandler, @@ -59,26 +55,6 @@ def wrapper(self, *args, **kwargs): def _get_llm(self, callback_handler: StreamingWebsocketCallbackHandler): """return the language model to use""" - @with_redis - def _get_chat_memory( - self, - session_id, - memory_key="chat_history", - input_key="question", - ttl=600, - redis_conn=None, - ): - message_history_storage = RedisChatHistoryStorage( - redis_client=redis_conn, ttl=ttl, session_id=session_id - ) - - return ConversationBufferMemory( - memory_key=memory_key, - chat_memory=message_history_storage, - input_key=input_key, - return_messages=True, - ) - def _get_sql_title_prompt(self): """Override this method to return specific prompt for your own assistant.""" return SQL_TITLE_PROMPT @@ -186,9 +162,7 @@ def handle_ai_command(self, command_type: str, payload: dict = {}): query = payload["query"] self.generate_title_from_query(query=query) elif command_type == AICommandType.TEXT_TO_SQL.value: - data_cell_id = payload["data_cell_id"] - data_cell = datadoc_logic.get_data_cell_by_id(data_cell_id) - query = data_cell.context if data_cell else None + original_query = payload["original_query"] query_engine_id = payload["query_engine_id"] tables = payload.get("tables") question = payload["question"] @@ -196,8 +170,7 @@ def handle_ai_command(self, command_type: str, payload: dict = {}): query_engine_id=query_engine_id, tables=tables, question=question, - original_query=query, - memory_session_id=f"{current_user.id}_{data_cell_id}", + original_query=original_query, ) elif command_type == AICommandType.SQL_FIX.value: query_execution_id = payload["query_execution_id"] @@ -219,7 +192,6 @@ def generate_sql_query( tables: list[str], question: str, original_query: str = None, - memory_session_id=None, session=None, ): query_engine = admin_logic.get_query_engine_by_id( @@ -230,11 +202,9 @@ def generate_sql_query( ) prompt = self._get_text2sql_prompt() - memory = self._get_chat_memory(session_id=memory_session_id) chain = self._get_llm_chain( command_type=AICommandType.TEXT_TO_SQL.value, prompt=prompt, - memory=memory, ) return chain.run( dialect=query_engine.language, diff --git a/querybook/server/lib/ai_assistant/prompts/text2sql_prompt.py b/querybook/server/lib/ai_assistant/prompts/text2sql_prompt.py index 0807d33e6..9f41ab6f5 100644 --- a/querybook/server/lib/ai_assistant/prompts/text2sql_prompt.py +++ b/querybook/server/lib/ai_assistant/prompts/text2sql_prompt.py @@ -6,8 +6,9 @@ ) -system_message_template = ( - "You are a SQL expert that can help generating SQL query.\n\n" +system_message_template = "You are a SQL expert that can help generating SQL query." + +human_message_template = ( "Please help to generate a new SQL query or modify the original query to answer the following question. Your response should ONLY be based on the given context.\n\n" "Please always follow the key/value pair format below for your response:\n" "===Response Format\n" @@ -28,24 +29,19 @@ "2. If the provided context is insufficient, please explain what information is missing.\n" "3. If the original query is provided, please modify the original query to answer the question. The original query may start with a comment containing a previously asked question. If you find such a comment, please use both the original question and the new question to generate the new query.\n" "4. Please always honor the table schmeas for the query generation\n\n" -) - -human_message_template = ( "===SQL Dialect\n" "{dialect}\n\n" "===Tables\n" "{table_schemas}\n\n" "===Original Query\n" "{original_query}\n\n" + "===Question\n" + "{question}\n\n" ) TEXT2SQL_PROMPT = ChatPromptTemplate.from_messages( [ SystemMessagePromptTemplate.from_template(system_message_template), HumanMessagePromptTemplate.from_template(human_message_template), - MessagesPlaceholder(variable_name="chat_history"), - HumanMessagePromptTemplate.from_template( - "{question}\nPlease remember always start your response with <@query@> or <@explanation@>.\n" - ), ] ) diff --git a/querybook/server/lib/ai_assistant/redis_chat_history_storage.py b/querybook/server/lib/ai_assistant/redis_chat_history_storage.py deleted file mode 100644 index bccdf0680..000000000 --- a/querybook/server/lib/ai_assistant/redis_chat_history_storage.py +++ /dev/null @@ -1,17 +0,0 @@ -from langchain.memory.chat_message_histories import RedisChatMessageHistory - - -class RedisChatHistoryStorage(RedisChatMessageHistory): - """Chat message history stored in a Redis database.""" - - def __init__( - self, - redis_client, - session_id: str, - key_prefix: str = "message_store:", - ttl=600, - ): - self.redis_client = redis_client - self.session_id = session_id - self.key_prefix = key_prefix - self.ttl = ttl diff --git a/querybook/webapp/components/AIAssistant/QueryGenerationModal.tsx b/querybook/webapp/components/AIAssistant/QueryGenerationModal.tsx index 7b383663c..2ed6107c1 100644 --- a/querybook/webapp/components/AIAssistant/QueryGenerationModal.tsx +++ b/querybook/webapp/components/AIAssistant/QueryGenerationModal.tsx @@ -18,6 +18,7 @@ import { Message } from 'ui/Message/Message'; import { Modal } from 'ui/Modal/Modal'; import { ResizableTextArea } from 'ui/ResizableTextArea/ResizableTextArea'; import { StyledText } from 'ui/StyledText/StyledText'; +import { Tag } from 'ui/Tag/Tag'; import { TableSelector } from './TableSelector'; import { TextToSQLMode, TextToSQLModeSelector } from './TextToSQLModeSelector'; @@ -72,6 +73,7 @@ export const QueryGenerationModal = ({ const [textToSQLMode, setTextToSQLMode] = useState( !!query ? TextToSQLMode.EDIT : TextToSQLMode.GENERATE ); + const [newQuery, setNewQuery] = useState(''); useEffect(() => { setTables(uniq([...tablesInQuery, ...tables])); @@ -83,14 +85,17 @@ export const QueryGenerationModal = ({ query_engine_id: engineId, tables: tables, question: question, - data_cell_id: - textToSQLMode === TextToSQLMode.EDIT ? dataCellId : undefined, + original_query: query, } ); const { explanation, query: rawNewQuery, data } = streamData; - const newQuery = trimSQLQuery(rawNewQuery); + // const newQuery = trimSQLQuery(rawNewQuery); + + useEffect(() => { + setNewQuery(trimSQLQuery(rawNewQuery)); + }, [rawNewQuery]); const onKeyDown = useCallback( (event: React.KeyboardEvent) => { @@ -272,7 +277,37 @@ export const QueryGenerationModal = ({ } toQuery={newQuery} fromQueryTitle="Original Query" - toQueryTitle="New Query" + toQueryTitle={ +
+ {New Query} +
+ } disableHighlight={ streamStatus === StreamStatus.STREAMING } diff --git a/querybook/webapp/components/TranspileQueryModal/QueryComparison.scss b/querybook/webapp/components/TranspileQueryModal/QueryComparison.scss index a6be10ec4..6ee8654e2 100644 --- a/querybook/webapp/components/TranspileQueryModal/QueryComparison.scss +++ b/querybook/webapp/components/TranspileQueryModal/QueryComparison.scss @@ -1,9 +1,6 @@ .QueryComparison { display: flex; gap: 8px; - .Tag { - margin-bottom: 12px; - } .diff-side-view { flex: 1; diff --git a/querybook/webapp/components/TranspileQueryModal/QueryComparison.tsx b/querybook/webapp/components/TranspileQueryModal/QueryComparison.tsx index 759c28da1..ba9e2e0ba 100644 --- a/querybook/webapp/components/TranspileQueryModal/QueryComparison.tsx +++ b/querybook/webapp/components/TranspileQueryModal/QueryComparison.tsx @@ -10,8 +10,8 @@ import './QueryComparison.scss'; export const QueryComparison: React.FC<{ fromQuery: string; toQuery: string; - fromQueryTitle?: string; - toQueryTitle?: string; + fromQueryTitle?: string | React.ReactNode; + toQueryTitle?: string | React.ReactNode; disableHighlight?: boolean; hideEmptyQuery?: boolean; }> = ({ @@ -63,7 +63,15 @@ export const QueryComparison: React.FC<{
{!(hideEmptyQuery && !fromQuery) && (
- {fromQueryTitle && {fromQueryTitle}} + {fromQueryTitle && ( +
+ {typeof fromQueryTitle === 'string' ? ( + {fromQueryTitle} + ) : ( + fromQueryTitle + )} +
+ )} - {toQueryTitle && {toQueryTitle}} + {toQueryTitle && ( +
+ {typeof toQueryTitle === 'string' ? ( + {toQueryTitle} + ) : ( + toQueryTitle + )} +
+ )} Date: Wed, 16 Aug 2023 02:06:20 +0000 Subject: [PATCH 5/5] fix linter --- querybook/server/lib/ai_assistant/base_ai_assistant.py | 1 - querybook/server/lib/ai_assistant/prompts/text2sql_prompt.py | 1 - 2 files changed, 2 deletions(-) diff --git a/querybook/server/lib/ai_assistant/base_ai_assistant.py b/querybook/server/lib/ai_assistant/base_ai_assistant.py index aed9ecc4d..00b1b2c74 100644 --- a/querybook/server/lib/ai_assistant/base_ai_assistant.py +++ b/querybook/server/lib/ai_assistant/base_ai_assistant.py @@ -1,7 +1,6 @@ import functools from abc import ABC, abstractmethod -from flask_login import current_user from langchain.chains import LLMChain from pydantic.error_wrappers import ValidationError diff --git a/querybook/server/lib/ai_assistant/prompts/text2sql_prompt.py b/querybook/server/lib/ai_assistant/prompts/text2sql_prompt.py index 9f41ab6f5..2b7a0f293 100644 --- a/querybook/server/lib/ai_assistant/prompts/text2sql_prompt.py +++ b/querybook/server/lib/ai_assistant/prompts/text2sql_prompt.py @@ -1,6 +1,5 @@ from langchain.prompts import ( ChatPromptTemplate, - MessagesPlaceholder, SystemMessagePromptTemplate, HumanMessagePromptTemplate, )