diff --git a/querybook/server/const/ai_assistant.py b/querybook/server/const/ai_assistant.py new file mode 100644 index 000000000..2c1b33070 --- /dev/null +++ b/querybook/server/const/ai_assistant.py @@ -0,0 +1,14 @@ +from enum import Enum + + +# KEEP IT CONSISTENT AS webapp/const/aiAssistant.ts +class AICommandType(Enum): + SQL_FIX = "SQL_FIX" + SQL_TITLE = "SQL_TITLE" + TEXT_TO_SQL = "TEXT_TO_SQL" + RESET_MEMORY = "RESET_MEMORY" + + +AI_ASSISTANT_NAMESPACE = "/ai_assistant" +AI_ASSISTANT_REQUEST_EVENT = "ai_assistant_request" +AI_ASSISTANT_RESPONSE_EVENT = "ai_assistant_response" diff --git a/querybook/server/datasources/__init__.py b/querybook/server/datasources/__init__.py index 65f990498..8701b5054 100644 --- a/querybook/server/datasources/__init__.py +++ b/querybook/server/datasources/__init__.py @@ -16,7 +16,6 @@ from . import event_log from . import data_element from . import comment -from . import ai_assistant # Keep this at the end of imports to make sure the plugin APIs override the default ones try: @@ -43,5 +42,4 @@ event_log data_element comment -ai_assistant api_plugin diff --git a/querybook/server/datasources/ai_assistant.py b/querybook/server/datasources/ai_assistant.py deleted file mode 100644 index d5ede985e..000000000 --- a/querybook/server/datasources/ai_assistant.py +++ /dev/null @@ -1,48 +0,0 @@ -from flask import Response -from flask_login import current_user -from app.datasource import register - -from lib.ai_assistant import ai_assistant -from logic import datadoc as datadoc_logic - - -@register("/ai/query_title/", custom_response=True) -def generate_query_title(data_cell_id: int): - data_cell = datadoc_logic.get_data_cell_by_id(data_cell_id) - query = data_cell.context if data_cell else None - - if not query: - return Response(None) - - title_stream = ai_assistant.generate_title_from_query( - query=query, user_id=current_user.id - ) - - return Response(title_stream, mimetype="text/event-stream") - - -@register("/ai/query_auto_fix/", custom_response=True) -def query_auto_fix(query_execution_id: int): - res_stream = ai_assistant.query_auto_fix( - query_execution_id=query_execution_id, - user_id=current_user.id, - ) - - return Response(res_stream, mimetype="text/event-stream") - - -@register("/ai/generate_query/", custom_response=True) -def generate_sql_query( - query_engine_id: int, tables: list[str], question: str, data_cell_id: int = None -): - data_cell = datadoc_logic.get_data_cell_by_id(data_cell_id) - original_query = data_cell.context if data_cell else None - res_stream = ai_assistant.generate_sql_query( - query_engine_id=query_engine_id, - tables=tables, - question=question, - original_query=original_query, - user_id=current_user.id, - ) - - return Response(res_stream, mimetype="text/event-stream") diff --git a/querybook/server/datasources_socketio/__init__.py b/querybook/server/datasources_socketio/__init__.py index 3d6425a41..6c06442b2 100644 --- a/querybook/server/datasources_socketio/__init__.py +++ b/querybook/server/datasources_socketio/__init__.py @@ -1,7 +1,9 @@ from . import query_execution from . import datadoc from . import connect +from . import ai_assistant connect query_execution datadoc +ai_assistant diff --git a/querybook/server/datasources_socketio/ai_assistant.py b/querybook/server/datasources_socketio/ai_assistant.py new file mode 100644 index 000000000..c0a7d2e1f --- /dev/null +++ b/querybook/server/datasources_socketio/ai_assistant.py @@ -0,0 +1,13 @@ +from const.ai_assistant import ( + AI_ASSISTANT_NAMESPACE, + AI_ASSISTANT_REQUEST_EVENT, +) + +from .helper import register_socket + + +@register_socket(AI_ASSISTANT_REQUEST_EVENT, namespace=AI_ASSISTANT_NAMESPACE) +def ai_assistant_request(command_type: str, payload={}): + from lib.ai_assistant import ai_assistant + + ai_assistant.handle_ai_command(command_type, payload) diff --git a/querybook/server/datasources_socketio/connect.py b/querybook/server/datasources_socketio/connect.py index c1057a390..69f6df2b6 100644 --- a/querybook/server/datasources_socketio/connect.py +++ b/querybook/server/datasources_socketio/connect.py @@ -1,6 +1,7 @@ from flask_login import current_user from flask_socketio import ConnectionRefusedError +from const.ai_assistant import AI_ASSISTANT_NAMESPACE from const.data_doc import DATA_DOC_NAMESPACE from const.query_execution import QUERY_EXECUTION_NAMESPACE @@ -20,3 +21,8 @@ def connect_query_execution(auth): @register_socket("connect", namespace=DATA_DOC_NAMESPACE) def connect_datadoc(auth): on_connect() + + +@register_socket("connect", namespace=AI_ASSISTANT_NAMESPACE) +def connect_ai_assistant(auth): + on_connect() diff --git a/querybook/server/lib/ai_assistant/__init__.py b/querybook/server/lib/ai_assistant/__init__.py index 3ecedab7c..4b0f320f5 100644 --- a/querybook/server/lib/ai_assistant/__init__.py +++ b/querybook/server/lib/ai_assistant/__init__.py @@ -1,11 +1,14 @@ from env import QuerybookSettings +from .all_ai_assistants import get_ai_assistant_class + if QuerybookSettings.AI_ASSISTANT_PROVIDER: - from .ai_assistant import AIAssistant + ai_assistant = get_ai_assistant_class(QuerybookSettings.AI_ASSISTANT_PROVIDER) + ai_assistant.set_config(QuerybookSettings.AI_ASSISTANT_CONFIG) - ai_assistant = AIAssistant( - QuerybookSettings.AI_ASSISTANT_PROVIDER, QuerybookSettings.AI_ASSISTANT_CONFIG - ) else: ai_assistant = None + + +__all__ = ["ai_assistant"] diff --git a/querybook/server/lib/ai_assistant/ai_assistant.py b/querybook/server/lib/ai_assistant/ai_assistant.py deleted file mode 100644 index 6dc633161..000000000 --- a/querybook/server/lib/ai_assistant/ai_assistant.py +++ /dev/null @@ -1,52 +0,0 @@ -import threading - -from .all_ai_assistants import get_ai_assistant_class -from .base_ai_assistant import ChainStreamHandler, EventStream - - -class AIAssistant: - def __init__(self, provider: str, config: dict = {}): - self._assisant = get_ai_assistant_class(provider) - self._assisant.set_config(config) - - def _get_streaming_result(self, fn, kwargs): - event_stream = EventStream() - callback_handler = ChainStreamHandler(event_stream) - kwargs["callback_handler"] = callback_handler - thread = threading.Thread(target=fn, kwargs=kwargs) - thread.start() - return event_stream - - def generate_title_from_query(self, query, user_id=None): - return self._get_streaming_result( - self._assisant.generate_title_from_query, - {"query": query, "user_id": user_id}, - ) - - def query_auto_fix(self, query_execution_id, user_id=None): - return self._get_streaming_result( - self._assisant.query_auto_fix, - { - "query_execution_id": query_execution_id, - "user_id": user_id, - }, - ) - - def generate_sql_query( - self, - query_engine_id: int, - tables: list[str], - question: str, - original_query: str = None, - user_id=None, - ): - return self._get_streaming_result( - self._assisant.generate_sql_query, - { - "query_engine_id": query_engine_id, - "tables": tables, - "question": question, - "original_query": original_query, - "user_id": user_id, - }, - ) diff --git a/querybook/server/lib/ai_assistant/assistants/openai_assistant.py b/querybook/server/lib/ai_assistant/assistants/openai_assistant.py index 94db5f76a..50b09d23c 100644 --- a/querybook/server/lib/ai_assistant/assistants/openai_assistant.py +++ b/querybook/server/lib/ai_assistant/assistants/openai_assistant.py @@ -3,11 +3,6 @@ from langchain.chat_models import ChatOpenAI from langchain.callbacks.manager import CallbackManager -from langchain.prompts.chat import ( - ChatPromptTemplate, - SystemMessage, - HumanMessagePromptTemplate, -) import openai @@ -29,174 +24,9 @@ def _get_error_msg(self, error) -> str: return super()._get_error_msg(error) - @property - def title_generation_prompt_template(self) -> ChatPromptTemplate: - system_message_prompt = SystemMessage( - content="You are a helpful assistant that can summerize SQL queries." - ) - human_template = ( - "Generate a brief 10-word-maximum title for the SQL query below. " - "===Query\n" - "{query}\n\n" - "===Response Guidelines\n" - "1. Only respond with the title without any explanation\n" - "2. Dont use double quotes to enclose the title\n" - "3. Dont add a final period to the title\n\n" - "===Example response\n" - "This is a title\n" - ) - human_message_prompt = HumanMessagePromptTemplate.from_template(human_template) - return ChatPromptTemplate.from_messages( - [system_message_prompt, human_message_prompt] - ) - - @property - def query_auto_fix_prompt_template(self) -> ChatPromptTemplate: - system_message_prompt = SystemMessage( - content=( - "You are a SQL expert that can help fix SQL query errors.\n\n" - "Please follow the format below for your response:\n" - "<@key-1@>\n" - "value-1\n\n" - "<@key-2@>\n" - "value-2\n\n" - ) - ) - human_template = ( - "Please help fix the query below based on the given error message and table schemas. \n\n" - "===SQL dialect\n" - "{dialect}\n\n" - "===Query\n" - "{query}\n\n" - "===Error\n" - "{error}\n\n" - "===Table Schemas\n" - "{table_schemas}\n\n" - "===Response Format\n" - "<@key-1@>\n" - "value-1\n\n" - "<@key-2@>\n" - "value-2\n\n" - "===Example response:\n" - "<@explanation@>\n" - "This is an explanation about the error\n\n" - "<@fix_suggestion@>\n" - "This is a recommended fix for the error\n\n" - "<@fixed_query@>\n" - "The fixed SQL query\n\n" - "===Response Guidelines\n" - "1. For the <@fixed_query@> section, it can only be a valid SQL query without any explanation.\n" - "2. If there is insufficient context to address the query error, you may leave the fixed_query section blank and provide a general suggestion instead.\n" - "3. Maintain the original query format and case in the fixed_query section, including comments, except when correcting the erroneous part.\n" - ) - human_message_prompt = HumanMessagePromptTemplate.from_template(human_template) - return ChatPromptTemplate.from_messages( - [system_message_prompt, human_message_prompt] - ) - - @property - def generate_sql_query_prompt_template(self) -> ChatPromptTemplate: - system_message_prompt = SystemMessage( - content=( - "You are a SQL expert that can help generating SQL query.\n\n" - "Please follow the key/value pair format below for your response:\n" - "<@key-1@>\n" - "value-1\n\n" - "<@key-2@>\n" - "value-2\n\n" - ) - ) - human_template = ( - "Please help to generate a new SQL query or modify the original query to answer the following question. Your response should ONLY be based on the given context.\n\n" - "===SQL Dialect\n" - "{dialect}\n\n" - "===Tables\n" - "{table_schemas}\n\n" - "===Original Query\n" - "{original_query}\n\n" - "===Question\n" - "{question}\n\n" - "===Response Format\n" - "<@key-1@>\n" - "value-1\n\n" - "<@key-2@>\n" - "value-2\n\n" - "===Example Response:\n" - "Example 1: Sufficient Context\n" - "<@query@>\n" - "A generated SQL query based on the provided context with the asked question at the beginning is provided here.\n\n" - "Example 2: Insufficient Context\n" - "<@explanation@>\n" - "An explanation of the missing context is provided here.\n\n" - "===Response Guidelines\n" - "1. If the provided context is sufficient, please respond only with a valid SQL query without any explanations in the <@query@> section. The query should start with a comment containing the question being asked.\n" - "2. If the provided context is insufficient, please explain what information is missing.\n" - "3. If the original query is provided, please modify the original query to answer the question. The original query may start with a comment containing a previously asked question. If you find such a comment, please use both the original question and the new question to generate the new query.\n" - "4. The <@key_name@> in the response can only be <@explanation@> or <@query@>.\n\n" - ) - human_message_prompt = HumanMessagePromptTemplate.from_template(human_template) - return ChatPromptTemplate.from_messages( - [system_message_prompt, human_message_prompt] - ) - - def _generate_title_from_query( - self, query, stream=True, callback_handler=None, user_id=None - ): - """Generate title from SQL query using OpenAI's chat model.""" - messages = self.title_generation_prompt_template.format_prompt( - query=query - ).to_messages() - chat = ChatOpenAI( - **self._config, - streaming=stream, - callback_manager=CallbackManager([callback_handler]), - ) - ai_message = chat(messages) - return ai_message.content - - def _query_auto_fix( - self, - language, - query, - error, - table_schemas, - stream, - callback_handler, - user_id=None, - ): - """Query auto fix using OpenAI's chat model.""" - messages = self.query_auto_fix_prompt_template.format_prompt( - dialect=language, query=query, error=error, table_schemas=table_schemas - ).to_messages() - chat = ChatOpenAI( - **self._config, - streaming=stream, - callback_manager=CallbackManager([callback_handler]), - ) - ai_message = chat(messages) - return ai_message.content - - def _generate_sql_query( - self, - language: str, - table_schemas: str, - question: str, - original_query: str, - stream, - callback_handler, - user_id=None, - ): - """Generate SQL query using OpenAI's chat model.""" - messages = self.generate_sql_query_prompt_template.format_prompt( - dialect=language, - question=question, - table_schemas=table_schemas, - original_query=original_query, - ).to_messages() - chat = ChatOpenAI( + def _get_llm(self, callback_handler): + return ChatOpenAI( **self._config, - streaming=stream, + streaming=True, callback_manager=CallbackManager([callback_handler]), ) - ai_message = chat(messages) - return ai_message.content diff --git a/querybook/server/lib/ai_assistant/base_ai_assistant.py b/querybook/server/lib/ai_assistant/base_ai_assistant.py index 635e19f72..00b1b2c74 100644 --- a/querybook/server/lib/ai_assistant/base_ai_assistant.py +++ b/querybook/server/lib/ai_assistant/base_ai_assistant.py @@ -1,66 +1,29 @@ -from abc import ABC, abstractmethod import functools -import json -import queue +from abc import ABC, abstractmethod -from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler +from langchain.chains import LLMChain from pydantic.error_wrappers import ValidationError from app.db import with_session +from app.flask_app import socketio +from const.ai_assistant import AICommandType from lib.logger import get_logger -from logic import query_execution as qe_logic from lib.query_analysis.lineage import process_query from logic import admin as admin_logic from logic import metastore as m_logic -from models.query_execution import QueryExecution +from logic import query_execution as qe_logic from models.metastore import DataTableColumn +from models.query_execution import QueryExecution -LOG = get_logger(__file__) - - -class EventStream: - """Generator to facilitate streaming result from Langchain. - The stream format is based on Server-Sent Events (SSE).""" - - def __init__(self): - self.queue = queue.Queue() - - def __iter__(self): - return self - - def __next__(self): - item = self.queue.get() - if item is StopIteration: - raise item - return item - - def send(self, data: str): - self.queue.put("data: " + json.dumps({"data": data}) + "\n\n") - - def close(self): - # the empty data is to make sure the client receives the close event - self.queue.put("event: close\ndata: \n\n") - self.queue.put(StopIteration) - - def send_error(self, error: str): - self.queue.put("event: error\n") - data = json.dumps({"data": error}) - self.queue.put(f"data: {data}\n\n") - self.close() - - -class ChainStreamHandler(StreamingStdOutCallbackHandler): - """Callback handlder to stream the result to a generator.""" - - def __init__(self, stream: EventStream): - super().__init__() - self.stream = stream - - def on_llm_new_token(self, token: str, **kwargs): - self.stream.send(token) +from .prompts.sql_fix_prompt import SQL_FIX_PROMPT +from .prompts.sql_title_prompt import SQL_TITLE_PROMPT +from .prompts.text2sql_prompt import TEXT2SQL_PROMPT +from .streaming_web_socket_callback_handler import ( + WebSocketStream, + StreamingWebsocketCallbackHandler, +) - def on_llm_end(self, response, **kwargs): - self.stream.close() +LOG = get_logger(__file__) class BaseAIAssistant(ABC): @@ -87,6 +50,31 @@ def wrapper(self, *args, **kwargs): return wrapper + @abstractmethod + def _get_llm(self, callback_handler: StreamingWebsocketCallbackHandler): + """return the language model to use""" + + def _get_sql_title_prompt(self): + """Override this method to return specific prompt for your own assistant.""" + return SQL_TITLE_PROMPT + + def _get_text2sql_prompt(self): + """Override this method to return specific prompt for your own assistant.""" + return TEXT2SQL_PROMPT + + def _get_sql_fix_prompt(self): + """Override this method to return specific prompt for your own assistant.""" + return SQL_FIX_PROMPT + + def _get_ws_stream(self, command_type: str): + return WebSocketStream(socketio, command_type) + + def _get_llm_chain(self, command_type, prompt, memory=None): + ws_stream = self._get_ws_stream(command_type=command_type) + callback_handler = StreamingWebsocketCallbackHandler(ws_stream) + llm = self._get_llm(callback_handler=callback_handler) + return LLMChain(llm=llm, prompt=prompt, memory=memory) + def _get_error_msg(self, error) -> str: """Override this method to return specific error messages for your own assistant.""" if isinstance(error, ValidationError): @@ -167,6 +155,34 @@ def _get_query_execution_error(self, query_execution: QueryExecution) -> str: return error[:1000] + def handle_ai_command(self, command_type: str, payload: dict = {}): + try: + if command_type == AICommandType.SQL_TITLE.value: + query = payload["query"] + self.generate_title_from_query(query=query) + elif command_type == AICommandType.TEXT_TO_SQL.value: + original_query = payload["original_query"] + query_engine_id = payload["query_engine_id"] + tables = payload.get("tables") + question = payload["question"] + self.generate_sql_query( + query_engine_id=query_engine_id, + tables=tables, + question=question, + original_query=original_query, + ) + elif command_type == AICommandType.SQL_FIX.value: + query_execution_id = payload["query_execution_id"] + self.query_auto_fix( + query_execution_id=query_execution_id, + ) + else: + self._get_ws_stream(command_type=command_type).send_error( + "Unsupported command" + ) + except Exception as e: + self._get_ws_stream(command_type=command_type).send_error(str(e)) + @catch_error @with_session def generate_sql_query( @@ -175,9 +191,6 @@ def generate_sql_query( tables: list[str], question: str, original_query: str = None, - stream=True, - callback_handler: ChainStreamHandler = None, - user_id=None, session=None, ): query_engine = admin_logic.get_query_engine_by_id( @@ -186,36 +199,23 @@ def generate_sql_query( table_schemas = self._generate_table_schema_prompt( metastore_id=query_engine.metastore_id, table_names=tables, session=session ) - return self._generate_sql_query( - language=query_engine.language, - table_schemas=table_schemas, + + prompt = self._get_text2sql_prompt() + chain = self._get_llm_chain( + command_type=AICommandType.TEXT_TO_SQL.value, + prompt=prompt, + ) + return chain.run( + dialect=query_engine.language, question=question, + table_schemas=table_schemas, original_query=original_query, - stream=stream, - callback_handler=callback_handler, - user_id=user_id, ) - @abstractmethod - def _generate_sql_query( - self, - language: str, - table_schemas: str, - question: str, - original_query: str = None, - stream=True, - callback_handler: ChainStreamHandler = None, - user_id=None, - ): - raise NotImplementedError() - @catch_error def generate_title_from_query( self, query, - stream=True, - callback_handler: ChainStreamHandler = None, - user_id=None, ): """Generate title from SQL query. @@ -224,39 +224,24 @@ def generate_title_from_query( stream (bool, optional): Whether to stream the result. Defaults to True. callback_handler (CallbackHandler, optional): Callback handler to handle the straming result. Required if stream is True. """ - return self._generate_title_from_query( - query=query, - stream=stream, - callback_handler=callback_handler, - user_id=user_id, + prompt = self._get_sql_title_prompt() + chain = self._get_llm_chain( + command_type=AICommandType.SQL_TITLE.value, + prompt=prompt, ) - - @abstractmethod - def _generate_title_from_query( - self, - query, - stream, - callback_handler, - user_id=None, - ): - raise NotImplementedError() + return chain.run(query=query) @catch_error @with_session def query_auto_fix( self, query_execution_id: int, - stream: bool = True, - callback_handler: ChainStreamHandler = None, - user_id: int = None, session=None, ): """Generate title from SQL query. Args: query_execution_id (int): The failed query execution id - stream (bool, optional): Whether to stream the result. Defaults to True. - callback_handler (CallbackHandler, optional): Callback handler to handle the straming result. Required if stream is True. """ query_execution = qe_logic.get_query_execution_by_id( query_execution_id, session=session @@ -276,25 +261,14 @@ def query_auto_fix( session=session, ) - return self._query_auto_fix( - language=language, + prompt = self._get_sql_fix_prompt() + chain = self._get_llm_chain( + command_type=AICommandType.SQL_FIX.value, + prompt=prompt, + ) + return chain.run( + dialect=language, query=query_execution.query, error=self._get_query_execution_error(query_execution), table_schemas=table_schemas, - stream=stream, - callback_handler=callback_handler, - user_id=user_id, ) - - @abstractmethod - def _query_auto_fix( - self, - language: str, - query: str, - error: str, - table_schemas: str, - stream: bool, - callback_handler: ChainStreamHandler, - user_id=None, - ): - raise NotImplementedError() diff --git a/querybook/server/lib/ai_assistant/prompts/sql_fix_prompt.py b/querybook/server/lib/ai_assistant/prompts/sql_fix_prompt.py new file mode 100644 index 000000000..82294c528 --- /dev/null +++ b/querybook/server/lib/ai_assistant/prompts/sql_fix_prompt.py @@ -0,0 +1,50 @@ +from langchain.prompts import ( + ChatPromptTemplate, + SystemMessagePromptTemplate, + HumanMessagePromptTemplate, +) + + +system_message_template = ( + "You are a SQL expert that can help fix SQL query errors.\n\n" + "Please follow the format below for your response:\n" + "<@key-1@>\n" + "value-1\n\n" + "<@key-2@>\n" + "value-2\n\n" +) + +human_message_template = ( + "Please help fix the query below based on the given error message and table schemas. \n\n" + "===SQL dialect\n" + "{dialect}\n\n" + "===Query\n" + "{query}\n\n" + "===Error\n" + "{error}\n\n" + "===Table Schemas\n" + "{table_schemas}\n\n" + "===Response Format\n" + "<@key-1@>\n" + "value-1\n\n" + "<@key-2@>\n" + "value-2\n\n" + "===Example response:\n" + "<@explanation@>\n" + "This is an explanation about the error\n\n" + "<@fix_suggestion@>\n" + "This is a recommended fix for the error\n\n" + "<@fixed_query@>\n" + "The fixed SQL query\n\n" + "===Response Guidelines\n" + "1. For the <@fixed_query@> section, it can only be a valid SQL query without any explanation.\n" + "2. If there is insufficient context to address the query error, you may leave the fixed_query section blank and provide a general suggestion instead.\n" + "3. Maintain the original query format and case in the fixed_query section, including comments, except when correcting the erroneous part.\n" +) + +SQL_FIX_PROMPT = ChatPromptTemplate.from_messages( + [ + SystemMessagePromptTemplate.from_template(system_message_template), + HumanMessagePromptTemplate.from_template(human_message_template), + ] +) diff --git a/querybook/server/lib/ai_assistant/prompts/sql_title_prompt.py b/querybook/server/lib/ai_assistant/prompts/sql_title_prompt.py new file mode 100644 index 000000000..1c8d7bbb0 --- /dev/null +++ b/querybook/server/lib/ai_assistant/prompts/sql_title_prompt.py @@ -0,0 +1,29 @@ +from langchain.prompts import ( + ChatPromptTemplate, + SystemMessagePromptTemplate, + HumanMessagePromptTemplate, +) + + +system_message_template = ( + """You are a helpful assistant that can summerize SQL queries.""" +) + +human_message_template = ( + "Generate a brief 10-word-maximum title for the SQL query below. " + "===Query\n" + "{query}\n\n" + "===Response Guidelines\n" + "1. Only respond with the title without any explanation\n" + "2. Dont use double quotes to enclose the title\n" + "3. Dont add a final period to the title\n\n" + "===Example response\n" + "This is a title\n" +) + +SQL_TITLE_PROMPT = ChatPromptTemplate.from_messages( + [ + SystemMessagePromptTemplate.from_template(system_message_template), + HumanMessagePromptTemplate.from_template(human_message_template), + ] +) diff --git a/querybook/server/lib/ai_assistant/prompts/text2sql_prompt.py b/querybook/server/lib/ai_assistant/prompts/text2sql_prompt.py new file mode 100644 index 000000000..2b7a0f293 --- /dev/null +++ b/querybook/server/lib/ai_assistant/prompts/text2sql_prompt.py @@ -0,0 +1,46 @@ +from langchain.prompts import ( + ChatPromptTemplate, + SystemMessagePromptTemplate, + HumanMessagePromptTemplate, +) + + +system_message_template = "You are a SQL expert that can help generating SQL query." + +human_message_template = ( + "Please help to generate a new SQL query or modify the original query to answer the following question. Your response should ONLY be based on the given context.\n\n" + "Please always follow the key/value pair format below for your response:\n" + "===Response Format\n" + "<@query@>\n" + "query\n\n" + "or\n\n" + "<@explanation@>\n" + "explanation\n\n" + "===Example Response:\n" + "Example 1: Sufficient Context\n" + "<@query@>\n" + "A generated SQL query based on the provided context with the asked question at the beginning is provided here.\n\n" + "Example 2: Insufficient Context\n" + "<@explanation@>\n" + "An explanation of the missing context is provided here.\n\n" + "===Response Guidelines\n" + "1. If the provided context is sufficient, please respond only with a valid SQL query without any explanations in the <@query@> section. The query should start with a comment containing the question being asked.\n" + "2. If the provided context is insufficient, please explain what information is missing.\n" + "3. If the original query is provided, please modify the original query to answer the question. The original query may start with a comment containing a previously asked question. If you find such a comment, please use both the original question and the new question to generate the new query.\n" + "4. Please always honor the table schmeas for the query generation\n\n" + "===SQL Dialect\n" + "{dialect}\n\n" + "===Tables\n" + "{table_schemas}\n\n" + "===Original Query\n" + "{original_query}\n\n" + "===Question\n" + "{question}\n\n" +) + +TEXT2SQL_PROMPT = ChatPromptTemplate.from_messages( + [ + SystemMessagePromptTemplate.from_template(system_message_template), + HumanMessagePromptTemplate.from_template(human_message_template), + ] +) diff --git a/querybook/server/lib/ai_assistant/streaming_web_socket_callback_handler.py b/querybook/server/lib/ai_assistant/streaming_web_socket_callback_handler.py new file mode 100644 index 000000000..e9610cb5a --- /dev/null +++ b/querybook/server/lib/ai_assistant/streaming_web_socket_callback_handler.py @@ -0,0 +1,63 @@ +from flask import request +from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler + +from const.ai_assistant import ( + AI_ASSISTANT_NAMESPACE, + AI_ASSISTANT_RESPONSE_EVENT, +) + + +class WebSocketStream: + def __init__(self, socketio, command_type: str): + self.socketio = socketio + self.command_type = command_type + self.room = request.sid + + def _send(self, payload: dict): + self.socketio.emit( + AI_ASSISTANT_RESPONSE_EVENT, + ( + self.command_type, + payload, + ), + namespace=AI_ASSISTANT_NAMESPACE, + room=self.room, + ) + + def send(self, data: str): + self._send( + { + "event": "data", + "data": data, + } + ) + + def send_error(self, error: str): + self._send( + { + "event": "error", + "data": error, + } + ) + self.close() + + def close(self): + self._send( + { + "event": "close", + } + ) + + +class StreamingWebsocketCallbackHandler(StreamingStdOutCallbackHandler): + """Callback handlder to stream the result through web socket.""" + + def __init__(self, stream: WebSocketStream): + super().__init__() + self.stream = stream + + def on_llm_new_token(self, token: str, **kwargs): + self.stream.send(token) + + def on_llm_end(self, response, **kwargs): + self.stream.close() diff --git a/querybook/webapp/components/AIAssistant/AutoFixButton.tsx b/querybook/webapp/components/AIAssistant/AutoFixButton.tsx index 042472ae7..5c993c127 100644 --- a/querybook/webapp/components/AIAssistant/AutoFixButton.tsx +++ b/querybook/webapp/components/AIAssistant/AutoFixButton.tsx @@ -1,6 +1,7 @@ import React, { useState } from 'react'; import { QueryComparison } from 'components/TranspileQueryModal/QueryComparison'; +import { AICommandType } from 'const/aiAssistant'; import { ComponentType, ElementType } from 'const/analytics'; import { StreamStatus, useStream } from 'hooks/useStream'; import { trackClick } from 'lib/analytics'; @@ -26,7 +27,7 @@ export const AutoFixButton = ({ const [show, setShow] = useState(false); const { streamStatus, startStream, streamData, cancelStream } = useStream( - '/ds/ai/query_auto_fix/', + AICommandType.SQL_FIX, { query_execution_id: queryExecutionId, } diff --git a/querybook/webapp/components/AIAssistant/QueryGenerationModal.tsx b/querybook/webapp/components/AIAssistant/QueryGenerationModal.tsx index 4156c6b2d..2ed6107c1 100644 --- a/querybook/webapp/components/AIAssistant/QueryGenerationModal.tsx +++ b/querybook/webapp/components/AIAssistant/QueryGenerationModal.tsx @@ -3,6 +3,7 @@ import React, { useCallback, useEffect, useState } from 'react'; import { QueryEngineSelector } from 'components/QueryRunButton/QueryRunButton'; import { QueryComparison } from 'components/TranspileQueryModal/QueryComparison'; +import { AICommandType } from 'const/aiAssistant'; import { ComponentType, ElementType } from 'const/analytics'; import { IQueryEngine } from 'const/queryEngine'; import { StreamStatus, useStream } from 'hooks/useStream'; @@ -12,12 +13,12 @@ import { trimSQLQuery } from 'lib/stream'; import { matchKeyPress } from 'lib/utils/keyboard'; import { analyzeCode } from 'lib/web-worker'; import { Button } from 'ui/Button/Button'; -import { DebouncedInput } from 'ui/DebouncedInput/DebouncedInput'; import { Icon } from 'ui/Icon/Icon'; import { Message } from 'ui/Message/Message'; import { Modal } from 'ui/Modal/Modal'; import { ResizableTextArea } from 'ui/ResizableTextArea/ResizableTextArea'; import { StyledText } from 'ui/StyledText/StyledText'; +import { Tag } from 'ui/Tag/Tag'; import { TableSelector } from './TableSelector'; import { TextToSQLMode, TextToSQLModeSelector } from './TextToSQLModeSelector'; @@ -72,25 +73,29 @@ export const QueryGenerationModal = ({ const [textToSQLMode, setTextToSQLMode] = useState( !!query ? TextToSQLMode.EDIT : TextToSQLMode.GENERATE ); + const [newQuery, setNewQuery] = useState(''); useEffect(() => { setTables(uniq([...tablesInQuery, ...tables])); }, [tablesInQuery]); const { streamStatus, startStream, streamData, cancelStream } = useStream( - '/ds/ai/generate_query/', + AICommandType.TEXT_TO_SQL, { query_engine_id: engineId, tables: tables, question: question, - data_cell_id: - textToSQLMode === TextToSQLMode.EDIT ? dataCellId : undefined, + original_query: query, } ); - const { explanation, query: rawNewQuery } = streamData; + const { explanation, query: rawNewQuery, data } = streamData; - const newQuery = trimSQLQuery(rawNewQuery); + // const newQuery = trimSQLQuery(rawNewQuery); + + useEffect(() => { + setNewQuery(trimSQLQuery(rawNewQuery)); + }, [rawNewQuery]); const onKeyDown = useCallback( (event: React.KeyboardEvent) => { @@ -258,8 +263,8 @@ export const QueryGenerationModal = ({ {tables.length > 0 && ( <> {questionBarDOM} - {explanation && ( -
{explanation}
+ {(explanation || data) && ( +
{explanation || data}
)} {(query || newQuery) && ( @@ -272,7 +277,37 @@ export const QueryGenerationModal = ({ } toQuery={newQuery} fromQueryTitle="Original Query" - toQueryTitle="New Query" + toQueryTitle={ +
+ {New Query} +
+ } disableHighlight={ streamStatus === StreamStatus.STREAMING } diff --git a/querybook/webapp/components/QueryCellTitle/QueryCellTitle.tsx b/querybook/webapp/components/QueryCellTitle/QueryCellTitle.tsx index 531cf5caa..f8ed81bf4 100644 --- a/querybook/webapp/components/QueryCellTitle/QueryCellTitle.tsx +++ b/querybook/webapp/components/QueryCellTitle/QueryCellTitle.tsx @@ -2,6 +2,7 @@ import React, { useCallback, useEffect } from 'react'; import { useDispatch } from 'react-redux'; import PublicConfig from 'config/querybook_public_config.yaml'; +import { AICommandType } from 'const/aiAssistant'; import { ComponentType, ElementType } from 'const/analytics'; import { StreamStatus, useStream } from 'hooks/useStream'; import { trackClick } from 'lib/analytics'; @@ -38,9 +39,9 @@ export const QueryCellTitle: React.FC = ({ query; const { streamStatus, startStream, streamData } = useStream( - '/ds/ai/query_title/', + AICommandType.SQL_TITLE, { - data_cell_id: cellId, + query, } ); const { data: title } = streamData; diff --git a/querybook/webapp/components/TranspileQueryModal/QueryComparison.scss b/querybook/webapp/components/TranspileQueryModal/QueryComparison.scss index a6be10ec4..6ee8654e2 100644 --- a/querybook/webapp/components/TranspileQueryModal/QueryComparison.scss +++ b/querybook/webapp/components/TranspileQueryModal/QueryComparison.scss @@ -1,9 +1,6 @@ .QueryComparison { display: flex; gap: 8px; - .Tag { - margin-bottom: 12px; - } .diff-side-view { flex: 1; diff --git a/querybook/webapp/components/TranspileQueryModal/QueryComparison.tsx b/querybook/webapp/components/TranspileQueryModal/QueryComparison.tsx index 759c28da1..ba9e2e0ba 100644 --- a/querybook/webapp/components/TranspileQueryModal/QueryComparison.tsx +++ b/querybook/webapp/components/TranspileQueryModal/QueryComparison.tsx @@ -10,8 +10,8 @@ import './QueryComparison.scss'; export const QueryComparison: React.FC<{ fromQuery: string; toQuery: string; - fromQueryTitle?: string; - toQueryTitle?: string; + fromQueryTitle?: string | React.ReactNode; + toQueryTitle?: string | React.ReactNode; disableHighlight?: boolean; hideEmptyQuery?: boolean; }> = ({ @@ -63,7 +63,15 @@ export const QueryComparison: React.FC<{
{!(hideEmptyQuery && !fromQuery) && (
- {fromQueryTitle && {fromQueryTitle}} + {fromQueryTitle && ( +
+ {typeof fromQueryTitle === 'string' ? ( + {fromQueryTitle} + ) : ( + fromQueryTitle + )} +
+ )} - {toQueryTitle && {toQueryTitle}} + {toQueryTitle && ( +
+ {typeof toQueryTitle === 'string' ? ( + {toQueryTitle} + ) : ( + toQueryTitle + )} +
+ )} = {} ): { streamStatus: StreamStatus; @@ -21,17 +22,17 @@ export function useStream( } { const [streamStatus, setSteamStatus] = useState(StreamStatus.NOT_STARTED); const [data, setData] = useState<{ [key: string]: string }>({}); - const streamRef = useRef(null); + const streamRef = useRef<{ close: () => void } | null>(null); const startStream = useCallback(() => { setSteamStatus(StreamStatus.STREAMING); setData({}); - streamRef.current = ds.stream(url, params, setData, (data) => { + streamRef.current = ds.stream(commandType, params, setData, (data) => { setData(data); setSteamStatus(StreamStatus.FINISHED); }); - }, [url, params]); + }, [commandType, params]); const resetStream = useCallback(() => { setSteamStatus(StreamStatus.NOT_STARTED); diff --git a/querybook/webapp/lib/ai-assistant/ai-assistant-socketio.ts b/querybook/webapp/lib/ai-assistant/ai-assistant-socketio.ts new file mode 100644 index 000000000..0ce6281a6 --- /dev/null +++ b/querybook/webapp/lib/ai-assistant/ai-assistant-socketio.ts @@ -0,0 +1,66 @@ +import type { Socket } from 'socket.io-client'; + +import { + AI_ASSISTANT_REQUEST_EVENT, + AI_ASSISTANT_RESPONSE_EVENT, + AICommandType, +} from 'const/aiAssistant'; +import SocketIOManager from 'lib/socketio-manager'; + +class AIAssistantSocket { + private static NAME_SPACE = '/ai_assistant'; + + private socket: Socket = null; + private socketPromise: Promise = null; + + public constructor() { + this.setupSocket(); + } + + public onSocketConnect(socket: Socket) { + socket.emit('subscribe'); + } + + public requestAIAssistant = (command: AICommandType, payload: object) => { + this.socket.emit(AI_ASSISTANT_REQUEST_EVENT, command, payload); + }; + + public addAIListener = ( + listener: (command: string, payload: object) => void + ) => { + this.socket.on(AI_ASSISTANT_RESPONSE_EVENT, listener); + }; + + public removeAIListener = ( + listener: (command: string, payload: object) => void + ) => { + this.socket.off(AI_ASSISTANT_RESPONSE_EVENT, listener); + }; + + private setupSocket = async () => { + if (this.socket) { + return this.socket; + } + if (this.socketPromise) { + this.socket = await this.socketPromise; + } else { + // We need to setup our socket + this.socketPromise = SocketIOManager.getSocket( + AIAssistantSocket.NAME_SPACE, + this.onSocketConnect.bind(this) + ); + + // Setup socket's connection functions + this.socket = await this.socketPromise; + } + + this.socket.on('error', (e) => { + console.error('Socket error', e); + }); + + return this.socket; + }; +} + +const defaultSocket = new AIAssistantSocket(); +export default defaultSocket; diff --git a/querybook/webapp/lib/datasource.ts b/querybook/webapp/lib/datasource.ts index f39d0696c..baca1f972 100644 --- a/querybook/webapp/lib/datasource.ts +++ b/querybook/webapp/lib/datasource.ts @@ -1,6 +1,8 @@ +import aiAssistantSocket from './ai-assistant/ai-assistant-socketio'; import axios, { AxiosRequestConfig, Canceler, Method } from 'axios'; import toast from 'react-hot-toast'; +import { AICommandType } from 'const/aiAssistant'; import { setSessionExpired } from 'lib/querybookUI'; import { formatError } from 'lib/utils/error'; @@ -159,7 +161,7 @@ export function uploadDatasource( } /** - * Stream data from a datasource using EventSource + * Stream data from WebSocket * * The data is streamed in the form of deltas. Each delta is a JSON object * ``` @@ -168,41 +170,48 @@ export function uploadDatasource( * } * ``` * - * @param url The url to stream from - * @param params The data to send to the url + * @param commandType The ai command type + * @param params The data to send * @param onStraming Callback when data is received. The data is the accumulated data. * @param onStramingEnd Callback when the stream ends */ function streamDatasource( - url: string, + commandType: AICommandType, params?: Record, onStreaming?: (data: { [key: string]: string }) => void, onStreamingEnd?: (data: { [key: string]: string }) => void ) { - const eventSource = new EventSource( - `${url}?params=${JSON.stringify(params)}` - ); const parser = new DeltaStreamParser(); - eventSource.addEventListener('message', (e) => { - const newToken = JSON.parse(e.data).data; - parser.parse(newToken); - onStreaming?.(parser.result); - }); - eventSource.addEventListener('error', (e) => { - console.error(e); - eventSource.close(); - onStreamingEnd?.(parser.result); - if (e instanceof MessageEvent) { - toast.error(JSON.parse(e.data).data); + + const onData = (command, payload) => { + if (command !== commandType) { + return; } - }); - eventSource.addEventListener('close', (e) => { - eventSource.close(); - parser.close(); - onStreamingEnd?.(parser.result); - }); - - return eventSource; + + if (payload.event === 'close') { + aiAssistantSocket.removeAIListener(onData); + parser.close(); + onStreamingEnd?.(parser.result); + return; + } else if (payload.event === 'error') { + aiAssistantSocket.removeAIListener(onData); + toast.error(payload.data); + onStreamingEnd?.(parser.result); + return; + } + + parser.parse(payload.data); + onStreaming?.(parser.result); + }; + + aiAssistantSocket.addAIListener(onData); + + aiAssistantSocket.requestAIAssistant(commandType, params); + return { + close: () => { + aiAssistantSocket.removeAIListener(onData); + }, + }; } export default {