From 840c4b00b0e13c1f5cdc4eee38a92cab4516a878 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20Aum=C3=BCller?= Date: Mon, 17 Jun 2024 09:50:52 +0200 Subject: [PATCH] Assert that indexes are unique. (#526) --- ann_benchmarks/runner.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/ann_benchmarks/runner.py b/ann_benchmarks/runner.py index 4916e13e..81428114 100644 --- a/ann_benchmarks/runner.py +++ b/ann_benchmarks/runner.py @@ -66,6 +66,10 @@ def single_query(v: numpy.array) -> Tuple[float, List[Tuple[int, float]]]: start = time.time() candidates = algo.query(v, count) total = time.time() - start + + # make sure all returned indices are unique + assert len(candidates) == len(set(candidates)), "Implementation returned duplicated candidates" + candidates = [ (int(idx), float(metrics[distance].distance(v, X_train[idx]))) for idx in candidates # noqa ] @@ -105,6 +109,11 @@ def batch_query(X: numpy.array) -> List[Tuple[float, List[Tuple[int, float]]]]: batch_latencies = algo.get_batch_latencies() else: batch_latencies = [total / float(len(X))] * len(X) + + # make sure all returned indices are unique + for res in results: + assert len(res) == len(set(res)), "Implementation returned duplicated candidates" + candidates = [ [(int(idx), float(metrics[distance].distance(v, X_train[idx]))) for idx in single_results] # noqa for v, single_results in zip(X, results)