Skip to content

Commit

Permalink
Normalize OR and IN expressions referencing the same symbol as IN
Browse files Browse the repository at this point in the history
  • Loading branch information
chenjian2664 authored and kokosing committed Oct 26, 2023
1 parent 306bb37 commit 8ffbf52
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMultimap;
import com.google.common.collect.ImmutableSet;
import io.trino.sql.tree.ComparisonExpression;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.ExpressionRewriter;
Expand All @@ -22,19 +24,17 @@
import io.trino.sql.tree.InPredicate;
import io.trino.sql.tree.LogicalExpression;

import java.util.LinkedHashMap;
import java.util.Collection;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static io.trino.sql.ExpressionUtils.and;
import static io.trino.sql.ExpressionUtils.or;
import static io.trino.sql.tree.ComparisonExpression.Operator.EQUAL;
import static io.trino.sql.tree.LogicalExpression.Operator.AND;
import static java.util.stream.Collectors.groupingBy;
import static java.util.stream.Collectors.mapping;

public final class NormalizeOrExpressionRewriter
{
Expand All @@ -59,35 +59,70 @@ public Expression rewriteLogicalExpression(LogicalExpression node, Void context,
return and(terms);
}

List<InPredicate> comparisons = terms.stream()
.filter(NormalizeOrExpressionRewriter::isEqualityComparisonExpression)
.map(ComparisonExpression.class::cast)
.collect(groupingBy(
ComparisonExpression::getLeft,
LinkedHashMap::new,
mapping(ComparisonExpression::getRight, Collectors.toList())))
.entrySet().stream()
.filter(entry -> entry.getValue().size() > 1)
.map(entry -> new InPredicate(entry.getKey(), new InListExpression(entry.getValue())))
.collect(Collectors.toList());
ImmutableList.Builder<InPredicate> inPredicateBuilder = ImmutableList.builder();
ImmutableSet.Builder<Expression> expressionToSkipBuilder = ImmutableSet.builder();
ImmutableList.Builder<Expression> othersExpressionBuilder = ImmutableList.builder();
groupComparisonAndInPredicate(terms).forEach((expression, values) -> {
if (values.size() > 1) {
inPredicateBuilder.add(new InPredicate(expression, mergeToInListExpression(values)));
expressionToSkipBuilder.add(expression);
}
});

Set<Expression> expressionToSkip = comparisons.stream()
.map(InPredicate::getValue)
.collect(toImmutableSet());

List<Expression> others = terms.stream()
.filter(expression -> !isEqualityComparisonExpression(expression) || !expressionToSkip.contains(((ComparisonExpression) expression).getLeft()))
.collect(Collectors.toList());
Set<Expression> expressionToSkip = expressionToSkipBuilder.build();
for (Expression expression : terms) {
if (expression instanceof ComparisonExpression comparisonExpression && comparisonExpression.getOperator() == EQUAL) {
if (!expressionToSkip.contains(comparisonExpression.getLeft())) {
othersExpressionBuilder.add(expression);
}
}
else if (expression instanceof InPredicate inPredicate && inPredicate.getValueList() instanceof InListExpression) {
if (!expressionToSkip.contains(inPredicate.getValue())) {
othersExpressionBuilder.add(expression);
}
}
else {
othersExpressionBuilder.add(expression);
}
}

return or(ImmutableList.<Expression>builder()
.addAll(others)
.addAll(comparisons)
.addAll(othersExpressionBuilder.build())
.addAll(inPredicateBuilder.build())
.build());
}
}

