Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
VijayanB committed Sep 30, 2024
1 parent 56a27f8 commit a5c6bc3
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ public static class CreateQueryRequest {
private VectorDataType vectorDataType;
private Map<String, ?> methodParameters;
private Integer k;
private Float minScore;
private Float radius;
private QueryBuilder filter;
private QueryShardContext context;
Expand Down
53 changes: 52 additions & 1 deletion src/main/java/org/opensearch/knn/index/query/ExactSearcher.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.opensearch.knn.common.FieldInfoExtractor;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.faiss.Faiss;
import org.opensearch.knn.index.query.iterators.ByteVectorIdsKNNIterator;
import org.opensearch.knn.index.query.iterators.VectorIdsKNNIterator;
import org.opensearch.knn.index.query.iterators.KNNIterator;
Expand Down Expand Up @@ -52,13 +53,63 @@ public class ExactSearcher {
public Map<Integer, Float> searchLeaf(final LeafReaderContext leafReaderContext, final ExactSearcherContext exactSearcherContext)
throws IOException {
KNNIterator iterator = getKNNIterator(leafReaderContext, exactSearcherContext);
if(exactSearcherContext.getK() == null){ // user
return doRadiusSearch(leafReaderContext, iterator, exactSearcherContext);
}
if (exactSearcherContext.getMatchedDocs() != null
&& exactSearcherContext.getMatchedDocs().cardinality() <= exactSearcherContext.getK()) {
return scoreAllDocs(iterator);
}
return searchTopK(iterator, exactSearcherContext.getK());
}

private Map<Integer, Float> doRadiusSearch(LeafReaderContext leafReaderContext, KNNIterator iterator, ExactSearcherContext exactSearcherContext) throws IOException {

Float minScore = Float.MIN_VALUE;
KNNQuery knnQuery = exactSearcherContext.getKnnQuery();
if (knnQuery.getMinScore() != null) {
minScore = exactSearcherContext.getKnnQuery().getMinScore();
}else {
final SegmentReader reader = Lucene.segmentReader(leafReaderContext.reader());
final FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField());
final SpaceType spaceType = FieldInfoExtractor.getSpaceType(modelDao, fieldInfo);
minScore = spaceType.scoreTranslation(knnQuery.getMaxDistance());
}
// Creating min heap and init with MAX DocID and Score as -INF.
final HitQueue queue = new HitQueue(exactSearcherContext.knnQuery.getContext().maxResultWindow, true);
ScoreDoc topDoc = queue.top();
final Map<Integer, Float> docToScore = new HashMap<>();
int docId;
while ((docId = iterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
final float currentScore = iterator.score();
if(currentScore < minScore){
continue;
}
if (currentScore > topDoc.score) {
topDoc.score = currentScore;
topDoc.doc = docId;
// As the HitQueue is min heap, updating top will bring the doc with -INF score or worst score we
// have seen till now on top.
topDoc = queue.updateTop();
}
}

// If scores are negative we will remove them.
// This is done, because there can be negative values in the Heap as we init the heap with Score as -INF.
// If filterIds < k, the some values in heap can have a negative score.
while (queue.size() > 0 && queue.top().score < 0) {
queue.pop();
}

while (queue.size() > 0) {
final ScoreDoc doc = queue.pop();
docToScore.put(doc.doc, doc.score);
}

return docToScore;

}

private Map<Integer, Float> scoreAllDocs(KNNIterator iterator) throws IOException {
final Map<Integer, Float> docToScore = new HashMap<>();
int docId;
Expand Down Expand Up @@ -172,7 +223,7 @@ public static class ExactSearcherContext {
* re-scoring we need to re-score using full precision vectors and not quantized vectors.
*/
boolean useQuantizedVectorsForSearch;
int k;
Integer k;
BitSet matchedDocs;
KNNQuery knnQuery;
/**
Expand Down
35 changes: 2 additions & 33 deletions src/main/java/org/opensearch/knn/index/query/KNNQuery.java
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ public class KNNQuery extends Query {
private Query filterQuery;
private BitSetProducer parentsFilter;
private Float radius;
private Float minScore;
private Float maxDistance;
private Context context;

public KNNQuery(
Expand Down Expand Up @@ -121,39 +123,6 @@ public KNNQuery(String field, float[] queryVector, String indexName, BitSetProdu
this(field, queryVector, null, 0, indexName, null, parentsFilter, VectorDataType.FLOAT, null);
}

/**
* Constructor for KNNQuery with radius
*
* @param radius engine radius
* @return KNNQuery
*/
public KNNQuery radius(Float radius) {
this.radius = radius;
return this;
}

/**
* Constructor for KNNQuery with Context
*
* @param context Context for KNNQuery
* @return KNNQuery
*/
public KNNQuery kNNQueryContext(Context context) {
this.context = context;
return this;
}

/**
* Constructor for KNNQuery with filter query
*
* @param filterQuery filter query
* @return KNNQuery
*/
public KNNQuery filterQuery(Query filterQuery) {
this.filterQuery = filterQuery;
return this;
}

/**
* Constructs Weight implementation for this query
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,7 @@ protected Query doToQuery(QueryShardContext context) {
.byteVector(VectorDataType.BYTE == vectorDataType ? byteVector : null)
.vectorDataType(vectorDataType)
.radius(radius)
.minScore(minScore)
.methodParameters(this.methodParameters)
.filter(this.filter)
.context(context)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ public static Query create(RNNQueryFactory.CreateQueryRequest createQueryRequest
final String indexName = createQueryRequest.getIndexName();
final String fieldName = createQueryRequest.getFieldName();
final Float radius = createQueryRequest.getRadius();
final Float minScore = createQueryRequest.getRadius();
final Float maxDistance = createQueryRequest.getRadius();
final float[] vector = createQueryRequest.getVector();
final byte[] byteVector = createQueryRequest.getByteVector();
final VectorDataType vectorDataType = createQueryRequest.getVectorDataType();
Expand All @@ -88,6 +90,8 @@ public static Query create(RNNQueryFactory.CreateQueryRequest createQueryRequest
.indexName(indexName)
.parentsFilter(parentFilter)
.radius(radius)
.maxDistance(maxDistance)
.minScore(minScore)
.methodParameters(methodParameters)
.context(knnQueryContext)
.filterQuery(filterQuery)
Expand Down

0 comments on commit a5c6bc3

Please sign in to comment.