From 819f78ba602470e0eac7ec40d9fd2f381f93da88 Mon Sep 17 00:00:00 2001 From: Rami Abdou <38056800+ramiAbdou@users.noreply.github.com> Date: Wed, 11 Sep 2024 14:44:21 -0700 Subject: [PATCH] =?UTF-8?q?feat:=20update=20ai=20answer=20for=20public=20q?= =?UTF-8?q?uestion=20=F0=9F=94=BC=20(#512)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- packages/core/src/modules/slack/slack.ts | 277 ++++++++++++++++------- 1 file changed, 190 insertions(+), 87 deletions(-) diff --git a/packages/core/src/modules/slack/slack.ts b/packages/core/src/modules/slack/slack.ts index b36d651e..e0647397 100644 --- a/packages/core/src/modules/slack/slack.ts +++ b/packages/core/src/modules/slack/slack.ts @@ -1,5 +1,7 @@ +import dayjs from 'dayjs'; import dedent from 'dedent'; import { type ExpressionBuilder } from 'kysely'; +import { match } from 'ts-pattern'; import { type DB, db } from '@oyster/db'; @@ -144,10 +146,8 @@ type AnswerPublicQuestionInput = { }; /** - * Answers a question asked in a public Slack message. - * - * This uses the underlying `getAnswerFromSlackHistory` function to answer the - * question, and then sends the answer to that thread. + * Answers a question asked in a public Slack message by linking to relevant + * threads in our Slack workspace. * * @param input - The message (public question) to answer. * @returns The result of the answer. @@ -192,24 +192,60 @@ export async function answerPublicQuestion({ return success({}); } - job('notification.slack.send', { - channel: channelId, - message: 'Searching our Slack history...', - threadId, - workspace: 'regular', + const threadsResult = await getMostRelevantThreads(text, { + exclude: [threadId], + topK: 10, }); - const answerResult = await getAnswerFromSlackHistory(text); + if (!threadsResult.ok) { + return threadsResult; + } - if (!answerResult.ok) { - return answerResult; + const threads = threadsResult.data + .filter((thread) => { + return thread.score >= 0.75; + }) + .map((thread, i) => { + const date = dayjs(thread.createdAt) + .tz('America/Los_Angeles') + .format('M/D/YY'); + + const emoji = match(i + 1) + .with(1, () => '1️⃣') + .with(2, () => '2️⃣') + .with(3, () => '3️⃣') + .with(4, () => '4️⃣') + .with(5, () => '5️⃣') + .with(6, () => '6️⃣') + .with(7, () => '7️⃣') + .with(8, () => '8️⃣') + .with(9, () => '9️⃣') + .with(10, () => '🔟') + .otherwise(() => ''); + + const message = + thread.message.length > 100 + ? thread.message.slice(0, 100) + '...' + : thread.message; + + return `${emoji}. [${date}] `; + }); + + if (!threads.length) { + // Though we didn't find any relevant threads, this is still a "success". + return success({}); } - const answerWithReferences = addThreadReferences(answerResult.data); + const message = + 'I found some threads in our workspace that _may_ be relevant to your question! 🔎' + + '\n\n' + + threads.join('\n') + + '\n\n' + + `_I'm a ColorStack AI assistant with the full context of our Slack workspace! I can answer your questions in detail -- just send me a DM !_`; job('notification.slack.send', { channel: channelId, - message: answerWithReferences, + message, threadId, workspace: 'regular', }); @@ -275,11 +311,8 @@ async function isQuestion(question: string): Promise> { * Ask a question to the Slack workspace. * * This is a RAG (Retrieval Augmented Generation) implementation that works - * as follows: - * - Create an embedding for the question. - * - Query the vector database for the most similar Slack messages. - * - Pass the most similar Slack threads found to an LLM. - * - Return the answer. + * by finding the most relevant Slack threads to the question and passing them + * to an LLM with additional instructions for answering. * * @param question - The question to ask. * @returns The answer to the question. @@ -287,79 +320,22 @@ async function isQuestion(question: string): Promise> { async function getAnswerFromSlackHistory( question: string ): Promise> { - const embeddingResult = await createEmbedding(question); - - if (!embeddingResult.ok) { - return fail(embeddingResult); - } - - const { matches } = await getPineconeIndex('slack-messages').query({ - includeMetadata: true, - topK: 50, - vector: embeddingResult.data, - }); - - const messages = await Promise.all( - matches.map(async (match) => { - const [thread, replies] = await Promise.all([ - db - .selectFrom('slackMessages') - .select(['channelId', 'createdAt', 'text']) - .where('id', '=', match.id) - .executeTakeFirst(), - - db - .selectFrom('slackMessages') - .select(['text']) - .where('threadId', '=', match.id) - .orderBy('createdAt', 'asc') - .limit(50) - .execute(), - ]); - - const formattedReplies = replies - .map((message) => message.text) - .join('\n'); - - return { - channelId: thread?.channelId || '', - createdAt: thread?.createdAt.toISOString() || '', - message: thread?.text || '', - replies: formattedReplies, - threadId: match.id, - }; - }) - ); - - // This next step is an important one -- we're going to rerank the messages - // based on their relevance to the question. This helps us get the most - // relevant threads to the LLM. Reranking models are different from - // vector search which are optimized for fast retrieval. Reranking models are - // more accurate at assessing relevance, but they are slower and more - // expensive to compute. - - const documents = messages.map((message) => { - return [message.createdAt, message.message, message.replies].join('\n'); - }); - - const rerankingResult = await rerankDocuments(question, documents, { + const threadsResult = await getMostRelevantThreads(question, { topK: 5, }); - if (!rerankingResult.ok) { - return fail(rerankingResult); + if (!threadsResult.ok) { + return threadsResult; } - const rerankedThreads = rerankingResult.data.map((document) => { - const message = messages[document.index]; - + const threads = threadsResult.data.map((thread) => { const parts = [ - '[Relevance Score]: ' + document.relevance_score, - '[Timestamp]: ' + message.createdAt, - '[Channel ID]: ' + message.channelId, - '[Thread ID]: ' + message.threadId, - '[Message]: ' + message.message, - '[Replies]: ' + message.replies, + '[Relevance Score]: ' + thread.score, + '[Timestamp]: ' + thread.createdAt, + '[Channel ID]: ' + thread.channelId, + '[Thread ID]: ' + thread.id, + '[Message]: ' + thread.message, + '[Replies]: ' + thread.replies, ]; return parts.join('\n'); @@ -368,7 +344,7 @@ async function getAnswerFromSlackHistory( const userPrompt = [ 'Please answer the following question based on the Slack context provided:', `${question}`, - `${rerankedThreads.join('\n\n')}`, + `${threads.join('\n\n')}`, ].join('\n'); const systemPrompt = dedent` @@ -437,6 +413,11 @@ async function getAnswerFromSlackHistory( particularly if the sentiment is a negative/speculative one. - Respond like you are an ambassador for the ColorStack community. + + + - MAINTAIN CONSISTENT THREAD NUMBERING: Each unique thread should always + be assigned the same reference number throughout the response. + `; const completionResult = await getChatCompletion({ @@ -458,6 +439,128 @@ async function getAnswerFromSlackHistory( return success(completionResult.data); } +type GetMostRelevantThreadsOptions = { + /** + * The IDs of the threads to exclude from the search. + * + * The common use case for this is that if we are answering a question in a + * thread, we don't want to include the current thread in the search. + */ + exclude?: string[]; + + /** + * The maximum number of threads to return. Note that this refers to the final + * number of threads AFTER reranking, not the initial vector database + * retrieval. + */ + topK: number; +}; + +type RelevantThread = { + channelId: string; + createdAt: string; + id: string; + message: string; + replies: string; + score: number; +}; + +/** + * Finds the most relevant threads to a question. + * + * This works by: + * - Creating an embedding for the question. + * - Querying the vector database for the most similar Slack messages. + * - Populating the results with more metadata. + * - Reranking the results using an different model. + * + * @param question - The question to get the most relevant threads for. + * @param options - The options for the query. + * @returns The most relevant threads to the question. + */ +async function getMostRelevantThreads( + question: string, + options: GetMostRelevantThreadsOptions +): Promise> { + const embeddingResult = await createEmbedding(question); + + if (!embeddingResult.ok) { + return embeddingResult; + } + + const embedding = embeddingResult.data; + + const { matches } = await getPineconeIndex('slack-messages').query({ + includeMetadata: true, + topK: 50, + vector: embedding, + }); + + const filteredMatches = matches.filter((match) => { + return !options.exclude?.includes(match.id); + }); + + const messages = await Promise.all( + filteredMatches.map(async (match) => { + const [thread, replies] = await Promise.all([ + db + .selectFrom('slackMessages') + .select(['channelId', 'createdAt', 'text']) + .where('id', '=', match.id) + .executeTakeFirst(), + + db + .selectFrom('slackMessages') + .select(['text']) + .where('threadId', '=', match.id) + .orderBy('createdAt', 'asc') + .limit(50) + .execute(), + ]); + + const formattedReplies = replies + .map((message) => message.text) + .join('\n'); + + return { + channelId: thread?.channelId || '', + createdAt: thread?.createdAt.toISOString() || '', + id: match.id, + message: thread?.text || '', + replies: formattedReplies, + }; + }) + ); + + // This next step is an important one -- we're going to rerank the messages + // based on their relevance to the question. This helps us get the most + // relevant threads to the LLM. Reranking models are different from + // vector search which are optimized for fast retrieval. Reranking models are + // more accurate at assessing relevance, but they are slower and more + // expensive to compute. + + const documents = messages.map((message) => { + return [message.createdAt, message.message, message.replies].join('\n'); + }); + + const rerankingResult = await rerankDocuments(question, documents, { + topK: options.topK, + }); + + if (!rerankingResult.ok) { + return rerankingResult; + } + + const threads = rerankingResult.data.map((document) => { + return { + ...messages[document.index], + score: document.relevance_score, + }; + }); + + return success(threads); +} + type SyncThreadInput = { /** * The action that was performed on the thread.