private static boolean isEqualityComparisonExpression(Expression expression)
{
return expression instanceof ComparisonExpression && ((ComparisonExpression) expression).getOperator() == EQUAL;
private InListExpression mergeToInListExpression(Collection<Expression> expressions)
{
LinkedHashSet<Expression> expressionValues = new LinkedHashSet<>();
for (Expression expression : expressions) {
if (expression instanceof ComparisonExpression comparisonExpression && comparisonExpression.getOperator() == EQUAL) {
expressionValues.add(comparisonExpression.getRight());
}
else if (expression instanceof InPredicate inPredicate && inPredicate.getValueList() instanceof InListExpression valueList) {
expressionValues.addAll(valueList.getValues());
}
else {
throw new IllegalStateException("Unexpected expression: " + expression);
}
}

return new InListExpression(ImmutableList.copyOf(expressionValues));
}

private Map<Expression, Collection<Expression>> groupComparisonAndInPredicate(List<Expression> terms)
{
ImmutableMultimap.Builder<Expression, Expression> expressionBuilder = ImmutableMultimap.builder();
for (Expression expression : terms) {
if (expression instanceof ComparisonExpression comparisonExpression && comparisonExpression.getOperator() == EQUAL) {
expressionBuilder.put(comparisonExpression.getLeft(), comparisonExpression);
}
else if (expression instanceof InPredicate inPredicate && inPredicate.getValueList() instanceof InListExpression) {
expressionBuilder.put(inPredicate.getValue(), inPredicate);
}
}

return expressionBuilder.build().asMap();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -383,10 +383,23 @@ public void testPushesDownNegationsNumericTypes()
public void testRewriteOrExpression()
{
assertSimplifiesNumericTypes("I1 = 1 OR I1 = 2 ", "I1 IN (1, 2)");
// TODO: Implement rule for Merging IN expression
assertSimplifiesNumericTypes("I1 = 1 OR I1 = 2 OR I1 IN (3, 4)", "I1 IN (3, 4) OR I1 IN (1, 2)");
assertSimplifiesNumericTypes("I1 = 1 OR I1 = 2 OR I1 IN (3, 4)", "I1 IN (1, 2, 3, 4)");
assertSimplifiesNumericTypes("I1 = 1 OR I1 = 2 OR I1 = I2", "I1 IN (1, 2, I2)");
assertSimplifiesNumericTypes("I1 = 1 OR I1 = 2 OR I2 = 3 OR I2 = 4", "I1 IN (1, 2) OR I2 IN (3, 4)");
assertSimplifiesNumericTypes("I1 = 1 OR I1 IN (1, 2)", "I1 IN (1, 2)");
assertSimplifiesNumericTypes("I1 = 1 OR I2 IN (1, 2) OR I2 IN (2, 3)", "I1 = 1 OR I2 IN (1, 2, 3)");
assertSimplifiesNumericTypes("I1 IN (1)", "I1 = 1");
assertSimplifiesNumericTypes("I1 = 1 OR I1 IN (1)", "I1 = 1");
assertSimplifiesNumericTypes("I1 = 1 OR I1 IN (2)", "I1 IN (1, 2)");
assertSimplifiesNumericTypes("I1 IN (1, 2) OR I1 = 1", "I1 IN (1, 2)");
assertSimplifiesNumericTypes("I1 IN (1, 2) OR I2 = 1 OR I1 = 3 OR I2 = 4", "I1 IN (1, 2, 3) OR I2 IN (1, 4)");
assertSimplifiesNumericTypes("I1 IN (1, 2) OR I1 = 3 OR I1 IN (4, 5, 6) OR I2 = 3 OR I2 IN (3, 4)", "I1 IN (1, 2, 3, 4, 5, 6) OR I2 IN (3, 4)");

assertSimplifiesNumericTypes("I1 = 1 OR I1 = 2 OR I1 IN (3, 4) OR I1 IN (SELECT 1)", "I1 IN (1, 2, 3, 4) OR I1 IN (SELECT 1)");
assertSimplifiesNumericTypes("I1 = 1 OR I2 = 2 OR I1 = 3 OR I2 = 4", "I1 IN (1, 3) OR I2 IN (2, 4)");
assertSimplifiesNumericTypes("I1 = 1 OR I2 = 2 OR I1 = 3 OR I2 IS NULL", "I1 IN(1, 3) OR I2 = 2 OR I2 IS NULL");
assertSimplifiesNumericTypes("I1 = 1 OR I2 IN (2, 3) OR I1 = 4 OR I2 IN (5, 6)", "I1 IN (1, 4) OR I2 IN (2, 3, 5, 6)");
assertSimplifiesNumericTypes("I1 = 1 OR I2 = 2 OR I1 = 3 OR I2 = I1", "I1 IN (1, 3) OR I2 IN (2, I1)");
}

private static void assertSimplifiesNumericTypes(String expression, String expected)
Expand Down

0 comments on commit 8ffbf52

Please sign in to comment.