From 575e9e628e12aee776cacd8948dff90157629414 Mon Sep 17 00:00:00 2001 From: Brian Sam-Bodden Date: Thu, 24 Aug 2023 08:10:01 -0700 Subject: [PATCH] feature: implement VSS (Vectorize and Search) for JSON documents --- .../redis/om/spring/RediSearchIndexer.java | 6 +- .../spring/RedisEnhancedKeyValueAdapter.java | 2 +- .../om/spring/RedisJSONKeyValueAdapter.java | 51 ++++---- .../om/spring/RedisModulesConfiguration.java | 10 +- .../redis/om/spring/audit/EntityAuditor.java | 4 + .../spring/metamodel/indexed/VectorField.java | 4 + .../RedisDocumentRepositoryFactory.java | 18 ++- .../RedisDocumentRepositoryFactoryBean.java | 5 +- .../SimpleRedisDocumentRepository.java | 29 ++--- .../SimpleRedisEnhancedRepository.java | 2 +- .../search/stream/SearchStreamImpl.java | 5 +- .../predicates/vector/KNNPredicate.java | 21 +++- .../com/redis/om/spring/util/ObjectUtils.java | 16 ++- .../vectorize/DefaultFeatureExtractor.java | 60 +++++++--- .../om/spring/vectorize/FeatureExtractor.java | 16 ++- .../vectorize/NoopFeatureExtractor.java | 23 +++- .../document/fixtures/Product.java | 52 ++++++++ .../document/fixtures/ProductRepository.java | 9 ++ .../vectorize/VectorizeDocumentTest.java | 111 ++++++++++++++++++ .../hash/serialization/SerializationTest.java | 2 +- .../vectorize/VectorizeHashTest.java} | 4 +- .../vectorize/face/FaceDetectionTest.java | 2 +- 22 files changed, 349 insertions(+), 103 deletions(-) create mode 100644 redis-om-spring/src/test/java/com/redis/om/spring/annotations/document/fixtures/Product.java create mode 100644 redis-om-spring/src/test/java/com/redis/om/spring/annotations/document/fixtures/ProductRepository.java create mode 100644 redis-om-spring/src/test/java/com/redis/om/spring/annotations/document/vectorize/VectorizeDocumentTest.java rename redis-om-spring/src/test/java/com/redis/om/spring/annotations/{vectorize/VectorizeTest.java => hash/vectorize/VectorizeHashTest.java} (97%) rename redis-om-spring/src/test/java/com/redis/om/spring/annotations/{ => hash}/vectorize/face/FaceDetectionTest.java (97%) diff --git a/redis-om-spring/src/main/java/com/redis/om/spring/RediSearchIndexer.java b/redis-om-spring/src/main/java/com/redis/om/spring/RediSearchIndexer.java index d9210779..ecf53c6c 100644 --- a/redis-om-spring/src/main/java/com/redis/om/spring/RediSearchIndexer.java +++ b/redis-om-spring/src/main/java/com/redis/om/spring/RediSearchIndexer.java @@ -347,11 +347,7 @@ private Field indexAsVectorFieldFor(java.lang.reflect.Field field, boolean isDoc Indexed indexed) { TypeInformation typeInfo = TypeInformation.of(field.getType()); String fieldPrefix = getFieldPrefix(prefix, isDocument); - - String fieldPostfix = (isDocument && typeInfo.isCollectionLike() && !field.isAnnotationPresent(JsonAdapter.class)) - ? "[*]" - : ""; - String fieldName = fieldPrefix + field.getName() + fieldPostfix; + String fieldName = fieldPrefix + field.getName(); Map attributes = new HashMap<>(); attributes.put("TYPE", indexed.type().toString()); diff --git a/redis-om-spring/src/main/java/com/redis/om/spring/RedisEnhancedKeyValueAdapter.java b/redis-om-spring/src/main/java/com/redis/om/spring/RedisEnhancedKeyValueAdapter.java index bca37840..ecd0f72c 100644 --- a/redis-om-spring/src/main/java/com/redis/om/spring/RedisEnhancedKeyValueAdapter.java +++ b/redis-om-spring/src/main/java/com/redis/om/spring/RedisEnhancedKeyValueAdapter.java @@ -134,7 +134,7 @@ public Object put(Object id, Object item, String keyspace) { } else { byte[] redisKey = createKey(keyspace, converter.getConversionService().convert(id, String.class)); auditor.processEntity(redisKey, item); - featureExtractor.processEntity(redisKey, item); + featureExtractor.processEntity(item); rdo = new RedisData(); converter.write(item, rdo); diff --git a/redis-om-spring/src/main/java/com/redis/om/spring/RedisJSONKeyValueAdapter.java b/redis-om-spring/src/main/java/com/redis/om/spring/RedisJSONKeyValueAdapter.java index 31cc4516..1691913c 100644 --- a/redis-om-spring/src/main/java/com/redis/om/spring/RedisJSONKeyValueAdapter.java +++ b/redis-om-spring/src/main/java/com/redis/om/spring/RedisJSONKeyValueAdapter.java @@ -3,17 +3,17 @@ import com.google.gson.Gson; import com.google.gson.GsonBuilder; import com.google.gson.reflect.TypeToken; +import com.redis.om.spring.audit.EntityAuditor; import com.redis.om.spring.convert.RedisOMCustomConversions; import com.redis.om.spring.ops.RedisModulesOperations; import com.redis.om.spring.ops.json.JSONOperations; import com.redis.om.spring.ops.search.SearchOperations; import com.redis.om.spring.util.ObjectUtils; +import com.redis.om.spring.vectorize.FeatureExtractor; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.springframework.beans.*; import org.springframework.dao.OptimisticLockingFailureException; -import org.springframework.data.annotation.CreatedDate; -import org.springframework.data.annotation.LastModifiedDate; import org.springframework.data.annotation.Reference; import org.springframework.data.annotation.Version; import org.springframework.data.redis.core.RedisCallback; @@ -34,9 +34,10 @@ import java.lang.reflect.Field; import java.lang.reflect.Method; -import java.time.LocalDate; -import java.time.LocalDateTime; -import java.util.*; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Optional; import java.util.concurrent.TimeUnit; import static com.redis.om.spring.util.ObjectUtils.getKey; @@ -50,6 +51,8 @@ public class RedisJSONKeyValueAdapter extends RedisKeyValueAdapter { private final RedisModulesOperations modulesOperations; private final RediSearchIndexer indexer; private final GsonBuilder gsonBuilder; + private final EntityAuditor auditor; + private final FeatureExtractor featureExtractor; private final RedisOMProperties redisOMProperties; /** @@ -62,16 +65,24 @@ public class RedisJSONKeyValueAdapter extends RedisKeyValueAdapter { * @param keyspaceToIndexMap must not be {@literal null}. */ @SuppressWarnings("unchecked") - public RedisJSONKeyValueAdapter(RedisOperations redisOps, RedisModulesOperations rmo, - RedisMappingContext mappingContext, RediSearchIndexer keyspaceToIndexMap, GsonBuilder gsonBuilder, - RedisOMProperties redisOMProperties) { + public RedisJSONKeyValueAdapter( // + RedisOperations redisOps, // + RedisModulesOperations rmo, // + RedisMappingContext mappingContext, // + RediSearchIndexer keyspaceToIndexMap, // + GsonBuilder gsonBuilder, // + FeatureExtractor featureExtractor, // + RedisOMProperties redisOMProperties + ) { super(redisOps, mappingContext, new RedisOMCustomConversions()); this.modulesOperations = (RedisModulesOperations) rmo; this.redisJSONOperations = modulesOperations.opsForJSON(); this.redisOperations = redisOps; this.mappingContext = mappingContext; this.indexer = keyspaceToIndexMap; + this.auditor = new EntityAuditor(this.redisOperations); this.gsonBuilder = gsonBuilder; + this.featureExtractor = featureExtractor; this.redisOMProperties = redisOMProperties; } @@ -91,7 +102,8 @@ public Object put(Object id, Object item, String keyspace) { String key = getKey(keyspace, id); processVersion(key, item); - processAuditAnnotations(key, item); + auditor.processEntity(key, item); + featureExtractor.processEntity(item); Optional maybeTtl = getTTLForEntity(item); ops.set(key, item); @@ -256,27 +268,6 @@ public boolean contains(Object id, String keyspace) { return exists != null && exists; } - private void processAuditAnnotations(String key, Object item) { - boolean isNew = (boolean) redisOperations - .execute((RedisCallback) connection -> !connection.keyCommands().exists(toBytes(key))); - - var auditClass = isNew ? CreatedDate.class : LastModifiedDate.class; - - List fields = ObjectUtils.getFieldsWithAnnotation(item.getClass(), auditClass); - if (!fields.isEmpty()) { - PropertyAccessor accessor = PropertyAccessorFactory.forBeanPropertyAccess(item); - fields.forEach(f -> { - if (f.getType() == Date.class) { - accessor.setPropertyValue(f.getName(), new Date(System.currentTimeMillis())); - } else if (f.getType() == LocalDateTime.class) { - accessor.setPropertyValue(f.getName(), LocalDateTime.now()); - } else if (f.getType() == LocalDate.class) { - accessor.setPropertyValue(f.getName(), LocalDate.now()); - } - }); - } - } - private void processReferences(String key, Object item) { List fields = ObjectUtils.getFieldsWithAnnotation(item.getClass(), Reference.class); if (!fields.isEmpty()) { diff --git a/redis-om-spring/src/main/java/com/redis/om/spring/RedisModulesConfiguration.java b/redis-om-spring/src/main/java/com/redis/om/spring/RedisModulesConfiguration.java index 7ffe2c59..a8e39fa4 100644 --- a/redis-om-spring/src/main/java/com/redis/om/spring/RedisModulesConfiguration.java +++ b/redis-om-spring/src/main/java/com/redis/om/spring/RedisModulesConfiguration.java @@ -293,9 +293,10 @@ RedisJSONKeyValueAdapter getRedisJSONKeyValueAdapter( // RedisMappingContext mappingContext, // RediSearchIndexer indexer, // @Qualifier("omGsonBuilder") GsonBuilder gsonBuilder, // - RedisOMProperties properties // + RedisOMProperties properties, // + @Nullable @Qualifier("featureExtractor") FeatureExtractor featureExtractor ) { - return new RedisJSONKeyValueAdapter(redisOps, redisModulesOperations, mappingContext, indexer, gsonBuilder, properties); + return new RedisJSONKeyValueAdapter(redisOps, redisModulesOperations, mappingContext, indexer, gsonBuilder, featureExtractor, properties); } @Bean(name = "redisJSONKeyValueTemplate") @@ -305,10 +306,11 @@ public CustomRedisKeyValueTemplate getRedisJSONKeyValueTemplate( // RedisMappingContext mappingContext, // RediSearchIndexer indexer, // @Qualifier("omGsonBuilder") GsonBuilder gsonBuilder, // - RedisOMProperties properties // + RedisOMProperties properties, // + @Nullable @Qualifier("featureExtractor") FeatureExtractor featureExtractor ) { return new CustomRedisKeyValueTemplate( - new RedisJSONKeyValueAdapter(redisOps, redisModulesOperations, mappingContext, indexer, gsonBuilder, properties), + new RedisJSONKeyValueAdapter(redisOps, redisModulesOperations, mappingContext, indexer, gsonBuilder, featureExtractor, properties), mappingContext); } diff --git a/redis-om-spring/src/main/java/com/redis/om/spring/audit/EntityAuditor.java b/redis-om-spring/src/main/java/com/redis/om/spring/audit/EntityAuditor.java index 9e43c7c1..265f1d44 100644 --- a/redis-om-spring/src/main/java/com/redis/om/spring/audit/EntityAuditor.java +++ b/redis-om-spring/src/main/java/com/redis/om/spring/audit/EntityAuditor.java @@ -26,6 +26,10 @@ public void processEntity(byte[] redisKey, Object item) { processEntity(item, isNew); } + public void processEntity(String redisKey, Object item) { + processEntity(redisKey.getBytes(), item); + } + public void processEntity(Object item, boolean isNew) { var auditClass = isNew ? CreatedDate.class : LastModifiedDate.class; diff --git a/redis-om-spring/src/main/java/com/redis/om/spring/metamodel/indexed/VectorField.java b/redis-om-spring/src/main/java/com/redis/om/spring/metamodel/indexed/VectorField.java index 1428c331..c038181a 100644 --- a/redis-om-spring/src/main/java/com/redis/om/spring/metamodel/indexed/VectorField.java +++ b/redis-om-spring/src/main/java/com/redis/om/spring/metamodel/indexed/VectorField.java @@ -12,4 +12,8 @@ public VectorField(SearchFieldAccessor field, boolean indexed) { public KNNPredicate knn(int k, byte[] blobAttribute) { return new KNNPredicate<>(searchFieldAccessor,k, blobAttribute); } + + public KNNPredicate knn(int k, float[] blobAttribute) { + return new KNNPredicate<>(searchFieldAccessor,k, blobAttribute); + } } diff --git a/redis-om-spring/src/main/java/com/redis/om/spring/repository/support/RedisDocumentRepositoryFactory.java b/redis-om-spring/src/main/java/com/redis/om/spring/repository/support/RedisDocumentRepositoryFactory.java index b2a17892..350035e5 100644 --- a/redis-om-spring/src/main/java/com/redis/om/spring/repository/support/RedisDocumentRepositoryFactory.java +++ b/redis-om-spring/src/main/java/com/redis/om/spring/repository/support/RedisDocumentRepositoryFactory.java @@ -5,6 +5,7 @@ import com.redis.om.spring.RedisOMProperties; import com.redis.om.spring.ops.RedisModulesOperations; import com.redis.om.spring.repository.query.RediSearchQuery; +import com.redis.om.spring.vectorize.FeatureExtractor; import org.springframework.beans.BeanUtils; import org.springframework.data.keyvalue.core.KeyValueOperations; import org.springframework.data.keyvalue.repository.query.KeyValuePartTreeQuery; @@ -41,6 +42,7 @@ public class RedisDocumentRepositoryFactory extends KeyValueRepositoryFactory { private final RediSearchIndexer indexer; private final GsonBuilder gsonBuilder; private final RedisMappingContext mappingContext; + private final FeatureExtractor featureExtractor; private final RedisOMProperties properties; /** @@ -59,9 +61,10 @@ public RedisDocumentRepositoryFactory( // RediSearchIndexer keyspaceToIndexMap, // RedisMappingContext mappingContext, // GsonBuilder gsonBuilder, // + FeatureExtractor featureExtractor, // RedisOMProperties properties // ) { - this(keyValueOperations, rmo, keyspaceToIndexMap, DEFAULT_QUERY_CREATOR, mappingContext, gsonBuilder, properties); + this(keyValueOperations, rmo, keyspaceToIndexMap, DEFAULT_QUERY_CREATOR, mappingContext, gsonBuilder, featureExtractor, properties); } /** @@ -82,10 +85,11 @@ public RedisDocumentRepositoryFactory( // Class> queryCreator, // RedisMappingContext mappingContext, // GsonBuilder gsonBuilder, // + FeatureExtractor featureExtractor, // RedisOMProperties properties // ) { - this(keyValueOperations, rmo, keyspaceToIndexMap, queryCreator, RediSearchQuery.class, mappingContext, gsonBuilder, properties); + this(keyValueOperations, rmo, keyspaceToIndexMap, queryCreator, RediSearchQuery.class, mappingContext, gsonBuilder, featureExtractor, properties); } /** @@ -108,12 +112,19 @@ public RedisDocumentRepositoryFactory( // Class repositoryQueryType, // RedisMappingContext mappingContext, // GsonBuilder gsonBuilder, // + FeatureExtractor featureExtractor, // RedisOMProperties properties // ) { super(keyValueOperations, queryCreator, repositoryQueryType); Assert.notNull(rmo, "RedisModulesOperations must not be null!"); + Assert.notNull(keyValueOperations, "KeyValueOperations must not be null!"); + Assert.notNull(rmo, "RedisModulesOperations must not be null!"); + Assert.notNull(queryCreator, "Query creator type must not be null!"); + Assert.notNull(repositoryQueryType, "RepositoryQueryType type must not be null!"); + Assert.notNull(featureExtractor, "FeatureExtractor type must not be null!"); + Assert.notNull(properties, "RedisOMSpringProperties type must not be null!"); this.keyValueOperations = keyValueOperations; this.rmo = rmo; @@ -122,6 +133,7 @@ public RedisDocumentRepositoryFactory( // this.repositoryQueryType = repositoryQueryType; this.mappingContext = mappingContext; this.gsonBuilder = gsonBuilder; + this.featureExtractor = featureExtractor; this.properties = properties; } @@ -129,7 +141,7 @@ public RedisDocumentRepositoryFactory( // protected Object getTargetRepository(RepositoryInformation repositoryInformation) { EntityInformation entityInformation = getEntityInformation(repositoryInformation.getDomainType()); return super.getTargetRepositoryViaReflection( - repositoryInformation, entityInformation, keyValueOperations, rmo, indexer, mappingContext, gsonBuilder, properties); + repositoryInformation, entityInformation, keyValueOperations, rmo, indexer, mappingContext, gsonBuilder, featureExtractor, properties); } @Override diff --git a/redis-om-spring/src/main/java/com/redis/om/spring/repository/support/RedisDocumentRepositoryFactoryBean.java b/redis-om-spring/src/main/java/com/redis/om/spring/repository/support/RedisDocumentRepositoryFactoryBean.java index dcba5f59..fd2ef847 100644 --- a/redis-om-spring/src/main/java/com/redis/om/spring/repository/support/RedisDocumentRepositoryFactoryBean.java +++ b/redis-om-spring/src/main/java/com/redis/om/spring/repository/support/RedisDocumentRepositoryFactoryBean.java @@ -4,6 +4,7 @@ import com.redis.om.spring.RediSearchIndexer; import com.redis.om.spring.RedisOMProperties; import com.redis.om.spring.ops.RedisModulesOperations; +import com.redis.om.spring.vectorize.FeatureExtractor; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.data.keyvalue.core.KeyValueOperations; import org.springframework.data.keyvalue.repository.support.KeyValueRepositoryFactoryBean; @@ -26,6 +27,8 @@ public class RedisDocumentRepositoryFactoryBean, S, @Autowired private GsonBuilder gsonBuilder; @Autowired + private @Nullable FeatureExtractor featureExtractor; + @Autowired private RedisOMProperties properties; /** @@ -46,7 +49,7 @@ protected final RedisDocumentRepositoryFactory createRepositoryFactory( // Class repositoryQueryType // ) { return new RedisDocumentRepositoryFactory(operations, rmo, indexer, queryCreator, repositoryQueryType, - this.mappingContext, this.gsonBuilder, this.properties); + this.mappingContext, this.gsonBuilder, this.featureExtractor, this.properties); } @Override diff --git a/redis-om-spring/src/main/java/com/redis/om/spring/repository/support/SimpleRedisDocumentRepository.java b/redis-om-spring/src/main/java/com/redis/om/spring/repository/support/SimpleRedisDocumentRepository.java index 45ff95a1..867eb204 100644 --- a/redis-om-spring/src/main/java/com/redis/om/spring/repository/support/SimpleRedisDocumentRepository.java +++ b/redis-om-spring/src/main/java/com/redis/om/spring/repository/support/SimpleRedisDocumentRepository.java @@ -6,6 +6,7 @@ import com.google.gson.reflect.TypeToken; import com.redis.om.spring.RediSearchIndexer; import com.redis.om.spring.RedisOMProperties; +import com.redis.om.spring.audit.EntityAuditor; import com.redis.om.spring.convert.MappingRedisOMConverter; import com.redis.om.spring.id.ULIDIdentifierGenerator; import com.redis.om.spring.metamodel.MetamodelField; @@ -18,6 +19,7 @@ import com.redis.om.spring.search.stream.FluentQueryByExample; import com.redis.om.spring.serialization.gson.GsonListOfType; import com.redis.om.spring.util.ObjectUtils; +import com.redis.om.spring.vectorize.FeatureExtractor; import org.springframework.beans.BeanWrapper; import org.springframework.beans.BeanWrapperImpl; import org.springframework.beans.PropertyAccessor; @@ -82,6 +84,8 @@ public class SimpleRedisDocumentRepository extends SimpleKeyValueReposito private final RedisOMProperties properties; private final RedisMappingContext mappingContext; private final EntityStream entityStream; + protected final EntityAuditor auditor; + protected final FeatureExtractor featureExtractor; @SuppressWarnings("unchecked") public SimpleRedisDocumentRepository( // @@ -91,6 +95,7 @@ public SimpleRedisDocumentRepository( // RediSearchIndexer indexer, // RedisMappingContext mappingContext, GsonBuilder gsonBuilder, + FeatureExtractor featureExtractor, // RedisOMProperties properties) { super(metadata, operations); this.modulesOperations = (RedisModulesOperations) rmo; @@ -102,6 +107,8 @@ public SimpleRedisDocumentRepository( // this.generator = ULIDIdentifierGenerator.INSTANCE; this.gsonBuilder = gsonBuilder; this.mappingContext = mappingContext; + this.auditor = new EntityAuditor(modulesOperations.getTemplate()); + this.featureExtractor = featureExtractor; this.properties = properties; this.entityStream = new EntityStreamImpl(modulesOperations, modulesOperations.getGsonBuilder(), indexer); } @@ -179,7 +186,9 @@ public List saveAll(Iterable entities) { String keyspace = keyValueEntity.getKeySpace(); byte[] objectKey = createKey(keyspace, Objects.requireNonNull(id).toString()); - processAuditAnnotations(entity, isNew); + // process entity pre-save mutation + auditor.processEntity(entity, isNew); + featureExtractor.processEntity(entity); Optional maybeTtl = getTTLForEntity(entity); @@ -244,24 +253,6 @@ public byte[] createKey(String keyspace, String id) { return this.mappingConverter.toBytes(keyspace + ":" + id); } - private void processAuditAnnotations(Object item, boolean isNew) { - var auditClass = isNew ? CreatedDate.class : LastModifiedDate.class; - - List fields = com.redis.om.spring.util.ObjectUtils.getFieldsWithAnnotation(item.getClass(), auditClass); - if (!fields.isEmpty()) { - PropertyAccessor accessor = PropertyAccessorFactory.forBeanPropertyAccess(item); - fields.forEach(f -> { - if (f.getType() == Date.class) { - accessor.setPropertyValue(f.getName(), new Date(System.currentTimeMillis())); - } else if (f.getType() == LocalDateTime.class) { - accessor.setPropertyValue(f.getName(), LocalDateTime.now()); - } else if (f.getType() == LocalDate.class) { - accessor.setPropertyValue(f.getName(), LocalDate.now()); - } - }); - } - } - private void processReferenceAnnotations(byte[] objectKey, Object entity, Pipeline pipeline) { List fields = com.redis.om.spring.util.ObjectUtils.getFieldsWithAnnotation(entity.getClass(), Reference.class); if (!fields.isEmpty()) { diff --git a/redis-om-spring/src/main/java/com/redis/om/spring/repository/support/SimpleRedisEnhancedRepository.java b/redis-om-spring/src/main/java/com/redis/om/spring/repository/support/SimpleRedisEnhancedRepository.java index 37f45734..5038ac44 100644 --- a/redis-om-spring/src/main/java/com/redis/om/spring/repository/support/SimpleRedisEnhancedRepository.java +++ b/redis-om-spring/src/main/java/com/redis/om/spring/repository/support/SimpleRedisEnhancedRepository.java @@ -250,7 +250,7 @@ public List saveAll(Iterable entities) { String keyspace = keyValueEntity.getKeySpace(); byte[] objectKey = createKey(keyspace, id.toString()); - // process entity pre-save mutation entities + // process entity pre-save mutation auditor.processEntity(entity, isNew); featureExtractor.processEntity(entity); diff --git a/redis-om-spring/src/main/java/com/redis/om/spring/search/stream/SearchStreamImpl.java b/redis-om-spring/src/main/java/com/redis/om/spring/search/stream/SearchStreamImpl.java index e58853a2..fcb5ce73 100644 --- a/redis-om-spring/src/main/java/com/redis/om/spring/search/stream/SearchStreamImpl.java +++ b/redis-om-spring/src/main/java/com/redis/om/spring/search/stream/SearchStreamImpl.java @@ -16,8 +16,8 @@ import com.redis.om.spring.tuple.AbstractTupleMapper; import com.redis.om.spring.tuple.Pair; import com.redis.om.spring.tuple.TupleMapper; -import com.redis.om.spring.util.SearchResultRawResponseToObjectConverter; import com.redis.om.spring.util.ObjectUtils; +import com.redis.om.spring.util.SearchResultRawResponseToObjectConverter; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.springframework.beans.BeanUtils; @@ -42,6 +42,7 @@ import java.util.stream.*; import static com.redis.om.spring.metamodel.MetamodelUtils.getMetamodelForIdField; +import static com.redis.om.spring.util.ObjectUtils.floatArrayToByteArray; import static java.util.stream.Collectors.toCollection; public class SearchStreamImpl implements SearchStream { @@ -410,7 +411,7 @@ Query prepareQuery() { if (knnPredicate != null) { query = new Query(knnPredicate.apply(rootNode).toString()); - query.addParam(knnPredicate.getBlobAttributeName(), knnPredicate.getBlobAttribute()); + query.addParam(knnPredicate.getBlobAttributeName(), knnPredicate.getBlobAttribute() != null ? knnPredicate.getBlobAttribute() : floatArrayToByteArray(knnPredicate.getDoublesAttribute())); query.addParam("K", knnPredicate.getK()); query.dialect(2); } else { diff --git a/redis-om-spring/src/main/java/com/redis/om/spring/search/stream/predicates/vector/KNNPredicate.java b/redis-om-spring/src/main/java/com/redis/om/spring/search/stream/predicates/vector/KNNPredicate.java index d38ff40a..92297717 100644 --- a/redis-om-spring/src/main/java/com/redis/om/spring/search/stream/predicates/vector/KNNPredicate.java +++ b/redis-om-spring/src/main/java/com/redis/om/spring/search/stream/predicates/vector/KNNPredicate.java @@ -7,12 +7,21 @@ public class KNNPredicate extends BaseAbstractPredicate { private final int k; - private final byte[] blobAttribute; + private final byte[] blob; + private final float[] floats; - public KNNPredicate(SearchFieldAccessor field, int k, byte[] blobAttribute) { + public KNNPredicate(SearchFieldAccessor field, int k, byte[] blob) { super(field); this.k = k; - this.blobAttribute = blobAttribute; + this.blob = blob; + this.floats = null; + } + + public KNNPredicate(SearchFieldAccessor field, int k, float[] floats) { + super(field); + this.k = k; + this.blob = null; + this.floats = floats; } public int getK() { @@ -20,7 +29,11 @@ public int getK() { } public byte[] getBlobAttribute() { - return blobAttribute; + return blob; + } + + public float[] getDoublesAttribute() { + return floats; } public String getBlobAttributeName() { diff --git a/redis-om-spring/src/main/java/com/redis/om/spring/util/ObjectUtils.java b/redis-om-spring/src/main/java/com/redis/om/spring/util/ObjectUtils.java index a03d94d7..1ec4f28b 100644 --- a/redis-om-spring/src/main/java/com/redis/om/spring/util/ObjectUtils.java +++ b/redis-om-spring/src/main/java/com/redis/om/spring/util/ObjectUtils.java @@ -36,6 +36,7 @@ import java.math.BigInteger; import java.nio.ByteBuffer; import java.nio.ByteOrder; +import java.nio.FloatBuffer; import java.util.*; import java.util.function.Function; import java.util.stream.Collectors; @@ -407,11 +408,24 @@ public static byte[] floatArrayToByteArray(float[] input) { } public static byte[] longArrayToByteArray(long[] input) { + return floatArrayToByteArray(longArrayToFloatArray(input)); + } + + public static float[] longArrayToFloatArray(long[] input) { float[] floats = new float[input.length]; for (int i = 0; i < input.length; i++) { floats[i] = input[i]; } - return floatArrayToByteArray(floats); + return floats; + } + + public static float[] byteArrayToFloatArray(byte[] bytes) { + ByteBuffer byteBuffer = ByteBuffer.wrap(bytes); + byteBuffer.order(ByteOrder.LITTLE_ENDIAN); + FloatBuffer floatBuffer = byteBuffer.asFloatBuffer(); + float[] floatArray = new float[floatBuffer.capacity()]; + floatBuffer.get(floatArray); + return floatArray; } public static Collection instantiateCollection(Type type) { diff --git a/redis-om-spring/src/main/java/com/redis/om/spring/vectorize/DefaultFeatureExtractor.java b/redis-om-spring/src/main/java/com/redis/om/spring/vectorize/DefaultFeatureExtractor.java index a8628ac7..b4c5eac7 100644 --- a/redis-om-spring/src/main/java/com/redis/om/spring/vectorize/DefaultFeatureExtractor.java +++ b/redis-om-spring/src/main/java/com/redis/om/spring/vectorize/DefaultFeatureExtractor.java @@ -9,6 +9,7 @@ import ai.djl.repository.zoo.ZooModel; import ai.djl.translate.Pipeline; import ai.djl.translate.TranslateException; +import com.redis.om.spring.annotations.Document; import com.redis.om.spring.annotations.Vectorize; import com.redis.om.spring.util.ObjectUtils; import org.apache.commons.logging.Log; @@ -23,12 +24,15 @@ import java.lang.reflect.Field; import java.util.List; +import static com.redis.om.spring.util.ObjectUtils.byteArrayToFloatArray; +import static com.redis.om.spring.util.ObjectUtils.longArrayToFloatArray; + public class DefaultFeatureExtractor implements FeatureExtractor { private final ZooModel imageEmbeddingModel; private final ZooModel faceEmbeddingModel; private final ImageFactory imageFactory; private final ApplicationContext applicationContext; - private ImageFeatureExtractor imageFeatureExtractor; + private final ImageFeatureExtractor imageFeatureExtractor; public final Pipeline imagePipeline; public final HuggingFaceTokenizer sentenceTokenizer; @@ -52,12 +56,7 @@ public DefaultFeatureExtractor( // } @Override - public void processEntity(byte[] redisKey, Object item) { - processEntity(item); - } - - @Override - public byte[] getImageEmbeddingsFor(InputStream is) { + public byte[] getImageEmbeddingsAsByteArrayFor(InputStream is) { try { var img = imageFactory.fromInputStream(is); Predictor predictor = imageEmbeddingModel.newPredictor(imageFeatureExtractor); @@ -69,19 +68,35 @@ public byte[] getImageEmbeddingsFor(InputStream is) { } @Override - public byte[] getFacialImageEmbeddingsFor(InputStream is) throws IOException, TranslateException { + public float[] getImageEmbeddingsAsFloatArrayFor(InputStream is) { + return byteArrayToFloatArray(getImageEmbeddingsAsByteArrayFor(is)); + } + + @Override + public byte[] getFacialImageEmbeddingsAsByteArrayFor(InputStream is) throws IOException, TranslateException { + return ObjectUtils.floatArrayToByteArray(getFacialImageEmbeddingsAsFloatArrayFor(is)); + } + + @Override + public float[] getFacialImageEmbeddingsAsFloatArrayFor(InputStream is) throws IOException, TranslateException { try (Predictor predictor = faceEmbeddingModel.newPredictor()) { var img = imageFactory.fromInputStream(is); - return ObjectUtils.floatArrayToByteArray(predictor.predict(img)); + return predictor.predict(img); } } @Override - public byte[] getSentenceEmbeddingsFor(String text) { + public byte[] getSentenceEmbeddingsAsByteArrayFor(String text) { Encoding encoding = sentenceTokenizer.encode(text); return ObjectUtils.longArrayToByteArray(encoding.getIds()); } + @Override + public float[] getSentenceEmbeddingAsFloatArrayFor(String text) { + Encoding encoding = sentenceTokenizer.encode(text); + return longArrayToFloatArray(encoding.getIds()); + } + @Override public void processEntity(Object item) { if (!isReady()) { @@ -93,13 +108,18 @@ public void processEntity(Object item) { fields.forEach(f -> { Vectorize vectorize = f.getAnnotation(Vectorize.class); Object fieldValue = accessor.getPropertyValue(f.getName()); + boolean isDocument = item.getClass().isAnnotationPresent(Document.class); + if (fieldValue != null) { switch (vectorize.embeddingType()) { case IMAGE -> { Resource resource = applicationContext.getResource(fieldValue.toString()); try { - byte[] feature = getImageEmbeddingsFor(resource.getInputStream()); - accessor.setPropertyValue(vectorize.destination(), feature); + if (isDocument) { + accessor.setPropertyValue(vectorize.destination(), getImageEmbeddingsAsFloatArrayFor(resource.getInputStream())); + } else { + accessor.setPropertyValue(vectorize.destination(), getImageEmbeddingsAsByteArrayFor(resource.getInputStream())); + } } catch (IOException e) { logger.warn("Error generating image embedding", e); } @@ -110,14 +130,22 @@ public void processEntity(Object item) { case FACE -> { Resource resource = applicationContext.getResource(fieldValue.toString()); try { - byte[] feature = getFacialImageEmbeddingsFor(resource.getInputStream()); - accessor.setPropertyValue(vectorize.destination(), feature); + if (isDocument) { + accessor.setPropertyValue(vectorize.destination(), getFacialImageEmbeddingsAsFloatArrayFor(resource.getInputStream())); + } else { + accessor.setPropertyValue(vectorize.destination(), getFacialImageEmbeddingsAsByteArrayFor(resource.getInputStream())); + } } catch (IOException | TranslateException e) { logger.warn("Error generating facial image embedding", e); } } - case SENTENCE -> - accessor.setPropertyValue(vectorize.destination(), getSentenceEmbeddingsFor(fieldValue.toString())); + case SENTENCE -> { + if (isDocument) { + accessor.setPropertyValue(vectorize.destination(), getSentenceEmbeddingAsFloatArrayFor(fieldValue.toString())); + } else { + accessor.setPropertyValue(vectorize.destination(), getSentenceEmbeddingsAsByteArrayFor(fieldValue.toString())); + } + } } } }); diff --git a/redis-om-spring/src/main/java/com/redis/om/spring/vectorize/FeatureExtractor.java b/redis-om-spring/src/main/java/com/redis/om/spring/vectorize/FeatureExtractor.java index 2404cff6..cbca2450 100644 --- a/redis-om-spring/src/main/java/com/redis/om/spring/vectorize/FeatureExtractor.java +++ b/redis-om-spring/src/main/java/com/redis/om/spring/vectorize/FeatureExtractor.java @@ -6,15 +6,19 @@ import java.io.InputStream; public interface FeatureExtractor { - void processEntity(byte[] redisKey, Object item); + byte[] getImageEmbeddingsAsByteArrayFor(InputStream is); - byte[] getImageEmbeddingsFor(InputStream is); + float[] getImageEmbeddingsAsFloatArrayFor(InputStream is); - byte[] getFacialImageEmbeddingsFor(InputStream is) throws IOException, TranslateException; + byte[] getFacialImageEmbeddingsAsByteArrayFor(InputStream is) throws IOException, TranslateException; - byte[] getSentenceEmbeddingsFor(String text); + float[] getFacialImageEmbeddingsAsFloatArrayFor(InputStream is) throws IOException, TranslateException; - void processEntity(Object item); + byte[] getSentenceEmbeddingsAsByteArrayFor(String text); - boolean isReady(); + float[] getSentenceEmbeddingAsFloatArrayFor(String text); + + void processEntity(Object item); + + boolean isReady(); } diff --git a/redis-om-spring/src/main/java/com/redis/om/spring/vectorize/NoopFeatureExtractor.java b/redis-om-spring/src/main/java/com/redis/om/spring/vectorize/NoopFeatureExtractor.java index 4b1a7ee5..f3daaa8b 100644 --- a/redis-om-spring/src/main/java/com/redis/om/spring/vectorize/NoopFeatureExtractor.java +++ b/redis-om-spring/src/main/java/com/redis/om/spring/vectorize/NoopFeatureExtractor.java @@ -6,26 +6,37 @@ import java.io.InputStream; public class NoopFeatureExtractor implements FeatureExtractor { + @Override - public void processEntity(byte[] redisKey, Object item) { - // NOOP + public byte[] getImageEmbeddingsAsByteArrayFor(InputStream is) { + return new byte[0]; } @Override - public byte[] getImageEmbeddingsFor(InputStream is) { - return new byte[0]; + public float[] getImageEmbeddingsAsFloatArrayFor(InputStream is) { + return new float[0]; } @Override - public byte[] getFacialImageEmbeddingsFor(InputStream is) throws IOException, TranslateException { + public byte[] getFacialImageEmbeddingsAsByteArrayFor(InputStream is) throws IOException, TranslateException { return new byte[0]; } @Override - public byte[] getSentenceEmbeddingsFor(String text) { + public float[] getFacialImageEmbeddingsAsFloatArrayFor(InputStream is) throws IOException, TranslateException { + return new float[0]; + } + + @Override + public byte[] getSentenceEmbeddingsAsByteArrayFor(String text) { return new byte[0]; } + @Override + public float[] getSentenceEmbeddingAsFloatArrayFor(String text) { + return new float[0]; + } + @Override public void processEntity(Object item) { // NOOP diff --git a/redis-om-spring/src/test/java/com/redis/om/spring/annotations/document/fixtures/Product.java b/redis-om-spring/src/test/java/com/redis/om/spring/annotations/document/fixtures/Product.java new file mode 100644 index 00000000..d81726ab --- /dev/null +++ b/redis-om-spring/src/test/java/com/redis/om/spring/annotations/document/fixtures/Product.java @@ -0,0 +1,52 @@ +package com.redis.om.spring.annotations.document.fixtures; + +import com.redis.om.spring.DistanceMetric; +import com.redis.om.spring.VectorType; +import com.redis.om.spring.annotations.*; +import lombok.Data; +import lombok.NoArgsConstructor; +import lombok.NonNull; +import lombok.RequiredArgsConstructor; +import org.springframework.data.annotation.Id; +import redis.clients.jedis.search.Schema.VectorField.VectorAlgo; + +@Data +@RequiredArgsConstructor(staticName = "of") +@NoArgsConstructor(force = true) +@Document +public class Product { + @Id + private String id; + + @Indexed + @NonNull + private String name; + + @Indexed(// + schemaFieldType = SchemaFieldType.VECTOR, // + algorithm = VectorAlgo.HNSW, // + type = VectorType.FLOAT32, // + dimension = 512, // + distanceMetric = DistanceMetric.L2, // + initialCapacity = 10 + ) + private float[] imageEmbedding; + + @Vectorize(destination = "imageEmbedding", embeddingType = EmbeddingType.IMAGE) + @NonNull + private String imagePath; + + @Indexed(// + schemaFieldType = SchemaFieldType.VECTOR, // + algorithm = VectorAlgo.HNSW, // + type = VectorType.FLOAT32, // + dimension = 768, // + distanceMetric = DistanceMetric.L2, // + initialCapacity = 10 + ) + private float[] sentenceEmbedding; + + @Vectorize(destination = "sentenceEmbedding", embeddingType = EmbeddingType.SENTENCE) + @NonNull + private String description; +} diff --git a/redis-om-spring/src/test/java/com/redis/om/spring/annotations/document/fixtures/ProductRepository.java b/redis-om-spring/src/test/java/com/redis/om/spring/annotations/document/fixtures/ProductRepository.java new file mode 100644 index 00000000..f7ec5549 --- /dev/null +++ b/redis-om-spring/src/test/java/com/redis/om/spring/annotations/document/fixtures/ProductRepository.java @@ -0,0 +1,9 @@ +package com.redis.om.spring.annotations.document.fixtures; + +import com.redis.om.spring.repository.RedisEnhancedRepository; + +import java.util.Optional; + +public interface ProductRepository extends RedisEnhancedRepository { + Optional findFirstByName(String name); +} diff --git a/redis-om-spring/src/test/java/com/redis/om/spring/annotations/document/vectorize/VectorizeDocumentTest.java b/redis-om-spring/src/test/java/com/redis/om/spring/annotations/document/vectorize/VectorizeDocumentTest.java new file mode 100644 index 00000000..0638a5d2 --- /dev/null +++ b/redis-om-spring/src/test/java/com/redis/om/spring/annotations/document/vectorize/VectorizeDocumentTest.java @@ -0,0 +1,111 @@ +package com.redis.om.spring.annotations.document.vectorize; + +import com.redis.om.spring.AbstractBaseDocumentTest; +import com.redis.om.spring.annotations.document.fixtures.Product; +import com.redis.om.spring.annotations.document.fixtures.Product$; +import com.redis.om.spring.annotations.document.fixtures.ProductRepository; +import com.redis.om.spring.search.stream.EntityStream; +import com.redis.om.spring.search.stream.SearchStream; +import com.redis.om.spring.vectorize.FeatureExtractor; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.test.context.junit.jupiter.EnabledIf; + +import java.io.IOException; +import java.util.List; +import java.util.Optional; +import java.util.stream.Collectors; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertAll; + +class VectorizeDocumentTest extends AbstractBaseDocumentTest { + @Autowired ProductRepository repository; + @Autowired EntityStream entityStream; + + @Autowired FeatureExtractor featureExtractor; + + @BeforeEach void loadTestData() throws IOException { + if (repository.count() == 0) { + repository.save(Product.of("cat", "classpath:/images/cat.jpg", + "The cat (Felis catus) is a domestic species of small carnivorous mammal.")); + repository.save(Product.of("cat2", "classpath:/images/cat2.jpg", + "It is the only domesticated species in the family Felidae and is commonly referred to as the domestic cat or house cat")); + repository.save(Product.of("catdog", "classpath:/images/catdog.jpg", "This is a picture of a cat and a dog together")); + repository.save(Product.of("face", "classpath:/images/face.jpg", "Three years later, the coffin was still full of Jello.")); + repository.save(Product.of("face2", "classpath:/images/face2.jpg", "The person box was packed with jelly many dozens of months later.")); + } + } + + @Test + @EnabledIf( + expression = "#{@featureExtractor.isReady()}", // + loadContext = true // + ) + void testImageIsVectorized() { + Optional cat = repository.findFirstByName("cat"); + assertAll( // + () -> assertThat(cat).isPresent(), // + () -> assertThat(cat.get()).extracting("imageEmbedding").isNotNull(), // + () -> assertThat(cat.get().getImageEmbedding()).hasSize(512) + ); + } + + @Test + @EnabledIf( + expression = "#{@featureExtractor.isReady()}", // + loadContext = true // + ) + void testSentenceIsVectorized() { + Optional cat = repository.findFirstByName("cat"); + assertAll( // + () -> assertThat(cat).isPresent(), // + () -> assertThat(cat.get()).extracting("sentenceEmbedding").isNotNull()//, // + //() -> assertThat(cat.get().getSentenceEmbedding()).hasSize(768*Float.BYTES) + ); + } + + @Test + @EnabledIf( + expression = "#{@featureExtractor.isReady()}", // + loadContext = true // + ) + void testKnnImageSimilaritySearch() { + Product cat = repository.findFirstByName("cat").get(); + int K = 5; + + SearchStream stream = entityStream.of(Product.class); + + List results = stream // + .filter(Product$.IMAGE_EMBEDDING.knn(K, cat.getImageEmbedding())) // + .sorted(Product$._IMAGE_EMBEDDING_SCORE) // + .limit(K) // + .collect(Collectors.toList()); + + assertThat(results).hasSize(5).map(Product::getName).containsExactly("cat", "cat2", + "face", "face2", "catdog"); + } + + @Test + @EnabledIf( + expression = "#{@featureExtractor.isReady()}", // + loadContext = true // + ) + void testKnnSentenceSimilaritySearch() { + Product cat = repository.findFirstByName("cat").get(); + int K = 5; + + SearchStream stream = entityStream.of(Product.class); + + List results = stream // + .filter(Product$.SENTENCE_EMBEDDING.knn(K, cat.getSentenceEmbedding())) // + .sorted(Product$._SENTENCE_EMBEDDING_SCORE) // + .limit(K) // + .collect(Collectors.toList()); + + assertThat(results).hasSize(5).map(Product::getName).containsExactly( // + "cat", "catdog", "cat2", "face", "face2" // + ); + } +} diff --git a/redis-om-spring/src/test/java/com/redis/om/spring/annotations/hash/serialization/SerializationTest.java b/redis-om-spring/src/test/java/com/redis/om/spring/annotations/hash/serialization/SerializationTest.java index 5ba20b44..f365b6cd 100644 --- a/redis-om-spring/src/test/java/com/redis/om/spring/annotations/hash/serialization/SerializationTest.java +++ b/redis-om-spring/src/test/java/com/redis/om/spring/annotations/hash/serialization/SerializationTest.java @@ -63,7 +63,7 @@ public void cleanUp() throws IOException { point = new Point(-111.83592170193586,33.62826024782707); ulid = UlidCreator.getMonotonicUlid(); byteArray = "Hello World!".getBytes(); - byteArray2 = featureExtractor.getImageEmbeddingsFor( // + byteArray2 = featureExtractor.getImageEmbeddingsAsByteArrayFor( // applicationContext.getResource("classpath:/images/cat.jpg").getInputStream() ); diff --git a/redis-om-spring/src/test/java/com/redis/om/spring/annotations/vectorize/VectorizeTest.java b/redis-om-spring/src/test/java/com/redis/om/spring/annotations/hash/vectorize/VectorizeHashTest.java similarity index 97% rename from redis-om-spring/src/test/java/com/redis/om/spring/annotations/vectorize/VectorizeTest.java rename to redis-om-spring/src/test/java/com/redis/om/spring/annotations/hash/vectorize/VectorizeHashTest.java index 7c1fc5d8..f0e6f6b3 100644 --- a/redis-om-spring/src/test/java/com/redis/om/spring/annotations/vectorize/VectorizeTest.java +++ b/redis-om-spring/src/test/java/com/redis/om/spring/annotations/hash/vectorize/VectorizeHashTest.java @@ -1,4 +1,4 @@ -package com.redis.om.spring.annotations.vectorize; +package com.redis.om.spring.annotations.hash.vectorize; import com.redis.om.spring.AbstractBaseEnhancedRedisTest; import com.redis.om.spring.annotations.hash.fixtures.Product; @@ -20,7 +20,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertAll; -class VectorizeTest extends AbstractBaseEnhancedRedisTest { +class VectorizeHashTest extends AbstractBaseEnhancedRedisTest { @Autowired ProductRepository repository; @Autowired EntityStream entityStream; diff --git a/redis-om-spring/src/test/java/com/redis/om/spring/annotations/vectorize/face/FaceDetectionTest.java b/redis-om-spring/src/test/java/com/redis/om/spring/annotations/hash/vectorize/face/FaceDetectionTest.java similarity index 97% rename from redis-om-spring/src/test/java/com/redis/om/spring/annotations/vectorize/face/FaceDetectionTest.java rename to redis-om-spring/src/test/java/com/redis/om/spring/annotations/hash/vectorize/face/FaceDetectionTest.java index 7cf6289c..0a6ff655 100644 --- a/redis-om-spring/src/test/java/com/redis/om/spring/annotations/vectorize/face/FaceDetectionTest.java +++ b/redis-om-spring/src/test/java/com/redis/om/spring/annotations/hash/vectorize/face/FaceDetectionTest.java @@ -1,4 +1,4 @@ -package com.redis.om.spring.annotations.vectorize.face; +package com.redis.om.spring.annotations.hash.vectorize.face; import ai.djl.inference.Predictor; import ai.djl.modality.cv.Image;