Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Push down sort through eval #2937

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ public static LogicalPlanOptimizer create() {
new MergeFilterAndFilter(),
new PushFilterUnderSort(),
EvalPushDown.PUSH_DOWN_LIMIT,
EvalPushDown.PUSH_DOWN_SORT,
/*
* Phase 2: Transformations that rely on data source push down capability
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,29 @@

import static org.opensearch.sql.planner.optimizer.pattern.Patterns.evalCapture;
import static org.opensearch.sql.planner.optimizer.pattern.Patterns.limit;
import static org.opensearch.sql.planner.optimizer.pattern.Patterns.sort;
import static org.opensearch.sql.planner.optimizer.rule.EvalPushDown.EvalPushDownBuilder.match;

import com.facebook.presto.matching.Capture;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.matching.pattern.CapturePattern;
import com.facebook.presto.matching.pattern.WithPattern;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.function.BiFunction;
import java.util.stream.Collectors;
import lombok.Getter;
import lombok.experimental.Accessors;
import org.apache.commons.lang3.tuple.Pair;
import org.opensearch.sql.ast.tree.Sort.SortOption;
import org.opensearch.sql.expression.Expression;
import org.opensearch.sql.expression.ReferenceExpression;
import org.opensearch.sql.planner.logical.LogicalEval;
import org.opensearch.sql.planner.logical.LogicalLimit;
import org.opensearch.sql.planner.logical.LogicalPlan;
import org.opensearch.sql.planner.logical.LogicalSort;
import org.opensearch.sql.planner.optimizer.Rule;

/**
Expand All @@ -42,6 +51,38 @@ public class EvalPushDown<T extends LogicalPlan> implements Rule<T> {
return logicalEval;
});

public static final Rule<LogicalSort> PUSH_DOWN_SORT =
match(sort(evalCapture()))
.apply(
(sort, logicalEval) -> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks this Lambda is too large. Please extract to a method or a class. And extract some portion to reduce the size of method for readability.

List<LogicalPlan> child = logicalEval.getChild();
Map<ReferenceExpression, Expression> evalExpressionMap =
logicalEval.getExpressions().stream()
.collect(Collectors.toMap(Pair::getKey, Pair::getValue));
List<Pair<SortOption, Expression>> sortList = sort.getSortList();
List<Pair<SortOption, Expression>> newSortList = new ArrayList<>();
for (Pair<SortOption, Expression> pair : sortList) {
/*
Narrow down the optimization to only support:
1. The expression in sort and replaced expression are both ReferenceExpression.
2. No internal reference in eval.
*/
if (pair.getRight() instanceof ReferenceExpression) {
ReferenceExpression ref = (ReferenceExpression) pair.getRight();
Expression replacedExpr = evalExpressionMap.getOrDefault(ref, ref);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it right to return ref as default value?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Logically, it's right I think, though the name of replacedExpr may be a little bit misleading. It includes 2 cases:

  1. the ref is produced by eval operator, then it needs to be replaced
  2. the ref isn't produced by eval operator, then we don't need to replace it. That's why I use the default value ref.

How about changing the name of replacedExpr to newExpr?

if (replacedExpr instanceof ReferenceExpression) {
ReferenceExpression newRef = (ReferenceExpression) replacedExpr;
if (!evalExpressionMap.containsKey(newRef)) {
newSortList.add(Pair.of(pair.getLeft(), newRef));
} else return sort;
} else return sort;
} else return sort;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can return sort at the end of for block once.

Copy link
Contributor Author

@qianheng-aws qianheng-aws Aug 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then we need a flag to indicate returning sort or not for each iteration.

It could be simplified by pattern matching with java version>=14. But since we also need to support java11 for 2.x, I don't use that feature in order to keep code align.

}
sort = new LogicalSort(child.getFirst(), newSortList);
logicalEval.replaceChildPlans(List.of(sort));
return logicalEval;
});

private final Capture<LogicalEval> capture;

@Accessors(fluent = true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import org.mockito.Spy;
import org.mockito.junit.jupiter.MockitoExtension;
import org.opensearch.sql.ast.tree.Sort;
import org.opensearch.sql.ast.tree.Sort.SortOption;
import org.opensearch.sql.data.type.ExprType;
import org.opensearch.sql.expression.DSL;
import org.opensearch.sql.expression.Expression;
Expand Down Expand Up @@ -368,6 +369,48 @@ void push_limit_through_eval_into_scan() {
optimize(limit(eval(relation("schema", table), evalExpr), 10, 5)));
}

