Skip to content

Commit

Permalink
feature: adds support for repository interface projection methods for…
Browse files Browse the repository at this point in the history
… Hashes
  • Loading branch information
bsbodden committed May 7, 2024
1 parent c200fc5 commit 49d999e
Show file tree
Hide file tree
Showing 15 changed files with 416 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import org.springframework.data.mapping.model.EntityInstantiators;
import org.springframework.data.mapping.model.PersistentEntityParameterValueProvider;
import org.springframework.data.mapping.model.PropertyValueProvider;
import org.springframework.data.projection.ProjectionFactory;
import org.springframework.data.projection.SpelAwareProxyProjectionFactory;
import org.springframework.data.redis.core.PartialUpdate;
import org.springframework.data.redis.core.PartialUpdate.PropertyUpdate;
import org.springframework.data.redis.core.PartialUpdate.UpdateCommand;
Expand Down Expand Up @@ -50,6 +52,7 @@ public class MappingRedisOMConverter implements RedisConverter, InitializingBean

private @Nullable ReferenceResolver referenceResolver;
private CustomConversions customConversions;
private final ProjectionFactory projectionFactory;

public MappingRedisOMConverter() {
this(null);
Expand Down Expand Up @@ -94,6 +97,7 @@ public MappingRedisOMConverter(@Nullable RedisMappingContext mappingContext,
this.customConversions = new RedisOMCustomConversions();
this.typeMapper = typeMapper != null ? typeMapper
: new DefaultRedisTypeMapper(DefaultRedisTypeMapper.DEFAULT_TYPE_KEY, this.mappingContext);
this.projectionFactory = new SpelAwareProxyProjectionFactory();
this.referenceResolver = referenceResolver;
afterPropertiesSet();
}
Expand All @@ -107,9 +111,15 @@ public MappingRedisOMConverter(@Nullable RedisMappingContext mappingContext,
public <R> R read(Class<R> type, RedisData source) {
TypeInformation<?> readType = typeMapper.readType(source.getBucket().getPath(), TypeInformation.of(type));

return readType.isCollectionLike()
? (R) readCollectionOrArray(type, "", ArrayList.class, Object.class, source.getBucket())
: doReadInternal(type, "", type, source);
if (readType.isCollectionLike()) {
return (R) readCollectionOrArray(type, "", ArrayList.class, Object.class, source.getBucket());
}

// if (type.isInterface()) {
// return (R) projectionFactory.createProjection(type, result);
// }

return readInternal(type, "", type, source);
}

@Nullable
Expand All @@ -119,70 +129,85 @@ private <R> R readInternal(Class<?> entityClass, String path, Class<R> type, Red

@SuppressWarnings("unchecked")
private <R> R doReadInternal(Class<?> entityClass, String path, Class<R> type, RedisData source) {

TypeInformation<?> readType = typeMapper.readType(source.getBucket().getPath(), TypeInformation.of(type));

if (customConversions.hasCustomReadTarget(Map.class, readType.getType())) {

Map<String, byte[]> partial = new HashMap<>();

if (!path.isEmpty()) {

for (Entry<String, byte[]> entry : source.getBucket().extract(path + ".").entrySet()) {
partial.put(entry.getKey().substring(path.length() + 1), entry.getValue());
}

} else {
partial.putAll(source.getBucket().asMap());
}

R instance = (R) conversionService.convert(partial, readType.getType());

RedisPersistentEntity<?> entity = mappingContext.getPersistentEntity(readType);
if (entity != null && instance != null && entity.hasIdProperty()) {

PersistentPropertyAccessor<R> propertyAccessor = entity.getPropertyAccessor(instance);

propertyAccessor.setProperty(entity.getRequiredIdProperty(), source.getId());
instance = propertyAccessor.getBean();
}

return instance;
}

if (conversionService.canConvert(byte[].class, readType.getType())) {
return (R) conversionService.convert(source.getBucket().get(StringUtils.hasText(path) ? path : "_raw"),
readType.getType());
return (R) conversionService.convert(source.getBucket().get(StringUtils.hasText(path) ? path : "_raw"), readType.getType());
}

RedisPersistentEntity<?> entity = mappingContext.getRequiredPersistentEntity(readType);
EntityInstantiator instantiator = entityInstantiators.getInstantiatorFor(entity);

Object instance = instantiator.createInstance((RedisPersistentEntity<RedisPersistentProperty>) entity,
Object instance;
if (type.isInterface()) {
instance = source.getBucket().asMap();
} else {
instance = instantiator.createInstance((RedisPersistentEntity<RedisPersistentProperty>) entity,
new PersistentEntityParameterValueProvider<>(entity,
new ConverterAwareParameterValueProvider(entityClass, path, source, conversionService),
this.conversionService));

PersistentPropertyAccessor<Object> accessor = entity.getPropertyAccessor(instance);

entity.doWithProperties((PropertyHandler<RedisPersistentProperty>) persistentProperty -> {

InstanceCreatorMetadata<RedisPersistentProperty> constructor = entity.getInstanceCreatorMetadata();

if (constructor != null && constructor.isCreatorParameter(persistentProperty)) {
return;
new ConverterAwareParameterValueProvider(entityClass, path, source, conversionService),
this.conversionService));
}

if (type.isInterface()) {
Map<String, Object> map = new HashMap<>();
RedisPersistentEntity<?> persistentEntity = mappingContext.getRequiredPersistentEntity(readType);
for (Entry<String, byte[]> entry : source.getBucket().asMap().entrySet()) {
String key = entry.getKey();
byte[] value = entry.getValue();
RedisPersistentProperty persistentProperty = persistentEntity.getPersistentProperty(key);
Object convertedValue;
if (persistentProperty != null) {
// Convert the byte[] value to the appropriate type
convertedValue = conversionService.convert(value, persistentProperty.getType());
} else {
// If the property is not found, treat the value as a String
convertedValue = new String(value);
}
map.put(key, convertedValue);
}

Object targetValue = readProperty(entityClass, path, source, persistentProperty);

if (targetValue != null) {
accessor.setProperty(persistentProperty, targetValue);
}
});

readAssociation(path, source, entity, accessor);
// Create a proxy instance of the interface using Spring's ProxyFactory
return projectionFactory.createProjection(type, map);
} else {
PersistentPropertyAccessor<Object> accessor = entity.getPropertyAccessor(instance);
entity.doWithProperties((PropertyHandler<RedisPersistentProperty>) persistentProperty -> {
InstanceCreatorMetadata<RedisPersistentProperty> constructor = entity.getInstanceCreatorMetadata();
if (constructor != null && constructor.isCreatorParameter(persistentProperty)) {
return;
}
Object targetValue = readProperty(entityClass, path, source, persistentProperty);
if (targetValue != null) {
accessor.setProperty(persistentProperty, targetValue);
}
});
readAssociation(path, source, entity, accessor);
}

return (R) accessor.getBean();
return (R) instance;
}

@Nullable
protected Object readProperty(Class<?> entityClass, String path, RedisData source,
RedisPersistentProperty persistentProperty) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,6 @@ private Object parseDocumentResult(redis.clients.jedis.search.Document doc) {
private Object executeDeleteQuery(Object[] parameters) {
SearchOperations<String> ops = modulesOperations.opsForSearch(searchIndex);
String baseQuery = prepareQuery(parameters, true);
// String[] fields = new String[] { "@__key" };
AggregationBuilder aggregation = new AggregationBuilder(baseQuery);

// Load fields with IS_NULL or IS_NOT_NULL query clauses
Expand Down Expand Up @@ -851,7 +850,9 @@ private Object executeNullQuery(Object[] parameters) {
AggregationResult aggregationResult = ops.aggregate(aggregation);

// extract the keys from the aggregation result
String[] keys = aggregationResult.getResults().stream().map(d -> d.get("__key").toString()).toArray(String[]::new);
String[] keys = aggregationResult.getResults().stream() //
.map(d -> d.get("__key").toString())
.toArray(String[]::new);

var entities = modulesOperations.opsForJSON().mget(domainType, keys);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,25 +13,26 @@
import com.redis.om.spring.util.ObjectUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.data.domain.PageImpl;
import org.springframework.data.domain.Pageable;
import org.springframework.data.domain.Sort.Order;
import org.springframework.data.geo.Point;
import org.springframework.data.keyvalue.core.KeyValueOperations;
import org.springframework.data.mapping.PropertyPath;
import org.springframework.data.redis.core.RedisCallback;
import org.springframework.data.redis.core.RedisOperations;
import org.springframework.data.redis.core.convert.RedisData;
import org.springframework.data.redis.core.convert.ReferenceResolverImpl;
import org.springframework.data.repository.core.RepositoryMetadata;
import org.springframework.data.repository.query.Parameter;
import org.springframework.data.repository.query.QueryMethod;
import org.springframework.data.repository.query.QueryMethodEvaluationContextProvider;
import org.springframework.data.repository.query.RepositoryQuery;
import org.springframework.data.repository.query.*;
import org.springframework.data.repository.query.parser.AbstractQueryCreator;
import org.springframework.data.repository.query.parser.Part;
import org.springframework.data.repository.query.parser.PartTree;
import org.springframework.data.util.Pair;
import org.springframework.util.ClassUtils;
import org.springframework.util.ReflectionUtils;
import redis.clients.jedis.search.FieldName;
import redis.clients.jedis.search.Query;
import redis.clients.jedis.search.Schema.FieldType;
import redis.clients.jedis.search.SearchResult;
Expand Down Expand Up @@ -95,6 +96,7 @@ public class RedisEnhancedQuery implements RepositoryQuery {
private final boolean isANDQuery;
private boolean isNullParamQuery;
private final KeyValueOperations keyValueOperations;
private final RedisOperations<?, ?> redisOperations;

@SuppressWarnings("unchecked")
public RedisEnhancedQuery(
Expand All @@ -117,6 +119,7 @@ public RedisEnhancedQuery(
Optional<String> maybeIndex = indexer.getIndexName(this.domainType);
this.searchIndex = maybeIndex.orElse(this.domainType.getName() + "Idx");
this.redisOMProperties = redisOMProperties;
this.redisOperations = redisOperations;
this.mappingConverter = new MappingRedisOMConverter(null, new ReferenceResolverImpl(redisOperations));

bloomQueryExecutor = new BloomQueryExecutor(this, modulesOperations);
Expand Down Expand Up @@ -234,7 +237,7 @@ public RedisEnhancedQuery(
});

this.isNullParamQuery = !nullParamNames.isEmpty() || !notNullParamNames.isEmpty();
this.type = RediSearchQueryType.QUERY;
this.type = queryMethod.getName().matches("(?:remove|delete).*") ? RediSearchQueryType.DELETE : RediSearchQueryType.QUERY;
this.returnFields = new String[] {};
processPartTree(pt, nullParamNames, notNullParamNames);
}
Expand Down Expand Up @@ -383,6 +386,8 @@ public Object execute(Object[] parameters) {
return !isNullParamQuery ? executeQuery(parameters) : executeNullQuery(parameters);
} else if (type == RediSearchQueryType.AGGREGATION) {
return executeAggregation(parameters);
} else if (type == RediSearchQueryType.DELETE) {
return executeDeleteQuery(parameters);
} else if (type == RediSearchQueryType.TAGVALS) {
return executeFtTagVals();
} else if (type == RediSearchQueryType.AUTOCOMPLETE) {
Expand All @@ -400,11 +405,29 @@ public QueryMethod getQueryMethod() {
}

private Object executeQuery(Object[] parameters) {
ParameterAccessor accessor = new ParametersParameterAccessor(queryMethod.getParameters(), parameters);
ResultProcessor processor = queryMethod.getResultProcessor().withDynamicProjection(accessor);

SearchOperations<String> ops = modulesOperations.opsForSearch(searchIndex);
boolean excludeNullParams = !isNullParamQuery;
String preparedQuery = prepareQuery(parameters, excludeNullParams);
Query query = new Query(preparedQuery);
query.returnFields(returnFields);

ReturnedType returnedType = processor.getReturnedType();

boolean isProjecting = returnedType.isProjecting() && returnedType.getReturnedType() != SearchResult.class;
boolean isOpenProjecting = Arrays.stream(returnedType.getReturnedType().getMethods())
.anyMatch(m -> m.isAnnotationPresent(Value.class));
boolean canPerformQueryOptimization = isProjecting && !isOpenProjecting;

if (canPerformQueryOptimization) {
query.returnFields(returnedType.getInputProperties()
.stream()
.map(inputProperty -> new FieldName( "$." + inputProperty, inputProperty))
.toArray(FieldName[]::new));
} else {
query.returnFields(returnFields);
}

Optional<Pageable> maybePageable = Optional.empty();

Expand Down Expand Up @@ -462,33 +485,104 @@ private Object executeQuery(Object[] parameters) {

SearchResult searchResult = ops.search(query);

// what to return
Object result = null;
Object result;

if (queryMethod.getReturnedObjectType() == SearchResult.class) {
result = searchResult;
} else if (queryMethod.isPageQuery()) {
List<Object> content = searchResult.getDocuments().stream() //
.map(d -> ObjectUtils.documentToObject(d, queryMethod.getReturnedObjectType(), mappingConverter)) //
.collect(Collectors.toList());
List<Object> content = searchResult.getDocuments().stream()
.map(d -> ObjectUtils.documentToObject(d, queryMethod.getReturnedObjectType(), mappingConverter))
.collect(Collectors.toList());

if (maybePageable.isPresent()) {
Pageable pageable = maybePageable.get();
result = new PageImpl<>(content, pageable, searchResult.getTotalResults());
} else {
result = content;
}

} else if (queryMethod.isQueryForEntity() && !queryMethod.isCollectionQuery()) {
} else if (!queryMethod.isCollectionQuery()) {
if (searchResult.getTotalResults() > 0 && !searchResult.getDocuments().isEmpty()) {
result = ObjectUtils.documentToObject(searchResult.getDocuments().get(0), queryMethod.getReturnedObjectType(),
mappingConverter);
result = ObjectUtils.documentToObject(searchResult.getDocuments().get(0), queryMethod.getReturnedObjectType(), mappingConverter);
} else {
result = null;
}
} else if (queryMethod.isQueryForEntity() && queryMethod.isCollectionQuery()) {
result = searchResult.getDocuments().stream() //
.map(d -> ObjectUtils.documentToObject(d, queryMethod.getReturnedObjectType(), mappingConverter)) //
.collect(Collectors.toList());
} else if (queryMethod.isCollectionQuery()) {
result = searchResult.getDocuments().stream()
.map(d -> ObjectUtils.documentToObject(d, queryMethod.getReturnedObjectType(), mappingConverter))
.collect(Collectors.toList());
} else {
result = null;
}

return result;
return processor.processResult(result);
}

private Object executeDeleteQuery(Object[] parameters) {
SearchOperations<String> ops = modulesOperations.opsForSearch(searchIndex);
String baseQuery = prepareQuery(parameters, true);
AggregationBuilder aggregation = new AggregationBuilder(baseQuery);

// Load fields with IS_NULL or IS_NOT_NULL query clauses
String[] fields = Stream.concat(Stream.of("@__key"), queryOrParts.stream().flatMap(List::stream)
.filter(pair -> pair.getSecond() == QueryClause.IS_NULL || pair.getSecond() == QueryClause.IS_NOT_NULL)
.map(pair -> String.format("@%s", pair.getFirst()))).toArray(String[]::new);
aggregation.load(fields);

// Apply exists or !exists filter for null parameters
for (List<Pair<String, QueryClause>> orPartParts : queryOrParts) {
for (Pair<String, QueryClause> pair : orPartParts) {
if (pair.getSecond() == QueryClause.IS_NULL) {
aggregation.filter("!exists(@" + pair.getFirst() + ")");
} else if (pair.getSecond() == QueryClause.IS_NOT_NULL) {
aggregation.filter("exists(@" + pair.getFirst() + ")");
}
}
}

aggregation.sortBy(aggregationSortedFields.toArray(new SortedField[] {}));
aggregation.limit(0, redisOMProperties.getRepository().getQuery().getLimit());

// Execute the aggregation query
AggregationResult aggregationResult = ops.aggregate(aggregation);

// extract the keys from the aggregation result
List<String> keys = aggregationResult.getResults().stream().map(d -> d.get("__key").toString()).toList();

// determine if we need to return the deleted entities or just obtain the keys
Class<?> returnType = queryMethod.getReturnedObjectType();
if (Number.class.isAssignableFrom(returnType) || returnType.equals(int.class) || returnType.equals(long.class) || returnType.equals(short.class)) {
// return the number of deleted entities, so we only need the ids
if (keys.isEmpty()) {
return 0;
} else {
return modulesOperations.template().delete(keys);
}
} else {
if (keys.isEmpty()) {
return Collections.emptyList();
} else {
// return the deleted entities
var entities = new ArrayList<>();

redisOperations.executePipelined((RedisCallback<Map<byte[], Map<byte[], byte[]>>>) connection -> {
for (String key : keys) {
connection.hashCommands().hGetAll(key.getBytes());
}

List<Object> results = connection.closePipeline();

for (Object result : results) {
Map<byte[], byte[]> hashMap = (Map<byte[], byte[]>) result;
Object entity = mappingConverter.read(returnType, new RedisData(hashMap));
entities.add(entity);
}
return null;
});
modulesOperations.template().delete(keys);

return entities;
}
}
}

private Object executeAggregation(Object[] parameters) {
Expand Down
Loading

0 comments on commit 49d999e

Please sign in to comment.