Skip to content

Commit

Permalink
feat: Enable adding files (#1864)
Browse files Browse the repository at this point in the history
Co-authored-by: Matt Zhou <[email protected]>
  • Loading branch information
mattzh72 and Matt Zhou authored Oct 14, 2024
1 parent cc616ef commit 9b34769
Show file tree
Hide file tree
Showing 26 changed files with 570 additions and 228 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
# Created by https://www.toptal.com/developers/gitignore/api/vim,linux,macos,pydev,python,eclipse,pycharm,windows,netbeans,pycharm+all,pycharm+iml,visualstudio,jupyternotebooks,visualstudiocode,xcode,xcodeinjection
# Edit at https://www.toptal.com/developers/gitignore?templates=vim,linux,macos,pydev,python,eclipse,pycharm,windows,netbeans,pycharm+all,pycharm+iml,visualstudio,jupyternotebooks,visualstudiocode,xcode,xcodeinjection

openapi_letta.json
openapi_openai.json

### Eclipse ###
.metadata
bin/
Expand Down
2 changes: 1 addition & 1 deletion docs/generate_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def generate_modules(config):
"Message",
"Passage",
"AgentState",
"Document",
"File",
"Source",
"LLMConfig",
"EmbeddingConfig",
Expand Down
2 changes: 1 addition & 1 deletion examples/notebooks/data_connector.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@
"outputs": [],
"source": [
"from letta.data_sources.connectors import DataConnector \n",
"from letta.schemas.document import Document\n",
"from letta.schemas.file import FileMetadata\n",
"from llama_index.core import Document as LlamaIndexDocument\n",
"from llama_index.core import SummaryIndex\n",
"from llama_index.readers.web import SimpleWebPageReader\n",
Expand Down
2 changes: 1 addition & 1 deletion letta/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
# imports for easier access
from letta.schemas.agent import AgentState
from letta.schemas.block import Block
from letta.schemas.document import Document
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import JobStatus
from letta.schemas.file import FileMetadata
from letta.schemas.job import Job
from letta.schemas.letta_message import LettaMessage
from letta.schemas.llm_config import LLMConfig
Expand Down
25 changes: 18 additions & 7 deletions letta/agent_store/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from letta.base import Base
from letta.config import LettaConfig
from letta.constants import MAX_EMBEDDING_DIM
from letta.metadata import EmbeddingConfigColumn, ToolCallColumn
from letta.metadata import EmbeddingConfigColumn, FileMetadataModel, ToolCallColumn

# from letta.schemas.message import Message, Passage, Record, RecordType, ToolCall
from letta.schemas.message import Message
Expand Down Expand Up @@ -141,7 +141,7 @@ class PassageModel(Base):
id = Column(String, primary_key=True)
user_id = Column(String, nullable=False)
text = Column(String)
doc_id = Column(String)
file_id = Column(String)
agent_id = Column(String)
source_id = Column(String)

Expand All @@ -160,7 +160,7 @@ class PassageModel(Base):
# Add a datetime column, with default value as the current time
created_at = Column(DateTime(timezone=True))

Index("passage_idx_user", user_id, agent_id, doc_id),
Index("passage_idx_user", user_id, agent_id, file_id),

def __repr__(self):
return f"<Passage(passage_id='{self.id}', text='{self.text}', embedding='{self.embedding})>"
Expand All @@ -170,7 +170,7 @@ def to_record(self):
text=self.text,
embedding=self.embedding,
embedding_config=self.embedding_config,
doc_id=self.doc_id,
file_id=self.file_id,
user_id=self.user_id,
id=self.id,
source_id=self.source_id,
Expand Down Expand Up @@ -365,12 +365,17 @@ def __init__(self, table_type: str, config: LettaConfig, user_id, agent_id=None)
self.uri = self.config.archival_storage_uri
self.db_model = PassageModel
if self.config.archival_storage_uri is None:
raise ValueError(f"Must specifiy archival_storage_uri in config {self.config.config_path}")
raise ValueError(f"Must specify archival_storage_uri in config {self.config.config_path}")
elif table_type == TableType.RECALL_MEMORY:
self.uri = self.config.recall_storage_uri
self.db_model = MessageModel
if self.config.recall_storage_uri is None:
raise ValueError(f"Must specifiy recall_storage_uri in config {self.config.config_path}")
raise ValueError(f"Must specify recall_storage_uri in config {self.config.config_path}")
elif table_type == TableType.FILES:
self.uri = self.config.metadata_storage_uri
self.db_model = FileMetadataModel
if self.config.metadata_storage_uri is None:
raise ValueError(f"Must specify metadata_storage_uri in config {self.config.config_path}")
else:
raise ValueError(f"Table type {table_type} not implemented")

Expand Down Expand Up @@ -487,8 +492,14 @@ def __init__(self, table_type: str, config: LettaConfig, user_id, agent_id=None)
# TODO: eventually implement URI option
self.path = self.config.recall_storage_path
if self.path is None:
raise ValueError(f"Must specifiy recall_storage_path in config {self.config.recall_storage_path}")
raise ValueError(f"Must specify recall_storage_path in config.")
self.db_model = MessageModel
elif table_type == TableType.FILES:
self.path = self.config.metadata_storage_path
if self.path is None:
raise ValueError(f"Must specify metadata_storage_path in config.")
self.db_model = FileMetadataModel

else:
raise ValueError(f"Table type {table_type} not implemented")

Expand Down
4 changes: 2 additions & 2 deletions letta/agent_store/lancedb.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class PassageModel(LanceModel):
id: uuid.UUID
user_id: str
text: str
doc_id: str
file_id: str
agent_id: str
data_source: str
embedding: Vector(config.default_embedding_config.embedding_dim)
Expand All @@ -37,7 +37,7 @@ def to_record(self):
return Passage(
text=self.text,
embedding=self.embedding,
doc_id=self.doc_id,
file_id=self.file_id,
user_id=self.user_id,
id=self.id,
data_source=self.data_source,
Expand Down
2 changes: 1 addition & 1 deletion letta/agent_store/milvus.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(self, table_type: str, config: LettaConfig, user_id, agent_id=None)
raise ValueError("Please set `archival_storage_uri` in the config file when using Milvus.")

# need to be converted to strings
self.uuid_fields = ["id", "user_id", "agent_id", "source_id", "doc_id"]
self.uuid_fields = ["id", "user_id", "agent_id", "source_id", "file_id"]

def _create_collection(self):
schema = MilvusClient.create_schema(
Expand Down
2 changes: 1 addition & 1 deletion letta/agent_store/qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(self, table_type: str, config: LettaConfig, user_id, agent_id=None)
distance=models.Distance.COSINE,
),
)
self.uuid_fields = ["id", "user_id", "agent_id", "source_id", "doc_id"]
self.uuid_fields = ["id", "user_id", "agent_id", "source_id", "file_id"]

def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: int = 10) -> Iterator[List[RecordType]]:
from qdrant_client import grpc
Expand Down
22 changes: 12 additions & 10 deletions letta/agent_store/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from pydantic import BaseModel

from letta.config import LettaConfig
from letta.schemas.document import Document
from letta.schemas.file import FileMetadata
from letta.schemas.message import Message
from letta.schemas.passage import Passage
from letta.utils import printd
Expand All @@ -22,7 +22,7 @@ class TableType:
ARCHIVAL_MEMORY = "archival_memory" # recall memory table: letta_agent_{agent_id}
RECALL_MEMORY = "recall_memory" # archival memory table: letta_agent_recall_{agent_id}
PASSAGES = "passages" # TODO
DOCUMENTS = "documents" # TODO
FILES = "files"


# table names used by Letta
Expand All @@ -33,17 +33,17 @@ class TableType:

# external data source tables
PASSAGE_TABLE_NAME = "letta_passages" # chunked/embedded passages (from source)
DOCUMENT_TABLE_NAME = "letta_documents" # original documents (from source)
FILE_TABLE_NAME = "letta_files" # original files (from source)


class StorageConnector:
"""Defines a DB connection that is user-specific to access data: Documents, Passages, Archival/Recall Memory"""
"""Defines a DB connection that is user-specific to access data: files, Passages, Archival/Recall Memory"""

type: Type[BaseModel]

def __init__(
self,
table_type: Union[TableType.ARCHIVAL_MEMORY, TableType.RECALL_MEMORY, TableType.PASSAGES, TableType.DOCUMENTS],
table_type: Union[TableType.ARCHIVAL_MEMORY, TableType.RECALL_MEMORY, TableType.PASSAGES, TableType.FILES],
config: LettaConfig,
user_id,
agent_id=None,
Expand All @@ -59,9 +59,9 @@ def __init__(
elif table_type == TableType.RECALL_MEMORY:
self.type = Message
self.table_name = RECALL_TABLE_NAME
elif table_type == TableType.DOCUMENTS:
self.type = Document
self.table_name == DOCUMENT_TABLE_NAME
elif table_type == TableType.FILES:
self.type = FileMetadata
self.table_name = FILE_TABLE_NAME
elif table_type == TableType.PASSAGES:
self.type = Passage
self.table_name = PASSAGE_TABLE_NAME
Expand All @@ -74,7 +74,7 @@ def __init__(
# agent-specific table
assert agent_id is not None, "Agent ID must be provided for agent-specific tables"
self.filters = {"user_id": self.user_id, "agent_id": self.agent_id}
elif self.table_type == TableType.PASSAGES or self.table_type == TableType.DOCUMENTS:
elif self.table_type == TableType.PASSAGES or self.table_type == TableType.FILES:
# setup base filters for user-specific tables
assert agent_id is None, "Agent ID must not be provided for user-specific tables"
self.filters = {"user_id": self.user_id}
Expand All @@ -83,7 +83,7 @@ def __init__(

@staticmethod
def get_storage_connector(
table_type: Union[TableType.ARCHIVAL_MEMORY, TableType.RECALL_MEMORY, TableType.PASSAGES, TableType.DOCUMENTS],
table_type: Union[TableType.ARCHIVAL_MEMORY, TableType.RECALL_MEMORY, TableType.PASSAGES, TableType.FILES],
config: LettaConfig,
user_id,
agent_id=None,
Expand All @@ -92,6 +92,8 @@ def get_storage_connector(
storage_type = config.archival_storage_type
elif table_type == TableType.RECALL_MEMORY:
storage_type = config.recall_storage_type
elif table_type == TableType.FILES:
storage_type = config.metadata_storage_type
else:
raise ValueError(f"Table type {table_type} not implemented")

Expand Down
2 changes: 1 addition & 1 deletion letta/cli/cli_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def load_vector_database(
# document_store=None,
# passage_store=passage_storage,
# )
# print(f"Loaded {num_passages} passages and {num_documents} documents from {name}")
# print(f"Loaded {num_passages} passages and {num_documents} files from {name}")
# except Exception as e:
# typer.secho(f"Failed to load data from provided information.\n{e}", fg=typer.colors.RED)
# ms.delete_source(source_id=source.id)
Expand Down
51 changes: 51 additions & 0 deletions letta/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

# new schemas
from letta.schemas.enums import JobStatus, MessageRole
from letta.schemas.file import FileMetadata
from letta.schemas.job import Job
from letta.schemas.letta_request import LettaRequest
from letta.schemas.letta_response import LettaResponse, LettaStreamingResponse
Expand Down Expand Up @@ -232,6 +233,9 @@ def list_sources(self) -> List[Source]:
def list_attached_sources(self, agent_id: str) -> List[Source]:
raise NotImplementedError

def list_files_from_source(self, source_id: str, limit: int = 1000, cursor: Optional[str] = None) -> List[FileMetadata]:
raise NotImplementedError

def update_source(self, source_id: str, name: Optional[str] = None) -> Source:
raise NotImplementedError

Expand Down Expand Up @@ -1016,6 +1020,12 @@ def get_job(self, job_id: str) -> Job:
raise ValueError(f"Failed to get job: {response.text}")
return Job(**response.json())

def delete_job(self, job_id: str) -> Job:
response = requests.delete(f"{self.base_url}/{self.api_prefix}/jobs/{job_id}", headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to delete job: {response.text}")
return Job(**response.json())

def list_jobs(self):
response = requests.get(f"{self.base_url}/{self.api_prefix}/jobs", headers=self.headers)
return [Job(**job) for job in response.json()]
Expand Down Expand Up @@ -1088,6 +1098,30 @@ def list_attached_sources(self, agent_id: str) -> List[Source]:
raise ValueError(f"Failed to list attached sources: {response.text}")
return [Source(**source) for source in response.json()]

def list_files_from_source(self, source_id: str, limit: int = 1000, cursor: Optional[str] = None) -> List[FileMetadata]:
"""
List files from source with pagination support.
Args:
source_id (str): ID of the source
limit (int): Number of files to return
cursor (Optional[str]): Pagination cursor for fetching the next page
Returns:
List[FileMetadata]: List of files
"""
# Prepare query parameters for pagination
params = {"limit": limit, "cursor": cursor}

# Make the request to the FastAPI endpoint
response = requests.get(f"{self.base_url}/{self.api_prefix}/sources/{source_id}/files", headers=self.headers, params=params)

if response.status_code != 200:
raise ValueError(f"Failed to list files with source id {source_id}: [{response.status_code}] {response.text}")

# Parse the JSON response
return [FileMetadata(**metadata) for metadata in response.json()]

def update_source(self, source_id: str, name: Optional[str] = None) -> Source:
"""
Update a source
Expand Down Expand Up @@ -2162,6 +2196,9 @@ def load_file_into_source(self, filename: str, source_id: str, blocking=True):
def get_job(self, job_id: str):
return self.server.get_job(job_id=job_id)

def delete_job(self, job_id: str):
return self.server.delete_job(job_id)

def list_jobs(self):
return self.server.list_jobs(user_id=self.user_id)

Expand Down Expand Up @@ -2261,6 +2298,20 @@ def list_attached_sources(self, agent_id: str) -> List[Source]:
"""
return self.server.list_attached_sources(agent_id=agent_id)

def list_files_from_source(self, source_id: str, limit: int = 1000, cursor: Optional[str] = None) -> List[FileMetadata]:
"""
List files from source.
Args:
source_id (str): ID of the source
limit (int): The # of items to return
cursor (str): The cursor for fetching the next page
Returns:
files (List[FileMetadata]): List of files
"""
return self.server.list_files_from_source(source_id=source_id, limit=limit, cursor=cursor)

def update_source(self, source_id: str, name: Optional[str] = None) -> Source:
"""
Update a source
Expand Down
Loading

0 comments on commit 9b34769

Please sign in to comment.