Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Qdrant storage connector #1023

Merged
merged 63 commits into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
7ec1711
feat: Qdrant storage connector
Anush008 Feb 17, 2024
2c84103
Merge branch 'main' into main
Anush008 Feb 23, 2024
9a86895
chore: poetry.lock
Anush008 Feb 23, 2024
1296059
Merge branch 'main' into main
Anush008 Feb 25, 2024
c2e0184
chore: poetry.lock
Anush008 Feb 25, 2024
d5e8274
Merge branch 'main' into main
Anush008 Mar 1, 2024
eaae0bb
docs: Qdrant reference
Anush008 Mar 1, 2024
f4349aa
ci: Qdrant test container
Anush008 Mar 1, 2024
b143d23
Merge branch 'main' into main
Anush008 Mar 3, 2024
68d38a5
chore: poetry.lock
Anush008 Mar 3, 2024
c2167a9
Merge branch 'main' into main
Anush008 Mar 5, 2024
b4db28b
chore: poetry.lock
Anush008 Mar 5, 2024
dfa6300
Merge remote-tracking branch 'origin' into Anush008/main
Anush008 Mar 6, 2024
882c3a2
chore: update latest changes
Anush008 Mar 6, 2024
0001206
Merge remote-tracking branch 'origin' into Anush008/main
Anush008 Mar 11, 2024
2b373cf
chore: poetry.lock
Anush008 Mar 11, 2024
9a081a2
Merge branch 'cpacker:main' into main
Anush008 Mar 15, 2024
12ae58d
Merge remote-tracking branch 'origin' into Anush008/main
Anush008 Mar 17, 2024
6d3db55
chore: update imports
Anush008 Mar 17, 2024
313e94a
Merge branch 'cpacker:main' into main
Anush008 Mar 18, 2024
c8bca8a
chore: Qdrant 1.8
Anush008 Mar 19, 2024
c6531a4
Merge branch 'main' of https://github.com/Anush008/MemGPT into Anush0…
Anush008 Mar 19, 2024
db6cbce
Merge branch 'cpacker:main' into main
Anush008 Mar 19, 2024
36702fb
chore: doc_id uuid, storage_uri check
Anush008 Mar 19, 2024
b8c369e
Merge branch 'cpacker:main' into main
Anush008 Mar 21, 2024
fafe3d0
ci: Bump image version tests.yml
Anush008 Mar 21, 2024
924715b
Merge branch 'cpacker:main' into main
Anush008 Mar 25, 2024
d83719c
Merge branch 'cpacker:main' into main
Anush008 Mar 26, 2024
1cc20c1
Merge branch 'main' into main
Anush008 Mar 28, 2024
6574a7e
chore: bump qdrant_client 1.8.2
Anush008 Mar 28, 2024
5ecb09d
Merge branch 'cpacker:main' into main
Anush008 Mar 29, 2024
b1540f2
Merge branch 'cpacker:main' into main
Anush008 Apr 2, 2024
a2990a7
Merge branch 'cpacker:main' into main
Anush008 Apr 3, 2024
815367f
Merge branch 'cpacker:main' into main
Anush008 Apr 4, 2024
ff5c7c6
Merge remote-tracking branch 'upstream/main'
Anush008 Apr 7, 2024
617704a
chore: Updated poetry lock
Anush008 Apr 7, 2024
6f3bfec
Merge branch 'cpacker:main' into main
Anush008 Apr 9, 2024
d9b139c
Merge branch 'cpacker:main' into main
Anush008 Apr 12, 2024
99c05fa
Merge branch 'cpacker:main' into main
Anush008 Apr 15, 2024
563fbc4
Merge branch 'main' into main
Anush008 Apr 18, 2024
9fed99b
chore: poetry.lock
Anush008 Apr 18, 2024
0f09fc5
Merge branch 'main' into main
Anush008 Apr 19, 2024
bf0c7f3
chore: poetry.lock update
Anush008 Apr 19, 2024
c827d4a
Merge branch 'main' into main
Anush008 Apr 22, 2024
60e5433
chore: poetry.lock
Anush008 Apr 22, 2024
bb766a3
Merge branch 'cpacker:main' into main
Anush008 Apr 23, 2024
798bac6
Merge branch 'cpacker:main' into main
Anush008 Apr 25, 2024
ea8aa63
Merge branch 'main' into main
Anush008 Apr 28, 2024
d1acb14
chore: poetry.lock
Anush008 Apr 28, 2024
36d2b06
Merge branch 'cpacker:main' into main
Anush008 Apr 29, 2024
67f5c07
Merge branch 'cpacker:main' into main
Anush008 May 1, 2024
dc7d616
Merge branch 'cpacker:main' into main
Anush008 May 5, 2024
1183411
Merge branch 'cpacker:main' into main
Anush008 May 7, 2024
44a448a
Merge branch 'cpacker:main' into main
Anush008 May 11, 2024
e7caf12
ci: Qdrant service for tests
Anush008 May 15, 2024
26c5cf1
Merge branch 'cpacker:main' into main
Anush008 May 22, 2024
5940f14
Merge branch 'cpacker:main' into main
Anush008 May 24, 2024
aba6052
test: Use Qdrant at localhost:6333 for test_storage too
Anush008 May 24, 2024
b05ef7d
chore: Make qdrant-client optional
Anush008 May 24, 2024
8c51663
Merge branch 'main' into main
Anush008 May 27, 2024
07b4992
chore: Resolve conflicts, poetry.lock
Anush008 May 27, 2024
2bdcf79
Merge branch 'main' of https://github.com/cpacker/MemGPT into Anush00…
Anush008 Jun 5, 2024
f8cc717
chore: poetry.lock, resolve conflicts
Anush008 Jun 5, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,13 @@ jobs:
test:
runs-on: ubuntu-latest
timeout-minutes: 15

