Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Core, Spark: Fix delete with filter on nested columns #7132

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@
import org.apache.iceberg.DataFile;
import org.apache.iceberg.Schema;
import org.apache.iceberg.expressions.ExpressionVisitors.BoundExpressionVisitor;
import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
import org.apache.iceberg.relocated.com.google.common.collect.Maps;
import org.apache.iceberg.types.Conversions;
import org.apache.iceberg.types.Types;
import org.apache.iceberg.types.Types.StructType;
import org.apache.iceberg.types.Type;
import org.apache.iceberg.types.TypeUtil;
import org.apache.iceberg.util.NaNUtil;

/**
Expand All @@ -52,17 +52,18 @@
*/
public class StrictMetricsEvaluator {
private final Schema schema;
private final StructType struct;
private final Expression expr;
private final Map<Integer, Integer> idToParent;
private final Map<Integer, Boolean> canEvaluate = Maps.newHashMap();

public StrictMetricsEvaluator(Schema schema, Expression unbound) {
this(schema, unbound, true);
}

public StrictMetricsEvaluator(Schema schema, Expression unbound, boolean caseSensitive) {
this.schema = schema;
this.struct = schema.asStruct();
this.expr = Binder.bind(struct, rewriteNot(unbound), caseSensitive);
this.expr = Binder.bind(schema.asStruct(), rewriteNot(unbound), caseSensitive);
this.idToParent = TypeUtil.indexParents(schema.asStruct());
}

/**
Expand Down Expand Up @@ -144,8 +145,9 @@ public <T> Boolean isNull(BoundReference<T> ref) {
// no need to check whether the field is required because binding evaluates that case
// if the column has any non-null values, the expression does not match
int id = ref.fieldId();
Preconditions.checkNotNull(
struct.field(id), "Cannot filter by nested column: %s", schema.findField(id));
if (!supportsEvaluate(id)) {
return ROWS_MIGHT_NOT_MATCH;
}

if (containsNullsOnly(id)) {
return ROWS_MUST_MATCH;
Expand All @@ -159,8 +161,9 @@ public <T> Boolean notNull(BoundReference<T> ref) {
// no need to check whether the field is required because binding evaluates that case
// if the column has any null values, the expression does not match
int id = ref.fieldId();
Preconditions.checkNotNull(
struct.field(id), "Cannot filter by nested column: %s", schema.findField(id));
if (!supportsEvaluate(id)) {
return ROWS_MIGHT_NOT_MATCH;
}

if (nullCounts != null && nullCounts.containsKey(id) && nullCounts.get(id) == 0) {
return ROWS_MUST_MATCH;
Expand Down Expand Up @@ -199,15 +202,16 @@ public <T> Boolean notNaN(BoundReference<T> ref) {
public <T> Boolean lt(BoundReference<T> ref, Literal<T> lit) {
// Rows must match when: <----------Min----Max---X------->
Integer id = ref.fieldId();
Types.NestedField field = struct.field(id);
Preconditions.checkNotNull(field, "Cannot filter by nested column: %s", schema.findField(id));
if (!supportsEvaluate(id)) {
return ROWS_MIGHT_NOT_MATCH;
}

if (canContainNulls(id) || canContainNaNs(id)) {
return ROWS_MIGHT_NOT_MATCH;
}

if (upperBounds != null && upperBounds.containsKey(id)) {
T upper = Conversions.fromByteBuffer(field.type(), upperBounds.get(id));
T upper = Conversions.fromByteBuffer(ref.type(), upperBounds.get(id));

int cmp = lit.comparator().compare(upper, lit.value());
if (cmp < 0) {
Expand All @@ -222,15 +226,16 @@ public <T> Boolean lt(BoundReference<T> ref, Literal<T> lit) {
public <T> Boolean ltEq(BoundReference<T> ref, Literal<T> lit) {
// Rows must match when: <----------Min----Max---X------->
Integer id = ref.fieldId();
Types.NestedField field = struct.field(id);
Preconditions.checkNotNull(field, "Cannot filter by nested column: %s", schema.findField(id));
if (!supportsEvaluate(id)) {
return ROWS_MIGHT_NOT_MATCH;
}

if (canContainNulls(id) || canContainNaNs(id)) {
return ROWS_MIGHT_NOT_MATCH;
}

if (upperBounds != null && upperBounds.containsKey(id)) {
T upper = Conversions.fromByteBuffer(field.type(), upperBounds.get(id));
T upper = Conversions.fromByteBuffer(ref.type(), upperBounds.get(id));

int cmp = lit.comparator().compare(upper, lit.value());
if (cmp <= 0) {
Expand All @@ -245,15 +250,16 @@ public <T> Boolean ltEq(BoundReference<T> ref, Literal<T> lit) {
public <T> Boolean gt(BoundReference<T> ref, Literal<T> lit) {
// Rows must match when: <-------X---Min----Max---------->
Integer id = ref.fieldId();
Types.NestedField field = struct.field(id);
Preconditions.checkNotNull(field, "Cannot filter by nested column: %s", schema.findField(id));
if (!supportsEvaluate(id)) {
return ROWS_MIGHT_NOT_MATCH;
}

if (canContainNulls(id) || canContainNaNs(id)) {
return ROWS_MIGHT_NOT_MATCH;
}

if (lowerBounds != null && lowerBounds.containsKey(id)) {
T lower = Conversions.fromByteBuffer(field.type(), lowerBounds.get(id));
T lower = Conversions.fromByteBuffer(ref.type(), lowerBounds.get(id));

if (NaNUtil.isNaN(lower)) {
// NaN indicates unreliable bounds. See the StrictMetricsEvaluator docs for more.
Expand All @@ -273,15 +279,16 @@ public <T> Boolean gt(BoundReference<T> ref, Literal<T> lit) {
public <T> Boolean gtEq(BoundReference<T> ref, Literal<T> lit) {
// Rows must match when: <-------X---Min----Max---------->
Integer id = ref.fieldId();
Types.NestedField field = struct.field(id);
Preconditions.checkNotNull(field, "Cannot filter by nested column: %s", schema.findField(id));
if (!supportsEvaluate(id)) {
return ROWS_MIGHT_NOT_MATCH;
}

if (canContainNulls(id) || canContainNaNs(id)) {
return ROWS_MIGHT_NOT_MATCH;
}

if (lowerBounds != null && lowerBounds.containsKey(id)) {
T lower = Conversions.fromByteBuffer(field.type(), lowerBounds.get(id));
T lower = Conversions.fromByteBuffer(ref.type(), lowerBounds.get(id));

if (NaNUtil.isNaN(lower)) {
// NaN indicates unreliable bounds. See the StrictMetricsEvaluator docs for more.
Expand All @@ -301,8 +308,9 @@ public <T> Boolean gtEq(BoundReference<T> ref, Literal<T> lit) {
public <T> Boolean eq(BoundReference<T> ref, Literal<T> lit) {
// Rows must match when Min == X == Max
Integer id = ref.fieldId();
Types.NestedField field = struct.field(id);
Preconditions.checkNotNull(field, "Cannot filter by nested column: %s", schema.findField(id));
if (!supportsEvaluate(id)) {
return ROWS_MIGHT_NOT_MATCH;
}

if (canContainNulls(id) || canContainNaNs(id)) {
return ROWS_MIGHT_NOT_MATCH;
Expand All @@ -312,14 +320,14 @@ public <T> Boolean eq(BoundReference<T> ref, Literal<T> lit) {
&& lowerBounds.containsKey(id)
&& upperBounds != null
&& upperBounds.containsKey(id)) {
T lower = Conversions.fromByteBuffer(struct.field(id).type(), lowerBounds.get(id));
T lower = Conversions.fromByteBuffer(ref.type(), lowerBounds.get(id));

int cmp = lit.comparator().compare(lower, lit.value());
if (cmp != 0) {
return ROWS_MIGHT_NOT_MATCH;
}

T upper = Conversions.fromByteBuffer(field.type(), upperBounds.get(id));
T upper = Conversions.fromByteBuffer(ref.type(), upperBounds.get(id));

cmp = lit.comparator().compare(upper, lit.value());
if (cmp != 0) {
Expand All @@ -336,15 +344,16 @@ public <T> Boolean eq(BoundReference<T> ref, Literal<T> lit) {
public <T> Boolean notEq(BoundReference<T> ref, Literal<T> lit) {
// Rows must match when X < Min or Max < X because it is not in the range
Integer id = ref.fieldId();
Types.NestedField field = struct.field(id);
Preconditions.checkNotNull(field, "Cannot filter by nested column: %s", schema.findField(id));
if (!supportsEvaluate(id)) {
return ROWS_MIGHT_NOT_MATCH;
}

if (containsNullsOnly(id) || containsNaNsOnly(id)) {
return ROWS_MUST_MATCH;
}

if (lowerBounds != null && lowerBounds.containsKey(id)) {
T lower = Conversions.fromByteBuffer(struct.field(id).type(), lowerBounds.get(id));
T lower = Conversions.fromByteBuffer(ref.type(), lowerBounds.get(id));

if (NaNUtil.isNaN(lower)) {
// NaN indicates unreliable bounds. See the StrictMetricsEvaluator docs for more.
Expand All @@ -358,7 +367,7 @@ public <T> Boolean notEq(BoundReference<T> ref, Literal<T> lit) {
}

if (upperBounds != null && upperBounds.containsKey(id)) {
T upper = Conversions.fromByteBuffer(field.type(), upperBounds.get(id));
T upper = Conversions.fromByteBuffer(ref.type(), upperBounds.get(id));

int cmp = lit.comparator().compare(upper, lit.value());
if (cmp < 0) {
Expand All @@ -372,8 +381,9 @@ public <T> Boolean notEq(BoundReference<T> ref, Literal<T> lit) {
@Override
public <T> Boolean in(BoundReference<T> ref, Set<T> literalSet) {
Integer id = ref.fieldId();
Types.NestedField field = struct.field(id);
Preconditions.checkNotNull(field, "Cannot filter by nested column: %s", schema.findField(id));
if (!supportsEvaluate(id)) {
return ROWS_MIGHT_NOT_MATCH;
}

if (canContainNulls(id) || canContainNaNs(id)) {
return ROWS_MIGHT_NOT_MATCH;
Expand All @@ -384,13 +394,13 @@ public <T> Boolean in(BoundReference<T> ref, Set<T> literalSet) {
&& upperBounds != null
&& upperBounds.containsKey(id)) {
// similar to the implementation in eq, first check if the lower bound is in the set
T lower = Conversions.fromByteBuffer(struct.field(id).type(), lowerBounds.get(id));
T lower = Conversions.fromByteBuffer(ref.type(), lowerBounds.get(id));
if (!literalSet.contains(lower)) {
return ROWS_MIGHT_NOT_MATCH;
}

// check if the upper bound is in the set
T upper = Conversions.fromByteBuffer(field.type(), upperBounds.get(id));
T upper = Conversions.fromByteBuffer(ref.type(), upperBounds.get(id));
if (!literalSet.contains(upper)) {
return ROWS_MIGHT_NOT_MATCH;
}
Expand All @@ -411,8 +421,9 @@ public <T> Boolean in(BoundReference<T> ref, Set<T> literalSet) {
@Override
public <T> Boolean notIn(BoundReference<T> ref, Set<T> literalSet) {
Integer id = ref.fieldId();
Types.NestedField field = struct.field(id);
Preconditions.checkNotNull(field, "Cannot filter by nested column: %s", schema.findField(id));
if (!supportsEvaluate(id)) {
return ROWS_MIGHT_NOT_MATCH;
}

if (containsNullsOnly(id) || containsNaNsOnly(id)) {
return ROWS_MUST_MATCH;
Expand All @@ -421,7 +432,7 @@ public <T> Boolean notIn(BoundReference<T> ref, Set<T> literalSet) {
Collection<T> literals = literalSet;

if (lowerBounds != null && lowerBounds.containsKey(id)) {
T lower = Conversions.fromByteBuffer(struct.field(id).type(), lowerBounds.get(id));
T lower = Conversions.fromByteBuffer(ref.type(), lowerBounds.get(id));

if (NaNUtil.isNaN(lower)) {
// NaN indicates unreliable bounds. See the StrictMetricsEvaluator docs for more.
Expand All @@ -439,7 +450,7 @@ public <T> Boolean notIn(BoundReference<T> ref, Set<T> literalSet) {
}

if (upperBounds != null && upperBounds.containsKey(id)) {
T upper = Conversions.fromByteBuffer(field.type(), upperBounds.get(id));
T upper = Conversions.fromByteBuffer(ref.type(), upperBounds.get(id));
literals =
literals.stream()
.filter(v -> ref.comparator().compare(upper, v) >= 0)
Expand Down Expand Up @@ -489,5 +500,32 @@ private boolean containsNaNsOnly(Integer id) {
&& valueCounts != null
&& nanCounts.get(id).equals(valueCounts.get(id));
}

private boolean supportsEvaluate(int fieldId) {
Boolean evaluable = canEvaluate.get(fieldId);
if (evaluable != null) {
return evaluable;
}

evaluable = true;
// Cannot evaluate on complex types or repeated primitive types.
if (!schema.findType(fieldId).isPrimitiveType()) {
evaluable = false;
} else {
Integer parent = idToParent.get(fieldId);
while (parent != null) {
Type type = schema.findType(parent);
if (type.isListType() || type.isMapType()) {
evaluable = false;
Copy link
Collaborator

@szehon-ho szehon-ho Mar 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So you are skipping if list/map type, but allowing struct, if I understand? I think it makes sense to me, as I feel we have nested column stats. but definitely like @rdblue @RussellSpitzer @aokolnychyi to have a sanity check here on the overall direction.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, your understanding is correct.

break;
}

parent = idToParent.get(parent);
}
}

canEvaluate.put(fieldId, evaluable);
return evaluable;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,18 @@ public class TestStrictMetricsEvaluator {
optional(11, "all_nulls_double", Types.DoubleType.get()),
optional(12, "all_nans_v1_stats", Types.FloatType.get()),
optional(13, "nan_and_null_only", Types.DoubleType.get()),
optional(14, "no_nan_stats", Types.DoubleType.get()));
optional(14, "no_nan_stats", Types.DoubleType.get()),
optional(
15,
"struct",
Types.StructType.of(Types.NestedField.required(16, "c1", Types.IntegerType.get()))),
optional(
17,
"list",
Types.ListType.ofRequired(
18,
Types.StructType.of(
Types.NestedField.required(19, "c2", Types.IntegerType.get())))));

private static final int INT_MIN_VALUE = 30;
private static final int INT_MAX_VALUE = 79;
Expand All @@ -88,6 +99,7 @@ public class TestStrictMetricsEvaluator {
.put(12, 50L)
.put(13, 50L)
.put(14, 50L)
.put(16, 50L)
.buildOrThrow(),
// null value counts
ImmutableMap.<Integer, Long>builder()
Expand All @@ -97,6 +109,7 @@ public class TestStrictMetricsEvaluator {
.put(11, 50L)
.put(12, 0L)
.put(13, 1L)
.put(16, 0L)
.buildOrThrow(),
// nan value counts
ImmutableMap.of(
Expand All @@ -108,13 +121,15 @@ public class TestStrictMetricsEvaluator {
1, toByteBuffer(IntegerType.get(), INT_MIN_VALUE),
7, toByteBuffer(IntegerType.get(), 5),
12, toByteBuffer(Types.FloatType.get(), Float.NaN),
13, toByteBuffer(Types.DoubleType.get(), Double.NaN)),
13, toByteBuffer(Types.DoubleType.get(), Double.NaN),
16, toByteBuffer(Types.IntegerType.get(), INT_MIN_VALUE)),
// upper bounds
ImmutableMap.of(
1, toByteBuffer(IntegerType.get(), INT_MAX_VALUE),
7, toByteBuffer(IntegerType.get(), 5),
12, toByteBuffer(Types.FloatType.get(), Float.NaN),
13, toByteBuffer(Types.DoubleType.get(), Double.NaN)));
13, toByteBuffer(Types.DoubleType.get(), Double.NaN),
16, toByteBuffer(IntegerType.get(), INT_MAX_VALUE)));

private static final DataFile FILE_2 =
new TestDataFile(
Expand Down Expand Up @@ -617,4 +632,32 @@ public void testIntegerNotIn() {
shouldRead = new StrictMetricsEvaluator(SCHEMA, notIn("no_nulls", "abc", "def")).eval(FILE);
Assert.assertFalse("Should not match: no_nulls field does not have bounds", shouldRead);
}

@Test
public void testEvaluateOnNestedColumns() {
boolean shouldRead =
new StrictMetricsEvaluator(SCHEMA, greaterThan("struct.c1", INT_MAX_VALUE)).eval(FILE);
Assert.assertFalse("Should not match: always false", shouldRead);

shouldRead =
new StrictMetricsEvaluator(SCHEMA, lessThanOrEqual("struct.c1", INT_MAX_VALUE)).eval(FILE);
Assert.assertTrue("Should match: always true", shouldRead);

shouldRead =
new StrictMetricsEvaluator(SCHEMA, greaterThan("struct.c1", INT_MAX_VALUE)).eval(FILE_2);
Assert.assertFalse("Should never match when stats are missing", shouldRead);

shouldRead =
new StrictMetricsEvaluator(SCHEMA, lessThanOrEqual("struct.c1", INT_MAX_VALUE))
.eval(FILE_2);
Assert.assertFalse("Should never match when stats are missing", shouldRead);

shouldRead =
new StrictMetricsEvaluator(SCHEMA, lessThanOrEqual("list.element.c2", INT_MAX_VALUE))
.eval(FILE);
Assert.assertFalse("Should never match for repeated fields", shouldRead);

shouldRead = new StrictMetricsEvaluator(SCHEMA, isNull("list")).eval(FILE);
Assert.assertFalse("Should never match for complex fields", shouldRead);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,22 @@ public void testDeleteWithConditionOnNestedColumn() {
"Should have expected rows", ImmutableList.of(), sql("SELECT id FROM %s", tableName));
}

@Test
public void testDeleteWithFilterOnNestedColumn() {
createAndInitNestedColumnsTable();

sql("INSERT INTO TABLE %s VALUES (1, named_struct(\"c1\", 3, \"c2\", \"v1\"))", tableName);
sql("INSERT INTO TABLE %s VALUES (2, named_struct(\"c1\", 2, \"c2\", \"v2\"))", tableName);

sql("DELETE FROM %s WHERE complex.c1 = 3", tableName);
Copy link
Contributor Author

@zhongyujiang zhongyujiang Mar 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Delete conditions in testDeleteWithConditionOnNestedColumn can not be push down, so added this UT to cover the corresponding scenario.

assertEquals(
"Should have expected rows", ImmutableList.of(row(2)), sql("SELECT id FROM %s", tableName));

sql("DELETE FROM %s t WHERE t.complex.c1 = 2", tableName);
assertEquals(
"Should have expected rows", ImmutableList.of(), sql("SELECT id FROM %s", tableName));
}

@Test
public void testDeleteWithInSubquery() throws NoSuchTableException {
createAndInitUnpartitionedTable();
Expand Down