From ed9bd087f8992b39cf45394fdf430ad4d9b22882 Mon Sep 17 00:00:00 2001 From: Francesco Capponi Date: Sat, 22 May 2021 02:31:27 -0400 Subject: [PATCH 1/2] Adding IN feature in new Engine --- .../sql/ast/dsl/AstDSL.java | 4 ++ .../sql/ast/expression/DataType.java | 1 + .../sql/data/model/ExprValueUtils.java | 6 +- .../sql/expression/DSL.java | 21 +++++- .../function/BuiltinFunctionName.java | 1 + .../predicate/BinaryPredicateOperator.java | 28 ++++++++ .../sql/utils/OperatorUtils.java | 22 ++++++ .../sql/data/model/ExprValueUtilsTest.java | 15 +++- .../BinaryPredicateOperatorTest.java | 70 +++++++++++++++++- .../script/filter/FilterQueryBuilder.java | 2 + .../script/filter/lucene/TermsQuery.java | 39 ++++++++++ .../script/filter/FilterQueryBuilderTest.java | 39 ++++++++++ .../sql/sql/InIT.java | 71 +++++++++++++++++++ sql/src/main/antlr/OpenDistroSQLParser.g4 | 9 +++ .../sql/sql/parser/AstExpressionBuilder.java | 24 +++++-- .../sql/sql/antlr/SQLSyntaxParserTest.java | 7 ++ .../sql/parser/AstExpressionBuilderTest.java | 15 ++++ 17 files changed, 359 insertions(+), 15 deletions(-) create mode 100644 elasticsearch/src/main/java/com/amazon/opendistroforelasticsearch/sql/elasticsearch/storage/script/filter/lucene/TermsQuery.java create mode 100644 integ-test/src/test/java/com/amazon/opendistroforelasticsearch/sql/sql/InIT.java diff --git a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/ast/dsl/AstDSL.java b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/ast/dsl/AstDSL.java index 8c15f71bd9..2c756faec1 100644 --- a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/ast/dsl/AstDSL.java +++ b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/ast/dsl/AstDSL.java @@ -161,6 +161,10 @@ public static Literal timestampLiteral(String value) { return literal(value, DataType.TIMESTAMP); } + public static Literal arrayLiteral(List value) { + return literal(value, DataType.ARRAY); + } + public static Literal doubleLiteral(Double value) { return literal(value, DataType.DOUBLE); } diff --git a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/ast/expression/DataType.java b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/ast/expression/DataType.java index 13e6b422bd..5b68d15caa 100644 --- a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/ast/expression/DataType.java +++ b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/ast/expression/DataType.java @@ -37,6 +37,7 @@ public enum DataType { DATE(ExprCoreType.DATE), TIME(ExprCoreType.TIME), TIMESTAMP(ExprCoreType.TIMESTAMP), + ARRAY(ExprCoreType.ARRAY), INTERVAL(ExprCoreType.INTERVAL); @Getter diff --git a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/data/model/ExprValueUtils.java b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/data/model/ExprValueUtils.java index 976f2de8bb..04aaf2b6f0 100644 --- a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/data/model/ExprValueUtils.java +++ b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/data/model/ExprValueUtils.java @@ -88,7 +88,7 @@ public static ExprValue tupleValue(Map map) { /** * {@link ExprCollectionValue} constructor. */ - public static ExprValue collectionValue(List list) { + public static ExprValue collectionValue(List list) { List valueList = new ArrayList<>(); list.forEach(o -> valueList.add(fromObjectValue(o))); return new ExprCollectionValue(valueList); @@ -109,7 +109,9 @@ public static ExprValue fromObjectValue(Object o) { if (null == o) { return LITERAL_NULL; } - if (o instanceof Map) { + if (o instanceof ExprValue) { + return (ExprValue) o; + } else if (o instanceof Map) { return tupleValue((Map) o); } else if (o instanceof List) { return collectionValue(((List) o)); diff --git a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/DSL.java b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/DSL.java index 0d03ddc536..467499cdd7 100644 --- a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/DSL.java +++ b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/DSL.java @@ -26,8 +26,10 @@ import com.amazon.opendistroforelasticsearch.sql.expression.function.BuiltinFunctionName; import com.amazon.opendistroforelasticsearch.sql.expression.function.BuiltinFunctionRepository; import com.amazon.opendistroforelasticsearch.sql.expression.window.ranking.RankingWindowFunction; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; +import java.util.List; import lombok.RequiredArgsConstructor; @RequiredArgsConstructor @@ -252,7 +254,7 @@ public FunctionExpression subtract(Expression... expressions) { public FunctionExpression multiply(Expression... expressions) { return function(BuiltinFunctionName.MULTIPLY, expressions); } - + public FunctionExpression adddate(Expression... expressions) { return function(BuiltinFunctionName.ADDDATE, expressions); } @@ -364,7 +366,7 @@ public FunctionExpression module(Expression... expressions) { public FunctionExpression substr(Expression... expressions) { return function(BuiltinFunctionName.SUBSTR, expressions); } - + public FunctionExpression substring(Expression... expressions) { return function(BuiltinFunctionName.SUBSTR, expressions); } @@ -588,4 +590,19 @@ public FunctionExpression castTimestamp(Expression value) { return (FunctionExpression) repository .compile(BuiltinFunctionName.CAST_TO_TIMESTAMP.getName(), Arrays.asList(value)); } + + /** + * Check that a field is contained in a set of values. + */ + public FunctionExpression in(Expression field, Expression... expressions) { + List where = new ArrayList<>(); + where.add(field); + where.addAll(Arrays.asList(expressions)); + + return function(BuiltinFunctionName.IN, where.toArray(new Expression[0])); + } + + public FunctionExpression not_in(Expression field, Expression... expressions) { + return not(in(field, expressions)); + } } diff --git a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/function/BuiltinFunctionName.java b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/function/BuiltinFunctionName.java index 6b29c68da1..1e4c245b52 100644 --- a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/function/BuiltinFunctionName.java +++ b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/function/BuiltinFunctionName.java @@ -106,6 +106,7 @@ public enum BuiltinFunctionName { GTE(FunctionName.of(">=")), LIKE(FunctionName.of("like")), NOT_LIKE(FunctionName.of("not like")), + IN(FunctionName.of("in")), /** * Aggregation Function. diff --git a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/operator/predicate/BinaryPredicateOperator.java b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/operator/predicate/BinaryPredicateOperator.java index d08c3fab8f..67e4fdd64a 100644 --- a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/operator/predicate/BinaryPredicateOperator.java +++ b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/operator/predicate/BinaryPredicateOperator.java @@ -19,9 +19,16 @@ import static com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils.LITERAL_MISSING; import static com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils.LITERAL_NULL; import static com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils.LITERAL_TRUE; +import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.ARRAY; import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.BOOLEAN; +import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.DATE; +import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.DATETIME; +import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.DOUBLE; +import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.FLOAT; import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.INTEGER; +import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.LONG; import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.STRING; +import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.TIMESTAMP; import com.amazon.opendistroforelasticsearch.sql.data.model.ExprBooleanValue; import com.amazon.opendistroforelasticsearch.sql.data.model.ExprValue; @@ -63,6 +70,7 @@ public static void register(BuiltinFunctionRepository repository) { repository.register(like()); repository.register(notLike()); repository.register(regexp()); + repository.register(in()); } /** @@ -262,6 +270,26 @@ private static FunctionResolver notLike() { STRING)); } + private static FunctionResolver in() { + return FunctionDSL.define(BuiltinFunctionName.IN.getName(), + FunctionDSL.impl(FunctionDSL.nullMissingHandling(OperatorUtils::in), + BOOLEAN, INTEGER, ARRAY), + FunctionDSL.impl(FunctionDSL.nullMissingHandling(OperatorUtils::in), + BOOLEAN, STRING, ARRAY), + FunctionDSL.impl(FunctionDSL.nullMissingHandling(OperatorUtils::in), + BOOLEAN, LONG, ARRAY), + FunctionDSL.impl(FunctionDSL.nullMissingHandling(OperatorUtils::in), + BOOLEAN, FLOAT, ARRAY), + FunctionDSL.impl(FunctionDSL.nullMissingHandling(OperatorUtils::in), + BOOLEAN, DOUBLE, ARRAY), + FunctionDSL.impl(FunctionDSL.nullMissingHandling(OperatorUtils::in), + BOOLEAN, DATE, ARRAY), + FunctionDSL.impl(FunctionDSL.nullMissingHandling(OperatorUtils::in), + BOOLEAN, DATETIME, ARRAY), + FunctionDSL.impl(FunctionDSL.nullMissingHandling(OperatorUtils::in), + BOOLEAN, TIMESTAMP, ARRAY)); + } + private static ExprValue lookupTableFunction(ExprValue arg1, ExprValue arg2, Table table) { if (table.contains(arg1, arg2)) { diff --git a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/utils/OperatorUtils.java b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/utils/OperatorUtils.java index d887d5c391..38af8ae827 100644 --- a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/utils/OperatorUtils.java +++ b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/utils/OperatorUtils.java @@ -15,8 +15,10 @@ package com.amazon.opendistroforelasticsearch.sql.utils; +import com.amazon.opendistroforelasticsearch.sql.data.model.AbstractExprNumberValue; import com.amazon.opendistroforelasticsearch.sql.data.model.ExprBooleanValue; import com.amazon.opendistroforelasticsearch.sql.data.model.ExprIntegerValue; +import com.amazon.opendistroforelasticsearch.sql.data.model.ExprStringValue; import com.amazon.opendistroforelasticsearch.sql.data.model.ExprValue; import java.util.regex.Pattern; import lombok.experimental.UtilityClass; @@ -99,4 +101,24 @@ private static String patternToRegex(String patternString) { regex.append('$'); return regex.toString(); } + + + /** + * IN (..., ...) operator util. + * Expression { expr IN (collection of values..) } is to judge + * if expr is contained in a given collection. + */ + public static ExprBooleanValue in(ExprValue expr, ExprValue setOfValues) { + return ExprBooleanValue.of(isIn(expr, setOfValues)); + } + + private static boolean isIn(ExprValue expr, ExprValue setOfValues) { + if (expr instanceof AbstractExprNumberValue) { + return setOfValues.collectionValue().contains(expr.longValue()); + } else if (expr instanceof ExprStringValue) { + return setOfValues.collectionValue().contains(expr.stringValue()); + } else { + return setOfValues.collectionValue().contains(expr); + } + } } diff --git a/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/data/model/ExprValueUtilsTest.java b/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/data/model/ExprValueUtilsTest.java index 4c97632958..66cf4fc9d2 100644 --- a/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/data/model/ExprValueUtilsTest.java +++ b/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/data/model/ExprValueUtilsTest.java @@ -215,13 +215,24 @@ public void invalidConvertExprValue(ExprValue value, Function assertThat(exception.getMessage(), Matchers.containsString("invalid")); } + // disabling test because in case of expr collections, we could pass ExprValues + // @Test + // public void unSupportedObject() { + // Exception exception = assertThrows(ExpressionEvaluationException.class, + // () -> ExprValueUtils.fromObjectValue(integerValue(1))); + // assertEquals( + // "unsupported object " + // + "class com.amazon.opendistroforelasticsearch.sql.data.model.ExprIntegerValue", + // exception.getMessage()); + // } + @Test public void unSupportedObject() { Exception exception = assertThrows(ExpressionEvaluationException.class, - () -> ExprValueUtils.fromObjectValue(integerValue(1))); + () -> ExprValueUtils.fromObjectValue(new Object())); assertEquals( "unsupported object " - + "class com.amazon.opendistroforelasticsearch.sql.data.model.ExprIntegerValue", + + "class java.lang.Object", exception.getMessage()); } diff --git a/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/expression/operator/predicate/BinaryPredicateOperatorTest.java b/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/expression/operator/predicate/BinaryPredicateOperatorTest.java index aa7402142c..7efd0dca46 100644 --- a/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/expression/operator/predicate/BinaryPredicateOperatorTest.java +++ b/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/expression/operator/predicate/BinaryPredicateOperatorTest.java @@ -27,14 +27,17 @@ import static com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils.LITERAL_TRUE; import static com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils.booleanValue; import static com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils.fromObjectValue; -import static com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils.missingValue; import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.BOOLEAN; +import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.DATE; +import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.DATETIME; import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.INTEGER; import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.STRING; +import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.TIMESTAMP; import static com.amazon.opendistroforelasticsearch.sql.utils.ComparisonUtil.compare; import static com.amazon.opendistroforelasticsearch.sql.utils.OperatorUtils.matches; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import com.amazon.opendistroforelasticsearch.sql.data.model.ExprBooleanValue; @@ -49,14 +52,15 @@ import com.amazon.opendistroforelasticsearch.sql.data.model.ExprTupleValue; import com.amazon.opendistroforelasticsearch.sql.data.model.ExprValue; import com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils; +import com.amazon.opendistroforelasticsearch.sql.exception.ExpressionEvaluationException; import com.amazon.opendistroforelasticsearch.sql.expression.DSL; import com.amazon.opendistroforelasticsearch.sql.expression.Expression; import com.amazon.opendistroforelasticsearch.sql.expression.ExpressionTestBase; import com.amazon.opendistroforelasticsearch.sql.expression.FunctionExpression; +import com.amazon.opendistroforelasticsearch.sql.utils.OperatorUtils; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Lists; -import com.sun.org.apache.xpath.internal.Arg; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.ObjectInputStream; @@ -832,4 +836,66 @@ public void compare_int_long() { FunctionExpression equal = dsl.equal(DSL.literal(1), DSL.literal(1L)); assertTrue(equal.valueOf(valueEnv()).booleanValue()); } + + private static Stream testInArguments() { + List arguments = + Arrays.asList(Arrays.asList(1, Arrays.asList(0, 2, 1, 3)), + Arrays.asList(1, Arrays.asList(2, 0)), Arrays.asList(1L, Arrays.asList(1L, 2L, 3L)), + Arrays.asList(2L, Arrays.asList(1L, 2L)), Arrays.asList(3F, Arrays.asList(1F, 2F)), + Arrays.asList(0F, Arrays.asList(1F, 2F)), Arrays.asList(1D, Arrays.asList(1D, 1D)), + Arrays.asList(1D, Arrays.asList(2D, 2D)), + Arrays.asList("b", Arrays.asList("a", "c")), + Arrays.asList("b", Arrays.asList("c", "a")), + Arrays.asList("a", Arrays.asList("a", "b")), + Arrays.asList("b", Arrays.asList("a", "b")), + Arrays.asList("c", Arrays.asList("a", "b")), + Arrays.asList("a", Arrays.asList("b", "c")), + Arrays.asList("a", Arrays.asList("a", "a")), + Arrays.asList("b", Arrays.asList("a", "a"))); + + Stream.Builder builder = Stream.builder(); + for (List argGroup : arguments) { + builder.add(Arguments.of(fromObjectValue(argGroup.get(0)), fromObjectValue(argGroup.get(1)))); + } + builder + .add(Arguments.of(fromObjectValue("2021-01-02", DATE), + fromObjectValue(Arrays.asList(fromObjectValue("2021-01-01", DATE), + fromObjectValue("2021-01-03", DATE))))) + .add(Arguments.of(fromObjectValue("2021-01-02", DATE), + fromObjectValue(Arrays.asList(fromObjectValue("2021-01-01", DATE), + fromObjectValue("2021-01-03", DATE))))) + .add(Arguments.of(fromObjectValue("2021-01-01 03:00:00", DATETIME), + fromObjectValue(Arrays.asList(fromObjectValue("2021-01-01 01:00:00", DATETIME), + fromObjectValue("3021-01-01 02:00:00", DATETIME))))) + .add(Arguments.of(fromObjectValue("2021-01-01 01:00:00", TIMESTAMP), + fromObjectValue(Arrays.asList(fromObjectValue("2021-01-01 01:00:00", TIMESTAMP), + fromObjectValue("3021-01-01 01:00:00", TIMESTAMP))))); + return builder.build(); + } + + @ParameterizedTest(name = "in({0}, ({1}))") + @MethodSource("testInArguments") + public void in(ExprValue field, ExprValue arrayOfArgs) { + FunctionExpression in = dsl.in( + DSL.literal(field), DSL.literal(arrayOfArgs)); + assertEquals(BOOLEAN, in.type()); + assertEquals(OperatorUtils.in(field, arrayOfArgs), in.valueOf(valueEnv())); + } + + @ParameterizedTest(name = "not in({0}, ({1}))") + @MethodSource("testInArguments") + public void not_in(ExprValue field, ExprValue arrayOfArgs) { + FunctionExpression notIn = dsl.not_in( + DSL.literal(field), DSL.literal(arrayOfArgs)); + assertEquals(BOOLEAN, notIn.type()); + assertEquals(!OperatorUtils.in(field, arrayOfArgs).booleanValue(), + notIn.valueOf(valueEnv()).booleanValue()); + } + + @Test + public void in_not_an_array() { + assertThrows(ExpressionEvaluationException.class, () -> + dsl.in(DSL.literal(1), DSL.literal("1"))); + } + } \ No newline at end of file diff --git a/elasticsearch/src/main/java/com/amazon/opendistroforelasticsearch/sql/elasticsearch/storage/script/filter/FilterQueryBuilder.java b/elasticsearch/src/main/java/com/amazon/opendistroforelasticsearch/sql/elasticsearch/storage/script/filter/FilterQueryBuilder.java index 4cc8be3512..af2dcf81e7 100644 --- a/elasticsearch/src/main/java/com/amazon/opendistroforelasticsearch/sql/elasticsearch/storage/script/filter/FilterQueryBuilder.java +++ b/elasticsearch/src/main/java/com/amazon/opendistroforelasticsearch/sql/elasticsearch/storage/script/filter/FilterQueryBuilder.java @@ -24,6 +24,7 @@ import com.amazon.opendistroforelasticsearch.sql.elasticsearch.storage.script.filter.lucene.RangeQuery; import com.amazon.opendistroforelasticsearch.sql.elasticsearch.storage.script.filter.lucene.RangeQuery.Comparison; import com.amazon.opendistroforelasticsearch.sql.elasticsearch.storage.script.filter.lucene.TermQuery; +import com.amazon.opendistroforelasticsearch.sql.elasticsearch.storage.script.filter.lucene.TermsQuery; import com.amazon.opendistroforelasticsearch.sql.elasticsearch.storage.script.filter.lucene.WildcardQuery; import com.amazon.opendistroforelasticsearch.sql.elasticsearch.storage.serialization.ExpressionSerializer; import com.amazon.opendistroforelasticsearch.sql.expression.Expression; @@ -63,6 +64,7 @@ public class FilterQueryBuilder extends ExpressionNodeVisitor 1)"; + JSONObject results = executeQuery(sql); + Assert.assertThat(getTotalHits(results), equalTo(1)); + } + +} diff --git a/sql/src/main/antlr/OpenDistroSQLParser.g4 b/sql/src/main/antlr/OpenDistroSQLParser.g4 index 4f01c657c9..01967fbcdc 100644 --- a/sql/src/main/antlr/OpenDistroSQLParser.g4 +++ b/sql/src/main/antlr/OpenDistroSQLParser.g4 @@ -260,6 +260,7 @@ predicate | predicate IS nullNotnull #isNullPredicate | left=predicate NOT? LIKE right=predicate #likePredicate | left=predicate REGEXP right=predicate #regexpPredicate + | predicate NOT? IN LR_BRACKET arrayArgs? RR_BRACKET #inPredicate ; expressionAtom @@ -370,3 +371,11 @@ functionArg : expression ; +arrayArgs + : arrayArg (COMMA arrayArg)* + ; + +arrayArg + : expression + ; + diff --git a/sql/src/main/java/com/amazon/opendistroforelasticsearch/sql/sql/parser/AstExpressionBuilder.java b/sql/src/main/java/com/amazon/opendistroforelasticsearch/sql/sql/parser/AstExpressionBuilder.java index 1fefc0ddfb..a20c7e86ab 100644 --- a/sql/src/main/java/com/amazon/opendistroforelasticsearch/sql/sql/parser/AstExpressionBuilder.java +++ b/sql/src/main/java/com/amazon/opendistroforelasticsearch/sql/sql/parser/AstExpressionBuilder.java @@ -16,13 +16,8 @@ package com.amazon.opendistroforelasticsearch.sql.sql.parser; -import static com.amazon.opendistroforelasticsearch.sql.ast.dsl.AstDSL.qualifiedName; -import static com.amazon.opendistroforelasticsearch.sql.ast.dsl.AstDSL.stringLiteral; -import static com.amazon.opendistroforelasticsearch.sql.expression.function.BuiltinFunctionName.IS_NOT_NULL; -import static com.amazon.opendistroforelasticsearch.sql.expression.function.BuiltinFunctionName.IS_NULL; -import static com.amazon.opendistroforelasticsearch.sql.expression.function.BuiltinFunctionName.LIKE; -import static com.amazon.opendistroforelasticsearch.sql.expression.function.BuiltinFunctionName.NOT_LIKE; -import static com.amazon.opendistroforelasticsearch.sql.expression.function.BuiltinFunctionName.REGEXP; +import static com.amazon.opendistroforelasticsearch.sql.ast.dsl.AstDSL.*; +import static com.amazon.opendistroforelasticsearch.sql.expression.function.BuiltinFunctionName.*; import static com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.BinaryComparisonPredicateContext; import static com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.BooleanContext; import static com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.CaseFuncAlternativeContext; @@ -61,8 +56,10 @@ import com.amazon.opendistroforelasticsearch.sql.ast.expression.Case; import com.amazon.opendistroforelasticsearch.sql.ast.expression.Cast; import com.amazon.opendistroforelasticsearch.sql.ast.expression.Function; +import com.amazon.opendistroforelasticsearch.sql.ast.expression.In; import com.amazon.opendistroforelasticsearch.sql.ast.expression.Interval; import com.amazon.opendistroforelasticsearch.sql.ast.expression.IntervalUnit; +import com.amazon.opendistroforelasticsearch.sql.ast.expression.Literal; import com.amazon.opendistroforelasticsearch.sql.ast.expression.Not; import com.amazon.opendistroforelasticsearch.sql.ast.expression.Or; import com.amazon.opendistroforelasticsearch.sql.ast.expression.QualifiedName; @@ -233,6 +230,19 @@ public UnresolvedExpression visitRegexpPredicate(RegexpPredicateContext ctx) { Arrays.asList(visit(ctx.left), visit(ctx.right))); } + @Override + public UnresolvedExpression visitInPredicate(OpenDistroSQLParser.InPredicateContext ctx) { + UnresolvedExpression between = new Function(IN.getName().getFunctionName(), + Arrays.asList(visit(ctx.predicate()), AstDSL.arrayLiteral( ctx.arrayArgs().arrayArg() + .stream() + .map(this::visitArrayArg) + .map(unresolvedExpression -> ((Literal) unresolvedExpression).getValue()) + .collect(Collectors.toList())))); + + return ctx.NOT() == null ? between : + new Function(NOT.getName().getFunctionName(), Collections.singletonList(between)); + } + @Override public UnresolvedExpression visitAndExpression(AndExpressionContext ctx) { return new And(visit(ctx.left), visit(ctx.right)); diff --git a/sql/src/test/java/com/amazon/opendistroforelasticsearch/sql/sql/antlr/SQLSyntaxParserTest.java b/sql/src/test/java/com/amazon/opendistroforelasticsearch/sql/sql/antlr/SQLSyntaxParserTest.java index 0f2605d7d6..d0dddf1142 100644 --- a/sql/src/test/java/com/amazon/opendistroforelasticsearch/sql/sql/antlr/SQLSyntaxParserTest.java +++ b/sql/src/test/java/com/amazon/opendistroforelasticsearch/sql/sql/antlr/SQLSyntaxParserTest.java @@ -131,6 +131,13 @@ public void canParseCaseStatement() { assertNotNull(parser.parse("SELECT CASE age WHEN 30 THEN 'age1' END FROM test")); } + @Test + public void canParseInStatement() { + assertNotNull(parser.parse("SELECT age FROM test WHERE age IN (1,30)")); + assertNotNull(parser.parse("SELECT age FROM test WHERE age NOT IN (1,30)")); + assertNotNull(parser.parse("SELECT age FROM test WHERE NOT (age IN (1,30))")); + } + @Test public void canNotParseAggregateFunctionWithWrongArgument() { assertThrows(SyntaxCheckException.class, () -> parser.parse("SELECT SUM() FROM test")); diff --git a/sql/src/test/java/com/amazon/opendistroforelasticsearch/sql/sql/parser/AstExpressionBuilderTest.java b/sql/src/test/java/com/amazon/opendistroforelasticsearch/sql/sql/parser/AstExpressionBuilderTest.java index b1ec56ce51..5df196ab66 100644 --- a/sql/src/test/java/com/amazon/opendistroforelasticsearch/sql/sql/parser/AstExpressionBuilderTest.java +++ b/sql/src/test/java/com/amazon/opendistroforelasticsearch/sql/sql/parser/AstExpressionBuilderTest.java @@ -23,6 +23,7 @@ import static com.amazon.opendistroforelasticsearch.sql.ast.dsl.AstDSL.dateLiteral; import static com.amazon.opendistroforelasticsearch.sql.ast.dsl.AstDSL.doubleLiteral; import static com.amazon.opendistroforelasticsearch.sql.ast.dsl.AstDSL.function; +import static com.amazon.opendistroforelasticsearch.sql.ast.dsl.AstDSL.in; import static com.amazon.opendistroforelasticsearch.sql.ast.dsl.AstDSL.intLiteral; import static com.amazon.opendistroforelasticsearch.sql.ast.dsl.AstDSL.intervalLiteral; import static com.amazon.opendistroforelasticsearch.sql.ast.dsl.AstDSL.longLiteral; @@ -283,6 +284,20 @@ public void canBuildLogicalExpression() { ); } + @Test + public void canBuildInPredicate() { + assertEquals(in(intLiteral(1), intLiteral(0), intLiteral(2), intLiteral(3), intLiteral(4)), + buildExprAst("1 in (0,2,3,4)")); + } + + @Test + public void canBuildNotInPredicate() { + assertEquals( + function("not", in(intLiteral(1), intLiteral(0), intLiteral(2), intLiteral(3), intLiteral(4))), + buildExprAst("1 not in (0,2,3,4)")); + } + + @Test public void canBuildWindowFunction() { assertEquals( From 7a379d82cd32a7ddec826399bf6480084dec2acb Mon Sep 17 00:00:00 2001 From: Francesco Capponi Date: Sat, 22 May 2021 13:39:16 -0400 Subject: [PATCH 2/2] IN: fixing some CSV strange behavior that in the legacy engine COUNT is returning float, and some tests for the new engine were disabled for that --- .../sql/ast/expression/In.java | 3 +++ .../sql/data/model/ExprValueUtils.java | 8 +++++--- .../sql/utils/OperatorUtils.java | 2 +- .../sql/legacy/CsvFormatResponseIT.java | 6 ++++++ .../sql/sql/parser/AstExpressionBuilder.java | 13 ++++++++++--- .../sql/sql/parser/AstExpressionBuilderTest.java | 8 +++++--- 6 files changed, 30 insertions(+), 10 deletions(-) diff --git a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/ast/expression/In.java b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/ast/expression/In.java index 365787780b..945bce841d 100644 --- a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/ast/expression/In.java +++ b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/ast/expression/In.java @@ -28,7 +28,10 @@ * Params include the field expression and/or wildcard field expression, * nested field expression (@field). * And the values that the field is mapped to (@valueList). + * + * @deprecated use function ("in") instead */ +@Deprecated @Getter @ToString @EqualsAndHashCode(callSuper = false) diff --git a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/data/model/ExprValueUtils.java b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/data/model/ExprValueUtils.java index 04aaf2b6f0..50de610b28 100644 --- a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/data/model/ExprValueUtils.java +++ b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/data/model/ExprValueUtils.java @@ -109,9 +109,7 @@ public static ExprValue fromObjectValue(Object o) { if (null == o) { return LITERAL_NULL; } - if (o instanceof ExprValue) { - return (ExprValue) o; - } else if (o instanceof Map) { + if (o instanceof Map) { return tupleValue((Map) o); } else if (o instanceof List) { return collectionValue(((List) o)); @@ -131,6 +129,10 @@ public static ExprValue fromObjectValue(Object o) { return stringValue((String) o); } else if (o instanceof Float) { return floatValue((Float) o); + } else if (o instanceof ExprValue) { + // since there is no primitive in Java for differentiating TIMESTAMP DATETIME and DATE + // we can allow passing a ExprValue that already contains this information + return (ExprValue) o; } else { throw new ExpressionEvaluationException("unsupported object " + o.getClass()); } diff --git a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/utils/OperatorUtils.java b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/utils/OperatorUtils.java index 38af8ae827..fa48fb6a57 100644 --- a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/utils/OperatorUtils.java +++ b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/utils/OperatorUtils.java @@ -114,7 +114,7 @@ public static ExprBooleanValue in(ExprValue expr, ExprValue setOfValues) { private static boolean isIn(ExprValue expr, ExprValue setOfValues) { if (expr instanceof AbstractExprNumberValue) { - return setOfValues.collectionValue().contains(expr.longValue()); + return setOfValues.collectionValue().contains(expr.value()); } else if (expr instanceof ExprStringValue) { return setOfValues.collectionValue().contains(expr.stringValue()); } else { diff --git a/integ-test/src/test/java/com/amazon/opendistroforelasticsearch/sql/legacy/CsvFormatResponseIT.java b/integ-test/src/test/java/com/amazon/opendistroforelasticsearch/sql/legacy/CsvFormatResponseIT.java index e106cdba03..30c5c3071b 100644 --- a/integ-test/src/test/java/com/amazon/opendistroforelasticsearch/sql/legacy/CsvFormatResponseIT.java +++ b/integ-test/src/test/java/com/amazon/opendistroforelasticsearch/sql/legacy/CsvFormatResponseIT.java @@ -395,6 +395,9 @@ public void aggAfterTermsGroupBy() throws Exception { @Test public void aggAfterTwoTermsGroupBy() throws Exception { + // disabling test for new engine because COUNT returns int in new engine, + // and float in the old engine + Assume.assumeFalse(isNewQueryEngineEabled()); String query = String.format(Locale.ROOT, "SELECT COUNT(*) FROM %s where age in (35,36) GROUP BY gender,age", TEST_INDEX_ACCOUNT); @@ -414,6 +417,9 @@ public void aggAfterTwoTermsGroupBy() throws Exception { @Test public void multipleAggAfterTwoTermsGroupBy() throws Exception { + // disabling test for new engine because COUNT returns int in new engine, + // and float in the old engine + Assume.assumeFalse(isNewQueryEngineEabled()); String query = String.format(Locale.ROOT, "SELECT COUNT(*) , sum(balance) FROM %s where age in (35,36) GROUP BY gender,age", TEST_INDEX_ACCOUNT); diff --git a/sql/src/main/java/com/amazon/opendistroforelasticsearch/sql/sql/parser/AstExpressionBuilder.java b/sql/src/main/java/com/amazon/opendistroforelasticsearch/sql/sql/parser/AstExpressionBuilder.java index a20c7e86ab..2fa59cfaef 100644 --- a/sql/src/main/java/com/amazon/opendistroforelasticsearch/sql/sql/parser/AstExpressionBuilder.java +++ b/sql/src/main/java/com/amazon/opendistroforelasticsearch/sql/sql/parser/AstExpressionBuilder.java @@ -16,8 +16,15 @@ package com.amazon.opendistroforelasticsearch.sql.sql.parser; -import static com.amazon.opendistroforelasticsearch.sql.ast.dsl.AstDSL.*; -import static com.amazon.opendistroforelasticsearch.sql.expression.function.BuiltinFunctionName.*; +import static com.amazon.opendistroforelasticsearch.sql.ast.dsl.AstDSL.qualifiedName; +import static com.amazon.opendistroforelasticsearch.sql.ast.dsl.AstDSL.stringLiteral; +import static com.amazon.opendistroforelasticsearch.sql.expression.function.BuiltinFunctionName.IN; +import static com.amazon.opendistroforelasticsearch.sql.expression.function.BuiltinFunctionName.IS_NOT_NULL; +import static com.amazon.opendistroforelasticsearch.sql.expression.function.BuiltinFunctionName.IS_NULL; +import static com.amazon.opendistroforelasticsearch.sql.expression.function.BuiltinFunctionName.LIKE; +import static com.amazon.opendistroforelasticsearch.sql.expression.function.BuiltinFunctionName.NOT; +import static com.amazon.opendistroforelasticsearch.sql.expression.function.BuiltinFunctionName.NOT_LIKE; +import static com.amazon.opendistroforelasticsearch.sql.expression.function.BuiltinFunctionName.REGEXP; import static com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.BinaryComparisonPredicateContext; import static com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.BooleanContext; import static com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.CaseFuncAlternativeContext; @@ -233,7 +240,7 @@ public UnresolvedExpression visitRegexpPredicate(RegexpPredicateContext ctx) { @Override public UnresolvedExpression visitInPredicate(OpenDistroSQLParser.InPredicateContext ctx) { UnresolvedExpression between = new Function(IN.getName().getFunctionName(), - Arrays.asList(visit(ctx.predicate()), AstDSL.arrayLiteral( ctx.arrayArgs().arrayArg() + Arrays.asList(visit(ctx.predicate()), AstDSL.arrayLiteral(ctx.arrayArgs().arrayArg() .stream() .map(this::visitArrayArg) .map(unresolvedExpression -> ((Literal) unresolvedExpression).getValue()) diff --git a/sql/src/test/java/com/amazon/opendistroforelasticsearch/sql/sql/parser/AstExpressionBuilderTest.java b/sql/src/test/java/com/amazon/opendistroforelasticsearch/sql/sql/parser/AstExpressionBuilderTest.java index 5df196ab66..4047f86ce1 100644 --- a/sql/src/test/java/com/amazon/opendistroforelasticsearch/sql/sql/parser/AstExpressionBuilderTest.java +++ b/sql/src/test/java/com/amazon/opendistroforelasticsearch/sql/sql/parser/AstExpressionBuilderTest.java @@ -18,6 +18,7 @@ import static com.amazon.opendistroforelasticsearch.sql.ast.dsl.AstDSL.aggregate; import static com.amazon.opendistroforelasticsearch.sql.ast.dsl.AstDSL.and; +import static com.amazon.opendistroforelasticsearch.sql.ast.dsl.AstDSL.arrayLiteral; import static com.amazon.opendistroforelasticsearch.sql.ast.dsl.AstDSL.booleanLiteral; import static com.amazon.opendistroforelasticsearch.sql.ast.dsl.AstDSL.caseWhen; import static com.amazon.opendistroforelasticsearch.sql.ast.dsl.AstDSL.dateLiteral; @@ -50,6 +51,7 @@ import com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLLexer; import com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser; import com.google.common.collect.ImmutableList; +import java.util.Arrays; import org.antlr.v4.runtime.CommonTokenStream; import org.apache.commons.lang3.tuple.ImmutablePair; import org.junit.jupiter.api.Test; @@ -286,18 +288,18 @@ public void canBuildLogicalExpression() { @Test public void canBuildInPredicate() { - assertEquals(in(intLiteral(1), intLiteral(0), intLiteral(2), intLiteral(3), intLiteral(4)), + assertEquals(function("in", intLiteral(1), arrayLiteral(Arrays.asList(0,2,3,4))), buildExprAst("1 in (0,2,3,4)")); } @Test public void canBuildNotInPredicate() { assertEquals( - function("not", in(intLiteral(1), intLiteral(0), intLiteral(2), intLiteral(3), intLiteral(4))), + function("not", + function("in", intLiteral(1),arrayLiteral(Arrays.asList(0,2,3,4)))), buildExprAst("1 not in (0,2,3,4)")); } - @Test public void canBuildWindowFunction() { assertEquals(