From 08f5455ad8c70b52bbdd1eb6ee2b8a0a15cdb6a5 Mon Sep 17 00:00:00 2001 From: ZhongYujiang <42907416+zhongyujiang@users.noreply.github.com> Date: Fri, 17 Mar 2023 23:38:20 +0800 Subject: [PATCH] Core, Spark: Fix delete with filter on nested columns. --- .../expressions/StrictMetricsEvaluator.java | 114 ++++++++++++------ .../TestStrictMetricsEvaluator.java | 49 +++++++- .../iceberg/spark/extensions/TestDelete.java | 16 +++ 3 files changed, 138 insertions(+), 41 deletions(-) diff --git a/api/src/main/java/org/apache/iceberg/expressions/StrictMetricsEvaluator.java b/api/src/main/java/org/apache/iceberg/expressions/StrictMetricsEvaluator.java index 4aee75c447d3..0565419a797d 100644 --- a/api/src/main/java/org/apache/iceberg/expressions/StrictMetricsEvaluator.java +++ b/api/src/main/java/org/apache/iceberg/expressions/StrictMetricsEvaluator.java @@ -29,10 +29,10 @@ import org.apache.iceberg.DataFile; import org.apache.iceberg.Schema; import org.apache.iceberg.expressions.ExpressionVisitors.BoundExpressionVisitor; -import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; import org.apache.iceberg.types.Conversions; -import org.apache.iceberg.types.Types; -import org.apache.iceberg.types.Types.StructType; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.TypeUtil; import org.apache.iceberg.util.NaNUtil; /** @@ -52,8 +52,9 @@ */ public class StrictMetricsEvaluator { private final Schema schema; - private final StructType struct; private final Expression expr; + private final Map idToParent; + private final Map canEvaluate = Maps.newHashMap(); public StrictMetricsEvaluator(Schema schema, Expression unbound) { this(schema, unbound, true); @@ -61,8 +62,8 @@ public StrictMetricsEvaluator(Schema schema, Expression unbound) { public StrictMetricsEvaluator(Schema schema, Expression unbound, boolean caseSensitive) { this.schema = schema; - this.struct = schema.asStruct(); - this.expr = Binder.bind(struct, rewriteNot(unbound), caseSensitive); + this.expr = Binder.bind(schema.asStruct(), rewriteNot(unbound), caseSensitive); + this.idToParent = TypeUtil.indexParents(schema.asStruct()); } /** @@ -144,8 +145,9 @@ public Boolean isNull(BoundReference ref) { // no need to check whether the field is required because binding evaluates that case // if the column has any non-null values, the expression does not match int id = ref.fieldId(); - Preconditions.checkNotNull( - struct.field(id), "Cannot filter by nested column: %s", schema.findField(id)); + if (!supportsEvaluate(id)) { + return ROWS_MIGHT_NOT_MATCH; + } if (containsNullsOnly(id)) { return ROWS_MUST_MATCH; @@ -159,8 +161,9 @@ public Boolean notNull(BoundReference ref) { // no need to check whether the field is required because binding evaluates that case // if the column has any null values, the expression does not match int id = ref.fieldId(); - Preconditions.checkNotNull( - struct.field(id), "Cannot filter by nested column: %s", schema.findField(id)); + if (!supportsEvaluate(id)) { + return ROWS_MIGHT_NOT_MATCH; + } if (nullCounts != null && nullCounts.containsKey(id) && nullCounts.get(id) == 0) { return ROWS_MUST_MATCH; @@ -199,15 +202,16 @@ public Boolean notNaN(BoundReference ref) { public Boolean lt(BoundReference ref, Literal lit) { // Rows must match when: <----------Min----Max---X-------> Integer id = ref.fieldId(); - Types.NestedField field = struct.field(id); - Preconditions.checkNotNull(field, "Cannot filter by nested column: %s", schema.findField(id)); + if (!supportsEvaluate(id)) { + return ROWS_MIGHT_NOT_MATCH; + } if (canContainNulls(id) || canContainNaNs(id)) { return ROWS_MIGHT_NOT_MATCH; } if (upperBounds != null && upperBounds.containsKey(id)) { - T upper = Conversions.fromByteBuffer(field.type(), upperBounds.get(id)); + T upper = Conversions.fromByteBuffer(ref.type(), upperBounds.get(id)); int cmp = lit.comparator().compare(upper, lit.value()); if (cmp < 0) { @@ -222,15 +226,16 @@ public Boolean lt(BoundReference ref, Literal lit) { public Boolean ltEq(BoundReference ref, Literal lit) { // Rows must match when: <----------Min----Max---X-------> Integer id = ref.fieldId(); - Types.NestedField field = struct.field(id); - Preconditions.checkNotNull(field, "Cannot filter by nested column: %s", schema.findField(id)); + if (!supportsEvaluate(id)) { + return ROWS_MIGHT_NOT_MATCH; + } if (canContainNulls(id) || canContainNaNs(id)) { return ROWS_MIGHT_NOT_MATCH; } if (upperBounds != null && upperBounds.containsKey(id)) { - T upper = Conversions.fromByteBuffer(field.type(), upperBounds.get(id)); + T upper = Conversions.fromByteBuffer(ref.type(), upperBounds.get(id)); int cmp = lit.comparator().compare(upper, lit.value()); if (cmp <= 0) { @@ -245,15 +250,16 @@ public Boolean ltEq(BoundReference ref, Literal lit) { public Boolean gt(BoundReference ref, Literal lit) { // Rows must match when: <-------X---Min----Max----------> Integer id = ref.fieldId(); - Types.NestedField field = struct.field(id); - Preconditions.checkNotNull(field, "Cannot filter by nested column: %s", schema.findField(id)); + if (!supportsEvaluate(id)) { + return ROWS_MIGHT_NOT_MATCH; + } if (canContainNulls(id) || canContainNaNs(id)) { return ROWS_MIGHT_NOT_MATCH; } if (lowerBounds != null && lowerBounds.containsKey(id)) { - T lower = Conversions.fromByteBuffer(field.type(), lowerBounds.get(id)); + T lower = Conversions.fromByteBuffer(ref.type(), lowerBounds.get(id)); if (NaNUtil.isNaN(lower)) { // NaN indicates unreliable bounds. See the StrictMetricsEvaluator docs for more. @@ -273,15 +279,16 @@ public Boolean gt(BoundReference ref, Literal lit) { public Boolean gtEq(BoundReference ref, Literal lit) { // Rows must match when: <-------X---Min----Max----------> Integer id = ref.fieldId(); - Types.NestedField field = struct.field(id); - Preconditions.checkNotNull(field, "Cannot filter by nested column: %s", schema.findField(id)); + if (!supportsEvaluate(id)) { + return ROWS_MIGHT_NOT_MATCH; + } if (canContainNulls(id) || canContainNaNs(id)) { return ROWS_MIGHT_NOT_MATCH; } if (lowerBounds != null && lowerBounds.containsKey(id)) { - T lower = Conversions.fromByteBuffer(field.type(), lowerBounds.get(id)); + T lower = Conversions.fromByteBuffer(ref.type(), lowerBounds.get(id)); if (NaNUtil.isNaN(lower)) { // NaN indicates unreliable bounds. See the StrictMetricsEvaluator docs for more. @@ -301,8 +308,9 @@ public Boolean gtEq(BoundReference ref, Literal lit) { public Boolean eq(BoundReference ref, Literal lit) { // Rows must match when Min == X == Max Integer id = ref.fieldId(); - Types.NestedField field = struct.field(id); - Preconditions.checkNotNull(field, "Cannot filter by nested column: %s", schema.findField(id)); + if (!supportsEvaluate(id)) { + return ROWS_MIGHT_NOT_MATCH; + } if (canContainNulls(id) || canContainNaNs(id)) { return ROWS_MIGHT_NOT_MATCH; @@ -312,14 +320,14 @@ public Boolean eq(BoundReference ref, Literal lit) { && lowerBounds.containsKey(id) && upperBounds != null && upperBounds.containsKey(id)) { - T lower = Conversions.fromByteBuffer(struct.field(id).type(), lowerBounds.get(id)); + T lower = Conversions.fromByteBuffer(ref.type(), lowerBounds.get(id)); int cmp = lit.comparator().compare(lower, lit.value()); if (cmp != 0) { return ROWS_MIGHT_NOT_MATCH; } - T upper = Conversions.fromByteBuffer(field.type(), upperBounds.get(id)); + T upper = Conversions.fromByteBuffer(ref.type(), upperBounds.get(id)); cmp = lit.comparator().compare(upper, lit.value()); if (cmp != 0) { @@ -336,15 +344,16 @@ public Boolean eq(BoundReference ref, Literal lit) { public Boolean notEq(BoundReference ref, Literal lit) { // Rows must match when X < Min or Max < X because it is not in the range Integer id = ref.fieldId(); - Types.NestedField field = struct.field(id); - Preconditions.checkNotNull(field, "Cannot filter by nested column: %s", schema.findField(id)); + if (!supportsEvaluate(id)) { + return ROWS_MIGHT_NOT_MATCH; + } if (containsNullsOnly(id) || containsNaNsOnly(id)) { return ROWS_MUST_MATCH; } if (lowerBounds != null && lowerBounds.containsKey(id)) { - T lower = Conversions.fromByteBuffer(struct.field(id).type(), lowerBounds.get(id)); + T lower = Conversions.fromByteBuffer(ref.type(), lowerBounds.get(id)); if (NaNUtil.isNaN(lower)) { // NaN indicates unreliable bounds. See the StrictMetricsEvaluator docs for more. @@ -358,7 +367,7 @@ public Boolean notEq(BoundReference ref, Literal lit) { } if (upperBounds != null && upperBounds.containsKey(id)) { - T upper = Conversions.fromByteBuffer(field.type(), upperBounds.get(id)); + T upper = Conversions.fromByteBuffer(ref.type(), upperBounds.get(id)); int cmp = lit.comparator().compare(upper, lit.value()); if (cmp < 0) { @@ -372,8 +381,9 @@ public Boolean notEq(BoundReference ref, Literal lit) { @Override public Boolean in(BoundReference ref, Set literalSet) { Integer id = ref.fieldId(); - Types.NestedField field = struct.field(id); - Preconditions.checkNotNull(field, "Cannot filter by nested column: %s", schema.findField(id)); + if (!supportsEvaluate(id)) { + return ROWS_MIGHT_NOT_MATCH; + } if (canContainNulls(id) || canContainNaNs(id)) { return ROWS_MIGHT_NOT_MATCH; @@ -384,13 +394,13 @@ public Boolean in(BoundReference ref, Set literalSet) { && upperBounds != null && upperBounds.containsKey(id)) { // similar to the implementation in eq, first check if the lower bound is in the set - T lower = Conversions.fromByteBuffer(struct.field(id).type(), lowerBounds.get(id)); + T lower = Conversions.fromByteBuffer(ref.type(), lowerBounds.get(id)); if (!literalSet.contains(lower)) { return ROWS_MIGHT_NOT_MATCH; } // check if the upper bound is in the set - T upper = Conversions.fromByteBuffer(field.type(), upperBounds.get(id)); + T upper = Conversions.fromByteBuffer(ref.type(), upperBounds.get(id)); if (!literalSet.contains(upper)) { return ROWS_MIGHT_NOT_MATCH; } @@ -411,8 +421,9 @@ public Boolean in(BoundReference ref, Set literalSet) { @Override public Boolean notIn(BoundReference ref, Set literalSet) { Integer id = ref.fieldId(); - Types.NestedField field = struct.field(id); - Preconditions.checkNotNull(field, "Cannot filter by nested column: %s", schema.findField(id)); + if (!supportsEvaluate(id)) { + return ROWS_MIGHT_NOT_MATCH; + } if (containsNullsOnly(id) || containsNaNsOnly(id)) { return ROWS_MUST_MATCH; @@ -421,7 +432,7 @@ public Boolean notIn(BoundReference ref, Set literalSet) { Collection literals = literalSet; if (lowerBounds != null && lowerBounds.containsKey(id)) { - T lower = Conversions.fromByteBuffer(struct.field(id).type(), lowerBounds.get(id)); + T lower = Conversions.fromByteBuffer(ref.type(), lowerBounds.get(id)); if (NaNUtil.isNaN(lower)) { // NaN indicates unreliable bounds. See the StrictMetricsEvaluator docs for more. @@ -439,7 +450,7 @@ public Boolean notIn(BoundReference ref, Set literalSet) { } if (upperBounds != null && upperBounds.containsKey(id)) { - T upper = Conversions.fromByteBuffer(field.type(), upperBounds.get(id)); + T upper = Conversions.fromByteBuffer(ref.type(), upperBounds.get(id)); literals = literals.stream() .filter(v -> ref.comparator().compare(upper, v) >= 0) @@ -489,5 +500,32 @@ private boolean containsNaNsOnly(Integer id) { && valueCounts != null && nanCounts.get(id).equals(valueCounts.get(id)); } + + private boolean supportsEvaluate(int fieldId) { + Boolean evaluable = canEvaluate.get(fieldId); + if (evaluable != null) { + return evaluable; + } + + evaluable = true; + // Cannot evaluate on complex types or repeated primitive types. + if (!schema.findType(fieldId).isPrimitiveType()) { + evaluable = false; + } else { + Integer parent = idToParent.get(fieldId); + while (parent != null) { + Type type = schema.findType(parent); + if (type.isListType() || type.isMapType()) { + evaluable = false; + break; + } + + parent = idToParent.get(parent); + } + } + + canEvaluate.put(fieldId, evaluable); + return evaluable; + } } } diff --git a/api/src/test/java/org/apache/iceberg/expressions/TestStrictMetricsEvaluator.java b/api/src/test/java/org/apache/iceberg/expressions/TestStrictMetricsEvaluator.java index 790d429f1d7a..2541a9188524 100644 --- a/api/src/test/java/org/apache/iceberg/expressions/TestStrictMetricsEvaluator.java +++ b/api/src/test/java/org/apache/iceberg/expressions/TestStrictMetricsEvaluator.java @@ -66,7 +66,18 @@ public class TestStrictMetricsEvaluator { optional(11, "all_nulls_double", Types.DoubleType.get()), optional(12, "all_nans_v1_stats", Types.FloatType.get()), optional(13, "nan_and_null_only", Types.DoubleType.get()), - optional(14, "no_nan_stats", Types.DoubleType.get())); + optional(14, "no_nan_stats", Types.DoubleType.get()), + optional( + 15, + "struct", + Types.StructType.of(Types.NestedField.required(16, "c1", Types.IntegerType.get()))), + optional( + 17, + "list", + Types.ListType.ofRequired( + 18, + Types.StructType.of( + Types.NestedField.required(19, "c2", Types.IntegerType.get()))))); private static final int INT_MIN_VALUE = 30; private static final int INT_MAX_VALUE = 79; @@ -88,6 +99,7 @@ public class TestStrictMetricsEvaluator { .put(12, 50L) .put(13, 50L) .put(14, 50L) + .put(16, 50L) .buildOrThrow(), // null value counts ImmutableMap.builder() @@ -97,6 +109,7 @@ public class TestStrictMetricsEvaluator { .put(11, 50L) .put(12, 0L) .put(13, 1L) + .put(16, 0L) .buildOrThrow(), // nan value counts ImmutableMap.of( @@ -108,13 +121,15 @@ public class TestStrictMetricsEvaluator { 1, toByteBuffer(IntegerType.get(), INT_MIN_VALUE), 7, toByteBuffer(IntegerType.get(), 5), 12, toByteBuffer(Types.FloatType.get(), Float.NaN), - 13, toByteBuffer(Types.DoubleType.get(), Double.NaN)), + 13, toByteBuffer(Types.DoubleType.get(), Double.NaN), + 16, toByteBuffer(Types.IntegerType.get(), INT_MIN_VALUE)), // upper bounds ImmutableMap.of( 1, toByteBuffer(IntegerType.get(), INT_MAX_VALUE), 7, toByteBuffer(IntegerType.get(), 5), 12, toByteBuffer(Types.FloatType.get(), Float.NaN), - 13, toByteBuffer(Types.DoubleType.get(), Double.NaN))); + 13, toByteBuffer(Types.DoubleType.get(), Double.NaN), + 16, toByteBuffer(IntegerType.get(), INT_MAX_VALUE))); private static final DataFile FILE_2 = new TestDataFile( @@ -617,4 +632,32 @@ public void testIntegerNotIn() { shouldRead = new StrictMetricsEvaluator(SCHEMA, notIn("no_nulls", "abc", "def")).eval(FILE); Assert.assertFalse("Should not match: no_nulls field does not have bounds", shouldRead); } + + @Test + public void testEvaluateOnNestedColumns() { + boolean shouldRead = + new StrictMetricsEvaluator(SCHEMA, greaterThan("struct.c1", INT_MAX_VALUE)).eval(FILE); + Assert.assertFalse("Should not match: always false", shouldRead); + + shouldRead = + new StrictMetricsEvaluator(SCHEMA, lessThanOrEqual("struct.c1", INT_MAX_VALUE)).eval(FILE); + Assert.assertTrue("Should match: always true", shouldRead); + + shouldRead = + new StrictMetricsEvaluator(SCHEMA, greaterThan("struct.c1", INT_MAX_VALUE)).eval(FILE_2); + Assert.assertFalse("Should never match when stats are missing", shouldRead); + + shouldRead = + new StrictMetricsEvaluator(SCHEMA, lessThanOrEqual("struct.c1", INT_MAX_VALUE)) + .eval(FILE_2); + Assert.assertFalse("Should never match when stats are missing", shouldRead); + + shouldRead = + new StrictMetricsEvaluator(SCHEMA, lessThanOrEqual("list.element.c2", INT_MAX_VALUE)) + .eval(FILE); + Assert.assertFalse("Should never match for repeated fields", shouldRead); + + shouldRead = new StrictMetricsEvaluator(SCHEMA, isNull("list")).eval(FILE); + Assert.assertFalse("Should never match for complex fields", shouldRead); + } } diff --git a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestDelete.java b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestDelete.java index 0b73821c617d..d081454c42d5 100644 --- a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestDelete.java +++ b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestDelete.java @@ -510,6 +510,22 @@ public void testDeleteWithConditionOnNestedColumn() { "Should have expected rows", ImmutableList.of(), sql("SELECT id FROM %s", tableName)); } + @Test + public void testDeleteWithFilterOnNestedColumn() { + createAndInitNestedColumnsTable(); + + sql("INSERT INTO TABLE %s VALUES (1, named_struct(\"c1\", 3, \"c2\", \"v1\"))", tableName); + sql("INSERT INTO TABLE %s VALUES (2, named_struct(\"c1\", 2, \"c2\", \"v2\"))", tableName); + + sql("DELETE FROM %s WHERE complex.c1 = 3", tableName); + assertEquals( + "Should have expected rows", ImmutableList.of(row(2)), sql("SELECT id FROM %s", tableName)); + + sql("DELETE FROM %s t WHERE t.complex.c1 = 2", tableName); + assertEquals( + "Should have expected rows", ImmutableList.of(), sql("SELECT id FROM %s", tableName)); + } + @Test public void testDeleteWithInSubquery() throws NoSuchTableException { createAndInitUnpartitionedTable();