diff --git a/LICENSE_HEADER b/LICENSE_HEADER index b6d7fb1..034ee43 100644 --- a/LICENSE_HEADER +++ b/LICENSE_HEADER @@ -1,4 +1,4 @@ -Copyright 2023-present, Argilla, Inc. +Copyright 2024-present, Argilla, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -10,4 +10,4 @@ Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and -limitations under the License. \ No newline at end of file +limitations under the License. diff --git a/README.md b/README.md index 24183e7..e010e58 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,7 @@ If you already have deployed Argilla, you can skip this step. Otherwise, you can ## Basic Usage -To easily log your data into Argilla within your LlamaIndex workflow, you only need a simple step. Just call the Argilla global handler for Llama Index before starting production with your LLM. +To easily log your data into Argilla within your LlamaIndex workflow, you only need to initialize the handler and attach it to the LlamaIndex dispatcher. This ensured that the predictions obtained using LlamaIndex are automatically logged to the Argilla instance. - `dataset_name`: The name of the dataset. If the dataset does not exist, it will be created with the specified name. Otherwise, it will be updated. - `api_url`: The URL to connect to the Argilla instance. @@ -33,23 +33,24 @@ To easily log your data into Argilla within your LlamaIndex workflow, you only n > For more information about the credentials, check the documentation for [users](https://docs.argilla.io/latest/how_to_guides/user/) and [workspaces](https://docs.argilla.io/latest/how_to_guides/workspace/). ```python -from llama_index.core import set_global_handler +from llama_index.core.instrumentation import get_dispatcher +from argilla_llama_index import ArgillaHandler -set_global_handler( - "argilla", - dataset_name="query_model", +argilla_handler = ArgillaHandler( + dataset_name="query_llama_index", api_url="http://localhost:6900", api_key="argilla.apikey", number_of_retrievals=2, ) +root_dispatcher = get_dispatcher() +root_dispatcher.add_span_handler(argilla_handler) +root_dispatcher.add_event_handler(argilla_handler) ``` -Let's log some data into Argilla. With the code below, you can create a basic LlamaIndex workflow. We will use GPT3.5 from OpenAI as our LLM ([OpenAI API key](https://openai.com/blog/openai-api)). Moreover, we will use an example `.txt` file obtained from the [Llama Index documentation](https://docs.llamaindex.ai/en/stable/getting_started/starter_example.html). - - +Let's log some data into Argilla. With the code below, you can create a basic LlamaIndex workflow. We will use GPT3.5 from OpenAI as our LLM ([OpenAI API key](https://openai.com/blog/openai-api)). Moreover, we will use an example `.txt` file obtained from the [LlamaIndex documentation](https://docs.llamaindex.ai/en/stable/getting_started/starter_example.html). ```python -import os +import os from llama_index.core import Settings, VectorStoreIndex, SimpleDirectoryReader from llama_index.llms.openai import OpenAI @@ -63,8 +64,8 @@ Settings.llm = OpenAI( documents = SimpleDirectoryReader("data").load_data() index = VectorStoreIndex.from_documents(documents) -# Create the query engine -query_engine = index.as_query_engine() +# Create the query engine with the same similarity top k as the number of retrievals +query_engine = index.as_query_engine(similarity_top_k=2) ``` Now, let's run the `query_engine` to have a response from the model. The generated response will be logged into Argilla. diff --git a/docs/assets/UI-screenshot-github.png b/docs/assets/UI-screenshot-github.png index 48c3997..d7e9c7e 100644 Binary files a/docs/assets/UI-screenshot-github.png and b/docs/assets/UI-screenshot-github.png differ diff --git a/docs/assets/UI-screenshot.png b/docs/assets/UI-screenshot.png index 1c3ea00..3ae8b16 100644 Binary files a/docs/assets/UI-screenshot.png and b/docs/assets/UI-screenshot.png differ diff --git a/docs/assets/UI-screeshot-workflow.png b/docs/assets/UI-screeshot-workflow.png new file mode 100644 index 0000000..0e23c95 Binary files /dev/null and b/docs/assets/UI-screeshot-workflow.png differ diff --git a/docs/assets/screenshot-workflow.png b/docs/assets/screenshot-workflow.png new file mode 100644 index 0000000..7958de3 Binary files /dev/null and b/docs/assets/screenshot-workflow.png differ diff --git a/docs/tutorials/getting_started.ipynb b/docs/tutorials/getting_started.ipynb index abd8e2c..ac9f4db 100644 --- a/docs/tutorials/getting_started.ipynb +++ b/docs/tutorials/getting_started.ipynb @@ -6,9 +6,9 @@ "source": [ "# ✨🦙 Getting started with Argilla's LlamaIndex Integration\n", "\n", - "In this tutorial, we will show the basic usage of this integration that allows the user to include the feedback loop that Argilla offers into the LlamaIndex ecosystem. It's based on a callback handler to be run within the LlamaIndex workflow. \n", + "In this tutorial, we will show the basic usage of this integration that allows the user to include the feedback loop that Argilla offers into the LlamaIndex ecosystem. It's based on the span and event handlers to be run within the LlamaIndex workflow.\n", "\n", - "Don't hesitate to check out both [LlamaIndex](https://github.com/run-llama/llama_index) and [Argilla](https://github.com/argilla-io/argilla)" + "Don't hesitate to check out both [LlamaIndex](https://github.com/run-llama/llama_index) and [Argilla](https://github.com/argilla-io/argilla)\n" ] }, { @@ -19,7 +19,7 @@ "\n", "### Deploy the Argilla server¶\n", "\n", - "If you already have deployed Argilla, you can skip this step. Otherwise, you can quickly deploy Argilla following [this guide](https://docs.argilla.io/latest/getting_started/quickstart/)." + "If you already have deployed Argilla, you can skip this step. Otherwise, you can quickly deploy Argilla following [this guide](https://docs.argilla.io/latest/getting_started/quickstart/).\n" ] }, { @@ -28,7 +28,7 @@ "source": [ "### Set up the environment¶\n", "\n", - "To complete this tutorial, you need to install this integration." + "To complete this tutorial, you need to install this integration.\n" ] }, { @@ -37,14 +37,14 @@ "metadata": {}, "outputs": [], "source": [ - "!pip install argilla-llama-index" + "%pip install \"argilla-llama-index>=2.1.0\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Let's make the required imports:" + "Let's make the required imports:\n" ] }, { @@ -57,9 +57,11 @@ " Settings,\n", " VectorStoreIndex,\n", " SimpleDirectoryReader,\n", - " set_global_handler,\n", ")\n", - "from llama_index.llms.openai import OpenAI" + "from llama_index.core.instrumentation import get_dispatcher\n", + "from llama_index.llms.openai import OpenAI\n", + "\n", + "from argilla_llama_index import ArgillaHandler" ] }, { @@ -87,7 +89,7 @@ "source": [ "## Set the Argilla's LlamaIndex handler\n", "\n", - "To easily log your data into Argilla within your LlamaIndex workflow, you only need a simple step. Just call the Argilla global handler for Llama Index before starting production with your LLM. This ensured that the predictions obtained using Llama Index are automatically logged to the Argilla instance.\n", + "To easily log your data into Argilla within your LlamaIndex workflow, you only need to initialize the Argilla handler and attach it to the Llama Index dispatcher for spans and events. This ensures that the predictions obtained using Llama Index are automatically logged to the Argilla instance, along with the useful metadata.\n", "\n", "- `dataset_name`: The name of the dataset. If the dataset does not exist, it will be created with the specified name. Otherwise, it will be updated.\n", "- `api_url`: The URL to connect to the Argilla instance.\n", @@ -95,7 +97,7 @@ "- `number_of_retrievals`: The number of retrieved documents to be logged. Defaults to 0.\n", "- `workspace_name`: The name of the workspace to log the data. By default, the first available workspace.\n", "\n", - "> For more information about the credentials, check the documentation for [users](https://docs.argilla.io/latest/how_to_guides/user/) and [workspaces](https://docs.argilla.io/latest/how_to_guides/workspace/)." + "> For more information about the credentials, check the documentation for [users](https://docs.argilla.io/latest/how_to_guides/user/) and [workspaces](https://docs.argilla.io/latest/how_to_guides/workspace/).\n" ] }, { @@ -104,27 +106,29 @@ "metadata": {}, "outputs": [], "source": [ - "set_global_handler(\n", - " \"argilla\",\n", - " dataset_name=\"query_model\",\n", + "argilla_handler = ArgillaHandler(\n", + " dataset_name=\"query_llama_index\",\n", " api_url=\"http://localhost:6900\",\n", " api_key=\"argilla.apikey\",\n", " number_of_retrievals=2,\n", - ")" + ")\n", + "root_dispatcher = get_dispatcher()\n", + "root_dispatcher.add_span_handler(argilla_handler)\n", + "root_dispatcher.add_event_handler(argilla_handler)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Log the data to Argilla" + "## Log the data to Argilla\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "With the code below, you can create a basic LlamaIndex workflow. We will use an example `.txt` file obtained from the [Llama Index documentation](https://docs.llamaindex.ai/en/stable/getting_started/starter_example.html)." + "With the code below, you can create a basic LlamaIndex workflow. We will use an example `.txt` file obtained from the [Llama Index documentation](https://docs.llamaindex.ai/en/stable/getting_started/starter_example.html).\n" ] }, { @@ -145,21 +149,23 @@ "outputs": [], "source": [ "# LLM settings\n", - "Settings.llm = OpenAI(model=\"gpt-3.5-turbo\", temperature=0.8, openai_api_key=openai_api_key)\n", + "Settings.llm = OpenAI(\n", + " model=\"gpt-3.5-turbo\", temperature=0.8, openai_api_key=openai_api_key\n", + ")\n", "\n", "# Load the data and create the index\n", "documents = SimpleDirectoryReader(\"../../data\").load_data()\n", "index = VectorStoreIndex.from_documents(documents)\n", "\n", - "# Create the query engine\n", - "query_engine = index.as_query_engine()" + "# Create the query engine with the same similarity top k as the number of retrievals\n", + "query_engine = index.as_query_engine(similarity_top_k=2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Now, let's run the `query_engine` to have a response from the model." + "Now, let's run the `query_engine` to have a response from the model.\n" ] }, { @@ -178,7 +184,7 @@ "source": [ "The prompt given and the response obtained will be logged in as a chat, as well as the indicated number of retrieved documents.\n", "\n", - "![Argilla UI](../assets/UI-screenshot.png)" + "![Argilla UI](../assets/UI-screenshot.png)\n" ] } ], diff --git a/docs/tutorials/github_rag_llamaindex_argilla.ipynb b/docs/tutorials/github_rag_llamaindex_argilla.ipynb index eb47e33..75066fb 100644 --- a/docs/tutorials/github_rag_llamaindex_argilla.ipynb +++ b/docs/tutorials/github_rag_llamaindex_argilla.ipynb @@ -9,12 +9,13 @@ "In this tutorial, we'll show you how to create a RAG system that can answer questions about a specific GitHub repository. As example, we will target the [Argilla repository](https://github.com/argilla-io/argilla). This RAG system will target the docs of the repository, as that's where most of the natural language information about the repository can be found.\n", "\n", "This tutorial includes the following steps:\n", - "- Setting up the Argilla callback handler for LlamaIndex.\n", - "- Initializing a GitHub client\n", - "- Creating an index with a specific set of files from the GitHub repository of our choice.\n", - "- Create a RAG system out of the Argilla repository, ask questions, and automatically log the answers to Argilla.\n", "\n", - "This tutorial is based on the [Github Repository Reader](https://docs.llamaindex.ai/en/stable/examples/data_connectors/GithubRepositoryReaderDemo/) made by LlamaIndex." + "- Setting up the Argilla handler for LlamaIndex.\n", + "- Initializing a GitHub client\n", + "- Creating an index with a specific set of files from the GitHub repository of our choice.\n", + "- Create a RAG system out of the Argilla repository, ask questions, and automatically log the answers to Argilla.\n", + "\n", + "This tutorial is based on the [Github Repository Reader](https://docs.llamaindex.ai/en/stable/examples/data_connectors/GithubRepositoryReaderDemo/) made by LlamaIndex.\n" ] }, { @@ -25,7 +26,7 @@ "\n", "### Deploy the Argilla server¶\n", "\n", - "If you already have deployed Argilla, you can skip this step. Otherwise, you can quickly deploy Argilla following [this guide](https://docs.argilla.io/latest/getting_started/quickstart/)." + "If you already have deployed Argilla, you can skip this step. Otherwise, you can quickly deploy Argilla following [this guide](https://docs.argilla.io/latest/getting_started/quickstart/).\n" ] }, { @@ -34,7 +35,7 @@ "source": [ "### Set up the environment¶\n", "\n", - "To complete this tutorial, you need to install this integration and a third-party library via pip." + "To complete this tutorial, you need to install this integration and a third-party library via pip.\n" ] }, { @@ -43,15 +44,15 @@ "metadata": {}, "outputs": [], "source": [ - "!pip install \"argilla-llama-index>=2.0.0\"\n", - "!pip install \"llama-index-readers-github==0.1.9\"" + "%pip install \"argilla-llama-index>=2.1.0\"\n", + "%pip install \"llama-index-readers-github==0.1.9\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Let's make the required imports:" + "Let's make the required imports:\n" ] }, { @@ -63,25 +64,27 @@ "from llama_index.core import (\n", " Settings,\n", " VectorStoreIndex,\n", - " set_global_handler,\n", ")\n", + "from llama_index.core.instrumentation import get_dispatcher\n", "from llama_index.llms.openai import OpenAI\n", "from llama_index.readers.github import (\n", " GithubClient,\n", " GithubRepositoryReader,\n", - ")" + ")\n", + "\n", + "from argilla_llama_index import ArgillaHandler" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "We need to set the OpenAI API key and the GitHub token. The OpenAI API key is required to run queries using GPT models, while the GitHub token ensures you have access to the repository you're using. Although the GitHub token might not be necessary for public repositories, it is still recommended." + "We need to set the OpenAI API key and the GitHub token. The OpenAI API key is required to run queries using GPT models, while the GitHub token ensures you have access to the repository you're using. Although the GitHub token might not be necessary for public repositories, it is still recommended.\n" ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -90,7 +93,7 @@ "os.environ[\"OPENAI_API_KEY\"] = \"sk-...\"\n", "openai_api_key = os.getenv(\"OPENAI_API_KEY\")\n", "\n", - "os.environ[\"GITHUB_TOKEN\"] = \"ghp_...\"\n", + "os.environ[\"GITHUB_TOKEN\"] = \"github_pat_....\"\n", "github_token = os.getenv(\"GITHUB_TOKEN\")" ] }, @@ -100,7 +103,7 @@ "source": [ "## Set the Argilla's LlamaIndex handler\n", "\n", - "To easily log your data into Argilla within your LlamaIndex workflow, you only need a simple step. Just call the Argilla global handler for Llama Index before starting production with your LLM. This ensured that the predictions obtained using Llama Index are automatically logged to the Argilla instance.\n", + "To easily log your data into Argilla within your LlamaIndex workflow, you only need to initialize the Argilla handler and attach it to the Llama Index dispatcher for spans and events. This ensures that the predictions obtained using Llama Index are automatically logged to the Argilla instance, along with the useful metadata.\n", "\n", "- `dataset_name`: The name of the dataset. If the dataset does not exist, it will be created with the specified name. Otherwise, it will be updated.\n", "- `api_url`: The URL to connect to the Argilla instance.\n", @@ -108,7 +111,7 @@ "- `number_of_retrievals`: The number of retrieved documents to be logged. Defaults to 0.\n", "- `workspace_name`: The name of the workspace to log the data. By default, the first available workspace.\n", "\n", - "> For more information about the credentials, check the documentation for [users](https://docs.argilla.io/latest/how_to_guides/user/) and [workspaces](https://docs.argilla.io/latest/how_to_guides/workspace/)." + "> For more information about the credentials, check the documentation for [users](https://docs.argilla.io/latest/how_to_guides/user/) and [workspaces](https://docs.argilla.io/latest/how_to_guides/workspace/).\n" ] }, { @@ -117,13 +120,15 @@ "metadata": {}, "outputs": [], "source": [ - "set_global_handler(\n", - " \"argilla\",\n", - " dataset_name=\"github_query_model\",\n", + "argilla_handler = ArgillaHandler(\n", + " dataset_name=\"github_query_llama_index\",\n", " api_url=\"http://localhost:6900\",\n", " api_key=\"argilla.apikey\",\n", " number_of_retrievals=2,\n", - ")" + ")\n", + "root_dispatcher = get_dispatcher()\n", + "root_dispatcher.add_span_handler(argilla_handler)\n", + "root_dispatcher.add_event_handler(argilla_handler)" ] }, { @@ -132,12 +137,12 @@ "source": [ "## Retrieve the data from GitHub\n", "\n", - "First, we need to initialize the GitHub client, which will include the GitHub token for repository access." + "First, we need to initialize the GitHub client, which will include the GitHub token for repository access.\n" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -148,12 +153,12 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Before creating our `GithubRepositoryReader` instance, we need to adjust the nesting. Since the Jupyter kernel operates on an event loop, we must prevent this loop from finishing before the repository is fully read." + "Before creating our `GithubRepositoryReader` instance, we need to adjust the nesting. Since the Jupyter kernel operates on an event loop, we must prevent this loop from finishing before the repository is fully read.\n" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -166,12 +171,12 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Now, let’s create a GithubRepositoryReader instance with the necessary repository details. In this case, we'll target the `main` branch of the `argilla` repository. As we will focus on the documentation, we will focus on the `argilla/docs/` folder, excluding images, json files, and ipynb files." + "Now, let’s create a GithubRepositoryReader instance with the necessary repository details. In this case, we'll target the `main` branch of the `argilla` repository. As we will focus on the documentation, we will focus on the `argilla/docs/` folder, excluding images, json files, and ipynb files.\n" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -194,8 +199,7 @@ " \".svg\",\n", " \".ico\",\n", " \".json\",\n", - " \".ipynb\", # Erase this line if you want to include notebooks\n", - "\n", + " \".ipynb\", # Erase this line if you want to include notebooks\n", " ],\n", " GithubRepositoryReader.FilterType.EXCLUDE,\n", " ),\n", @@ -206,14 +210,14 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Create the index and make some queries" + "## Create the index and make some queries\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Now, let's create a LlamaIndex index out of this document, and we can start querying the RAG system." + "Now, let's create a LlamaIndex index out of this document, and we can start querying the RAG system.\n" ] }, { @@ -230,8 +234,8 @@ "# Load the data and create the index\n", "index = VectorStoreIndex.from_documents(documents)\n", "\n", - "# Create the query engine\n", - "query_engine = index.as_query_engine()" + "# Create the query engine with the same similarity top k as the number of retrievals\n", + "query_engine = index.as_query_engine(similarity_top_k=2)" ] }, { @@ -248,7 +252,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "The generated response will be automatically logged in our Argilla instance. Check it out! From Argilla you can quickly have a look at your predictions and annotate them, so you can combine both synthetic data and human feedback.\n", + "The generated response will be automatically logged in our Argilla instance. Check it out! From Argilla, you can quickly look at your predictions and annotate them so you can combine both synthetic data and human feedback.\n", "\n", "![Argilla UI](../assets/UI-screenshot-github.png)\n" ] @@ -257,7 +261,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Let's ask a couple of more questions to see the overall behavior of the RAG chatbot. Remember that the answers are automatically logged into your Argilla instance." + "Let's ask a couple of more questions to see the overall behavior of the RAG chatbot. Remember that the answers are automatically logged into your Argilla instance.\n" ] }, { @@ -270,16 +274,16 @@ "output_type": "stream", "text": [ "Question: How can I list the available datasets?\n", - "Answer: You can list all the datasets available in a workspace by utilizing the `datasets` attribute of the `Workspace` class. Additionally, you can determine the number of datasets in a workspace by using `len(workspace.datasets)`. To list the datasets, you can iterate over them and print out each dataset. Remember that dataset settings are not preloaded when listing datasets, and if you need to work with settings, you must load them explicitly for each dataset.\n", + "Answer: To list the available datasets, you can utilize the `datasets` attribute of the `Workspace` class. By importing `argilla as rg` and setting up the `client` with your API URL and key, you can access the datasets in a workspace. Simply loop through the datasets and print each one to display the list of available datasets. Remember that when listing datasets, the dataset settings are not preloaded, so you may need to load them separately if you want to work with settings while listing datasets.\n", "----------------------------\n", "Question: Which are the user credentials?\n", - "Answer: The user credentials in Argilla consist of a username, password, and API key.\n", + "Answer: The user credentials typically consist of a username, password, and an API key in Argilla.\n", "----------------------------\n", "Question: Can I use markdown in Argilla?\n", "Answer: Yes, you can use Markdown in Argilla.\n", "----------------------------\n", "Question: Could you explain how to annotate datasets in Argilla?\n", - "Answer: To annotate datasets in Argilla, users can manage their data annotation projects by setting up `Users`, `Workspaces`, `Datasets`, and `Records`. By deploying Argilla on the Hugging Face Hub or with `Docker`, installing the Python SDK with `pip`, and creating the first project, users can get started in just 5 minutes. The tool allows for interacting with data in a more engaging way through features like quick labeling with filters, AI feedback suggestions, and semantic search, enabling users to focus on training models and monitoring their performance effectively.\n", + "Answer: To annotate datasets in Argilla, users can deploy the tool for free on the Hugging Face Hub or with Docker. They can then install the Python SDK with pip and create their first project. By managing Users, Workspaces, Datasets, and Records, users can set up their data annotation projects in Argilla. Additionally, users can interact with their data through engaging labeling processes that involve filters, AI feedback suggestions, and semantic search to efficiently label the data while focusing on training models and monitoring their performance.\n", "----------------------------\n" ] }, diff --git a/docs/tutorials/workflow_step_back_llamaindex_argilla.ipynb b/docs/tutorials/workflow_step_back_llamaindex_argilla.ipynb new file mode 100644 index 0000000..64a2454 --- /dev/null +++ b/docs/tutorials/workflow_step_back_llamaindex_argilla.ipynb @@ -0,0 +1,418 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 🏔️ Step-back prompting with workflows for RAG" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This tutorial will show how to use step-back prompting with LlamaIndex workflows for RAG integrated with Argilla.\n", + "\n", + "This prompting approach is based on \"[Take a Step Back: Evoking Reasoning via Abstraction in Large Language Models](https://arxiv.org/abs/2310.06117)\". This paper suggests that the response can be improved by asking the model to take a step back and reason about the context in a more abstract way. This way, the original query is abstracted and used to retrieved the relevant information. Then, this context along with the original context and query are used to generate the final response. \n", + "\n", + "[Argilla](https://github.com/argilla-io/argilla) is a collaboration tool for AI engineers and domain experts to build high-quality datasets. By doing this, you can analyze and enhance the quality of your data, leading to improved model performance by incorporating human feedback into the loop. The integration will automatically log the query, response, retrieved contexts with their scores, and the full trace (including spans and events), along with relevant metadata in Argilla. By default, you'll have the ability to rate responses, provide feedback, and evaluate the retrieved contexts, ensuring accuracy and preventing any discrepancies.\n", + "\n", + "It includes the following steps:\n", + "\n", + "- Setting up the Argilla handler for LlamaIndex.\n", + "- Designing the step-back workflow.\n", + "- Run the step-back workflow with LlamaIndex and automatically log the responses to Argilla." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Getting started\n", + "\n", + "### Deploy the Argilla server¶\n", + "\n", + "If you already have deployed Argilla, you can skip this step. Otherwise, you can quickly deploy Argilla following [this guide](https://docs.argilla.io/latest/getting_started/quickstart/).\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Set up the environment¶\n", + "\n", + "To complete this tutorial, you need to install this integration.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install \"argilla-llama-index>=2.1.0\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's make the required imports:\n" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "from llama_index.core import (\n", + " Settings,\n", + " SimpleDirectoryReader,\n", + " VectorStoreIndex,\n", + ")\n", + "from llama_index.core.instrumentation import get_dispatcher\n", + "from llama_index.core.node_parser import SentenceSplitter\n", + "from llama_index.core.response_synthesizers import ResponseMode\n", + "from llama_index.core.schema import NodeWithScore\n", + "from llama_index.core.workflow import (\n", + " Context,\n", + " StartEvent,\n", + " StopEvent,\n", + " Workflow,\n", + " step,\n", + ")\n", + "\n", + "from llama_index.core import get_response_synthesizer\n", + "from llama_index.core.workflow import Event\n", + "from llama_index.utils.workflow import draw_all_possible_flows\n", + "from llama_index.llms.openai import OpenAI\n", + "\n", + "from argilla_llama_index import ArgillaHandler" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We need to set the OpenAI API key. The OpenAI API key is required to run queries using GPT models." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "os.environ[\"OPENAI_API_KEY\"] = \"sk-...\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Set the Argilla's LlamaIndex handler\n", + "\n", + "To easily log your data into Argilla within your LlamaIndex workflow, you only need to initialize the Argilla handler and attach it to the Llama Index dispatcher for spans and events. This ensures that the predictions obtained using Llama Index are automatically logged to the Argilla instance, along with the useful metadata.\n", + "\n", + "- `dataset_name`: The name of the dataset. If the dataset does not exist, it will be created with the specified name. Otherwise, it will be updated.\n", + "- `api_url`: The URL to connect to the Argilla instance.\n", + "- `api_key`: The API key to authenticate with the Argilla instance.\n", + "- `number_of_retrievals`: The number of retrieved documents to be logged. Defaults to 0.\n", + "- `workspace_name`: The name of the workspace to log the data. By default, the first available workspace.\n", + "\n", + "> For more information about the credentials, check the documentation for [users](https://docs.argilla.io/latest/how_to_guides/user/) and [workspaces](https://docs.argilla.io/latest/how_to_guides/workspace/).\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "argilla_handler = ArgillaHandler(\n", + " dataset_name=\"workflow_llama_index\",\n", + " api_url=\"http://localhost:6900\",\n", + " api_key=\"argilla.apikey\",\n", + " number_of_retrievals=2,\n", + ")\n", + "root_dispatcher = get_dispatcher()\n", + "root_dispatcher.add_span_handler(argilla_handler)\n", + "root_dispatcher.add_event_handler(argilla_handler)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define the step-back workflow" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First, we need to define the two events that will be used in the step-back workflow. The `StepBackEvent` that will receive the step-back query, and the `RetriverEvent` that will receive the relevant nodes for the original and step-back queries after the retrieval." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "class StepBackEvent(Event):\n", + " \"\"\"Get the step-back query\"\"\"\n", + "\n", + " step_back_query: str\n", + "\n", + "class RetrieverEvent(Event):\n", + " \"\"\"Result of running the retrievals\"\"\"\n", + "\n", + " nodes_original: list[NodeWithScore]\n", + " nodes_step_back: list[NodeWithScore]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, we will define the prompts according to the original paper to get the step-back query and then get the final response." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "STEP_BACK_TEMPLATE = \"\"\"\n", + "You are an expert at world knowledge. Your task is to step back and\n", + "paraphrase a question to a more generic step-back question, which is\n", + "easier to answer. Here are a few examples:\n", + "\n", + "Original Question: Which position did Knox Cunningham hold from May 1955 to Apr 1956?\n", + "Stepback Question: Which positions have Knox Cunningham held in his career?\n", + "\n", + "Original Question: Who was the spouse of Anna Karina from 1968 to 1974?\n", + "Stepback Question: Who were the spouses of Anna Karina?\n", + "\n", + "Original Question: what is the biggest hotel in las vegas nv as of November 28, 1993\n", + "Stepback Question: what is the size of the hotels in las vegas nv as of November 28, 1993?\n", + "\n", + "Original Question: {original_query}\n", + "Stepback Question:\n", + "\"\"\"\n", + "\n", + "GENERATE_ANSWER_TEMPLATE = \"\"\"\n", + "You are an expert of world knowledge. I am going to ask you a question.\n", + "Your response should be comprehensive and not contradicted with the\n", + "following context if they are relevant. Otherwise, ignore them if they are\n", + "not relevant.\n", + "\n", + "{context_original}\n", + "{context_step_back}\n", + "\n", + "Original Question: {query}\n", + "Answer:\n", + "\"\"\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now, we will define the step-back workflow. In this case, the workflow will be linear. First, we will prompt the LLM to make an abstraction of the original query (step-back prompting). Then, we will retrieve the relevant nodes for the original and step-back queries. Finally, we will prompt the LLM to generate the final response." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "class RAGWorkflow(Workflow):\n", + " @step\n", + " async def step_back(self, ctx: Context, ev: StartEvent) -> StepBackEvent | None:\n", + " \"\"\"Generate the step-back query.\"\"\"\n", + " query = ev.get(\"query\")\n", + " index = ev.get(\"index\")\n", + " \n", + " if not query:\n", + " return None\n", + " \n", + " if not index:\n", + " return None\n", + " \n", + " llm = Settings.llm\n", + " step_back_query = llm.complete(prompt =STEP_BACK_TEMPLATE.format(original_query=query), formatted=True)\n", + "\n", + " await ctx.set(\"query\", query)\n", + " await ctx.set(\"index\", index)\n", + " \n", + " return StepBackEvent(step_back_query=str(step_back_query))\n", + "\n", + " @step\n", + " async def retrieve(\n", + " self, ctx: Context, ev: StepBackEvent\n", + " ) -> RetrieverEvent | None:\n", + " \"Retrieve the relevant nodes for the original and step-back queries.\"\n", + " query = await ctx.get(\"query\", default=None)\n", + " index = await ctx.get(\"index\", default=None)\n", + " \n", + " await ctx.set(\"step_back_query\", ev.step_back_query)\n", + "\n", + " retriever = index.as_retriever(similarity_top_k=2)\n", + " nodes_step_back = await retriever.aretrieve(ev.step_back_query)\n", + " nodes_original = await retriever.aretrieve(query)\n", + "\n", + " return RetrieverEvent(nodes_original=nodes_original, nodes_step_back=nodes_step_back)\n", + "\n", + " @step\n", + " async def synthesize(self, ctx: Context, ev: RetrieverEvent) -> StopEvent:\n", + " \"\"\"Return a response using the contextualized prompt and retrieved nodes.\"\"\"\n", + " nodes_original = ev.nodes_original\n", + " nodes_step_back = ev.nodes_step_back\n", + " \n", + " context_original = max(nodes_original, key=lambda node: node.get_score()).get_text()\n", + " context_step_back = max(nodes_step_back, key=lambda node: node.get_score()).get_text()\n", + " \n", + " query = await ctx.get(\"query\", default=None)\n", + " formatted_query = GENERATE_ANSWER_TEMPLATE.format(\n", + " context_original=context_original,\n", + " context_step_back=context_step_back,\n", + " query=query\n", + " )\n", + " \n", + " response_synthesizer = get_response_synthesizer(\n", + " response_mode=ResponseMode.COMPACT\n", + " )\n", + "\n", + " response =response_synthesizer.synthesize(formatted_query, nodes=ev.nodes_original)\n", + " return StopEvent(result=response)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "draw_all_possible_flows(RAGWorkflow, filename=\"step_back_workflow.html\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "![Workflow](../assets/screenshot-workflow.png)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Run the step-back workflow" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We will use an example `.txt` file obtained from the [Llama Index documentation](https://docs.llamaindex.ai/en/stable/getting_started/starter_example.html). " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Retrieve the data if needed\n", + "!mkdir -p ../../data\n", + "!curl https://raw.githubusercontent.com/run-llama/llama_index/main/docs/docs/examples/data/paul_graham/paul_graham_essay.txt -o ../../data/paul_graham_essay.txt" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now, let's create a LlamaIndex index out of this document. As the highest-rated context for the original and step-back query will be included in the final prompt, we will lower the chuck size and use a `SentenceSplitter`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# LLM settings\n", + "Settings.llm = OpenAI(model=\"gpt-3.5-turbo\", temperature=0.8)\n", + "\n", + "# Load the data and create the index\n", + "transformations = [\n", + " SentenceSplitter(chunk_size=256, chunk_overlap=75),\n", + "]\n", + "\n", + "documents = SimpleDirectoryReader(\n", + " \"../../data\"\n", + ").load_data()\n", + "index = VectorStoreIndex.from_documents(\n", + " documents=documents,\n", + " transformations=transformations,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now, let's run the step-back workflow and make a query." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "w = RAGWorkflow()\n", + "\n", + "result = await w.run(query=\"What's Paul's work\", index=index)\n", + "result" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The generated response will be automatically logged in our Argilla instance. Check it out! From Argilla, you can quickly look at your predictions and annotate them so you can combine both synthetic data and human feedback.\n", + "\n", + "![UI](../assets/UI-screeshot-workflow.png)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "argilla-llama", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/pyproject.toml b/pyproject.toml index 06a916f..496a660 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "argilla-llama-index" -description = "Argilla-Llama Index Integration" +description = "Argilla-LlamaIndex Integration" readme = "README.md" requires-python = ">=3.8" license = "MIT" @@ -17,9 +17,8 @@ classifiers = [ "Programming Language :: Python :: 3.11", ] dependencies = [ - "argilla >= 2.0.0, < 3.0.0", - "llama-index >= 0.10.0, < 1.0", - "llama-index-callbacks-argilla >= 0.1.4", + "argilla >= 2.2.0, < 3.0.0", + "llama-index >= 0.10.20, < 1.0", "markdown >= 3.6.0", "packaging >= 23.2", "typing-extensions >= 4.3.0", diff --git a/src/argilla_llama_index/__init__.py b/src/argilla_llama_index/__init__.py index 090aaa1..e68593a 100644 --- a/src/argilla_llama_index/__init__.py +++ b/src/argilla_llama_index/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2023-present, Argilla, Inc. +# Copyright 2024-present, Argilla, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2.0.0" +__version__ = "2.1.0" -from argilla_llama_index.llama_index_handler import ArgillaCallbackHandler +from argilla_llama_index.llama_index_handler import ArgillaHandler -__all__ = ["ArgillaCallbackHandler"] +__all__ = ["ArgillaHandler"] diff --git a/src/argilla_llama_index/helpers.py b/src/argilla_llama_index/helpers.py index 20271b0..bf16531 100644 --- a/src/argilla_llama_index/helpers.py +++ b/src/argilla_llama_index/helpers.py @@ -1,62 +1,94 @@ -""" -Auxiliary methods for the Argilla Llama Index integration. -""" - -from datetime import datetime -from typing import Dict, List - -from llama_index.core.callbacks.schema import CBEvent - -def _get_time_diff(event_1_time_str: str, event_2_time_str: str) -> float: +# Copyright 2024-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, List, Tuple + + +def _create_tree_structure( # noqa: C901 + span_buffer: List[Dict[str, Any]], event_buffer: List[Dict[str, Any]] +) -> List[Tuple]: """ - Get the time difference between two events Follows the American format (month, day, year). + Create a tree structure from the trace buffer using the parent_id and attach events as subnodes. Args: - event_1_time_str (str): The first event time. - event_2_time_str (str): The second event time. + span_buffer (List[Dict[str, Any]]): The trace buffer to create the tree structure from. + event_buffer (List[Dict[str, Any]]): The event buffer containing events related to spans. Returns: - float: The time difference between the two events. + List[Tuple]: The formatted tree structure as a list of tuples. """ - time_format = "%m/%d/%Y, %H:%M:%S.%f" + nodes = [] - event_1_time = datetime.strptime(event_1_time_str, time_format) - event_2_time = datetime.strptime(event_2_time_str, time_format) + node_dict = {item["id_"]: item.copy() for item in span_buffer} - return round((event_2_time - event_1_time).total_seconds(), 4) + for node in node_dict.values(): + node["children"] = [] -def _calc_time(events_data: Dict[str, List[CBEvent]], id: str) -> float: - """ - Calculate the time difference between the start and end of an event using the events_data. + for node in node_dict.values(): + parent_id = node["parent_id"] + if parent_id and parent_id in node_dict: + node_dict[parent_id]["children"].append(node) - Args: - events_data (Dict[str, List[CBEvent]]): The events data, stored in a dictionary. - id (str): The event id to calculate the time difference between start and finish timestamps. + event_dict = {} + for event in event_buffer: + span_id = event.get("span_id") + if span_id not in event_dict: + event_dict[span_id] = [] + event_dict[span_id].append(event) - Returns: - float: The time difference between the start and end of the event. - """ + def build_tree(node, depth=0): + node_name = node["id_"].split(".")[0] + node_time = node["duration"] + + row = len(nodes) + nodes.append((row, depth, node_name, node_time)) + + span_id = node["id_"] + if span_id in event_dict: + for event in event_dict[span_id]: + event_name = event.get("event_type", "Unknown Event") + event_row = len(nodes) + nodes.append((event_row, depth + 1, event_name, "")) - start_time = events_data[id][0].time - end_time = events_data[id][1].time - return _get_time_diff(start_time, end_time) + for child in node.get("children", []): + build_tree(child, depth + 1) + root_nodes = [ + node + for node in node_dict.values() + if node["parent_id"] is None or node["parent_id"] not in node_dict + ] + for root in root_nodes: + build_tree(root) -def _create_svg(data: List) -> str: + return nodes + + +def _create_svg(data: List[Tuple]) -> str: """ Create an SVG file from the data. Args: - data (List): The data to create the SVG file from. + data (List[Tuple]): The data to create the SVG file from. Returns: str: The SVG file. """ - svg_template = """ - - + + {node_name} @@ -66,21 +98,22 @@ def _create_svg(data: List) -> str: body = "".join( svg_template.format( - x=indent * 40, - y=row * 54, - width=47 * 8.65, # 47 is the height of the box - node_name_indent=47 * 0.35, - text_centering=47 * 0.6341, - font_size_node_name=47 * 0.4188, + x=indent * 30, + y=row * 45, + width=40 * 8, # 40 is the height of the box + node_name_indent=40 * 0.35, + text_centering=40 * 0.6341, + font_size_node_name=40 * 0.4188, node_name=node_name, - time_indent=47 * 7.15, - font_size_time=47 * 0.4188 - 4, + time_indent=40 * 6.5, + font_size_time=40 * 0.4188 - 4, node_time=node_time, + font_color="#cdf1f9" if "event" in node_name.lower() else "#fff", ) for row, indent, node_name, node_time in data ) return f""" - + {body} """ diff --git a/src/argilla_llama_index/llama_index_handler.py b/src/argilla_llama_index/llama_index_handler.py index c60fa8a..b0131bc 100644 --- a/src/argilla_llama_index/llama_index_handler.py +++ b/src/argilla_llama_index/llama_index_handler.py @@ -1,25 +1,86 @@ +# Copyright 2024-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect import logging import os -from collections import defaultdict +import uuid +from contextvars import ContextVar from datetime import datetime -from typing import Any, Dict, List, Optional - -import argilla as rg -from argilla.markdown import chat_to_html -from llama_index.core.callbacks.base_handler import BaseCallbackHandler -from llama_index.core.callbacks.schema import ( - CBEvent, - CBEventType, - EventPayload, +from itertools import islice +from typing import Any, Dict, List, Optional, Tuple, Union, cast + +from argilla import ( + Argilla, + ChatField, + Dataset, + FloatMetadataProperty, + IntegerMetadataProperty, + RatingQuestion, + Record, + Settings, + TermsMetadataProperty, + TextField, + TextQuestion, ) -from packaging.version import parse +from llama_index.core.instrumentation.event_handlers import BaseEventHandler +from llama_index.core.instrumentation.events import BaseEvent +from llama_index.core.instrumentation.events.agent import ( + AgentChatWithStepEndEvent, + AgentChatWithStepStartEvent, +) +from llama_index.core.instrumentation.events.embedding import ( + EmbeddingStartEvent, +) +from llama_index.core.instrumentation.events.llm import ( + LLMChatInProgressEvent, + LLMChatStartEvent, + LLMCompletionEndEvent, + LLMCompletionStartEvent, + LLMPredictEndEvent, + LLMStructuredPredictEndEvent, +) +from llama_index.core.instrumentation.events.query import ( + QueryEndEvent, + QueryStartEvent, +) +from llama_index.core.instrumentation.events.rerank import ( + ReRankEndEvent, + ReRankStartEvent, +) +from llama_index.core.instrumentation.events.retrieval import ( + RetrievalEndEvent, + RetrievalStartEvent, +) +from llama_index.core.instrumentation.events.synthesis import ( + GetResponseStartEvent, + SynthesizeEndEvent, + SynthesizeStartEvent, +) +from llama_index.core.instrumentation.span.simple import SimpleSpan +from llama_index.core.instrumentation.span_handlers import BaseSpanHandler + +from argilla_llama_index.helpers import _create_svg, _create_tree_structure -from argilla_llama_index.helpers import _calc_time, _create_svg, _get_time_diff +context_root: ContextVar[Union[Tuple[str, str], Tuple[None, None]]] = ContextVar( + "context_root", default=(None, None) +) -class ArgillaCallbackHandler(BaseCallbackHandler): +class ArgillaHandler(BaseSpanHandler[SimpleSpan], BaseEventHandler, extra="allow"): """ - Callback handler that logs predictions to Argilla. + Handler that logs predictions to Argilla. This handler automatically logs the predictions made with LlamaIndex to Argilla, without the need to create a dataset and log the predictions manually. Events relevant @@ -30,73 +91,57 @@ class ArgillaCallbackHandler(BaseCallbackHandler): dataset_name (str): The name of the Argilla dataset. api_url (str): Argilla API URL. api_key (str): Argilla API key. - number_of_retrievals (int): The number of retrievals to log. By default, it is set to 0. + number_of_retrievals (int): The number of retrievals to log. By default, it is set to 2. workspace_name (str): The name of the Argilla workspace. By default, it will use the first available workspace. - event_starts_to_ignore (List[CBEventType]): List of event types to ignore at the start of the trace. - event_ends_to_ignore (List[CBEventType]): List of event types to ignore at the end of the trace. - handlers (List[BaseCallbackHandler]): List of extra handlers to include. - - Methods: - start_trace(trace_id: Optional[str] = None) -> None: - Logic to be executed at the beginning of the tracing process. - end_trace(trace_id: Optional[str] = None, trace_map: Optional[Dict[str, List[str]]] = None) -> None: - Logic to be executed at the end of the tracing process. + Usage: + ```python + from llama_index.core import Settings, VectorStoreIndex, SimpleDirectoryReader + from llama_index.core.query_engine import RetrieverQueryEngine + from llama_index.core.instrumentation import get_dispatcher + from llama_index.core.retrievers import VectorIndexRetriever + from llama_index.llms.openai import OpenAI + + from argilla_llama_index import ArgillaHandler + + argilla_handler = ArgillaHandler( + dataset_name="query_llama_index", + api_url="http://localhost:6900", + api_key="argilla.apikey", + number_of_retrievals=2, + ) + root_dispatcher = get_dispatcher() + root_dispatcher.add_span_handler(argilla_handler) + root_dispatcher.add_event_handler(argilla_handler) - on_event_start(event_type: CBEventType, payload: Optional[Dict[str, Any]] = None, event_id: Optional[str] = None, parent_id: str = None) -> str: - Store event start data by event type. Executed at the start of an event. + Settings.llm = OpenAI(model="gpt-3.5-turbo", temperature=0.8, openai_api_key=os.getenv("OPENAI_API_KEY")) - on_event_end(event_type: CBEventType, payload: Optional[Dict[str, Any]] = None, event_id: str = None) -> None: - Store event end data by event type. Executed at the end of an event. + documents = SimpleDirectoryReader("../../data").load_data() + index = VectorStoreIndex.from_documents(documents) + query_engine = index.as_query_engine(similarity_top_k=2) - Usage: - ```python - from llama_index.core import Settings, VectorStoreIndex, SimpleDirectoryReader, set_global_handler - from llama_index.core.query_engine import RetrieverQueryEngine - from llama_index.core.retrievers import VectorIndexRetriever - from llama_index.llms.openai import OpenAI - - set_global_handler("argilla", - api_url="http://localhost:6900", - api_key="argilla.apikey", - dataset_name="query_model", - number_of_retrievals=2 - ) - - Settings.llm = OpenAI(model="gpt-3.5-turbo", temperature=0.8, openai_api_key=os.getenv("OPENAI_API_KEY")) - - documents = SimpleDirectoryReader("../../data").load_data() - index = VectorStoreIndex.from_documents(documents) - query_engine = index.as_query_engine() - - response = query_engine.query("What did the author do growing up?") - ``` + response = query_engine.query("What did the author do growing up?") + ``` """ - def __init__( # noqa: C901 + def __init__( self, dataset_name: str, - api_url: Optional[str] = None, - api_key: Optional[str] = None, - number_of_retrievals: int = 0, + api_url: str, + api_key: str, workspace_name: Optional[str] = None, - event_starts_to_ignore: Optional[List[CBEventType]] = None, - event_ends_to_ignore: Optional[List[CBEventType]] = None, - handlers: Optional[List[BaseCallbackHandler]] = None, - ) -> None: - self.event_starts_to_ignore = event_starts_to_ignore or [] - self.event_ends_to_ignore = event_ends_to_ignore or [] - self.handlers = handlers or [] - self.number_of_retrievals = number_of_retrievals + number_of_retrievals: int = 2, + ): + super().__init__() - self.ARGILLA_VERSION = rg.__version__ + self.dataset_name = dataset_name + self.workspace_name = workspace_name - if parse(self.ARGILLA_VERSION) < parse("2.0.0"): - raise ImportError( - f"The installed `argilla` version is {self.ARGILLA_VERSION} but " - "`ArgillaCallbackHandler` requires at least version 2.0.0. Please " - "upgrade `argilla` with `pip install --upgrade argilla`." + if number_of_retrievals < 0: + raise ValueError( + "The number of retrievals must be 0 (to show no retrieved documents) or a positive number." ) + self.number_of_retrievals = number_of_retrievals if (api_url is None and os.getenv("ARGILLA_API_URL") is None) or ( api_key is None and os.getenv("ARGILLA_API_KEY") is None @@ -105,72 +150,73 @@ def __init__( # noqa: C901 "Both `api_url` and `api_key` must be set. The current values are: " f"`api_url`={api_url} and `api_key`={api_key}." ) + self.client = Argilla(api_key=api_key, api_url=api_url) - client = rg.Argilla(api_key=api_key, api_url=api_url) + self.span_buffer: List[Dict[str, Any]] = [] + self.event_buffer: List[Dict[str, Any]] = [] + self.fields_info: Dict[str, Any] = {} - self.dataset_name = dataset_name - self.workspace_name = workspace_name - self.settings = rg.Settings( + self._initialize_dataset() + + def _initialize_dataset(self): + """Create the dataset in Argilla if it does not exist, or update it if it does.""" + + self.settings = Settings( fields=[ - rg.TextField( - name="chat", title="Chat", use_markdown=True, required=True - ), - rg.TextField( - name="time-details", title="Time Details", use_markdown=True - ), + ChatField(name="chat", title="Chat", use_markdown=False, required=True), ] - + self._add_context_fields(number_of_retrievals), + + self._add_context_fields(self.number_of_retrievals) + + [ + TextField( + name="time-details", title="Time Details", use_markdown=False + ), + ], questions=[ - rg.RatingQuestion( + RatingQuestion( name="response-rating", title="Rating for the response", description="How would you rate the quality of the response?", values=[1, 2, 3, 4, 5, 6, 7], required=True, ), - rg.TextQuestion( + TextQuestion( name="response-feedback", title="Feedback for the response", description="What feedback do you have for the response?", required=False, ), ] - + self._add_context_questions(number_of_retrievals), + + self._add_context_questions(self.number_of_retrievals), guidelines="You're asked to rate the quality of the response and provide feedback.", allow_extra_metadata=True, ) # Either create a new dataset or use an existing one, updating it if necessary try: - dataset_names = [ds.name for ds in client.datasets] - + dataset_names = [ds.name for ds in self.client.datasets] if self.dataset_name not in dataset_names: - dataset = rg.Dataset( + dataset = Dataset( name=self.dataset_name, workspace=self.workspace_name, settings=self.settings, ) self.dataset = dataset.create() - self.is_new_dataset_created = True logging.info( f"A new dataset with the name '{self.dataset_name}' has been created.", ) - else: # Update the existing dataset. If the fields and questions do not match, # a new dataset will be created with the -updated flag in the name. - self.dataset = client.datasets( + self.dataset = self.client.datasets( name=self.dataset_name, workspace=self.workspace_name, ) - self.is_new_dataset_created = False - - if number_of_retrievals > 0: + if self.number_of_retrievals > 0: required_context_fields = self._add_context_fields( - number_of_retrievals + self.number_of_retrievals ) required_context_questions = self._add_context_questions( - number_of_retrievals + self.number_of_retrievals ) existing_fields = list(self.dataset.fields) existing_questions = list(self.dataset.questions) @@ -185,13 +231,15 @@ def __init__( # noqa: C901 for element in required_context_questions ) ): - self.dataset = rg.Dataset( + self.dataset = Dataset( name=f"{self.dataset_name}-updated", workspace=self.workspace_name, settings=self.settings, ) self.dataset = self.dataset.create() - + logging.info( + f"A new dataset with the name '{self.dataset_name}-updated' has been created.", + ) except Exception as e: raise FileNotFoundError( f"`Dataset` creation or update failed with exception `{e}`." @@ -199,10 +247,10 @@ def __init__( # noqa: C901 f"as an `integration` issue." ) from e - supported_context_fields = [ - f"retrieved_document_{i+1}" for i in range(number_of_retrievals) + supported_context_fields = ["retrieved_document_scores"] + [ + f"retrieved_document_{i+1}" for i in range(self.number_of_retrievals) ] - supported_fields = ["chat", "time-details"] + supported_context_fields + supported_fields = ["chat"] + supported_context_fields + ["time-details"] if supported_fields != [field.name for field in self.dataset.fields]: raise ValueError( f"`Dataset` with name={self.dataset_name} had fields that are not supported" @@ -210,406 +258,364 @@ def __init__( # noqa: C901 f" Current fields are {[field.name for field in self.dataset.fields]}." ) - self.events_data: Dict[str, List[CBEvent]] = defaultdict(list) - self.event_map_id_to_name = {} - self._ignore_components_in_tree = ["templating"] - self.components_to_log = set() - self.event_ids_traced = set() - def _add_context_fields(self, number_of_retrievals: int) -> List[Any]: """Create the context fields to be added to the dataset.""" + context_scores = [ + TextField( + name="retrieved_document_scores", + title="Retrieved document scores", + use_markdown=True, + required=False, + ) + ] context_fields = [ - rg.TextField( - name=f"retrieved_document_{doc + 1}", - title=f"Retrieved document {doc + 1}", + TextField( + name=f"retrieved_document_{doc+1}", + title=f"Retrieved document {doc+1}", use_markdown=True, required=False, ) for doc in range(number_of_retrievals) ] - return context_fields + return context_scores + context_fields def _add_context_questions(self, number_of_retrievals: int) -> List[Any]: """Create the context questions to be added to the dataset.""" rating_questions = [ - rg.RatingQuestion( + RatingQuestion( name=f"rating_retrieved_document_{doc + 1}", - title=f"Rate the relevance of the Retrieved document {doc + 1} (if present)", + title=f"Rate the relevance of the Retrieved document {doc + 1}, if present.", values=list(range(1, 8)), - description=f"Rate the relevance of the retrieved document {doc + 1}.", + description=f"Rate the relevance of the retrieved document {doc + 1}, if present.", required=False, ) for doc in range(number_of_retrievals) ] return rating_questions - def _create_root_and_other_nodes(self, trace_map: Dict[str, List[str]]) -> None: - """Create the root node and the other nodes in the tree.""" - self.root_node = self._get_event_name_by_id(trace_map["root"][0]) - self.event_ids_traced = set(trace_map.keys()) - {"root"} - self.event_ids_traced.update(*trace_map.values()) - for id in self.event_ids_traced: - self.components_to_log.add(self._get_event_name_by_id(id)) + def class_name(cls) -> str: + """Class name.""" + return "ArgillaHandler" - def _get_event_name_by_id(self, event_id: str) -> str: - """Get the name of the event by its id.""" - return str(self.events_data[event_id][0].event_type).split(".")[1].lower() - - # TODO: If we have a component more than once, properties currently don't account for those after the first one and get overwritten - - def _add_missing_metadata( - self, - dataset: rg.Dataset, - metadata: Optional[Dict[str, Any]] = None, - ) -> None: - """Add missing metadata properties to the dataset.""" + def handle(self, event: BaseEvent) -> None: + """ + Logic to handle different events. - for mt in metadata.keys(): - if mt not in [metadata.name for metadata in self.dataset.settings.metadata]: - if mt.endswith("_time"): - self.dataset.settings.metadata.add( - rg.FloatMetadataProperty(name=mt, title=mt) - ) - dataset.update() + Args: + event (BaseEvent): The event to be handled. - def _check_components_for_tree( - self, tree_structure_dict: Dict[str, List[str]] - ) -> Dict[str, List[str]]: - """ - Check whether the components in the tree are in the components to log. - Removes components that are not in the components to log so that they are not shown in the tree. - """ - final_components_in_tree = self.components_to_log.copy() - final_components_in_tree.add("root") - for component in self._ignore_components_in_tree: - if component in final_components_in_tree: - final_components_in_tree.remove(component) - for key in list(tree_structure_dict.keys()): - if key.strip("0") not in final_components_in_tree: - del tree_structure_dict[key] - for key, value in tree_structure_dict.items(): - if isinstance(value, list): - tree_structure_dict[key] = [ - element - for element in value - if element.strip("0") in final_components_in_tree - ] - return tree_structure_dict - - def _get_events_map_with_names( - self, events_data: Dict[str, List[CBEvent]], trace_map: Dict[str, List[str]] - ) -> Dict[str, List[str]]: - """ - Returns a dictionary where trace_map is mapped with the event names instead of the event ids. - Also returns a set of the event ids that were traced. + Returns: + None """ - self.event_map_id_to_name = {} - for event_id in self.event_ids_traced: - event_name = str(events_data[event_id][0].event_type).split(".")[1].lower() - while event_name in self.event_map_id_to_name.values(): - event_name = event_name + "0" - self.event_map_id_to_name[event_id] = event_name - events_trace_map = { - self.event_map_id_to_name.get(k, k): [ - self.event_map_id_to_name.get(v, v) for v in values - ] - for k, values in trace_map.items() + metadata = {} + + query_events = { + QueryStartEvent: "query", + AgentChatWithStepStartEvent: "user_msg", + RetrievalStartEvent: "str_or_query_bundle", + ReRankStartEvent: "query", + GetResponseStartEvent: "query_str", + SynthesizeStartEvent: "query", + LLMCompletionStartEvent: "prompt", + LLMChatInProgressEvent: "messages", } - return events_trace_map + response_events = { + QueryEndEvent: "response", + AgentChatWithStepEndEvent: "response", + LLMPredictEndEvent: "output", + LLMStructuredPredictEndEvent: "output", + LLMCompletionEndEvent: "response", + SynthesizeEndEvent: "response", + LLMChatInProgressEvent: "response", + } - def _extract_and_log_info( - self, events_data: Dict[str, List[CBEvent]], trace_map: Dict[str, List[str]] - ) -> None: - """ - Main function that extracts the information from the events and logs it to Argilla. - We currently log data if the root node is either "agent_step" or "query". - Otherwise, we do not log anything. - If we want to account for more root nodes, we just need to add them to the if statement. - """ - events_trace_map = self._get_events_map_with_names(events_data, trace_map) - root_node = trace_map.get("root") - - if not root_node or len(root_node) != 1: - return - - if self.root_node == "agent_step": - data_to_log = self._process_agent_step(events_data, root_node) - elif self.root_node == "query": - data_to_log = self._process_query(events_data, root_node) - else: - return - - self.event_ids_traced.remove(root_node[0]) - components_to_log = [ - comp for comp in self.components_to_log if comp != self.root_node - ] - number_of_components_used = defaultdict(int) - retrieval_metadata = {} + event_type = type(event) + + if event_type in query_events: + if "query" not in self.fields_info: + self.fields_info["query"] = str( + getattr(event, query_events[event_type]) + ) + if event_type == ReRankStartEvent: + metadata["reranker_model"] = event.model_name - for event_id in self.event_ids_traced: - event_name = self.event_map_id_to_name[event_id] - event_name_reduced = ( - event_name.rstrip("0") if event_name.endswith("0") else event_name + if event_type in response_events: + self.fields_info["response"] = str( + getattr(event, response_events[event_type]) ) - number_of_components_used[event_name_reduced] += event_name.endswith("0") - if event_name_reduced in components_to_log: - data_to_log[f"{event_name}_time"] = _calc_time(events_data, event_id) + if isinstance(event, EmbeddingStartEvent): + metadata["embedding_model"] = event.model_dict.get("model_name", "") - if event_name_reduced == "llm": - payload = events_data[event_id][0].payload - data_to_log.update( + if isinstance(event, LLMChatStartEvent): + metadata.update( + { + "llm_model": event.model_dict.get("model", ""), + "llm_temperature": event.model_dict.get("temperature", 0), + "llm_max_tokens": event.model_dict.get("max_tokens", 0), + } + ) + + if isinstance(event, (RetrievalEndEvent, ReRankEndEvent)): + for i, n in enumerate(event.nodes, start=1): + idx = f"retrieved_document_{i}" + metadata.update( { - f"{event_name}_system_prompt": payload.get( - EventPayload.MESSAGES - )[0].content, - f"{event_name}_model_name": payload.get( - EventPayload.SERIALIZED - )["model"], + f"{idx}_file_name": n.metadata.get("file_name", "unknown"), + f"{idx}_file_type": n.metadata.get("file_type", "unknown"), + f"{idx}_file_size": n.metadata.get("file_size", 0), + f"{idx}_start_char": getattr(n.node, "start_char_idx", -1), + f"{idx}_end_char": getattr(n.node, "end_char_idx", -1), + f"{idx}_score": getattr(n, "score", 0), } ) + text = getattr(n, "text", "") + self.fields_info[f"{idx}_score"] = metadata[f"{idx}_score"] + self.fields_info[f"{idx}_text"] = text - if event_name_reduced == "retrieve": - for idx, retrieval_node in enumerate( - events_data[event_id][1].payload.get(EventPayload.NODES), 1 - ): - if idx > self.number_of_retrievals: - break - retrieve_dict = retrieval_node.to_dict() - retrieval_metadata.update( - { - f"{event_name}_document_{idx}_score": retrieval_node.score, - f"{event_name}_document_{idx}_filename": retrieve_dict[ - "node" - ]["metadata"]["file_name"], - f"{event_name}_document_{idx}_text": retrieve_dict["node"][ - "text" - ], - f"{event_name}_document_{idx}_start_character": retrieve_dict[ - "node" - ][ - "start_char_idx" - ], - f"{event_name}_document_{idx}_end_character": retrieve_dict[ - "node" - ]["end_char_idx"], - } - ) - - metadata_to_log = { - key: data_to_log[key] - for key in data_to_log - if key.endswith("_time") or key not in ["query", "response"] - } - metadata_to_log["total_time"] = data_to_log.get( - "query_time", data_to_log.get("agent_step_time") - ) - metadata_to_log.update( + self.event_buffer.append( { - f"number_of_{key}_used": value + 1 - for key, value in number_of_components_used.items() + "id_": event.id_, + "event_type": event.class_name(), + "span_id": event.span_id, + "timestamp": event.timestamp.timestamp(), + "metadata": metadata, } ) - metadata_to_log.update(retrieval_metadata) - self._add_missing_metadata(self.dataset, metadata_to_log) + def new_span( + self, + id_: str, + bound_args: inspect.BoundArguments, + instance: Optional[Any] = None, + parent_span_id: Optional[str] = None, + tags: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> Optional[SimpleSpan]: + """ + Create a new span using the SimpleSpan class. If the span is the root span, it generates a new trace ID. - tree_structure = self._create_tree_structure(events_trace_map, data_to_log) - tree = _create_svg(tree_structure) + Args: + id_ (str): The unique identifier for the new span. + bound_args (inspect.BoundArguments): The arguments that were bound to when the span was created. + instance (Optional[Any], optional): The instance associated with the span, if present. Defaults to None. + parent_span_id (Optional[str], optional): The identifier of the parent span. Defaults to None. + tags (Optional[Dict[str, Any]], optional): Additional information about the span. Defaults to None. - message = [ - {"role": "user", "content": data_to_log["query"]}, - {"role": "assistant", "content": data_to_log["response"]}, - ] - fields = { - "chat": chat_to_html(message), - "time-details": tree, - } + Returns: + Optional[SimpleSpan]: The newly created SimpleSpan object if the span is successfully created. + """ + trace_id, root_span_id = context_root.get() - if self.number_of_retrievals > 0: - for key in list(retrieval_metadata.keys()): - if key.endswith("_text"): - idx = key.split("_")[-2] - fields[f"retrieved_document_{idx}"] = ( - f"DOCUMENT SCORE: {retrieval_metadata[f'{key[:-5]}_score']}\n\n{retrieval_metadata[key]}" - ) - del metadata_to_log[key] + if not parent_span_id: + trace_id = str(uuid.uuid4()) + root_span_id = id_ + context_root.set((trace_id, root_span_id)) - valid_metadata_keys = [ - metadata.name for metadata in self.dataset.settings.metadata - ] - metadata_to_log = { - k: v - for k, v in metadata_to_log.items() - if k in valid_metadata_keys or not k.endswith("_time") - } + if "workflow.run" in id_.lower(): + self.fields_info["query"] = bound_args.kwargs["query"] + if "workflow._done" in id_.lower(): + self.fields_info["response"] = bound_args.kwargs["response"] - self.dataset.records.log( - records=[ - rg.Record( - fields=fields, - metadata=metadata_to_log, - ), - ] - ) + return SimpleSpan(id_=id_, parent_id=parent_span_id, tags=tags or {}) - def _process_agent_step( - self, events_data: Dict[str, List[CBEvent]], root_node: str - ) -> Dict: - """ - Processes events data for 'agent_step' root node. + def prepare_to_exit_span( + self, + id_: str, + bound_args: inspect.BoundArguments, + instance: Optional[Any] = None, + result: Optional[Any] = None, + **kwargs: Any, + ) -> Optional[SimpleSpan]: """ - data_to_log = {} + Logic to exit the span. It stores the span information in the trace buffer. + If the trace has ended and and belongs to specific components, it logs the buffered data to Argilla. + + Args: + id_ (str): The unique identifier of the span to be exited. + bound_args (inspect.BoundArguments): The arguments that were bound to the span's function during its invocation. + instance (Optional[Any], optional): The instance associated with the span, if applicable.. Defaults to None. + result (Optional[Any], optional): The output or result produced by the span's execution.. Defaults to None. - event_start = events_data[root_node[0]][0] - data_to_log["query"] = event_start.payload.get(EventPayload.MESSAGES)[0] - query_start_time = event_start.time + Returns: + Optional[SimpleSpan]: The exited SimpleSpan object if the span exists and the trace is active; otherwise, None. + """ + trace_id, root_span_id = context_root.get() + if not trace_id: + return None - event_end = events_data[root_node[0]][1] - data_to_log["response"] = event_end.payload.get(EventPayload.RESPONSE).response - query_end_time = event_end.time + span = self.open_spans[id_] + span = cast(SimpleSpan, span) + span.end_time = datetime.now() + span.duration = round((span.end_time - span.start_time).total_seconds(), 4) - data_to_log["agent_step_time"] = _get_time_diff( - query_start_time, query_end_time + self.span_buffer.append( + { + "id_": span.id_, + "parent_id": span.parent_id, + "start_time": span.start_time.timestamp(), + "end_time": span.end_time.timestamp(), + "duration": span.duration, + } ) - return data_to_log + with self.lock: + self.completed_spans += [span] - def _process_query( - self, events_data: Dict[str, List[CBEvent]], root_node: str - ) -> Dict: - """ - Processes events data for 'query' root node. - """ - data_to_log = {} - - event_start = events_data[root_node[0]][0] - data_to_log["query"] = event_start.payload.get(EventPayload.QUERY_STR) - query_start_time = event_start.time - - event_end = events_data[root_node[0]][1] - data_to_log["response"] = event_end.payload.get(EventPayload.RESPONSE).response - query_end_time = event_end.time - - data_to_log["query_time"] = _get_time_diff(query_start_time, query_end_time) - - return data_to_log - - def _create_tree_structure( - self, events_trace_map: Dict[str, List[str]], data_to_log: Dict[str, Any] - ) -> List: - """Create the tree data to be converted to an SVG.""" - events_trace_map = self._check_components_for_tree(events_trace_map) - data = [] - data.append( - ( - 0, - 0, - self.root_node.strip("0").upper(), - data_to_log[f"{self.root_node}_time"], - ) - ) - current_row = 1 - for root_child in events_trace_map[self.root_node]: - data.append( - ( - current_row, - 1, - root_child.strip("0").upper(), - data_to_log[f"{root_child}_time"], - ) + if id_ == root_span_id and any( + term.lower() in id_.lower() for term in ["AgentRunner", "QueryEngine"] + ): + self._log_to_argilla( + trace_id=trace_id, + span_buffer=self.span_buffer, + event_buffer=self.event_buffer, + fields_info=self.fields_info, ) - current_row += 1 - for child in events_trace_map[root_child]: - data.append( - ( - current_row, - 2, - child.strip("0").upper(), - data_to_log[f"{child}_time"], - ) - ) - current_row += 1 - return data + self.span_buffer.clear() + self.event_buffer.clear() + self.fields_info.clear() + context_root.set((None, None)) + elif id_ == root_span_id and not any( + term.lower() in id_.lower() for term in ["Workflow.run", "Workflow._done"] + ): + self.span_buffer.clear() + self.event_buffer.clear() + self.fields_info.clear() + context_root.set((None, None)) - # The four methods required by the abstract class - # BaseCallbackHandler executed on the different events. + return span - def start_trace(self, trace_id: Optional[str] = None) -> None: + def prepare_to_drop_span( + self, + id_: str, + bound_args: inspect.BoundArguments, + instance: Optional[Any] = None, + err: Optional[BaseException] = None, + **kwargs: Any, + ) -> None: """ - Start tracing events. + Logic to drop the span. If the trace has ended, it clears the data. Args: - trace_id (str, optional): The trace_id to start tracing. + id_ (str): The unique identifier of the span to be dropped. + bound_args (inspect.BoundArguments): The arguments that were bound to the span function during its invocation. + instance (Optional[Any], optional): The instance associated with the span, if applicable. Defaults to None. + err (Optional[BaseException], optional): An exception that caused the span to be dropped, if any. Defaults to None. + + Returns: + None: """ + trace_id, root_span_id = context_root.get() + if not trace_id: + return None + + if id_ in self.open_spans: + with self.lock: + span = self.open_spans[id_] + self.dropped_spans += [span] + + if "workflow.run" in root_span_id.lower(): + self._log_to_argilla( + trace_id=trace_id, + span_buffer=self.span_buffer, + event_buffer=self.event_buffer, + fields_info=self.fields_info, + ) + self.span_buffer.clear() + self.event_buffer.clear() + self.fields_info.clear() + context_root.set((None, None)) - self._trace_map = defaultdict(list) - self._cur_trace_id = trace_id - self._start_time = datetime.now() + if id_ == root_span_id: + self.span_buffer.clear() + self.event_buffer.clear() + self.fields_info.clear() + context_root.set((None, None)) - # Clear the events and the components prior to running the query. - # They are usually events related to creating the docs and indexing. - self.events_data.clear() - self.components_to_log.clear() + return None - def end_trace( + def _log_to_argilla( self, - trace_id: Optional[str] = None, - trace_map: Optional[Dict[str, List[str]]] = None, + trace_id: str, + span_buffer: List[Dict[str, Any]], + event_buffer: List[Dict[str, Any]], + fields_info: Dict[str, Any], ) -> None: - """ - End tracing events. - - Args: - trace_id (str, optional): The trace_id to end tracing. - trace_map (Dict[str, List[str]], optional): The trace_map to end. This map has been obtained from the parent class. - """ - - self._trace_map = trace_map or defaultdict(list) - self._end_time = datetime.now() - self._create_root_and_other_nodes(trace_map) - self._extract_and_log_info(self.events_data, trace_map) + """Logs the data in the trace buffer to Argilla.""" - def on_event_start( - self, - event_type: CBEventType, - payload: Optional[Dict[str, Any]] = None, - event_id: Optional[str] = None, - parent_id: str = None, - ) -> str: - """ - Store event start data by event type. Executed at the start of an event. + message = [ + {"role": "user", "content": fields_info["query"]}, + {"role": "assistant", "content": fields_info["response"]}, + ] + tree_structure = _create_tree_structure(span_buffer, event_buffer) + tree = _create_svg(tree_structure) - Args: - event_type (CBEventType): The event type to store. - payload (Dict[str, Any], optional): The payload to store. - event_id (str, optional): The event id to store. - parent_id (str, optional): The parent id to store. + fields = { + "chat": message, + "time-details": tree, + } + if self.number_of_retrievals > 0: + score_keys = filter(lambda k: k.endswith("_score"), fields_info.keys()) + text_keys = filter(lambda k: k.endswith("_text"), fields_info.keys()) - Returns: - str: The event id. - """ + scores = "\n".join( + f"{key.replace('_score', '').replace('_', ' ').capitalize()}: {fields_info[key]}" + for key in islice(score_keys, self.number_of_retrievals) + ) + fields["retrieved_document_scores"] = scores + + for key in islice(text_keys, self.number_of_retrievals): + idx = key.split("_")[-2] + fields[f"retrieved_document_{idx}"] = fields_info[key] + + metadata = self._process_metadata(span_buffer, event_buffer) + self._add_metadata_properties(metadata) + + records = [Record(id=trace_id, fields=fields, metadata=metadata)] + self.dataset.records.log(records=records) + + def _process_metadata( + self, span_buffer: List[Dict[str, Any]], event_buffer: List[Dict[str, Any]] + ) -> Dict[str, Any]: + """Process the metadata to be logged to Argilla.""" + metadata_to_log = {} + + for span in span_buffer: + key_prefix = span["id_"].split(".")[0].lower() + metadata_to_log[f"{key_prefix}_start_time"] = span["start_time"] + metadata_to_log[f"{key_prefix}_end_time"] = span["end_time"] + metadata_to_log[f"{key_prefix}_duration"] = span["duration"] + + for event in event_buffer: + key_prefix = event["event_type"].lower() + metadata_to_log[f"{key_prefix}_timestamp"] = event["timestamp"] + if event["metadata"]: + metadata_to_log.update(event["metadata"]) + + metadata_to_log["total_duration"] = sum( + span["duration"] for span in span_buffer + ) + metadata_to_log["total_spans"] = len(span_buffer) + metadata_to_log["total_events"] = len(event_buffer) - event = CBEvent(event_type, payload=payload, id_=event_id) - self.events_data[event_id].append(event) + return metadata_to_log - return event.id_ + def _add_metadata_properties(self, metadata: Dict[str, Any]) -> None: + """Add metadata properties to the dataset if they do not exist.""" + existing_metadata = [ + existing_metadata.name + for existing_metadata in self.dataset.settings.metadata + ] + for mt in metadata.keys(): + if mt not in existing_metadata: + if isinstance(metadata[mt], str): + self.dataset.settings.metadata.add(TermsMetadataProperty(name=mt)) - def on_event_end( - self, - event_type: CBEventType, - payload: Optional[Dict[str, Any]] = None, - event_id: str = None, - ) -> None: - """ - Store event end data by event type. Executed at the end of an event. + elif isinstance(metadata[mt], int): + self.dataset.settings.metadata.add(IntegerMetadataProperty(name=mt)) - Args: - event_type (CBEventType): The event type to store. - payload (Dict[str, Any], optional): The payload to store. - event_id (str, optional): The event id to store. - """ + elif isinstance(metadata[mt], float): + self.dataset.settings.metadata.add(FloatMetadataProperty(name=mt)) - event = CBEvent(event_type, payload=payload, id_=event_id) - self.events_data[event_id].append(event) + self.dataset.update() diff --git a/tests/test_helpers.py b/tests/test_helpers.py new file mode 100644 index 0000000..e201648 --- /dev/null +++ b/tests/test_helpers.py @@ -0,0 +1,54 @@ +# Copyright 2024-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from argilla_llama_index.helpers import _create_svg, _create_tree_structure + + +class TestHelpers(unittest.TestCase): + def test_create_tree_structure(self): + span_data = [ + {"id_": "A", "parent_id": None, "duration": "10s"}, + {"id_": "B", "parent_id": "A", "duration": "15s"}, + {"id_": "C", "parent_id": "A", "duration": "5s"}, + {"id_": "D", "parent_id": "B", "duration": "20s"}, + ] + event_data = [] + expected_output = [ + (0, 0, "A", "10s"), + (1, 1, "B", "15s"), + (2, 2, "D", "20s"), + (3, 1, "C", "5s"), + ] + + result = _create_tree_structure(span_data, event_data) + self.assertEqual(result, expected_output) + + def test_create_svg(self): + input_data = [(0, 1, "Node1", "10ms"), (1, 2, "Node2", "20ms")] + + result = _create_svg(input_data) + + self.assertIn('viewBox="0 0 750 90"', result) + self.assertIn('', result) + self.assertIn('Node1', result) + self.assertIn('10ms', result) + self.assertIn('', result) + self.assertIn('Node2', result) + self.assertIn('20ms', result) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_llama_index_callback.py b/tests/test_llama_index_callback.py deleted file mode 100644 index 3572983..0000000 --- a/tests/test_llama_index_callback.py +++ /dev/null @@ -1,167 +0,0 @@ -import unittest -import argilla as rg - -from collections import defaultdict -from datetime import datetime, timedelta -from unittest.mock import MagicMock, patch - -from argilla_llama_index import ArgillaCallbackHandler -from argilla_llama_index.helpers import _calc_time, _create_svg, _get_time_diff - - -class TestArgillaCallbackHandler(unittest.TestCase): - def setUp(self): - self.dataset_name = "test_dataset_llama_index" - self.api_url = "http://localhost:6900" - self.api_key = "argilla.apikey" - - self.client = rg.Argilla(api_url=self.api_url, api_key=self.api_key) - self.create_workspace("argilla") - - self.handler = ArgillaCallbackHandler( - dataset_name=self.dataset_name, - api_url=self.api_url, - api_key=self.api_key, - ) - - self.events_data = MagicMock() - self.data_to_log = MagicMock() - self.components_to_log = MagicMock() - self._ignore_components_in_tree = MagicMock() - self.trace_map = MagicMock() - - self.tree_structure_dict = { - "root": ["query"], - "query": ["retrieve", "synthesize"], - "synthesize": ["llm", "grandchild1"], - } - - def create_workspace(self, workspace_name): - workspace = self.client.workspaces(name=workspace_name) - if workspace is None: - workspace = rg.Workspace(name=workspace_name) - workspace.create() - - def test_init(self): - self.assertEqual(self.handler.dataset_name, self.dataset_name) - - @patch("argilla_llama_index.llama_index_handler.rg.Argilla") - def test_init_connection_error(self, mock_init): - mock_init.side_effect = ConnectionError("Connection failed") - with self.assertRaises(ConnectionError): - ArgillaCallbackHandler( - dataset_name=self.dataset_name, - api_url=self.api_url, - api_key=self.api_key, - ) - - @patch("argilla_llama_index.llama_index_handler.rg.Argilla.datasets") - @patch("argilla_llama_index.llama_index_handler.rg.Argilla._validate_connection") - def test_init_file_not_found_error(self, mock_validate_connection, mock_list): - mock_list.return_value = [] - mock_validate_connection.return_value = None - with self.assertRaises(FileNotFoundError): - ArgillaCallbackHandler( - dataset_name="test_dataset", - api_url="http://example.com", - api_key="test_key", - ) - - def test_check_components_for_tree(self): - self.handler._check_components_for_tree(self.tree_structure_dict) - - def test_get_events_map_with_names(self): - - trace_map = {"query": ["retrieve"], "llm": []} - events_map = self.handler._get_events_map_with_names( - self.events_data, trace_map - ) - self.assertIsInstance(events_map, dict) - self.assertEqual(len(events_map), 2) - - def test_extract_and_log_info(self): - - tree_structure_dict = self.handler._check_components_for_tree( - self.tree_structure_dict - ) - self.handler._extract_and_log_info(self.events_data, tree_structure_dict) - - def test_start_trace(self): - self.handler.start_trace() - self.assertIsNotNone(self.handler._start_time) - self.assertEqual(self.handler._trace_map, defaultdict(list)) - - @patch( - "argilla_llama_index.llama_index_handler.ArgillaCallbackHandler._create_root_and_other_nodes" - ) - @patch( - "argilla_llama_index.llama_index_handler.ArgillaCallbackHandler._extract_and_log_info" - ) - def test_end_trace( - self, mock_extract_and_log_info, mock_create_root_and_other_nodes - ): - self.handler.start_trace() - trace_id = "test_trace_id" - trace_map = {"test_key": ["test_value"]} - - self.handler.end_trace(trace_id=trace_id, trace_map=trace_map) - self.assertIsNotNone(self.handler._end_time) - self.assertAlmostEqual( - self.handler._end_time, datetime.now(), delta=timedelta(seconds=1) - ) - self.assertEqual(self.handler._trace_map, trace_map) - - mock_create_root_and_other_nodes.assert_called_once_with(trace_map) - mock_extract_and_log_info.assert_called_once_with( - self.handler.events_data, trace_map - ) - - def test_on_event_start(self): - event_type = "event1" - payload = {} - event_id = "123" - parent_id = "456" - self.handler.on_event_start(event_type, payload, event_id, parent_id) - - def test_on_event_end(self): - event_type = "event1" - payload = {} - event_id = "123" - self.handler.on_event_end(event_type, payload, event_id) - - def test_get_time_diff(self): - event_1_time_str = "01/11/2024, 17:01:04.328656" - event_2_time_str = "01/11/2024, 17:02:07.328523" - time_diff = _get_time_diff(event_1_time_str, event_2_time_str) - self.assertIsInstance(time_diff, float) - - def test_calc_time(self): - id = "event1" - self.events_data.__getitem__().__getitem__().time = ( - "01/11/2024, 17:01:04.328656" - ) - self.events_data.__getitem__().__getitem__().time = ( - "01/11/2024, 17:02:07.328523" - ) - time = _calc_time(self.events_data, id) - self.assertIsInstance(time, float) - - def test_create_svg(self): - data = [ - (0, 1, "Node1", "10ms"), - (1, 2, "Node2", "20ms") - ] - - result = _create_svg(data) - - self.assertIn('viewBox="0 0 750 108"', result) - self.assertIn('', result) - self.assertIn('Node1', result) - self.assertIn('10ms', result) - self.assertIn('', result) - self.assertIn('Node2', result) - self.assertIn('20ms', result) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_llama_index_handler.py b/tests/test_llama_index_handler.py new file mode 100644 index 0000000..cbce7bb --- /dev/null +++ b/tests/test_llama_index_handler.py @@ -0,0 +1,300 @@ +# Copyright 2024-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import unittest +from collections import namedtuple +from datetime import datetime, timedelta +from unittest.mock import MagicMock, Mock, call, patch +from uuid import uuid4 + +from argilla import ( + Argilla, + Workspace, +) +from argilla_llama_index.llama_index_handler import ArgillaHandler +from llama_index.core.base.base_query_engine import BaseQueryEngine +from llama_index.core.instrumentation.span.simple import SimpleSpan + +CommonData = namedtuple( + "CommonData", + [ + "trace_id", + "root_span_id", + "id_", + "bound_args", + "instance", + "parent_span_id", + "tags", + "span", + "result", + ], +) + + +class TestArgillaSpanHandlerLogToArgilla(unittest.TestCase): + def setUp(self): + self.dataset_name = f"test_dataset_llama_index_{uuid4()}" + self.api_url = "http://localhost:6900" + self.api_key = "argilla.apikey" + self.workspace_name = "argilla" + self.number_of_retrievals = 2 + self.client = Argilla(api_url=self.api_url, api_key=self.api_key) + self._create_workspace("argilla") + + self.handler = ArgillaHandler( + dataset_name=self.dataset_name, + api_url=self.api_url, + api_key=self.api_key, + workspace_name=self.workspace_name, + number_of_retrievals=self.number_of_retrievals, + ) + self.handler.open_spans = {} + self.handler.span_buffer = [] + self.handler.event_buffer = [] + self.handler.fields_info = {} + self.handler.completed_spans = [] + self.handler.dropped_spans = [] + + self.handler.dataset = MagicMock() + self.handler.dataset.records.log = MagicMock() + self.context_root_patcher = patch( + "argilla_llama_index.llama_index_handler.context_root" + ) + self.mock_context_root = self.context_root_patcher.start() + + def _create_workspace(self, workspace_name): + workspace = self.client.workspaces(name=workspace_name) + if workspace is None: + workspace = Workspace(name=workspace_name) + workspace.create() + + def _tearDown(self): + self.context_root_patcher.stop() + + def _create_common_data( + self, with_instance=False, with_span=False, with_result=False, **kwargs + ): + trace_id = kwargs.get("trace_id", "trace_id") + root_span_id = kwargs.get("root_span_id", "QueryEngine.query") + id_ = kwargs.get("id_", root_span_id) + parent_span_id = kwargs.get("parent_span_id", "parent_span_id") + tags = kwargs.get("tags", {"tag1": "value1"}) + bound_args = Mock(spec=inspect.BoundArguments) + bound_args.arguments = kwargs.get( + "arguments", {"message": "Test query message"} + ) + + instance = None + span = None + result = None + + if with_instance: + instance = Mock(spec=BaseQueryEngine) + instance.__class__ = BaseQueryEngine + if with_span: + span = SimpleSpan( + id_=id_, + parent_id=parent_span_id, + tags=tags, + start_time=datetime.now() - timedelta(seconds=5), + ) + if with_result: + result = Mock() + result.response = kwargs.get("response", "Test response") + + return CommonData( + trace_id, + root_span_id, + id_, + bound_args, + instance, + parent_span_id, + tags, + span, + result, + ) + + @patch("argilla_llama_index.llama_index_handler.Argilla") + @patch.object(ArgillaHandler, "_initialize_dataset") + def test_initialization(self, mock_initialize_dataset, mock_argilla): + dataset_name = "test_dataset" + api_url = "http://example.com" + api_key = "test_key" + workspace_name = "test_workspace" + number_of_retrievals = 5 + + handler = ArgillaHandler( + dataset_name=dataset_name, + api_url=api_url, + api_key=api_key, + workspace_name=workspace_name, + number_of_retrievals=number_of_retrievals, + ) + mock_argilla.assert_called_once_with(api_key=api_key, api_url=api_url) + self.assertEqual(handler.dataset_name, dataset_name) + self.assertEqual(handler.workspace_name, workspace_name) + self.assertEqual(handler.number_of_retrievals, number_of_retrievals) + self.assertEqual(handler.span_buffer, []) + self.assertEqual(handler.event_buffer, []) + self.assertEqual(handler.fields_info, {}) + self.assertIsInstance(handler.client, MagicMock) + mock_initialize_dataset.assert_called_once() + + def test_new_span(self): + data = self._create_common_data() + self.mock_context_root.get.return_value = (data.trace_id, data.root_span_id) + + span = self.handler.new_span( + id_=data.id_, + bound_args=data.bound_args, + instance=data.instance, + parent_span_id=data.parent_span_id, + tags=data.tags, + ) + + self.assertIsInstance(span, SimpleSpan) + self.assertEqual(span.id_, data.id_) + self.assertEqual(span.parent_id, data.parent_span_id) + self.assertEqual(span.tags, data.tags) + + def test_prepare_to_exit_span(self): + data = self._create_common_data(id_="test_id", with_span=True) + self.mock_context_root.get.return_value = (data.trace_id, data.root_span_id) + self.handler.open_spans[data.id_] = data.span + + self.handler.prepare_to_exit_span(data.id_, data.bound_args) + + self.assertIsNotNone(data.span.end_time) + self.assertAlmostEqual(data.span.duration, 5, delta=0.1) + self.assertEqual(len(self.handler.span_buffer), 1) + self.assertIn(data.span, self.handler.completed_spans) + + def test_prepare_to_drop_span(self): + data = self._create_common_data(with_span=True) + self.mock_context_root.get.return_value = (data.trace_id, data.root_span_id) + self.handler.open_spans[data.id_] = data.span + + self.handler.prepare_to_drop_span(id_=data.id_, bound_args=data.bound_args) + + self.assertEqual(self.handler.span_buffer, []) + self.assertEqual(self.handler.event_buffer, []) + self.assertEqual(self.handler.fields_info, {}) + self.mock_context_root.set.assert_called_once_with((None, None)) + self.assertIn(data.span, self.handler.dropped_spans) + + @patch("argilla_llama_index.llama_index_handler._create_tree_structure") + @patch("argilla_llama_index.llama_index_handler._create_svg") + def test_log_to_argilla(self, mock_create_svg, mock_create_tree_structure): + data = self._create_common_data() + span_buffer = [ + { + "id_": "span_1", + "parent_id": None, + "tags": {}, + "start_time": 0, + "end_time": 1, + "duration": 1, + } + ] + event_buffer = [ + { + "id-": "event_1", + "span_id": "span_1", + "timestamp": 1, + "event_type": "test_event", + "metadata": {}, + } + ] + fields_info = { + "query": "test_query", + "response": "test_response", + "retrieved_document_1_text": "doc1", + "retrieved_document_2_text": "doc2", + } + mock_create_tree_structure.return_value = "tree_structure" + mock_create_svg.return_value = "svg_tree" + + self.handler._log_to_argilla( + trace_id=data.trace_id, + span_buffer=span_buffer, + event_buffer=event_buffer, + fields_info=fields_info, + ) + + self.handler.dataset.records.log.assert_called_once() + records = self.handler.dataset.records.log.call_args[1]["records"] + self.assertEqual(len(records), 1) + self.assertEqual(records[0].id, data.trace_id) + self.assertIn("chat", records[0].fields) + self.assertIn("time-details", records[0].fields) + for i in range(1, self.number_of_retrievals + 1): + self.assertIn(f"retrieved_document_{i}", records[0].fields) + self.assertNotIn("retrieved_document_3", records[0].fields) + + @patch("argilla_llama_index.llama_index_handler.TermsMetadataProperty") + @patch("argilla_llama_index.llama_index_handler.IntegerMetadataProperty") + @patch("argilla_llama_index.llama_index_handler.FloatMetadataProperty") + def test_add_metadata_properties( + self, + mock_float_prop_class, + mock_int_prop_class, + mock_terms_prop_class, + ): + existing_metadata_property = Mock() + existing_metadata_property.name = "existing_metadata" + existing_metadata = [existing_metadata_property] + self.handler.dataset.settings.metadata = MagicMock() + self.handler.dataset.settings.metadata.__iter__.return_value = iter( + existing_metadata + ) + self.handler.dataset.settings.metadata.add = Mock() + self.handler.dataset.update = Mock() + + metadata = { + "new_string_property": "test", + "new_int_property": 42, + "new_float_property": 3.14, + } + + property_classes = [ + ("new_string_property", mock_terms_prop_class), + ("new_int_property", mock_int_prop_class), + ("new_float_property", mock_float_prop_class), + ] + mock_properties = [] + for prop_name, mock_class in property_classes: + mock_instance = Mock() + mock_instance.name = prop_name + mock_class.return_value = mock_instance + mock_properties.append((prop_name, mock_class, mock_instance)) + + self.handler._add_metadata_properties(metadata) + + for prop_name, mock_class, _ in mock_properties: + mock_class.assert_called_once_with(name=prop_name) + expected_calls = [ + call(mock_instance) for _, _, mock_instance in mock_properties + ] + self.handler.dataset.settings.metadata.add.assert_has_calls( + expected_calls, any_order=True + ) + self.assertEqual( + self.handler.dataset.settings.metadata.add.call_count, len(expected_calls) + ) + self.handler.dataset.update.assert_called_once() + + +if __name__ == "__main__": + unittest.main()