From 4a57e8931b72e883b642eeaf42af660c12ea4444 Mon Sep 17 00:00:00 2001 From: Dax Pryce Date: Tue, 7 Nov 2023 09:28:04 -0800 Subject: [PATCH] Adding Filtered Index support to Python bindings (#482) * Halfway approach to the new indexfactory, but it doesn't have the same featureset as the old way. Committing this for posterity but reverting my changes ultimately * Revert "Halfway approach to the new indexfactory, but it doesn't have the same featureset as the old way. Committing this for posterity but reverting my changes ultimately" This reverts commit 03dccb599449881f64664a10b397a790a7d00985. * Adding filtered search. API is going to change still. * Further enhancements to the new filter capability in the static memory index. * Ran automatic formatting * Fixing my logic and ensuring the unit tests pass. * Setting this up as a rc build first * list[list[Hashable]] -> list[list[str]] * Adding halfway to a solution where we query for more items than exist in the filter set. We need to replicate this behavior across all indices though - dynamic, static disk and memory w/o filters, etc * Removing the import of Hashable too --- pyproject.toml | 2 +- python/include/builder.h | 5 +- python/include/static_memory_index.h | 4 + python/src/_builder.py | 51 +++++++++-- python/src/_builder.pyi | 12 +-- python/src/_common.py | 22 ++--- python/src/_dynamic_memory_index.py | 2 +- python/src/_static_memory_index.py | 45 +++++++++- python/src/builder.cpp | 60 +++++++++++-- python/src/diskann_bindings.cpp | 1 - python/src/module.cpp | 5 +- python/src/static_memory_index.cpp | 12 +++ python/tests/test_static_memory_index.py | 108 ++++++++++++++++++++++- 13 files changed, 290 insertions(+), 39 deletions(-) delete mode 100644 python/src/diskann_bindings.cpp diff --git a/pyproject.toml b/pyproject.toml index df2a342ff..39c79b81e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ build-backend = "setuptools.build_meta" [project] name = "diskannpy" -version = "0.6.1" +version = "0.7.0rc1" description = "DiskANN Python extension module" readme = "python/README.md" diff --git a/python/include/builder.h b/python/include/builder.h index fc12976e7..6b1a5b4f3 100644 --- a/python/include/builder.h +++ b/python/include/builder.h @@ -20,7 +20,8 @@ template void build_memory_index(diskann::Metric metric, const std::string &vector_bin_path, const std::string &index_output_path, uint32_t graph_degree, uint32_t complexity, float alpha, uint32_t num_threads, bool use_pq_build, - size_t num_pq_bytes, bool use_opq, uint32_t filter_complexity, - bool use_tags = false); + size_t num_pq_bytes, bool use_opq, bool use_tags = false, + const std::string& filter_labels_file = "", const std::string& universal_label = "", + uint32_t filter_complexity = 0); } diff --git a/python/include/static_memory_index.h b/python/include/static_memory_index.h index 6a222bedb..6ed5a0822 100644 --- a/python/include/static_memory_index.h +++ b/python/include/static_memory_index.h @@ -26,6 +26,10 @@ template class StaticMemoryIndex NeighborsAndDistances search(py::array_t &query, uint64_t knn, uint64_t complexity); + NeighborsAndDistances search_with_filter( + py::array_t &query, uint64_t knn, uint64_t complexity, + filterT filter); + NeighborsAndDistances batch_search( py::array_t &queries, uint64_t num_queries, uint64_t knn, uint64_t complexity, uint32_t num_threads); diff --git a/python/src/_builder.py b/python/src/_builder.py index db2b200db..013b7f2c9 100644 --- a/python/src/_builder.py +++ b/python/src/_builder.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT license. +import json import os import shutil from pathlib import Path @@ -174,8 +175,10 @@ def build_memory_index( num_pq_bytes: int = defaults.NUM_PQ_BYTES, use_opq: bool = defaults.USE_OPQ, vector_dtype: Optional[VectorDType] = None, - filter_complexity: int = defaults.FILTER_COMPLEXITY, tags: Union[str, VectorIdentifierBatch] = "", + filter_labels: Optional[list[list[str]]] = None, + universal_label: str = "", + filter_complexity: int = defaults.FILTER_COMPLEXITY, index_prefix: str = "ann", ) -> None: """ @@ -223,10 +226,20 @@ def build_memory_index( Default is `0`. - **use_opq**: Use optimized product quantization during build. - **vector_dtype**: Required if the provided `data` is of type `str`, else we use the `data.dtype` if np array. - - **filter_complexity**: Complexity to use when using filters. Default is 0. - - **tags**: A `str` representing a path to a pre-built tags file on disk, or a `numpy.ndarray` of uint32 ids - corresponding to the ordinal position of the vectors provided to build the index. Defaults to "". **This value - must be provided if you want to build a memory index intended for use with `diskannpy.DynamicMemoryIndex`**. + - **tags**: Tags can be defined either as a path on disk to an existing .tags file, or provided as a np.array of + the same length as the number of vectors. Tags are used to identify vectors in the index via your *own* + numbering conventions, and is absolutely required for loading DynamicMemoryIndex indices `from_file`. + - **filter_labels**: An optional, but exhaustive list of categories for each vector. This is used to filter + search results by category. If provided, this must be a list of lists, where each inner list is a list of + categories for the corresponding vector. For example, if you have 3 vectors, and the first vector belongs to + categories "a" and "b", the second vector belongs to category "b", and the third vector belongs to no categories, + you would provide `filter_labels=[["a", "b"], ["b"], []]`. If you do not want to provide categories for a + particular vector, you can provide an empty list. If you do not want to provide categories for any vectors, + you can provide `None` for this parameter (which is the default) + - **universal_label**: An optional label that indicates that this vector should be included in *every* search + in which it also meets the knn search criteria. + - **filter_complexity**: Complexity to use when using filters. Default is 0. 0 is strictly invalid if you are + using filters. - **index_prefix**: The prefix of the index files. Defaults to "ann". """ _assert( @@ -245,6 +258,10 @@ def build_memory_index( _assert_is_nonnegative_uint32(num_pq_bytes, "num_pq_bytes") _assert_is_nonnegative_uint32(filter_complexity, "filter_complexity") _assert(index_prefix != "", "index_prefix cannot be an empty string") + _assert( + filter_labels is None or filter_complexity > 0, + "if filter_labels is provided, filter_complexity must not be 0" + ) index_path = Path(index_directory) _assert( @@ -262,6 +279,11 @@ def build_memory_index( ) num_points, dimensions = vectors_metadata_from_file(vector_bin_path) + if filter_labels is not None: + _assert( + len(filter_labels) == num_points, + "filter_labels must be the same length as the number of points" + ) if vector_dtype_actual == np.uint8: _builder = _native_dap.build_memory_uint8_index @@ -272,6 +294,21 @@ def build_memory_index( index_prefix_path = os.path.join(index_directory, index_prefix) + filter_labels_file = "" + if filter_labels is not None: + label_counts = {} + filter_labels_file = f"{index_prefix_path}_pylabels.txt" + with open(filter_labels_file, "w") as labels_file: + for labels in filter_labels: + for label in labels: + label_counts[label] = 1 if label not in label_counts else label_counts[label] + 1 + if len(labels) == 0: + print("default", file=labels_file) + else: + print(",".join(labels), file=labels_file) + with open(f"{index_prefix_path}_label_metadata.json", "w") as label_metadata_file: + json.dump(label_counts, label_metadata_file, indent=True) + if isinstance(tags, str) and tags != "": use_tags = True shutil.copy(tags, index_prefix_path + ".tags") @@ -299,8 +336,10 @@ def build_memory_index( use_pq_build=use_pq_build, num_pq_bytes=num_pq_bytes, use_opq=use_opq, - filter_complexity=filter_complexity, use_tags=use_tags, + filter_labels_file=filter_labels_file, + universal_label=universal_label, + filter_complexity=filter_complexity, ) _write_index_metadata( diff --git a/python/src/_builder.pyi b/python/src/_builder.pyi index 5014880c6..223e6c923 100644 --- a/python/src/_builder.pyi +++ b/python/src/_builder.pyi @@ -47,11 +47,11 @@ def build_memory_index( use_pq_build: bool, num_pq_bytes: int, use_opq: bool, - label_file: str, + tags: Union[str, VectorIdentifierBatch], + filter_labels: Optional[list[list[str]]], universal_label: str, filter_complexity: int, - tags: Optional[VectorIdentifierBatch], - index_prefix: str, + index_prefix: str ) -> None: ... @overload def build_memory_index( @@ -66,9 +66,9 @@ def build_memory_index( num_pq_bytes: int, use_opq: bool, vector_dtype: VectorDType, - label_file: str, + tags: Union[str, VectorIdentifierBatch], + filter_labels_file: Optional[list[list[str]]], universal_label: str, filter_complexity: int, - tags: Optional[str], - index_prefix: str, + index_prefix: str ) -> None: ... diff --git a/python/src/_common.py b/python/src/_common.py index 53f1dbcab..2b28802ff 100644 --- a/python/src/_common.py +++ b/python/src/_common.py @@ -211,6 +211,7 @@ def _ensure_index_metadata( distance_metric: Optional[DistanceMetric], max_vectors: int, dimensions: Optional[int], + warn_size_exceeded: bool = False, ) -> Tuple[VectorDType, str, np.uint64, np.uint64]: possible_metadata = _read_index_metadata(index_path_and_prefix) if possible_metadata is None: @@ -226,16 +227,17 @@ def _ensure_index_metadata( return vector_dtype, distance_metric, max_vectors, dimensions # type: ignore else: vector_dtype, distance_metric, num_vectors, dimensions = possible_metadata - if max_vectors is not None and num_vectors > max_vectors: - warnings.warn( - "The number of vectors in the saved index exceeds the max_vectors parameter. " - "max_vectors is being adjusted to accommodate the dataset, but any insertions will fail." - ) - max_vectors = num_vectors - if num_vectors == max_vectors: - warnings.warn( - "The number of vectors in the saved index equals max_vectors parameter. Any insertions will fail." - ) + if warn_size_exceeded: + if max_vectors is not None and num_vectors > max_vectors: + warnings.warn( + "The number of vectors in the saved index exceeds the max_vectors parameter. " + "max_vectors is being adjusted to accommodate the dataset, but any insertions will fail." + ) + max_vectors = num_vectors + if num_vectors == max_vectors: + warnings.warn( + "The number of vectors in the saved index equals max_vectors parameter. Any insertions will fail." + ) return possible_metadata diff --git a/python/src/_dynamic_memory_index.py b/python/src/_dynamic_memory_index.py index 0346a2c76..cdf643208 100644 --- a/python/src/_dynamic_memory_index.py +++ b/python/src/_dynamic_memory_index.py @@ -144,7 +144,7 @@ def from_file( f"The file {tags_file} does not exist in {index_directory}", ) vector_dtype, dap_metric, num_vectors, dimensions = _ensure_index_metadata( - index_prefix_path, vector_dtype, distance_metric, max_vectors, dimensions + index_prefix_path, vector_dtype, distance_metric, max_vectors, dimensions, warn_size_exceeded=True ) index = cls( diff --git a/python/src/_static_memory_index.py b/python/src/_static_memory_index.py index b1ffb468d..f9bd7e8cc 100644 --- a/python/src/_static_memory_index.py +++ b/python/src/_static_memory_index.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT license. +import json import os import warnings from typing import Optional @@ -43,6 +44,7 @@ def __init__( distance_metric: Optional[DistanceMetric] = None, vector_dtype: Optional[VectorDType] = None, dimensions: Optional[int] = None, + enable_filters: bool = False ): """ ### Parameters @@ -73,8 +75,22 @@ def __init__( - **dimensions**: The vector dimensionality of this index. All new vectors inserted must be the same dimensionality. **This value is only used if a `{index_prefix}_metadata.bin` file does not exist.** If it does not exist, you are required to provide it. + - **enable_filters**: Indexes built with filters can also be used for filtered search. """ index_prefix = _valid_index_prefix(index_directory, index_prefix) + self._labels_map = {} + self._labels_metadata = {} + if enable_filters: + try: + with open(index_prefix + "_labels_map.txt", "r") as labels_map_if: + for line in labels_map_if: + (key, val) = line.split("\t") + self._labels_map[key] = int(val) + with open(f"{index_prefix}_label_metadata.json", "r") as labels_metadata_if: + self._labels_metadata = json.load(labels_metadata_if) + except: # noqa: E722 + # exceptions are basically presumed to be either file not found or file not formatted correctly + raise RuntimeException("Filter labels file was unable to be processed.") vector_dtype, metric, num_points, dims = _ensure_index_metadata( index_prefix, vector_dtype, @@ -109,7 +125,7 @@ def __init__( ) def search( - self, query: VectorLike, k_neighbors: int, complexity: int + self, query: VectorLike, k_neighbors: int, complexity: int, filter_label: str = "" ) -> QueryResponse: """ Searches the index by a single query vector. @@ -121,13 +137,25 @@ def search( - **complexity**: Size of distance ordered list of candidate neighbors to use while searching. List size increases accuracy at the cost of latency. Must be at least k_neighbors in size. """ + if filter_label != "": + if len(self._labels_map) == 0: + raise ValueError( + f"A filter label of {filter_label} was provided, but this class was not initialized with filters " + "enabled, e.g. StaticDiskMemory(..., enable_filters=True)" + ) + if filter_label not in self._labels_map: + raise ValueError( + f"A filter label of {filter_label} was provided, but the external(str)->internal(np.uint32) labels map " + f"does not include that label." + ) + k_neighbors = min(k_neighbors, self._labels_metadata[filter_label]) _query = _castable_dtype_or_raise(query, expected=self._vector_dtype) _assert(len(_query.shape) == 1, "query vector must be 1-d") _assert( _query.shape[0] == self._dimensions, f"query vector must have the same dimensionality as the index; index dimensionality: {self._dimensions}, " f"query dimensionality: {_query.shape[0]}", - ) + ) _assert_is_positive_uint32(k_neighbors, "k_neighbors") _assert_is_nonnegative_uint32(complexity, "complexity") @@ -136,9 +164,20 @@ def search( f"k_neighbors={k_neighbors} asked for, but list_size={complexity} was smaller. Increasing {complexity} to {k_neighbors}" ) complexity = k_neighbors - neighbors, distances = self._index.search(query=_query, knn=k_neighbors, complexity=complexity) + + if filter_label == "": + neighbors, distances = self._index.search(query=_query, knn=k_neighbors, complexity=complexity) + else: + filter = self._labels_map[filter_label] + neighbors, distances = self._index.search_with_filter( + query=query, + knn=k_neighbors, + complexity=complexity, + filter=filter + ) return QueryResponse(identifiers=neighbors, distances=distances) + def batch_search( self, queries: VectorLikeBatch, diff --git a/python/src/builder.cpp b/python/src/builder.cpp index 3576cab6d..e02a86d6c 100644 --- a/python/src/builder.cpp +++ b/python/src/builder.cpp @@ -31,12 +31,37 @@ template void build_disk_index(diskann::Metric, const std::string &, co template void build_disk_index(diskann::Metric, const std::string &, const std::string &, uint32_t, uint32_t, double, double, uint32_t, uint32_t); +template +std::string prepare_filtered_label_map(diskann::Index &index, const std::string &index_output_path, + const std::string &filter_labels_file, const std::string &universal_label) +{ + std::string labels_file_to_use = index_output_path + "_label_formatted.txt"; + std::string mem_labels_int_map_file = index_output_path + "_labels_map.txt"; + convert_labels_string_to_int(filter_labels_file, labels_file_to_use, mem_labels_int_map_file, universal_label); + if (!universal_label.empty()) + { + uint32_t unv_label_as_num = 0; + index.set_universal_label(unv_label_as_num); + } + return labels_file_to_use; +} + +template std::string prepare_filtered_label_map(diskann::Index &, const std::string &, + const std::string &, const std::string &); + +template std::string prepare_filtered_label_map(diskann::Index &, + const std::string &, const std::string &, const std::string &); + +template std::string prepare_filtered_label_map(diskann::Index &, + const std::string &, const std::string &, const std::string &); + template void build_memory_index(const diskann::Metric metric, const std::string &vector_bin_path, const std::string &index_output_path, const uint32_t graph_degree, const uint32_t complexity, const float alpha, const uint32_t num_threads, const bool use_pq_build, - const size_t num_pq_bytes, const bool use_opq, const uint32_t filter_complexity, - const bool use_tags) + const size_t num_pq_bytes, const bool use_opq, const bool use_tags, + const std::string &filter_labels_file, const std::string &universal_label, + const uint32_t filter_complexity) { diskann::IndexWriteParameters index_build_params = diskann::IndexWriteParametersBuilder(complexity, graph_degree) .with_filter_list_size(filter_complexity) @@ -65,23 +90,44 @@ void build_memory_index(const diskann::Metric metric, const std::string &vector_ size_t tag_dims = 1; diskann::load_bin(tags_file, tags_data, data_num, tag_dims); std::vector tags(tags_data, tags_data + data_num); - index.build(vector_bin_path.c_str(), data_num, tags); + if (filter_labels_file.empty()) + { + index.build(vector_bin_path.c_str(), data_num, tags); + } + else + { + auto labels_file = prepare_filtered_label_map(index, index_output_path, filter_labels_file, + universal_label); + index.build_filtered_index(vector_bin_path.c_str(), labels_file, data_num, tags); + } } else { - index.build(vector_bin_path.c_str(), data_num); + if (filter_labels_file.empty()) + { + index.build(vector_bin_path.c_str(), data_num); + } + else + { + auto labels_file = prepare_filtered_label_map(index, index_output_path, filter_labels_file, + universal_label); + index.build_filtered_index(vector_bin_path.c_str(), labels_file, data_num); + } } index.save(index_output_path.c_str()); } template void build_memory_index(diskann::Metric, const std::string &, const std::string &, uint32_t, uint32_t, - float, uint32_t, bool, size_t, bool, uint32_t, bool); + float, uint32_t, bool, size_t, bool, bool, const std::string &, + const std::string &, uint32_t); template void build_memory_index(diskann::Metric, const std::string &, const std::string &, uint32_t, uint32_t, - float, uint32_t, bool, size_t, bool, uint32_t, bool); + float, uint32_t, bool, size_t, bool, bool, const std::string &, + const std::string &, uint32_t); template void build_memory_index(diskann::Metric, const std::string &, const std::string &, uint32_t, uint32_t, - float, uint32_t, bool, size_t, bool, uint32_t, bool); + float, uint32_t, bool, size_t, bool, bool, const std::string &, + const std::string &, uint32_t); } // namespace diskannpy diff --git a/python/src/diskann_bindings.cpp b/python/src/diskann_bindings.cpp deleted file mode 100644 index 8b1378917..000000000 --- a/python/src/diskann_bindings.cpp +++ /dev/null @@ -1 +0,0 @@ - diff --git a/python/src/module.cpp b/python/src/module.cpp index 7aea9fc03..376515661 100644 --- a/python/src/module.cpp +++ b/python/src/module.cpp @@ -48,7 +48,8 @@ template inline void add_variant(py::module_ &m, const Variant &var m.def(variant.memory_builder_name.c_str(), &diskannpy::build_memory_index, "distance_metric"_a, "data_file_path"_a, "index_output_path"_a, "graph_degree"_a, "complexity"_a, "alpha"_a, "num_threads"_a, - "use_pq_build"_a, "num_pq_bytes"_a, "use_opq"_a, "filter_complexity"_a = 0, "use_tags"_a = false); + "use_pq_build"_a, "num_pq_bytes"_a, "use_opq"_a, "use_tags"_a = false, "filter_labels_file"_a = "", + "universal_label"_a = "", "filter_complexity"_a = 0); py::class_>(m, variant.static_memory_index_name.c_str()) .def(py::init inline void add_variant(py::module_ &m, const Variant &var "distance_metric"_a, "index_path"_a, "num_points"_a, "dimensions"_a, "num_threads"_a, "initial_search_complexity"_a) .def("search", &diskannpy::StaticMemoryIndex::search, "query"_a, "knn"_a, "complexity"_a) + .def("search_with_filter", &diskannpy::StaticMemoryIndex::search_with_filter, "query"_a, "knn"_a, + "complexity"_a, "filter"_a) .def("batch_search", &diskannpy::StaticMemoryIndex::batch_search, "queries"_a, "num_queries"_a, "knn"_a, "complexity"_a, "num_threads"_a); diff --git a/python/src/static_memory_index.cpp b/python/src/static_memory_index.cpp index 23a349fac..d3ac079af 100644 --- a/python/src/static_memory_index.cpp +++ b/python/src/static_memory_index.cpp @@ -51,6 +51,18 @@ NeighborsAndDistances StaticMemoryIndex
::search( return std::make_pair(ids, dists); } +template +NeighborsAndDistances StaticMemoryIndex
::search_with_filter( + py::array_t &query, const uint64_t knn, const uint64_t complexity, + const filterT filter) +{ + py::array_t ids(knn); + py::array_t dists(knn); + std::vector
empty_vector; + _index.search_with_filters(query.data(), filter, knn, complexity, ids.mutable_data(), dists.mutable_data()); + return std::make_pair(ids, dists); +} + template NeighborsAndDistances StaticMemoryIndex
::batch_search( py::array_t &queries, const uint64_t num_queries, const uint64_t knn, diff --git a/python/tests/test_static_memory_index.py b/python/tests/test_static_memory_index.py index a04f98928..3078f15a5 100644 --- a/python/tests/test_static_memory_index.py +++ b/python/tests/test_static_memory_index.py @@ -1,12 +1,16 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT license. +import os import shutil import unittest +from tempfile import mkdtemp + import diskannpy as dap import numpy as np from fixtures import build_random_vectors_and_memory_index, calculate_recall +from fixtures import random_vectors from sklearn.neighbors import NearestNeighbors @@ -185,4 +189,106 @@ def test_zero_threads(self): ) k = 5 - ids, dists = index.batch_search(query_vectors, k_neighbors=k, complexity=5, num_threads=0) \ No newline at end of file + ids, dists = index.batch_search(query_vectors, k_neighbors=k, complexity=5, num_threads=0) + + +class TestFilteredStaticMemoryIndex(unittest.TestCase): + def test_simple_scenario(self): + vectors: np.ndarray = random_vectors(10000, 10, dtype=np.float32, seed=54321) + query_vectors: np.ndarray = random_vectors(10, 10, dtype=np.float32) + temp = mkdtemp() + labels = [] + for idx in range(0, vectors.shape[0]): + label_list = [] + if idx % 3 == 0: + label_list.append("even_by_3") + if idx % 5 == 0: + label_list.append("even_by_5") + if len(label_list) == 0: + label_list = ["neither"] + labels.append(label_list) + try: + dap.build_memory_index( + data=vectors, + distance_metric="l2", + index_directory=temp, + complexity=64, + graph_degree=32, + num_threads=16, + filter_labels=labels, + universal_label="all", + filter_complexity=128, + ) + index = dap.StaticMemoryIndex( + index_directory=temp, + num_threads=16, + initial_search_complexity=64, + enable_filters=True + ) + + k = 50 + probable_superset, _ = index.search(query_vectors[0], k_neighbors=k*2, complexity=128) + ids_1, _ = index.search(query_vectors[0], k_neighbors=k, complexity=64, filter_label="even_by_3") + self.assertTrue(all(id % 3 == 0 for id in ids_1)) + ids_2, _ = index.search(query_vectors[0], k_neighbors=k, complexity=64, filter_label="even_by_5") + self.assertTrue(all(id % 5 == 0 for id in ids_2)) + + in_superset = np.intersect1d(probable_superset, np.append(ids_1, ids_2)).shape[0] + self.assertTrue(in_superset/k*2 > 0.98) + finally: + shutil.rmtree(temp, ignore_errors=True) + + + def test_exhaustive_validation(self): + vectors: np.ndarray = random_vectors(10000, 10, dtype=np.float32, seed=54321) + query_vectors: np.ndarray = random_vectors(10, 10, dtype=np.float32) + temp = mkdtemp() + labels = [] + for idx in range(0, vectors.shape[0]): + label_list = [] + label_list.append("all") + if idx % 2 == 0: + label_list.append("even") + else: + label_list.append("odd") + if idx % 3 == 0: + label_list.append("by_three") + labels.append(label_list) + try: + dap.build_memory_index( + data=vectors, + distance_metric="l2", + index_directory=temp, + complexity=64, + graph_degree=32, + num_threads=16, + filter_labels=labels, + universal_label="", + filter_complexity=128, + ) + index = dap.StaticMemoryIndex( + index_directory=temp, + num_threads=16, + initial_search_complexity=64, + enable_filters=True + ) + + k = 5_000 + without_filter, _ = index.search(query_vectors[0], k_neighbors=k*2, complexity=128) + with_filter_but_label_all, _ = index.search( + query_vectors[0], k_neighbors=k*2, complexity=128, filter_label="all" + ) + intersection = np.intersect1d(without_filter, with_filter_but_label_all) + intersect_count = intersection.shape[0] + self.assertEqual(intersect_count, k*2) + + ids_1, _ = index.search(query_vectors[0], k_neighbors=k*10, complexity=128, filter_label="even") + # we ask for more than 5000. prior to the addition of the `_label_metadata.json` file + # asking for more k than we had items with that label would result in nonsense results past the first + # 5000. + self.assertTrue(all(id % 2 == 0 for id in ids_1)) + ids_2, _ = index.search(query_vectors[0], k_neighbors=k, complexity=128, filter_label="odd") + self.assertTrue(all(id % 2 != 0 for id in ids_2)) + + finally: + shutil.rmtree(temp, ignore_errors=True)