Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: implement VSS (Vectorize and Search) for JSON documents #333

Merged
merged 1 commit into from
Aug 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -347,11 +347,7 @@ private Field indexAsVectorFieldFor(java.lang.reflect.Field field, boolean isDoc
Indexed indexed) {
TypeInformation<?> typeInfo = TypeInformation.of(field.getType());
String fieldPrefix = getFieldPrefix(prefix, isDocument);

String fieldPostfix = (isDocument && typeInfo.isCollectionLike() && !field.isAnnotationPresent(JsonAdapter.class))
? "[*]"
: "";
String fieldName = fieldPrefix + field.getName() + fieldPostfix;
String fieldName = fieldPrefix + field.getName();

Map<String, Object> attributes = new HashMap<>();
attributes.put("TYPE", indexed.type().toString());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ public Object put(Object id, Object item, String keyspace) {
} else {
byte[] redisKey = createKey(keyspace, converter.getConversionService().convert(id, String.class));
auditor.processEntity(redisKey, item);
featureExtractor.processEntity(redisKey, item);
featureExtractor.processEntity(item);

rdo = new RedisData();
converter.write(item, rdo);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,17 @@
import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import com.google.gson.reflect.TypeToken;
import com.redis.om.spring.audit.EntityAuditor;
import com.redis.om.spring.convert.RedisOMCustomConversions;
import com.redis.om.spring.ops.RedisModulesOperations;
import com.redis.om.spring.ops.json.JSONOperations;
import com.redis.om.spring.ops.search.SearchOperations;
import com.redis.om.spring.util.ObjectUtils;
import com.redis.om.spring.vectorize.FeatureExtractor;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.beans.*;
import org.springframework.dao.OptimisticLockingFailureException;
import org.springframework.data.annotation.CreatedDate;
import org.springframework.data.annotation.LastModifiedDate;
import org.springframework.data.annotation.Reference;
import org.springframework.data.annotation.Version;
import org.springframework.data.redis.core.RedisCallback;
Expand All @@ -34,9 +34,10 @@

import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.util.*;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.TimeUnit;

import static com.redis.om.spring.util.ObjectUtils.getKey;
Expand All @@ -50,6 +51,8 @@ public class RedisJSONKeyValueAdapter extends RedisKeyValueAdapter {
private final RedisModulesOperations<String> modulesOperations;
private final RediSearchIndexer indexer;
private final GsonBuilder gsonBuilder;
private final EntityAuditor auditor;
private final FeatureExtractor featureExtractor;
private final RedisOMProperties redisOMProperties;

/**
Expand All @@ -62,16 +65,24 @@ public class RedisJSONKeyValueAdapter extends RedisKeyValueAdapter {
* @param keyspaceToIndexMap must not be {@literal null}.
*/
@SuppressWarnings("unchecked")
public RedisJSONKeyValueAdapter(RedisOperations<?, ?> redisOps, RedisModulesOperations<?> rmo,
RedisMappingContext mappingContext, RediSearchIndexer keyspaceToIndexMap, GsonBuilder gsonBuilder,
RedisOMProperties redisOMProperties) {
public RedisJSONKeyValueAdapter( //
RedisOperations<?, ?> redisOps, //
RedisModulesOperations<?> rmo, //
RedisMappingContext mappingContext, //
RediSearchIndexer keyspaceToIndexMap, //
GsonBuilder gsonBuilder, //
FeatureExtractor featureExtractor, //
RedisOMProperties redisOMProperties
) {
super(redisOps, mappingContext, new RedisOMCustomConversions());
this.modulesOperations = (RedisModulesOperations<String>) rmo;
this.redisJSONOperations = modulesOperations.opsForJSON();
this.redisOperations = redisOps;
this.mappingContext = mappingContext;
this.indexer = keyspaceToIndexMap;
this.auditor = new EntityAuditor(this.redisOperations);
this.gsonBuilder = gsonBuilder;
this.featureExtractor = featureExtractor;
this.redisOMProperties = redisOMProperties;
}

Expand All @@ -91,7 +102,8 @@ public Object put(Object id, Object item, String keyspace) {
String key = getKey(keyspace, id);

processVersion(key, item);
processAuditAnnotations(key, item);
auditor.processEntity(key, item);
featureExtractor.processEntity(item);
Optional<Long> maybeTtl = getTTLForEntity(item);

ops.set(key, item);
Expand Down Expand Up @@ -256,27 +268,6 @@ public boolean contains(Object id, String keyspace) {
return exists != null && exists;
}

private void processAuditAnnotations(String key, Object item) {
boolean isNew = (boolean) redisOperations
.execute((RedisCallback<Object>) connection -> !connection.keyCommands().exists(toBytes(key)));

var auditClass = isNew ? CreatedDate.class : LastModifiedDate.class;

List<Field> fields = ObjectUtils.getFieldsWithAnnotation(item.getClass(), auditClass);
if (!fields.isEmpty()) {
PropertyAccessor accessor = PropertyAccessorFactory.forBeanPropertyAccess(item);
fields.forEach(f -> {
if (f.getType() == Date.class) {
accessor.setPropertyValue(f.getName(), new Date(System.currentTimeMillis()));
} else if (f.getType() == LocalDateTime.class) {
accessor.setPropertyValue(f.getName(), LocalDateTime.now());
} else if (f.getType() == LocalDate.class) {
accessor.setPropertyValue(f.getName(), LocalDate.now());
}
});
}
}

private void processReferences(String key, Object item) {
List<Field> fields = ObjectUtils.getFieldsWithAnnotation(item.getClass(), Reference.class);
if (!fields.isEmpty()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -293,9 +293,10 @@ RedisJSONKeyValueAdapter getRedisJSONKeyValueAdapter( //
RedisMappingContext mappingContext, //
RediSearchIndexer indexer, //
@Qualifier("omGsonBuilder") GsonBuilder gsonBuilder, //
RedisOMProperties properties //
RedisOMProperties properties, //
@Nullable @Qualifier("featureExtractor") FeatureExtractor featureExtractor
) {
return new RedisJSONKeyValueAdapter(redisOps, redisModulesOperations, mappingContext, indexer, gsonBuilder, properties);
return new RedisJSONKeyValueAdapter(redisOps, redisModulesOperations, mappingContext, indexer, gsonBuilder, featureExtractor, properties);
}

@Bean(name = "redisJSONKeyValueTemplate")
Expand All @@ -305,10 +306,11 @@ public CustomRedisKeyValueTemplate getRedisJSONKeyValueTemplate( //
RedisMappingContext mappingContext, //
RediSearchIndexer indexer, //
@Qualifier("omGsonBuilder") GsonBuilder gsonBuilder, //
RedisOMProperties properties //
RedisOMProperties properties, //
@Nullable @Qualifier("featureExtractor") FeatureExtractor featureExtractor
) {
return new CustomRedisKeyValueTemplate(
new RedisJSONKeyValueAdapter(redisOps, redisModulesOperations, mappingContext, indexer, gsonBuilder, properties),
new RedisJSONKeyValueAdapter(redisOps, redisModulesOperations, mappingContext, indexer, gsonBuilder, featureExtractor, properties),
mappingContext);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ public void processEntity(byte[] redisKey, Object item) {
processEntity(item, isNew);
}

public void processEntity(String redisKey, Object item) {
processEntity(redisKey.getBytes(), item);
}

public void processEntity(Object item, boolean isNew) {
var auditClass = isNew ? CreatedDate.class : LastModifiedDate.class;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,8 @@ public VectorField(SearchFieldAccessor field, boolean indexed) {
public KNNPredicate<E,T> knn(int k, byte[] blobAttribute) {
return new KNNPredicate<>(searchFieldAccessor,k, blobAttribute);
}

public KNNPredicate<E,T> knn(int k, float[] blobAttribute) {
return new KNNPredicate<>(searchFieldAccessor,k, blobAttribute);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import com.redis.om.spring.RedisOMProperties;
import com.redis.om.spring.ops.RedisModulesOperations;
import com.redis.om.spring.repository.query.RediSearchQuery;
import com.redis.om.spring.vectorize.FeatureExtractor;
import org.springframework.beans.BeanUtils;
import org.springframework.data.keyvalue.core.KeyValueOperations;
import org.springframework.data.keyvalue.repository.query.KeyValuePartTreeQuery;
Expand Down Expand Up @@ -41,6 +42,7 @@ public class RedisDocumentRepositoryFactory extends KeyValueRepositoryFactory {
private final RediSearchIndexer indexer;
private final GsonBuilder gsonBuilder;
private final RedisMappingContext mappingContext;
private final FeatureExtractor featureExtractor;
private final RedisOMProperties properties;

/**
Expand All @@ -59,9 +61,10 @@ public RedisDocumentRepositoryFactory( //
RediSearchIndexer keyspaceToIndexMap, //
RedisMappingContext mappingContext, //
GsonBuilder gsonBuilder, //
FeatureExtractor featureExtractor, //
RedisOMProperties properties //
) {
this(keyValueOperations, rmo, keyspaceToIndexMap, DEFAULT_QUERY_CREATOR, mappingContext, gsonBuilder, properties);
this(keyValueOperations, rmo, keyspaceToIndexMap, DEFAULT_QUERY_CREATOR, mappingContext, gsonBuilder, featureExtractor, properties);
}

/**
Expand All @@ -82,10 +85,11 @@ public RedisDocumentRepositoryFactory( //
Class<? extends AbstractQueryCreator<?, ?>> queryCreator, //
RedisMappingContext mappingContext, //
GsonBuilder gsonBuilder, //
FeatureExtractor featureExtractor, //
RedisOMProperties properties //
) {

this(keyValueOperations, rmo, keyspaceToIndexMap, queryCreator, RediSearchQuery.class, mappingContext, gsonBuilder, properties);
this(keyValueOperations, rmo, keyspaceToIndexMap, queryCreator, RediSearchQuery.class, mappingContext, gsonBuilder, featureExtractor, properties);
}

/**
Expand All @@ -108,12 +112,19 @@ public RedisDocumentRepositoryFactory( //
Class<? extends RepositoryQuery> repositoryQueryType, //
RedisMappingContext mappingContext, //
GsonBuilder gsonBuilder, //
FeatureExtractor featureExtractor, //
RedisOMProperties properties //
) {

super(keyValueOperations, queryCreator, repositoryQueryType);

Assert.notNull(rmo, "RedisModulesOperations must not be null!");
Assert.notNull(keyValueOperations, "KeyValueOperations must not be null!");
Assert.notNull(rmo, "RedisModulesOperations must not be null!");
Assert.notNull(queryCreator, "Query creator type must not be null!");
Assert.notNull(repositoryQueryType, "RepositoryQueryType type must not be null!");
Assert.notNull(featureExtractor, "FeatureExtractor type must not be null!");
Assert.notNull(properties, "RedisOMSpringProperties type must not be null!");

this.keyValueOperations = keyValueOperations;
this.rmo = rmo;
Expand All @@ -122,14 +133,15 @@ public RedisDocumentRepositoryFactory( //
this.repositoryQueryType = repositoryQueryType;
this.mappingContext = mappingContext;
this.gsonBuilder = gsonBuilder;
this.featureExtractor = featureExtractor;
this.properties = properties;
}

@Override
protected Object getTargetRepository(RepositoryInformation repositoryInformation) {
EntityInformation<?, ?> entityInformation = getEntityInformation(repositoryInformation.getDomainType());
return super.getTargetRepositoryViaReflection(
repositoryInformation, entityInformation, keyValueOperations, rmo, indexer, mappingContext, gsonBuilder, properties);
repositoryInformation, entityInformation, keyValueOperations, rmo, indexer, mappingContext, gsonBuilder, featureExtractor, properties);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import com.redis.om.spring.RediSearchIndexer;
import com.redis.om.spring.RedisOMProperties;
import com.redis.om.spring.ops.RedisModulesOperations;
import com.redis.om.spring.vectorize.FeatureExtractor;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.keyvalue.core.KeyValueOperations;
import org.springframework.data.keyvalue.repository.support.KeyValueRepositoryFactoryBean;
Expand All @@ -26,6 +27,8 @@ public class RedisDocumentRepositoryFactoryBean<T extends Repository<S, ID>, S,
@Autowired
private GsonBuilder gsonBuilder;
@Autowired
private @Nullable FeatureExtractor featureExtractor;
@Autowired
private RedisOMProperties properties;

/**
Expand All @@ -46,7 +49,7 @@ protected final RedisDocumentRepositoryFactory createRepositoryFactory( //
Class<? extends RepositoryQuery> repositoryQueryType //
) {
return new RedisDocumentRepositoryFactory(operations, rmo, indexer, queryCreator, repositoryQueryType,
this.mappingContext, this.gsonBuilder, this.properties);
this.mappingContext, this.gsonBuilder, this.featureExtractor, this.properties);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import com.google.gson.reflect.TypeToken;
import com.redis.om.spring.RediSearchIndexer;
import com.redis.om.spring.RedisOMProperties;
import com.redis.om.spring.audit.EntityAuditor;
import com.redis.om.spring.convert.MappingRedisOMConverter;
import com.redis.om.spring.id.ULIDIdentifierGenerator;
import com.redis.om.spring.metamodel.MetamodelField;
Expand All @@ -18,6 +19,7 @@
import com.redis.om.spring.search.stream.FluentQueryByExample;
import com.redis.om.spring.serialization.gson.GsonListOfType;
import com.redis.om.spring.util.ObjectUtils;
import com.redis.om.spring.vectorize.FeatureExtractor;
import org.springframework.beans.BeanWrapper;
import org.springframework.beans.BeanWrapperImpl;
import org.springframework.beans.PropertyAccessor;
Expand Down Expand Up @@ -82,6 +84,8 @@ public class SimpleRedisDocumentRepository<T, ID> extends SimpleKeyValueReposito
private final RedisOMProperties properties;
private final RedisMappingContext mappingContext;
private final EntityStream entityStream;
protected final EntityAuditor auditor;
protected final FeatureExtractor featureExtractor;

@SuppressWarnings("unchecked")
public SimpleRedisDocumentRepository( //
Expand All @@ -91,6 +95,7 @@ public SimpleRedisDocumentRepository( //
RediSearchIndexer indexer, //
RedisMappingContext mappingContext,
GsonBuilder gsonBuilder,
FeatureExtractor featureExtractor, //
RedisOMProperties properties) {
super(metadata, operations);
this.modulesOperations = (RedisModulesOperations<String>) rmo;
Expand All @@ -102,6 +107,8 @@ public SimpleRedisDocumentRepository( //
this.generator = ULIDIdentifierGenerator.INSTANCE;
this.gsonBuilder = gsonBuilder;
this.mappingContext = mappingContext;
this.auditor = new EntityAuditor(modulesOperations.getTemplate());
this.featureExtractor = featureExtractor;
this.properties = properties;
this.entityStream = new EntityStreamImpl(modulesOperations, modulesOperations.getGsonBuilder(), indexer);
}
Expand Down Expand Up @@ -179,7 +186,9 @@ public <S extends T> List<S> saveAll(Iterable<S> entities) {
String keyspace = keyValueEntity.getKeySpace();
byte[] objectKey = createKey(keyspace, Objects.requireNonNull(id).toString());

processAuditAnnotations(entity, isNew);
// process entity pre-save mutation
auditor.processEntity(entity, isNew);
featureExtractor.processEntity(entity);

Optional<Long> maybeTtl = getTTLForEntity(entity);

Expand Down Expand Up @@ -244,24 +253,6 @@ public byte[] createKey(String keyspace, String id) {
return this.mappingConverter.toBytes(keyspace + ":" + id);
}

private void processAuditAnnotations(Object item, boolean isNew) {
var auditClass = isNew ? CreatedDate.class : LastModifiedDate.class;

List<Field> fields = com.redis.om.spring.util.ObjectUtils.getFieldsWithAnnotation(item.getClass(), auditClass);
if (!fields.isEmpty()) {
PropertyAccessor accessor = PropertyAccessorFactory.forBeanPropertyAccess(item);
fields.forEach(f -> {
if (f.getType() == Date.class) {
accessor.setPropertyValue(f.getName(), new Date(System.currentTimeMillis()));
} else if (f.getType() == LocalDateTime.class) {
accessor.setPropertyValue(f.getName(), LocalDateTime.now());
} else if (f.getType() == LocalDate.class) {
accessor.setPropertyValue(f.getName(), LocalDate.now());
}
});
}
}

private void processReferenceAnnotations(byte[] objectKey, Object entity, Pipeline pipeline) {
List<Field> fields = com.redis.om.spring.util.ObjectUtils.getFieldsWithAnnotation(entity.getClass(), Reference.class);
if (!fields.isEmpty()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ public <S extends T> List<S> saveAll(Iterable<S> entities) {
String keyspace = keyValueEntity.getKeySpace();
byte[] objectKey = createKey(keyspace, id.toString());

// process entity pre-save mutation entities
// process entity pre-save mutation
auditor.processEntity(entity, isNew);
featureExtractor.processEntity(entity);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
import com.redis.om.spring.tuple.AbstractTupleMapper;
import com.redis.om.spring.tuple.Pair;
import com.redis.om.spring.tuple.TupleMapper;
import com.redis.om.spring.util.SearchResultRawResponseToObjectConverter;
import com.redis.om.spring.util.ObjectUtils;
import com.redis.om.spring.util.SearchResultRawResponseToObjectConverter;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.beans.BeanUtils;
Expand All @@ -42,6 +42,7 @@
import java.util.stream.*;

import static com.redis.om.spring.metamodel.MetamodelUtils.getMetamodelForIdField;
import static com.redis.om.spring.util.ObjectUtils.floatArrayToByteArray;
import static java.util.stream.Collectors.toCollection;

public class SearchStreamImpl<E> implements SearchStream<E> {
Expand Down Expand Up @@ -410,7 +411,7 @@ Query prepareQuery() {

if (knnPredicate != null) {
query = new Query(knnPredicate.apply(rootNode).toString());
query.addParam(knnPredicate.getBlobAttributeName(), knnPredicate.getBlobAttribute());
query.addParam(knnPredicate.getBlobAttributeName(), knnPredicate.getBlobAttribute() != null ? knnPredicate.getBlobAttribute() : floatArrayToByteArray(knnPredicate.getDoublesAttribute()));
query.addParam("K", knnPredicate.getK());
query.dialect(2);
} else {
Expand Down
Loading
Loading