diff --git a/docs/api.md b/docs/api.md index ffcb3d8..6e1764a 100644 --- a/docs/api.md +++ b/docs/api.md @@ -62,10 +62,12 @@ docs.upsert( ## Deleting vectors -Deleting records removes them from the collection. To delete records, specify a list of `ids` to the `delete` method. The ids of the sucessfully deleted records are returned from the method. Note that attempting to delete non-existent records does not raise an error. +Deleting records removes them from the collection. To delete records, specify a list of `ids` or metadata filters to the `delete` method. The ids of the sucessfully deleted records are returned from the method. Note that attempting to delete non-existent records does not raise an error. ```python docs.delete(ids=["vec0", "vec1"]) +# or delete by a metadata filter +docs.delete(filters={"year": {"$eq": 2012}}) ``` ## Create an index diff --git a/src/tests/test_collection.py b/src/tests/test_collection.py index ad30685..e7c4b38 100644 --- a/src/tests/test_collection.py +++ b/src/tests/test_collection.py @@ -1,3 +1,4 @@ +import itertools import random import numpy as np @@ -91,24 +92,44 @@ def test_delete(client: vecs.Client) -> None: f"vec{ix}", vec, { - "genre": random.choice(["action", "rom-com", "drama"]), + "genre": genre, "year": int(50 * random.random()) + 1970, }, ) - for ix, vec in enumerate(np.random.random((n_records, dim))) + for (ix, vec), genre in zip( + enumerate(np.random.random((n_records, dim))), + itertools.cycle(["action", "rom-com", "drama"]), + ) ] # insert works movies.upsert(records) + # delete by IDs. delete_ids = ["vec0", "vec15", "vec99"] movies.delete(ids=delete_ids) assert len(movies) == n_records - len(delete_ids) + # insert works + movies.upsert(records) + + # delete with filters + genre_to_delete = "action" + deleted_ids_by_genre = movies.delete(filters={"genre": {"$eq": genre_to_delete}}) + assert len(deleted_ids_by_genre) == 34 + # bad input with pytest.raises(vecs.exc.ArgError): movies.delete(ids="should_be_a_list") + # bad input: neither ids nor filters provided. + with pytest.raises(vecs.exc.ArgError): + movies.delete() + + # bad input: should only provide either ids or filters, not both + with pytest.raises(vecs.exc.ArgError): + movies.delete(ids=["vec0"], filters={"genre": {"$eq": genre_to_delete}}) + def test_repr(client: vecs.Client) -> None: movies = client.get_or_create_collection(name="movies", dimension=99) diff --git a/src/vecs/collection.py b/src/vecs/collection.py index d236d7a..7ff7884 100644 --- a/src/vecs/collection.py +++ b/src/vecs/collection.py @@ -375,33 +375,52 @@ def fetch(self, ids: Iterable[str]) -> List[Record]: records.extend(chunk_records) return records - def delete(self, ids: Iterable[str]) -> List[str]: + def delete( + self, ids: Optional[Iterable[str]] = None, filters: Optional[Metadata] = None + ) -> List[str]: """ - Deletes vectors from the collection by their identifiers. + Deletes vectors from the collection by matching filters or ids. Args: - ids (Iterable[str]): An iterable of vector identifiers. + ids (Iterable[str], optional): An iterable of vector identifiers. + filters (Optional[Dict], optional): Filters to apply to the search. Defaults to None. Returns: List[str]: A list of the identifiers of the deleted vectors. """ + if ids is None and filters is None: + raise ArgError("Either ids or filters must be provided.") + + if ids is not None and filters is not None: + raise ArgError("Either ids or filters must be provided, not both.") + if isinstance(ids, str): raise ArgError("ids must be a list of strings") - chunk_size = 12 + ids = ids or [] + filters = filters or {} + del_ids = [] - del_ids = list(ids) - ids = [] with self.client.Session() as sess: with sess.begin(): - for id_chunk in flu(del_ids).chunk(chunk_size): + if ids: + for id_chunk in flu(ids).chunk(12): + stmt = ( + delete(self.table) + .where(self.table.c.id.in_(id_chunk)) + .returning(self.table.c.id) + ) + del_ids.extend(sess.execute(stmt).scalars() or []) + + if filters: + meta_filter = build_filters(self.table.c.metadata, filters) stmt = ( - delete(self.table) - .where(self.table.c.id.in_(id_chunk)) - .returning(self.table.c.id) + delete(self.table).where(meta_filter).returning(self.table.c.id) # type: ignore ) - ids.extend(sess.execute(stmt).scalars() or []) - return ids + result = sess.execute(stmt).scalars() + del_ids.extend(result.fetchall()) + + return del_ids def __getitem__(self, items): """ @@ -516,7 +535,9 @@ def query( stmt = select(*cols) if filters: - stmt = stmt.filter(build_filters(self.table.c.metadata, filters)) # type: ignore + stmt = stmt.filter( + build_filters(self.table.c.metadata, filters) # type: ignore + ) stmt = stmt.order_by(distance_clause) stmt = stmt.limit(limit)