From a5c6bc3b7942b12d635e878bca2430713fb0a00f Mon Sep 17 00:00:00 2001 From: Vijayan Balasubramanian Date: Mon, 30 Sep 2024 12:17:50 -0700 Subject: [PATCH] test --- .../knn/index/query/BaseQueryFactory.java | 1 + .../knn/index/query/ExactSearcher.java | 53 ++++++++++++++++++- .../opensearch/knn/index/query/KNNQuery.java | 35 +----------- .../knn/index/query/KNNQueryBuilder.java | 1 + .../knn/index/query/RNNQueryFactory.java | 4 ++ 5 files changed, 60 insertions(+), 34 deletions(-) diff --git a/src/main/java/org/opensearch/knn/index/query/BaseQueryFactory.java b/src/main/java/org/opensearch/knn/index/query/BaseQueryFactory.java index cfb604c18..34656412c 100644 --- a/src/main/java/org/opensearch/knn/index/query/BaseQueryFactory.java +++ b/src/main/java/org/opensearch/knn/index/query/BaseQueryFactory.java @@ -46,6 +46,7 @@ public static class CreateQueryRequest { private VectorDataType vectorDataType; private Map methodParameters; private Integer k; + private Float minScore; private Float radius; private QueryBuilder filter; private QueryShardContext context; diff --git a/src/main/java/org/opensearch/knn/index/query/ExactSearcher.java b/src/main/java/org/opensearch/knn/index/query/ExactSearcher.java index 193cba8c1..c23b8fc72 100644 --- a/src/main/java/org/opensearch/knn/index/query/ExactSearcher.java +++ b/src/main/java/org/opensearch/knn/index/query/ExactSearcher.java @@ -20,6 +20,7 @@ import org.opensearch.knn.common.FieldInfoExtractor; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.engine.faiss.Faiss; import org.opensearch.knn.index.query.iterators.ByteVectorIdsKNNIterator; import org.opensearch.knn.index.query.iterators.VectorIdsKNNIterator; import org.opensearch.knn.index.query.iterators.KNNIterator; @@ -52,6 +53,9 @@ public class ExactSearcher { public Map searchLeaf(final LeafReaderContext leafReaderContext, final ExactSearcherContext exactSearcherContext) throws IOException { KNNIterator iterator = getKNNIterator(leafReaderContext, exactSearcherContext); + if(exactSearcherContext.getK() == null){ // user + return doRadiusSearch(leafReaderContext, iterator, exactSearcherContext); + } if (exactSearcherContext.getMatchedDocs() != null && exactSearcherContext.getMatchedDocs().cardinality() <= exactSearcherContext.getK()) { return scoreAllDocs(iterator); @@ -59,6 +63,53 @@ public Map searchLeaf(final LeafReaderContext leafReaderContext, return searchTopK(iterator, exactSearcherContext.getK()); } + private Map doRadiusSearch(LeafReaderContext leafReaderContext, KNNIterator iterator, ExactSearcherContext exactSearcherContext) throws IOException { + + Float minScore = Float.MIN_VALUE; + KNNQuery knnQuery = exactSearcherContext.getKnnQuery(); + if (knnQuery.getMinScore() != null) { + minScore = exactSearcherContext.getKnnQuery().getMinScore(); + }else { + final SegmentReader reader = Lucene.segmentReader(leafReaderContext.reader()); + final FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField()); + final SpaceType spaceType = FieldInfoExtractor.getSpaceType(modelDao, fieldInfo); + minScore = spaceType.scoreTranslation(knnQuery.getMaxDistance()); + } + // Creating min heap and init with MAX DocID and Score as -INF. + final HitQueue queue = new HitQueue(exactSearcherContext.knnQuery.getContext().maxResultWindow, true); + ScoreDoc topDoc = queue.top(); + final Map docToScore = new HashMap<>(); + int docId; + while ((docId = iterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) { + final float currentScore = iterator.score(); + if(currentScore < minScore){ + continue; + } + if (currentScore > topDoc.score) { + topDoc.score = currentScore; + topDoc.doc = docId; + // As the HitQueue is min heap, updating top will bring the doc with -INF score or worst score we + // have seen till now on top. + topDoc = queue.updateTop(); + } + } + + // If scores are negative we will remove them. + // This is done, because there can be negative values in the Heap as we init the heap with Score as -INF. + // If filterIds < k, the some values in heap can have a negative score. + while (queue.size() > 0 && queue.top().score < 0) { + queue.pop(); + } + + while (queue.size() > 0) { + final ScoreDoc doc = queue.pop(); + docToScore.put(doc.doc, doc.score); + } + + return docToScore; + + } + private Map scoreAllDocs(KNNIterator iterator) throws IOException { final Map docToScore = new HashMap<>(); int docId; @@ -172,7 +223,7 @@ public static class ExactSearcherContext { * re-scoring we need to re-score using full precision vectors and not quantized vectors. */ boolean useQuantizedVectorsForSearch; - int k; + Integer k; BitSet matchedDocs; KNNQuery knnQuery; /** diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQuery.java b/src/main/java/org/opensearch/knn/index/query/KNNQuery.java index f5c4d3131..bee928640 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQuery.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQuery.java @@ -50,6 +50,8 @@ public class KNNQuery extends Query { private Query filterQuery; private BitSetProducer parentsFilter; private Float radius; + private Float minScore; + private Float maxDistance; private Context context; public KNNQuery( @@ -121,39 +123,6 @@ public KNNQuery(String field, float[] queryVector, String indexName, BitSetProdu this(field, queryVector, null, 0, indexName, null, parentsFilter, VectorDataType.FLOAT, null); } - /** - * Constructor for KNNQuery with radius - * - * @param radius engine radius - * @return KNNQuery - */ - public KNNQuery radius(Float radius) { - this.radius = radius; - return this; - } - - /** - * Constructor for KNNQuery with Context - * - * @param context Context for KNNQuery - * @return KNNQuery - */ - public KNNQuery kNNQueryContext(Context context) { - this.context = context; - return this; - } - - /** - * Constructor for KNNQuery with filter query - * - * @param filterQuery filter query - * @return KNNQuery - */ - public KNNQuery filterQuery(Query filterQuery) { - this.filterQuery = filterQuery; - return this; - } - /** * Constructs Weight implementation for this query * diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java index 8f7c5a3ff..45460e0b2 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java @@ -548,6 +548,7 @@ protected Query doToQuery(QueryShardContext context) { .byteVector(VectorDataType.BYTE == vectorDataType ? byteVector : null) .vectorDataType(vectorDataType) .radius(radius) + .minScore(minScore) .methodParameters(this.methodParameters) .filter(this.filter) .context(context) diff --git a/src/main/java/org/opensearch/knn/index/query/RNNQueryFactory.java b/src/main/java/org/opensearch/knn/index/query/RNNQueryFactory.java index 99152ef6b..5871b46b0 100644 --- a/src/main/java/org/opensearch/knn/index/query/RNNQueryFactory.java +++ b/src/main/java/org/opensearch/knn/index/query/RNNQueryFactory.java @@ -66,6 +66,8 @@ public static Query create(RNNQueryFactory.CreateQueryRequest createQueryRequest final String indexName = createQueryRequest.getIndexName(); final String fieldName = createQueryRequest.getFieldName(); final Float radius = createQueryRequest.getRadius(); + final Float minScore = createQueryRequest.getRadius(); + final Float maxDistance = createQueryRequest.getRadius(); final float[] vector = createQueryRequest.getVector(); final byte[] byteVector = createQueryRequest.getByteVector(); final VectorDataType vectorDataType = createQueryRequest.getVectorDataType(); @@ -88,6 +90,8 @@ public static Query create(RNNQueryFactory.CreateQueryRequest createQueryRequest .indexName(indexName) .parentsFilter(parentFilter) .radius(radius) + .maxDistance(maxDistance) + .minScore(minScore) .methodParameters(methodParameters) .context(knnQueryContext) .filterQuery(filterQuery)