Skip to content
This repository has been archived by the owner on Aug 2, 2022. It is now read-only.

Adding simple In support #1114

Open
wants to merge 2 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,10 @@ public static Literal timestampLiteral(String value) {
return literal(value, DataType.TIMESTAMP);
}

public static <T> Literal arrayLiteral(List<T> value) {
return literal(value, DataType.ARRAY);
}

public static Literal doubleLiteral(Double value) {
return literal(value, DataType.DOUBLE);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ public enum DataType {
DATE(ExprCoreType.DATE),
TIME(ExprCoreType.TIME),
TIMESTAMP(ExprCoreType.TIMESTAMP),
ARRAY(ExprCoreType.ARRAY),
INTERVAL(ExprCoreType.INTERVAL);

@Getter
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@
* Params include the field expression and/or wildcard field expression,
* nested field expression (@field).
* And the values that the field is mapped to (@valueList).
*
* @deprecated use function ("in") instead
*/
@Deprecated
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's fine to remove this file if the code base has no dependency on it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

@Getter
@ToString
@EqualsAndHashCode(callSuper = false)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ public static ExprValue tupleValue(Map<String, Object> map) {
/**
* {@link ExprCollectionValue} constructor.
*/
public static ExprValue collectionValue(List<Object> list) {
public static <T> ExprValue collectionValue(List<T> list) {
List<ExprValue> valueList = new ArrayList<>();
list.forEach(o -> valueList.add(fromObjectValue(o)));
return new ExprCollectionValue(valueList);
Expand Down Expand Up @@ -129,6 +129,10 @@ public static ExprValue fromObjectValue(Object o) {
return stringValue((String) o);
} else if (o instanceof Float) {
return floatValue((Float) o);
} else if (o instanceof ExprValue) {
// since there is no primitive in Java for differentiating TIMESTAMP DATETIME and DATE
// we can allow passing a ExprValue that already contains this information
return (ExprValue) o;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any insights here? @penghuo

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll wait for a few more comment/suggestions before making changes to the rest.
Thank you @chloe-zh for the comments!

} else {
throw new ExpressionEvaluationException("unsupported object " + o.getClass());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@
import com.amazon.opendistroforelasticsearch.sql.expression.function.BuiltinFunctionName;
import com.amazon.opendistroforelasticsearch.sql.expression.function.BuiltinFunctionRepository;
import com.amazon.opendistroforelasticsearch.sql.expression.window.ranking.RankingWindowFunction;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import lombok.RequiredArgsConstructor;

@RequiredArgsConstructor
Expand Down Expand Up @@ -252,7 +254,7 @@ public FunctionExpression subtract(Expression... expressions) {
public FunctionExpression multiply(Expression... expressions) {
return function(BuiltinFunctionName.MULTIPLY, expressions);
}

public FunctionExpression adddate(Expression... expressions) {
return function(BuiltinFunctionName.ADDDATE, expressions);
}
Expand Down Expand Up @@ -364,7 +366,7 @@ public FunctionExpression module(Expression... expressions) {
public FunctionExpression substr(Expression... expressions) {
return function(BuiltinFunctionName.SUBSTR, expressions);
}

public FunctionExpression substring(Expression... expressions) {
return function(BuiltinFunctionName.SUBSTR, expressions);
}
Expand Down Expand Up @@ -588,4 +590,19 @@ public FunctionExpression castTimestamp(Expression value) {
return (FunctionExpression) repository
.compile(BuiltinFunctionName.CAST_TO_TIMESTAMP.getName(), Arrays.asList(value));
}

/**
* Check that a field is contained in a set of values.
*/
public FunctionExpression in(Expression field, Expression... expressions) {
List<Expression> where = new ArrayList<>();
where.add(field);
where.addAll(Arrays.asList(expressions));

return function(BuiltinFunctionName.IN, where.toArray(new Expression[0]));
}

public FunctionExpression not_in(Expression field, Expression... expressions) {
return not(in(field, expressions));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ public enum BuiltinFunctionName {
GTE(FunctionName.of(">=")),
LIKE(FunctionName.of("like")),
NOT_LIKE(FunctionName.of("not like")),
IN(FunctionName.of("in")),

/**
* Aggregation Function.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,16 @@
import static com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils.LITERAL_MISSING;
import static com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils.LITERAL_NULL;
import static com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils.LITERAL_TRUE;
import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.ARRAY;
import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.BOOLEAN;
import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.DATE;
import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.DATETIME;
import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.DOUBLE;
import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.FLOAT;
import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.INTEGER;
import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.LONG;
import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.STRING;
import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.TIMESTAMP;

import com.amazon.opendistroforelasticsearch.sql.data.model.ExprBooleanValue;
import com.amazon.opendistroforelasticsearch.sql.data.model.ExprValue;
Expand Down Expand Up @@ -63,6 +70,7 @@ public static void register(BuiltinFunctionRepository repository) {
repository.register(like());
repository.register(notLike());
repository.register(regexp());
repository.register(in());
}

/**
Expand Down Expand Up @@ -262,6 +270,26 @@ private static FunctionResolver notLike() {
STRING));
}

private static FunctionResolver in() {
return FunctionDSL.define(BuiltinFunctionName.IN.getName(),
FunctionDSL.impl(FunctionDSL.nullMissingHandling(OperatorUtils::in),
BOOLEAN, INTEGER, ARRAY),
FunctionDSL.impl(FunctionDSL.nullMissingHandling(OperatorUtils::in),
BOOLEAN, STRING, ARRAY),
FunctionDSL.impl(FunctionDSL.nullMissingHandling(OperatorUtils::in),
BOOLEAN, LONG, ARRAY),
FunctionDSL.impl(FunctionDSL.nullMissingHandling(OperatorUtils::in),
BOOLEAN, FLOAT, ARRAY),
FunctionDSL.impl(FunctionDSL.nullMissingHandling(OperatorUtils::in),
BOOLEAN, DOUBLE, ARRAY),
FunctionDSL.impl(FunctionDSL.nullMissingHandling(OperatorUtils::in),
BOOLEAN, DATE, ARRAY),
FunctionDSL.impl(FunctionDSL.nullMissingHandling(OperatorUtils::in),
BOOLEAN, DATETIME, ARRAY),
FunctionDSL.impl(FunctionDSL.nullMissingHandling(OperatorUtils::in),
BOOLEAN, TIMESTAMP, ARRAY));
}

private static ExprValue lookupTableFunction(ExprValue arg1, ExprValue arg2,
Table<ExprValue, ExprValue, ExprValue> table) {
if (table.contains(arg1, arg2)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@

package com.amazon.opendistroforelasticsearch.sql.utils;

import com.amazon.opendistroforelasticsearch.sql.data.model.AbstractExprNumberValue;
import com.amazon.opendistroforelasticsearch.sql.data.model.ExprBooleanValue;
import com.amazon.opendistroforelasticsearch.sql.data.model.ExprIntegerValue;
import com.amazon.opendistroforelasticsearch.sql.data.model.ExprStringValue;
import com.amazon.opendistroforelasticsearch.sql.data.model.ExprValue;
import java.util.regex.Pattern;
import lombok.experimental.UtilityClass;
Expand Down Expand Up @@ -99,4 +101,24 @@ private static String patternToRegex(String patternString) {
regex.append('$');
return regex.toString();
}


/**
* IN (..., ...) operator util.
* Expression { expr IN (collection of values..) } is to judge
* if expr is contained in a given collection.
*/
public static ExprBooleanValue in(ExprValue expr, ExprValue setOfValues) {
return ExprBooleanValue.of(isIn(expr, setOfValues));
}

private static boolean isIn(ExprValue expr, ExprValue setOfValues) {
if (expr instanceof AbstractExprNumberValue) {
return setOfValues.collectionValue().contains(expr.value());
} else if (expr instanceof ExprStringValue) {
return setOfValues.collectionValue().contains(expr.stringValue());
} else {
return setOfValues.collectionValue().contains(expr);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -215,13 +215,24 @@ public void invalidConvertExprValue(ExprValue value, Function<ExprValue, Object>
assertThat(exception.getMessage(), Matchers.containsString("invalid"));
}

// disabling test because in case of expr collections, we could pass ExprValues
// @Test
// public void unSupportedObject() {
// Exception exception = assertThrows(ExpressionEvaluationException.class,
// () -> ExprValueUtils.fromObjectValue(integerValue(1)));
// assertEquals(
// "unsupported object "
// + "class com.amazon.opendistroforelasticsearch.sql.data.model.ExprIntegerValue",
// exception.getMessage());
// }

Comment on lines +218 to +228
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove these lines?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

@Test
public void unSupportedObject() {
Exception exception = assertThrows(ExpressionEvaluationException.class,
() -> ExprValueUtils.fromObjectValue(integerValue(1)));
() -> ExprValueUtils.fromObjectValue(new Object()));
assertEquals(
"unsupported object "
+ "class com.amazon.opendistroforelasticsearch.sql.data.model.ExprIntegerValue",
+ "class java.lang.Object",
exception.getMessage());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,17 @@
import static com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils.LITERAL_TRUE;
import static com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils.booleanValue;
import static com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils.fromObjectValue;
import static com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils.missingValue;
import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.BOOLEAN;
import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.DATE;
import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.DATETIME;
import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.INTEGER;
import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.STRING;
import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.TIMESTAMP;
import static com.amazon.opendistroforelasticsearch.sql.utils.ComparisonUtil.compare;
import static com.amazon.opendistroforelasticsearch.sql.utils.OperatorUtils.matches;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;

import com.amazon.opendistroforelasticsearch.sql.data.model.ExprBooleanValue;
Expand All @@ -49,14 +52,15 @@
import com.amazon.opendistroforelasticsearch.sql.data.model.ExprTupleValue;
import com.amazon.opendistroforelasticsearch.sql.data.model.ExprValue;
import com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils;
import com.amazon.opendistroforelasticsearch.sql.exception.ExpressionEvaluationException;
import com.amazon.opendistroforelasticsearch.sql.expression.DSL;
import com.amazon.opendistroforelasticsearch.sql.expression.Expression;
import com.amazon.opendistroforelasticsearch.sql.expression.ExpressionTestBase;
import com.amazon.opendistroforelasticsearch.sql.expression.FunctionExpression;
import com.amazon.opendistroforelasticsearch.sql.utils.OperatorUtils;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import com.sun.org.apache.xpath.internal.Arg;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.ObjectInputStream;
Expand Down Expand Up @@ -832,4 +836,66 @@ public void compare_int_long() {
FunctionExpression equal = dsl.equal(DSL.literal(1), DSL.literal(1L));
assertTrue(equal.valueOf(valueEnv()).booleanValue());
}

private static Stream<Arguments> testInArguments() {
List<List> arguments =
Arrays.asList(Arrays.asList(1, Arrays.asList(0, 2, 1, 3)),
Arrays.asList(1, Arrays.asList(2, 0)), Arrays.asList(1L, Arrays.asList(1L, 2L, 3L)),
Arrays.asList(2L, Arrays.asList(1L, 2L)), Arrays.asList(3F, Arrays.asList(1F, 2F)),
Arrays.asList(0F, Arrays.asList(1F, 2F)), Arrays.asList(1D, Arrays.asList(1D, 1D)),
Arrays.asList(1D, Arrays.asList(2D, 2D)),
Arrays.asList("b", Arrays.asList("a", "c")),
Arrays.asList("b", Arrays.asList("c", "a")),
Arrays.asList("a", Arrays.asList("a", "b")),
Arrays.asList("b", Arrays.asList("a", "b")),
Arrays.asList("c", Arrays.asList("a", "b")),
Arrays.asList("a", Arrays.asList("b", "c")),
Arrays.asList("a", Arrays.asList("a", "a")),
Arrays.asList("b", Arrays.asList("a", "a")));

Stream.Builder<Arguments> builder = Stream.builder();
for (List<Object> argGroup : arguments) {
builder.add(Arguments.of(fromObjectValue(argGroup.get(0)), fromObjectValue(argGroup.get(1))));
}
builder
.add(Arguments.of(fromObjectValue("2021-01-02", DATE),
fromObjectValue(Arrays.asList(fromObjectValue("2021-01-01", DATE),
fromObjectValue("2021-01-03", DATE)))))
.add(Arguments.of(fromObjectValue("2021-01-02", DATE),
fromObjectValue(Arrays.asList(fromObjectValue("2021-01-01", DATE),
fromObjectValue("2021-01-03", DATE)))))
.add(Arguments.of(fromObjectValue("2021-01-01 03:00:00", DATETIME),
fromObjectValue(Arrays.asList(fromObjectValue("2021-01-01 01:00:00", DATETIME),
fromObjectValue("3021-01-01 02:00:00", DATETIME)))))
.add(Arguments.of(fromObjectValue("2021-01-01 01:00:00", TIMESTAMP),
fromObjectValue(Arrays.asList(fromObjectValue("2021-01-01 01:00:00", TIMESTAMP),
fromObjectValue("3021-01-01 01:00:00", TIMESTAMP)))));
return builder.build();
}

@ParameterizedTest(name = "in({0}, ({1}))")
@MethodSource("testInArguments")
public void in(ExprValue field, ExprValue arrayOfArgs) {
FunctionExpression in = dsl.in(
DSL.literal(field), DSL.literal(arrayOfArgs));
assertEquals(BOOLEAN, in.type());
assertEquals(OperatorUtils.in(field, arrayOfArgs), in.valueOf(valueEnv()));
}

@ParameterizedTest(name = "not in({0}, ({1}))")
@MethodSource("testInArguments")
public void not_in(ExprValue field, ExprValue arrayOfArgs) {
FunctionExpression notIn = dsl.not_in(
DSL.literal(field), DSL.literal(arrayOfArgs));
assertEquals(BOOLEAN, notIn.type());
assertEquals(!OperatorUtils.in(field, arrayOfArgs).booleanValue(),
notIn.valueOf(valueEnv()).booleanValue());
}

@Test
public void in_not_an_array() {
assertThrows(ExpressionEvaluationException.class, () ->
dsl.in(DSL.literal(1), DSL.literal("1")));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import com.amazon.opendistroforelasticsearch.sql.elasticsearch.storage.script.filter.lucene.RangeQuery;
import com.amazon.opendistroforelasticsearch.sql.elasticsearch.storage.script.filter.lucene.RangeQuery.Comparison;
import com.amazon.opendistroforelasticsearch.sql.elasticsearch.storage.script.filter.lucene.TermQuery;
import com.amazon.opendistroforelasticsearch.sql.elasticsearch.storage.script.filter.lucene.TermsQuery;
import com.amazon.opendistroforelasticsearch.sql.elasticsearch.storage.script.filter.lucene.WildcardQuery;
import com.amazon.opendistroforelasticsearch.sql.elasticsearch.storage.serialization.ExpressionSerializer;
import com.amazon.opendistroforelasticsearch.sql.expression.Expression;
Expand Down Expand Up @@ -63,6 +64,7 @@ public class FilterQueryBuilder extends ExpressionNodeVisitor<QueryBuilder, Obje
.put(BuiltinFunctionName.LTE.getName(), new RangeQuery(Comparison.LTE))
.put(BuiltinFunctionName.GTE.getName(), new RangeQuery(Comparison.GTE))
.put(BuiltinFunctionName.LIKE.getName(), new WildcardQuery())
.put(BuiltinFunctionName.IN.getName(), new TermsQuery())
.build();

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*
*/

package com.amazon.opendistroforelasticsearch.sql.elasticsearch.storage.script.filter.lucene;

import com.amazon.opendistroforelasticsearch.sql.data.model.ExprValue;
import com.amazon.opendistroforelasticsearch.sql.data.type.ExprType;
import java.util.stream.Collectors;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;


/**
* Lucene query that build terms query for equality comparison.
*/
public class TermsQuery extends LuceneQuery {

@Override
protected QueryBuilder doBuild(String fieldName, ExprType fieldType, ExprValue literal) {
fieldName = convertTextToKeyword(fieldName, fieldType);
return QueryBuilders.termsQuery(fieldName,
literal.collectionValue().stream().map(ExprValue::value)
.collect(Collectors.toList()));
}

}
Loading