From a342e7bf7f1e55c454f0cff5d4f0fd59cd443e21 Mon Sep 17 00:00:00 2001 From: Tejas Shah Date: Mon, 23 Sep 2024 15:33:32 -0700 Subject: [PATCH] Makes sure KNNVectorValues aren't recreated unnecessarily when quantization isn't needed (#2133) Signed-off-by: Tejas Shah (cherry picked from commit e33afa5de5f8658ad7fbe71125707436e81cc5b8) --- CHANGELOG.md | 1 + .../NativeEngines990KnnVectorsWriter.java | 63 +++++++++++-------- ...eEngines990KnnVectorsWriterFlushTests.java | 11 ++++ ...eEngines990KnnVectorsWriterMergeTests.java | 9 +++ 4 files changed, 59 insertions(+), 25 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d4e333da9..a730a201b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Documentation ### Maintenance ### Refactoring +* Does not create additional KNNVectorValues in NativeEngines990KNNVectorWriter when quantization is not needed [#2133](https://github.com/opensearch-project/k-NN/pull/2133) ## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.17...2.x) ### Features diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java index 3f32003ac..23cd2a4de 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java @@ -36,6 +36,7 @@ import java.io.IOException; import java.util.ArrayList; import java.util.List; +import java.util.function.Supplier; import static org.opensearch.knn.common.FieldInfoExtractor.extractVectorDataType; import static org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory.getVectorValues; @@ -82,19 +83,19 @@ public void flush(int maxDoc, final Sorter.DocMap sortMap) throws IOException { for (final NativeEngineFieldVectorsWriter field : fields) { final FieldInfo fieldInfo = field.getFieldInfo(); final VectorDataType vectorDataType = extractVectorDataType(fieldInfo); - int totalLiveDocs = getLiveDocs(getVectorValues(vectorDataType, field.getDocsWithField(), field.getVectors())); + int totalLiveDocs = field.getVectors().size(); if (totalLiveDocs > 0) { - KNNVectorValues knnVectorValues = getVectorValues(vectorDataType, field.getDocsWithField(), field.getVectors()); - - final QuantizationState quantizationState = train(field.getFieldInfo(), knnVectorValues, totalLiveDocs); + final Supplier> knnVectorValuesSupplier = () -> getVectorValues( + vectorDataType, + field.getDocsWithField(), + field.getVectors() + ); + final QuantizationState quantizationState = train(field.getFieldInfo(), knnVectorValuesSupplier, totalLiveDocs); final NativeIndexWriter writer = NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState); - - knnVectorValues = getVectorValues(vectorDataType, field.getDocsWithField(), field.getVectors()); + final KNNVectorValues knnVectorValues = knnVectorValuesSupplier.get(); StopWatch stopWatch = new StopWatch().start(); - writer.flushIndex(knnVectorValues, totalLiveDocs); - long time_in_millis = stopWatch.stop().totalTime().millis(); KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS.incrementBy(time_in_millis); log.debug("Flush took {} ms for vector field [{}]", time_in_millis, fieldInfo.getName()); @@ -110,17 +111,20 @@ public void mergeOneField(final FieldInfo fieldInfo, final MergeState mergeState flatVectorsWriter.mergeOneField(fieldInfo, mergeState); final VectorDataType vectorDataType = extractVectorDataType(fieldInfo); - int totalLiveDocs = getLiveDocs(getKNNVectorValuesForMerge(vectorDataType, fieldInfo, mergeState)); + final Supplier> knnVectorValuesSupplier = () -> getKNNVectorValuesForMerge( + vectorDataType, + fieldInfo, + mergeState + ); + int totalLiveDocs = getLiveDocs(knnVectorValuesSupplier.get()); if (totalLiveDocs == 0) { log.debug("[Merge] No live docs for field {}", fieldInfo.getName()); return; } - KNNVectorValues knnVectorValues = getKNNVectorValuesForMerge(vectorDataType, fieldInfo, mergeState); - final QuantizationState quantizationState = train(fieldInfo, knnVectorValues, totalLiveDocs); + final QuantizationState quantizationState = train(fieldInfo, knnVectorValuesSupplier, totalLiveDocs); final NativeIndexWriter writer = NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState); - - knnVectorValues = getKNNVectorValuesForMerge(vectorDataType, fieldInfo, mergeState); + final KNNVectorValues knnVectorValues = knnVectorValuesSupplier.get(); StopWatch stopWatch = new StopWatch().start(); @@ -191,27 +195,36 @@ private KNNVectorValues getKNNVectorValuesForMerge( final VectorDataType vectorDataType, final FieldInfo fieldInfo, final MergeState mergeState - ) throws IOException { - switch (fieldInfo.getVectorEncoding()) { - case FLOAT32: - FloatVectorValues mergedFloats = MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState); - return getVectorValues(vectorDataType, mergedFloats); - case BYTE: - ByteVectorValues mergedBytes = MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState); - return getVectorValues(vectorDataType, mergedBytes); - default: - throw new IllegalStateException("Unsupported vector encoding [" + fieldInfo.getVectorEncoding() + "]"); + ) { + try { + switch (fieldInfo.getVectorEncoding()) { + case FLOAT32: + FloatVectorValues mergedFloats = MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState); + return getVectorValues(vectorDataType, mergedFloats); + case BYTE: + ByteVectorValues mergedBytes = MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState); + return getVectorValues(vectorDataType, mergedBytes); + default: + throw new IllegalStateException("Unsupported vector encoding [" + fieldInfo.getVectorEncoding() + "]"); + } + } catch (final IOException e) { + log.error("Unable to merge vectors for field [{}]", fieldInfo.getName(), e); + throw new IllegalStateException("Unable to merge vectors for field [" + fieldInfo.getName() + "]", e); } } - private QuantizationState train(final FieldInfo fieldInfo, final KNNVectorValues knnVectorValues, final int totalLiveDocs) - throws IOException { + private QuantizationState train( + final FieldInfo fieldInfo, + final Supplier> knnVectorValuesSupplier, + final int totalLiveDocs + ) throws IOException { final QuantizationService quantizationService = QuantizationService.getInstance(); final QuantizationParams quantizationParams = quantizationService.getQuantizationParams(fieldInfo); QuantizationState quantizationState = null; if (quantizationParams != null && totalLiveDocs > 0) { initQuantizationStateWriterIfNecessary(); + KNNVectorValues knnVectorValues = knnVectorValuesSupplier.get(); quantizationState = quantizationService.train(quantizationParams, knnVectorValues, totalLiveDocs); quantizationStateWriter.writeState(fieldInfo.getFieldNumber(), quantizationState); } diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java index ad72f5b24..dbb564908 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java @@ -44,6 +44,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mockConstruction; import static org.mockito.Mockito.mockStatic; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -176,6 +177,11 @@ public void testFlush() { throw new RuntimeException(e); } }); + + knnVectorValuesFactoryMockedStatic.verify( + () -> KNNVectorValuesFactory.getVectorValues(any(VectorDataType.class), any(DocsWithFieldSet.class), any()), + times(expectedVectorValues.size()) + ); } } @@ -264,6 +270,11 @@ public void testFlush_WithQuantization() { throw new RuntimeException(e); } }); + + knnVectorValuesFactoryMockedStatic.verify( + () -> KNNVectorValuesFactory.getVectorValues(any(VectorDataType.class), any(DocsWithFieldSet.class), any()), + times(expectedVectorValues.size() * 2) + ); } } diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterMergeTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterMergeTests.java index 440e8bbc5..41940c4d4 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterMergeTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterMergeTests.java @@ -45,6 +45,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mockConstruction; import static org.mockito.Mockito.mockStatic; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; @@ -144,6 +145,10 @@ public void testMerge() { if (!mergedVectors.isEmpty()) { verify(nativeIndexWriter).mergeIndex(knnVectorValues, mergedVectors.size()); assertTrue(KNNGraphValue.MERGE_TOTAL_TIME_IN_MILLIS.getValue() > 0L); + knnVectorValuesFactoryMockedStatic.verify( + () -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, floatVectorValues), + times(2) + ); } else { verifyNoInteractions(nativeIndexWriter); } @@ -211,6 +216,10 @@ public void testMerge_WithQuantization() { verify(knn990QuantWriterMockedConstruction.constructed().get(0)).writeState(0, quantizationState); verify(nativeIndexWriter).mergeIndex(knnVectorValues, mergedVectors.size()); assertTrue(KNNGraphValue.MERGE_TOTAL_TIME_IN_MILLIS.getValue() > 0L); + knnVectorValuesFactoryMockedStatic.verify( + () -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, floatVectorValues), + times(3) + ); } else { assertEquals(0, knn990QuantWriterMockedConstruction.constructed().size()); verifyNoInteractions(nativeIndexWriter);