services:
qdrant:
image: qdrant/qdrant
ports:
- 6333:6333

steps:
- name: Checkout
uses: actions/checkout@v4
Expand Down
20 changes: 17 additions & 3 deletions docs/storage.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,19 +36,19 @@ To run the Postgres backend, you will need a URI to a Postgres database that sup

3. Configure the environment for `pgvector`. You can either:
- Add the following line to your shell profile (e.g., `~/.bashrc`, `~/.zshrc`):

```sh
export MEMGPT_PGURI=postgresql+pg8000://memgpt:memgpt@localhost:8888/memgpt
```

- Or create a `.env` file in the root project directory with:

```sh
MEMGPT_PGURI=postgresql+pg8000://memgpt:memgpt@localhost:8888/memgpt
```

4. Run the script from the root project directory:

```sh
bash db/run_postgres.sh
```
Expand Down Expand Up @@ -105,6 +105,20 @@ memgpt configure

and selecting `lancedb` for archival storage, and database URI (e.g. `./.lancedb`"), Empty archival uri is also handled and default uri is set at `./.lancedb`. For more checkout [lancedb docs](https://lancedb.github.io/lancedb/)

## Qdrant

To enable the Qdrant backend, make sure to install the required dependencies with:

```sh
pip install 'pymemgpt[qdrant]'
```

You can configure Qdrant with an in-memory instance or a server using the `memgpt configure` command. You can set an API key for authentication with a Qdrant server using the `QDRANT_API_KEY` environment variable. Learn more about setting up Qdrant [here](https://qdrant.tech/documentation/guides/installation/).

```sh
? Select Qdrant backend: server
? Enter the Qdrant instance URI (Default: localhost:6333): localhost:6333
```

## Milvus

Expand Down
201 changes: 201 additions & 0 deletions memgpt/agent_store/qdrant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
import os
import uuid
from copy import deepcopy
from typing import Dict, Iterator, List, Optional, cast

from memgpt.agent_store.storage import StorageConnector, TableType
from memgpt.config import MemGPTConfig
from memgpt.constants import MAX_EMBEDDING_DIM
from memgpt.data_types import Passage, Record, RecordType
from memgpt.utils import datetime_to_timestamp, timestamp_to_datetime

TEXT_PAYLOAD_KEY = "text_content"
METADATA_PAYLOAD_KEY = "metadata"


class QdrantStorageConnector(StorageConnector):
"""Storage via Qdrant"""

def __init__(self, table_type: str, config: MemGPTConfig, user_id, agent_id=None):
super().__init__(table_type=table_type, config=config, user_id=user_id, agent_id=agent_id)
try:
from qdrant_client import QdrantClient, models
except ImportError as e:
raise ImportError("'qdrant-client' not installed. Run `pip install qdrant-client`.") from e
assert table_type in [TableType.ARCHIVAL_MEMORY, TableType.PASSAGES], "Qdrant only supports archival memory"
if config.archival_storage_uri and len(config.archival_storage_uri.split(":")) == 2:
host, port = config.archival_storage_uri.split(":")
self.qdrant_client = QdrantClient(host=host, port=port, api_key=os.getenv("QDRANT_API_KEY"))
elif config.archival_storage_path:
self.qdrant_client = QdrantClient(path=config.archival_storage_path)
else:
raise ValueError("Qdrant storage requires either a URI or a path to the storage configured")
if not self.qdrant_client.collection_exists(self.table_name):
self.qdrant_client.create_collection(
collection_name=self.table_name,
vectors_config=models.VectorParams(
size=MAX_EMBEDDING_DIM,
distance=models.Distance.COSINE,
),
)
self.uuid_fields = ["id", "user_id", "agent_id", "source_id", "doc_id"]

def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: int = 10) -> Iterator[List[RecordType]]:
from qdrant_client import grpc

filters = self.get_qdrant_filters(filters)
next_offset = None
stop_scrolling = False
while not stop_scrolling:
results, next_offset = self.qdrant_client.scroll(
collection_name=self.table_name,
scroll_filter=filters,
limit=page_size,
offset=next_offset,
with_payload=True,
with_vectors=True,
)
stop_scrolling = next_offset is None or (
isinstance(next_offset, grpc.PointId) and next_offset.num == 0 and next_offset.uuid == ""
)
yield self.to_records(results)

def get_all(self, filters: Optional[Dict] = {}, limit=10) -> List[RecordType]:
if self.size(filters) == 0:
return []
filters = self.get_qdrant_filters(filters)
results, _ = self.qdrant_client.scroll(
self.table_name,
scroll_filter=filters,
limit=limit,
with_payload=True,
with_vectors=True,
)
return self.to_records(results)

def get(self, id: uuid.UUID) -> Optional[RecordType]:
results = self.qdrant_client.retrieve(
collection_name=self.table_name,
ids=[str(id)],
with_payload=True,
with_vectors=True,
)
if not results:
return None
return self.to_records(results)[0]

def insert(self, record: Record):
points = self.to_points([record])
self.qdrant_client.upsert(self.table_name, points=points)

def insert_many(self, records: List[RecordType], show_progress=False):
points = self.to_points(records)
self.qdrant_client.upsert(self.table_name, points=points)

def delete(self, filters: Optional[Dict] = {}):
filters = self.get_qdrant_filters(filters)
self.qdrant_client.delete(self.table_name, points_selector=filters)

def delete_table(self):
self.qdrant_client.delete_collection(self.table_name)
self.qdrant_client.close()

def size(self, filters: Optional[Dict] = {}) -> int:
filters = self.get_qdrant_filters(filters)
return self.qdrant_client.count(collection_name=self.table_name, count_filter=filters).count

def close(self):
self.qdrant_client.close()

def query(
self,
query: str,
query_vec: List[float],
top_k: int = 10,
filters: Optional[Dict] = {},
) -> List[RecordType]:
filters = self.get_filters(filters)
results = self.qdrant_client.search(
self.table_name,
query_vector=query_vec,
query_filter=filters,
limit=top_k,
with_payload=True,
with_vectors=True,
)
return self.to_records(results)

def to_records(self, records: list) -> List[RecordType]:
parsed_records = []
for record in records:
record = deepcopy(record)
metadata = record.payload[METADATA_PAYLOAD_KEY]
text = record.payload[TEXT_PAYLOAD_KEY]
_id = metadata.pop("id")
embedding = record.vector
for key, value in metadata.items():
if key in self.uuid_fields:
metadata[key] = uuid.UUID(value)
elif key == "created_at":
metadata[key] = timestamp_to_datetime(value)
parsed_records.append(
cast(
RecordType,
self.type(
text=text,
embedding=embedding,
id=uuid.UUID(_id),
**metadata,
),
)
)
return parsed_records

def to_points(self, records: List[RecordType]):
from qdrant_client import models

assert all(isinstance(r, Passage) for r in records)
points = []
records = list(set(records))
for record in records:
record = vars(record)
_id = record.pop("id")
text = record.pop("text", "")
embedding = record.pop("embedding", {})
record_metadata = record.pop("metadata_", None) or {}
if "created_at" in record:
record["created_at"] = datetime_to_timestamp(record["created_at"])
metadata = {key: value for key, value in record.items() if value is not None}
metadata = {
**metadata,
**record_metadata,
"id": str(_id),
}
for key, value in metadata.items():
if key in self.uuid_fields:
metadata[key] = str(value)
points.append(
models.PointStruct(
id=str(_id),
vector=embedding,
payload={
TEXT_PAYLOAD_KEY: text,
METADATA_PAYLOAD_KEY: metadata,
},
)
)
return points

def get_qdrant_filters(self, filters: Optional[Dict] = {}):
from qdrant_client import models

filter_conditions = {**self.filters, **filters} if filters is not None else self.filters
must_conditions = []
for key, value in filter_conditions.items():
match_value = str(value) if key in self.uuid_fields else value
field_condition = models.FieldCondition(
key=f"{METADATA_PAYLOAD_KEY}.{key}",
match=models.MatchValue(value=match_value),
)
must_conditions.append(field_condition)
return models.Filter(must=must_conditions)
4 changes: 4 additions & 0 deletions memgpt/agent_store/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@ def get_storage_connector(

return ChromaStorageConnector(table_type, config, user_id, agent_id)

elif storage_type == "qdrant":
from memgpt.agent_store.qdrant import QdrantStorageConnector

return QdrantStorageConnector(table_type, config, user_id, agent_id)
# TODO: add back
# elif storage_type == "lancedb":
# from memgpt.agent_store.db import LanceDBConnector
Expand Down
15 changes: 14 additions & 1 deletion memgpt/cli/cli_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -913,7 +913,7 @@ def configure_embedding_endpoint(config: MemGPTConfig, credentials: MemGPTCreden

def configure_archival_storage(config: MemGPTConfig, credentials: MemGPTCredentials):
# Configure archival storage backend
archival_storage_options = ["postgres", "chroma", "milvus"]
archival_storage_options = ["postgres", "chroma", "milvus", "qdrant"]
archival_storage_type = questionary.select(
"Select storage backend for archival data:", archival_storage_options, default=config.archival_storage_type
).ask()
Expand Down Expand Up @@ -950,6 +950,19 @@ def configure_archival_storage(config: MemGPTConfig, credentials: MemGPTCredenti
if chroma_type == "persistent":
archival_storage_path = os.path.join(MEMGPT_DIR, "chroma")

if archival_storage_type == "qdrant":
qdrant_type = questionary.select("Select Qdrant backend:", ["local", "server"], default="local").ask()
if qdrant_type is None:
raise KeyboardInterrupt
if qdrant_type == "server":
archival_storage_uri = questionary.text(
"Enter the Qdrant instance URI (Default: localhost:6333):", default="localhost:6333"
).ask()
if archival_storage_uri is None:
raise KeyboardInterrupt
if qdrant_type == "local":
archival_storage_path = os.path.join(MEMGPT_DIR, "qdrant")

if archival_storage_type == "milvus":
default_milvus_uri = archival_storage_path = os.path.join(MEMGPT_DIR, "milvus.db")
archival_storage_uri = questionary.text(
Expand Down
Loading
Loading