-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Qdrant storage connector (#1023)
- Loading branch information
Showing
10 changed files
with
426 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,201 @@ | ||
import os | ||
import uuid | ||
from copy import deepcopy | ||
from typing import Dict, Iterator, List, Optional, cast | ||
|
||
from memgpt.agent_store.storage import StorageConnector, TableType | ||
from memgpt.config import MemGPTConfig | ||
from memgpt.constants import MAX_EMBEDDING_DIM | ||
from memgpt.data_types import Passage, Record, RecordType | ||
from memgpt.utils import datetime_to_timestamp, timestamp_to_datetime | ||
|
||
TEXT_PAYLOAD_KEY = "text_content" | ||
METADATA_PAYLOAD_KEY = "metadata" | ||
|
||
|
||
class QdrantStorageConnector(StorageConnector): | ||
"""Storage via Qdrant""" | ||
|
||
def __init__(self, table_type: str, config: MemGPTConfig, user_id, agent_id=None): | ||
super().__init__(table_type=table_type, config=config, user_id=user_id, agent_id=agent_id) | ||
try: | ||
from qdrant_client import QdrantClient, models | ||
except ImportError as e: | ||
raise ImportError("'qdrant-client' not installed. Run `pip install qdrant-client`.") from e | ||
assert table_type in [TableType.ARCHIVAL_MEMORY, TableType.PASSAGES], "Qdrant only supports archival memory" | ||
if config.archival_storage_uri and len(config.archival_storage_uri.split(":")) == 2: | ||
host, port = config.archival_storage_uri.split(":") | ||
self.qdrant_client = QdrantClient(host=host, port=port, api_key=os.getenv("QDRANT_API_KEY")) | ||
elif config.archival_storage_path: | ||
self.qdrant_client = QdrantClient(path=config.archival_storage_path) | ||
else: | ||
raise ValueError("Qdrant storage requires either a URI or a path to the storage configured") | ||
if not self.qdrant_client.collection_exists(self.table_name): | ||
self.qdrant_client.create_collection( | ||
collection_name=self.table_name, | ||
vectors_config=models.VectorParams( | ||
size=MAX_EMBEDDING_DIM, | ||
distance=models.Distance.COSINE, | ||
), | ||
) | ||
self.uuid_fields = ["id", "user_id", "agent_id", "source_id", "doc_id"] | ||
|
||
def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: int = 10) -> Iterator[List[RecordType]]: | ||
from qdrant_client import grpc | ||
|
||
filters = self.get_qdrant_filters(filters) | ||
next_offset = None | ||
stop_scrolling = False | ||
while not stop_scrolling: | ||
results, next_offset = self.qdrant_client.scroll( | ||
collection_name=self.table_name, | ||
scroll_filter=filters, | ||
limit=page_size, | ||
offset=next_offset, | ||
with_payload=True, | ||
with_vectors=True, | ||
) | ||
stop_scrolling = next_offset is None or ( | ||
isinstance(next_offset, grpc.PointId) and next_offset.num == 0 and next_offset.uuid == "" | ||
) | ||
yield self.to_records(results) | ||
|
||
def get_all(self, filters: Optional[Dict] = {}, limit=10) -> List[RecordType]: | ||
if self.size(filters) == 0: | ||
return [] | ||
filters = self.get_qdrant_filters(filters) | ||
results, _ = self.qdrant_client.scroll( | ||
self.table_name, | ||
scroll_filter=filters, | ||
limit=limit, | ||
with_payload=True, | ||
with_vectors=True, | ||
) | ||
return self.to_records(results) | ||
|
||
def get(self, id: uuid.UUID) -> Optional[RecordType]: | ||
results = self.qdrant_client.retrieve( | ||
collection_name=self.table_name, | ||
ids=[str(id)], | ||
with_payload=True, | ||
with_vectors=True, | ||
) | ||
if not results: | ||
return None | ||
return self.to_records(results)[0] | ||
|
||
def insert(self, record: Record): | ||
points = self.to_points([record]) | ||
self.qdrant_client.upsert(self.table_name, points=points) | ||
|
||
def insert_many(self, records: List[RecordType], show_progress=False): | ||
points = self.to_points(records) | ||
self.qdrant_client.upsert(self.table_name, points=points) | ||
|
||
def delete(self, filters: Optional[Dict] = {}): | ||
filters = self.get_qdrant_filters(filters) | ||
self.qdrant_client.delete(self.table_name, points_selector=filters) | ||
|
||
def delete_table(self): | ||
self.qdrant_client.delete_collection(self.table_name) | ||
self.qdrant_client.close() | ||
|
||
def size(self, filters: Optional[Dict] = {}) -> int: | ||
filters = self.get_qdrant_filters(filters) | ||
return self.qdrant_client.count(collection_name=self.table_name, count_filter=filters).count | ||
|
||
def close(self): | ||
self.qdrant_client.close() | ||
|
||
def query( | ||
self, | ||
query: str, | ||
query_vec: List[float], | ||
top_k: int = 10, | ||
filters: Optional[Dict] = {}, | ||
) -> List[RecordType]: | ||
filters = self.get_filters(filters) | ||
results = self.qdrant_client.search( | ||
self.table_name, | ||
query_vector=query_vec, | ||
query_filter=filters, | ||
limit=top_k, | ||
with_payload=True, | ||
with_vectors=True, | ||
) | ||
return self.to_records(results) | ||
|
||
def to_records(self, records: list) -> List[RecordType]: | ||
parsed_records = [] | ||
for record in records: | ||
record = deepcopy(record) | ||
metadata = record.payload[METADATA_PAYLOAD_KEY] | ||
text = record.payload[TEXT_PAYLOAD_KEY] | ||
_id = metadata.pop("id") | ||
embedding = record.vector | ||
for key, value in metadata.items(): | ||
if key in self.uuid_fields: | ||
metadata[key] = uuid.UUID(value) | ||
elif key == "created_at": | ||
metadata[key] = timestamp_to_datetime(value) | ||
parsed_records.append( | ||
cast( | ||
RecordType, | ||
self.type( | ||
text=text, | ||
embedding=embedding, | ||
id=uuid.UUID(_id), | ||
**metadata, | ||
), | ||
) | ||
) | ||
return parsed_records | ||
|
||
def to_points(self, records: List[RecordType]): | ||
from qdrant_client import models | ||
|
||
assert all(isinstance(r, Passage) for r in records) | ||
points = [] | ||
records = list(set(records)) | ||
for record in records: | ||
record = vars(record) | ||
_id = record.pop("id") | ||
text = record.pop("text", "") | ||
embedding = record.pop("embedding", {}) | ||
record_metadata = record.pop("metadata_", None) or {} | ||
if "created_at" in record: | ||
record["created_at"] = datetime_to_timestamp(record["created_at"]) | ||
metadata = {key: value for key, value in record.items() if value is not None} | ||
metadata = { | ||
**metadata, | ||
**record_metadata, | ||
"id": str(_id), | ||
} | ||
for key, value in metadata.items(): | ||
if key in self.uuid_fields: | ||
metadata[key] = str(value) | ||
points.append( | ||
models.PointStruct( | ||
id=str(_id), | ||
vector=embedding, | ||
payload={ | ||
TEXT_PAYLOAD_KEY: text, | ||
METADATA_PAYLOAD_KEY: metadata, | ||
}, | ||
) | ||
) | ||
return points | ||
|
||
def get_qdrant_filters(self, filters: Optional[Dict] = {}): | ||
from qdrant_client import models | ||
|
||
filter_conditions = {**self.filters, **filters} if filters is not None else self.filters | ||
must_conditions = [] | ||
for key, value in filter_conditions.items(): | ||
match_value = str(value) if key in self.uuid_fields else value | ||
field_condition = models.FieldCondition( | ||
key=f"{METADATA_PAYLOAD_KEY}.{key}", | ||
match=models.MatchValue(value=match_value), | ||
) | ||
must_conditions.append(field_condition) | ||
return models.Filter(must=must_conditions) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.