/** Sort - Eval --> Eval - Sort. */
@Test
void push_sort_under_eval() {
ReferenceExpression sortRef = DSL.ref("intV", INTEGER);
ReferenceExpression evalRef = DSL.ref("name1", INTEGER);
Pair<ReferenceExpression, Expression> evalExpr = Pair.of(evalRef, DSL.ref("name", STRING));
Pair<SortOption, Expression> sortExpr = Pair.of(Sort.SortOption.DEFAULT_ASC, sortRef);
assertEquals(
eval(sort(tableScanBuilder, sortExpr), evalExpr),
optimize(sort(eval(relation("schema", table), evalExpr), sortExpr)));

// don't push sort if sort field is not ReferenceExpression
Expression nonRefExpr = DSL.add(DSL.ref("intV", INTEGER), DSL.literal(1));
Pair<SortOption, Expression> sortExprWithNonRef =
Pair.of(Sort.SortOption.DEFAULT_ASC, nonRefExpr);
LogicalPlan originPlan = sort(eval(relation("schema", table), evalExpr), sortExprWithNonRef);
assertEquals(originPlan, optimize(originPlan));

// don't push sort if replaced expr in eval is not ReferenceExpression
Pair<ReferenceExpression, Expression> evalExprWithNonRef = Pair.of(sortRef, nonRefExpr);
originPlan = sort(eval(relation("schema", table), evalExprWithNonRef), sortExpr);
assertEquals(originPlan, optimize(originPlan));

// don't push sort if there are internal reference in eval
Pair<ReferenceExpression, Expression> evalExpr2 = Pair.of(sortRef, evalRef);
originPlan = sort(eval(relation("schema", table), evalExpr, evalExpr2), sortExpr);
assertEquals(originPlan, optimize(originPlan));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please split the method for each test case.

}

/** Sort - Eval - Scan --> Eval - Scan. */
@Test
void push_sort_through_eval_into_scan() {
when(tableScanBuilder.pushDownSort(any())).thenReturn(true);
Pair<ReferenceExpression, Expression> evalExpr =
Pair.of(DSL.ref("name1", STRING), DSL.ref("name", STRING));
Pair<SortOption, Expression> sortExpr =
Pair.of(Sort.SortOption.DEFAULT_ASC, DSL.ref("intV", INTEGER));
assertEquals(
eval(tableScanBuilder, evalExpr),
optimize(sort(eval(relation("schema", table), evalExpr), sortExpr)));
}

private LogicalPlan optimize(LogicalPlan plan) {
final LogicalPlanOptimizer optimizer = LogicalPlanOptimizer.create();
return optimizer.optimize(plan);
Expand Down
13 changes: 13 additions & 0 deletions integ-test/src/test/java/org/opensearch/sql/ppl/ExplainIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,19 @@ public void testLimitPushDownExplain() throws Exception {
+ "| fields ageMinus"));
}

@Test
public void testSortPushDownThroughEvalExplain() throws Exception {
String expected = loadFromFile("expectedOutput/ppl/explain_sort_push_through_eval.json");

assertJsonEquals(
expected,
explainQueryToString(
"source=opensearch-sql_test_index_account"
+ "| eval newAge = age"
+ "| sort newAge"
+ "| fields newAge"));
}

String loadFromFile(String filename) throws Exception {
URI uri = Resources.getResource(filename).toURI();
return new String(Files.readAllBytes(Paths.get(uri)));
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
{
"root": {
"name": "ProjectOperator",
"description": {
"fields": "[newAge]"
},
"children": [
{
"name": "EvalOperator",
"description": {
"expressions": {
"newAge": "age"
}
},
"children": [
{
"name": "OpenSearchIndexScan",
"description": {
"request": "OpenSearchQueryRequest(indexName=opensearch-sql_test_index_account, sourceBuilder={\"from\":0,\"size\":10000,\"timeout\":\"1m\",\"sort\":[{\"age\":{\"order\":\"asc\",\"missing\":\"_first\"}}]}, searchDone=false)"
},
"children": []
}
]
}
]
}
}
Loading