Skip to content

Commit

Permalink
API, Spark: Make StrictMetricsEvaluator not fail on nested column pre…
Browse files Browse the repository at this point in the history
…dicates (#11261)
  • Loading branch information
zhongyujiang authored Oct 14, 2024
1 parent 12ff959 commit ca8a3a4
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,7 @@
import org.apache.iceberg.DataFile;
import org.apache.iceberg.Schema;
import org.apache.iceberg.expressions.ExpressionVisitors.BoundExpressionVisitor;
import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
import org.apache.iceberg.types.Conversions;
import org.apache.iceberg.types.Types;
import org.apache.iceberg.types.Types.StructType;
import org.apache.iceberg.util.NaNUtil;

Expand All @@ -51,7 +49,6 @@
* checks for NaN is necessary in order to not include files that may contain rows that don't match.
*/
public class StrictMetricsEvaluator {
private final Schema schema;
private final StructType struct;
private final Expression expr;

Expand All @@ -60,7 +57,6 @@ public StrictMetricsEvaluator(Schema schema, Expression unbound) {
}

public StrictMetricsEvaluator(Schema schema, Expression unbound, boolean caseSensitive) {
this.schema = schema;
this.struct = schema.asStruct();
this.expr = Binder.bind(struct, rewriteNot(unbound), caseSensitive);
}
Expand Down Expand Up @@ -144,8 +140,9 @@ public <T> Boolean isNull(BoundReference<T> ref) {
// no need to check whether the field is required because binding evaluates that case
// if the column has any non-null values, the expression does not match
int id = ref.fieldId();
Preconditions.checkNotNull(
struct.field(id), "Cannot filter by nested column: %s", schema.findField(id));
if (isNestedColumn(id)) {
return ROWS_MIGHT_NOT_MATCH;
}

if (containsNullsOnly(id)) {
return ROWS_MUST_MATCH;
Expand All @@ -159,8 +156,9 @@ public <T> Boolean notNull(BoundReference<T> ref) {
// no need to check whether the field is required because binding evaluates that case
// if the column has any null values, the expression does not match
int id = ref.fieldId();
Preconditions.checkNotNull(
struct.field(id), "Cannot filter by nested column: %s", schema.findField(id));
if (isNestedColumn(id)) {
return ROWS_MIGHT_NOT_MATCH;
}

if (nullCounts != null && nullCounts.containsKey(id) && nullCounts.get(id) == 0) {
return ROWS_MUST_MATCH;
Expand Down Expand Up @@ -199,15 +197,16 @@ public <T> Boolean notNaN(BoundReference<T> ref) {
public <T> Boolean lt(BoundReference<T> ref, Literal<T> lit) {
// Rows must match when: <----------Min----Max---X------->
Integer id = ref.fieldId();
Types.NestedField field = struct.field(id);
Preconditions.checkNotNull(field, "Cannot filter by nested column: %s", schema.findField(id));
if (isNestedColumn(id)) {
return ROWS_MIGHT_NOT_MATCH;
}

if (canContainNulls(id) || canContainNaNs(id)) {
return ROWS_MIGHT_NOT_MATCH;
}

if (upperBounds != null && upperBounds.containsKey(id)) {
T upper = Conversions.fromByteBuffer(field.type(), upperBounds.get(id));
T upper = Conversions.fromByteBuffer(ref.type(), upperBounds.get(id));

int cmp = lit.comparator().compare(upper, lit.value());
if (cmp < 0) {
Expand All @@ -222,15 +221,16 @@ public <T> Boolean lt(BoundReference<T> ref, Literal<T> lit) {
public <T> Boolean ltEq(BoundReference<T> ref, Literal<T> lit) {
// Rows must match when: <----------Min----Max---X------->
Integer id = ref.fieldId();
Types.NestedField field = struct.field(id);
Preconditions.checkNotNull(field, "Cannot filter by nested column: %s", schema.findField(id));
if (isNestedColumn(id)) {
return ROWS_MIGHT_NOT_MATCH;
}

if (canContainNulls(id) || canContainNaNs(id)) {
return ROWS_MIGHT_NOT_MATCH;
}

if (upperBounds != null && upperBounds.containsKey(id)) {
T upper = Conversions.fromByteBuffer(field.type(), upperBounds.get(id));
T upper = Conversions.fromByteBuffer(ref.type(), upperBounds.get(id));

int cmp = lit.comparator().compare(upper, lit.value());
if (cmp <= 0) {
Expand All @@ -245,15 +245,16 @@ public <T> Boolean ltEq(BoundReference<T> ref, Literal<T> lit) {
public <T> Boolean gt(BoundReference<T> ref, Literal<T> lit) {
// Rows must match when: <-------X---Min----Max---------->
Integer id = ref.fieldId();
Types.NestedField field = struct.field(id);
Preconditions.checkNotNull(field, "Cannot filter by nested column: %s", schema.findField(id));
if (isNestedColumn(id)) {
return ROWS_MIGHT_NOT_MATCH;
}

if (canContainNulls(id) || canContainNaNs(id)) {
return ROWS_MIGHT_NOT_MATCH;
}

if (lowerBounds != null && lowerBounds.containsKey(id)) {
T lower = Conversions.fromByteBuffer(field.type(), lowerBounds.get(id));
T lower = Conversions.fromByteBuffer(ref.type(), lowerBounds.get(id));

if (NaNUtil.isNaN(lower)) {
// NaN indicates unreliable bounds. See the StrictMetricsEvaluator docs for more.
Expand All @@ -273,15 +274,16 @@ public <T> Boolean gt(BoundReference<T> ref, Literal<T> lit) {
public <T> Boolean gtEq(BoundReference<T> ref, Literal<T> lit) {
// Rows must match when: <-------X---Min----Max---------->
Integer id = ref.fieldId();
Types.NestedField field = struct.field(id);
Preconditions.checkNotNull(field, "Cannot filter by nested column: %s", schema.findField(id));
if (isNestedColumn(id)) {
return ROWS_MIGHT_NOT_MATCH;
}

if (canContainNulls(id) || canContainNaNs(id)) {
return ROWS_MIGHT_NOT_MATCH;
}

if (lowerBounds != null && lowerBounds.containsKey(id)) {
T lower = Conversions.fromByteBuffer(field.type(), lowerBounds.get(id));
T lower = Conversions.fromByteBuffer(ref.type(), lowerBounds.get(id));

if (NaNUtil.isNaN(lower)) {
// NaN indicates unreliable bounds. See the StrictMetricsEvaluator docs for more.
Expand All @@ -301,8 +303,9 @@ public <T> Boolean gtEq(BoundReference<T> ref, Literal<T> lit) {
public <T> Boolean eq(BoundReference<T> ref, Literal<T> lit) {
// Rows must match when Min == X == Max
Integer id = ref.fieldId();
Types.NestedField field = struct.field(id);
Preconditions.checkNotNull(field, "Cannot filter by nested column: %s", schema.findField(id));
if (isNestedColumn(id)) {
return ROWS_MIGHT_NOT_MATCH;
}

if (canContainNulls(id) || canContainNaNs(id)) {
return ROWS_MIGHT_NOT_MATCH;
Expand All @@ -319,7 +322,7 @@ public <T> Boolean eq(BoundReference<T> ref, Literal<T> lit) {
return ROWS_MIGHT_NOT_MATCH;
}

T upper = Conversions.fromByteBuffer(field.type(), upperBounds.get(id));
T upper = Conversions.fromByteBuffer(ref.type(), upperBounds.get(id));

cmp = lit.comparator().compare(upper, lit.value());
if (cmp != 0) {
Expand All @@ -336,8 +339,9 @@ public <T> Boolean eq(BoundReference<T> ref, Literal<T> lit) {
public <T> Boolean notEq(BoundReference<T> ref, Literal<T> lit) {
// Rows must match when X < Min or Max < X because it is not in the range
Integer id = ref.fieldId();
Types.NestedField field = struct.field(id);
Preconditions.checkNotNull(field, "Cannot filter by nested column: %s", schema.findField(id));
if (isNestedColumn(id)) {
return ROWS_MIGHT_NOT_MATCH;
}

if (containsNullsOnly(id) || containsNaNsOnly(id)) {
return ROWS_MUST_MATCH;
Expand All @@ -358,7 +362,7 @@ public <T> Boolean notEq(BoundReference<T> ref, Literal<T> lit) {
}

if (upperBounds != null && upperBounds.containsKey(id)) {
T upper = Conversions.fromByteBuffer(field.type(), upperBounds.get(id));
T upper = Conversions.fromByteBuffer(ref.type(), upperBounds.get(id));

int cmp = lit.comparator().compare(upper, lit.value());
if (cmp < 0) {
Expand All @@ -372,8 +376,9 @@ public <T> Boolean notEq(BoundReference<T> ref, Literal<T> lit) {
@Override
public <T> Boolean in(BoundReference<T> ref, Set<T> literalSet) {
Integer id = ref.fieldId();
Types.NestedField field = struct.field(id);
Preconditions.checkNotNull(field, "Cannot filter by nested column: %s", schema.findField(id));
if (isNestedColumn(id)) {
return ROWS_MIGHT_NOT_MATCH;
}

if (canContainNulls(id) || canContainNaNs(id)) {
return ROWS_MIGHT_NOT_MATCH;
Expand All @@ -390,7 +395,7 @@ public <T> Boolean in(BoundReference<T> ref, Set<T> literalSet) {
}

// check if the upper bound is in the set
T upper = Conversions.fromByteBuffer(field.type(), upperBounds.get(id));
T upper = Conversions.fromByteBuffer(ref.type(), upperBounds.get(id));
if (!literalSet.contains(upper)) {
return ROWS_MIGHT_NOT_MATCH;
}
Expand All @@ -411,8 +416,9 @@ public <T> Boolean in(BoundReference<T> ref, Set<T> literalSet) {
@Override
public <T> Boolean notIn(BoundReference<T> ref, Set<T> literalSet) {
Integer id = ref.fieldId();
Types.NestedField field = struct.field(id);
Preconditions.checkNotNull(field, "Cannot filter by nested column: %s", schema.findField(id));
if (isNestedColumn(id)) {
return ROWS_MIGHT_NOT_MATCH;
}

if (containsNullsOnly(id) || containsNaNsOnly(id)) {
return ROWS_MUST_MATCH;
Expand All @@ -439,7 +445,7 @@ public <T> Boolean notIn(BoundReference<T> ref, Set<T> literalSet) {
}

if (upperBounds != null && upperBounds.containsKey(id)) {
T upper = Conversions.fromByteBuffer(field.type(), upperBounds.get(id));
T upper = Conversions.fromByteBuffer(ref.type(), upperBounds.get(id));
literals =
literals.stream()
.filter(v -> ref.comparator().compare(upper, v) >= 0)
Expand All @@ -466,6 +472,10 @@ public <T> Boolean notStartsWith(BoundReference<T> ref, Literal<T> lit) {
return ROWS_MIGHT_NOT_MATCH;
}

private boolean isNestedColumn(int id) {
return struct.field(id) == null;
}

private boolean canContainNulls(Integer id) {
return nullCounts == null || (nullCounts.containsKey(id) && nullCounts.get(id) > 0);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,14 @@ public class TestStrictMetricsEvaluator {
optional(11, "all_nulls_double", Types.DoubleType.get()),
optional(12, "all_nans_v1_stats", Types.FloatType.get()),
optional(13, "nan_and_null_only", Types.DoubleType.get()),
optional(14, "no_nan_stats", Types.DoubleType.get()));
optional(14, "no_nan_stats", Types.DoubleType.get()),
optional(
15,
"struct",
Types.StructType.of(
Types.NestedField.optional(16, "nested_col_no_stats", Types.IntegerType.get()),
Types.NestedField.optional(
17, "nested_col_with_stats", Types.IntegerType.get()))));

private static final int INT_MIN_VALUE = 30;
private static final int INT_MAX_VALUE = 79;
Expand All @@ -88,6 +95,7 @@ public class TestStrictMetricsEvaluator {
.put(12, 50L)
.put(13, 50L)
.put(14, 50L)
.put(17, 50L)
.buildOrThrow(),
// null value counts
ImmutableMap.<Integer, Long>builder()
Expand All @@ -97,6 +105,7 @@ public class TestStrictMetricsEvaluator {
.put(11, 50L)
.put(12, 0L)
.put(13, 1L)
.put(17, 0L)
.buildOrThrow(),
// nan value counts
ImmutableMap.of(
Expand All @@ -108,13 +117,15 @@ public class TestStrictMetricsEvaluator {
1, toByteBuffer(IntegerType.get(), INT_MIN_VALUE),
7, toByteBuffer(IntegerType.get(), 5),
12, toByteBuffer(Types.FloatType.get(), Float.NaN),
13, toByteBuffer(Types.DoubleType.get(), Double.NaN)),
13, toByteBuffer(Types.DoubleType.get(), Double.NaN),
17, toByteBuffer(Types.IntegerType.get(), INT_MIN_VALUE)),
// upper bounds
ImmutableMap.of(
1, toByteBuffer(IntegerType.get(), INT_MAX_VALUE),
7, toByteBuffer(IntegerType.get(), 5),
12, toByteBuffer(Types.FloatType.get(), Float.NaN),
13, toByteBuffer(Types.DoubleType.get(), Double.NaN)));
13, toByteBuffer(Types.DoubleType.get(), Double.NaN),
17, toByteBuffer(IntegerType.get(), INT_MAX_VALUE)));

private static final DataFile FILE_2 =
new TestDataFile(
Expand Down Expand Up @@ -627,4 +638,50 @@ public void testIntegerNotIn() {
shouldRead = new StrictMetricsEvaluator(SCHEMA, notIn("no_nulls", "abc", "def")).eval(FILE);
assertThat(shouldRead).as("Should not match: no_nulls field does not have bounds").isFalse();
}

@Test
public void testEvaluateOnNestedColumnWithoutStats() {
boolean shouldRead =
new StrictMetricsEvaluator(
SCHEMA, greaterThanOrEqual("struct.nested_col_no_stats", INT_MIN_VALUE))
.eval(FILE);
assertThat(shouldRead).as("greaterThanOrEqual nested column should not match").isFalse();

shouldRead =
new StrictMetricsEvaluator(
SCHEMA, lessThanOrEqual("struct.nested_col_no_stats", INT_MAX_VALUE))
.eval(FILE);
assertThat(shouldRead).as("lessThanOrEqual nested column should not match").isFalse();

shouldRead =
new StrictMetricsEvaluator(SCHEMA, isNull("struct.nested_col_no_stats")).eval(FILE);
assertThat(shouldRead).as("isNull nested column should not match").isFalse();

shouldRead =
new StrictMetricsEvaluator(SCHEMA, notNull("struct.nested_col_no_stats")).eval(FILE);
assertThat(shouldRead).as("notNull nested column should not match").isFalse();
}

@Test
public void testEvaluateOnNestedColumnWithStats() {
boolean shouldRead =
new StrictMetricsEvaluator(
SCHEMA, greaterThanOrEqual("struct.nested_col_with_stats", INT_MIN_VALUE))
.eval(FILE);
assertThat(shouldRead).as("greaterThanOrEqual nested column should not match").isFalse();

shouldRead =
new StrictMetricsEvaluator(
SCHEMA, lessThanOrEqual("struct.nested_col_with_stats", INT_MAX_VALUE))
.eval(FILE);
assertThat(shouldRead).as("lessThanOrEqual nested column should not match").isFalse();

shouldRead =
new StrictMetricsEvaluator(SCHEMA, isNull("struct.nested_col_with_stats")).eval(FILE);
assertThat(shouldRead).as("isNull nested column should not match").isFalse();

shouldRead =
new StrictMetricsEvaluator(SCHEMA, notNull("struct.nested_col_with_stats")).eval(FILE);
assertThat(shouldRead).as("notNull nested column should not match").isFalse();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1401,6 +1401,28 @@ public void testDeleteToCustomWapBranchWithoutWhereClause() throws NoSuchTableEx
});
}

@TestTemplate
public void testDeleteWithFilterOnNestedColumn() {
createAndInitNestedColumnsTable();

sql("INSERT INTO TABLE %s VALUES (1, named_struct(\"c1\", 3, \"c2\", \"v1\"))", tableName);
sql("INSERT INTO TABLE %s VALUES (2, named_struct(\"c1\", 2, \"c2\", \"v2\"))", tableName);

sql("DELETE FROM %s WHERE complex.c1 > 3", tableName);
assertEquals(
"Should have expected rows",
ImmutableList.of(row(1), row(2)),
sql("SELECT id FROM %s order by id", tableName));

sql("DELETE FROM %s WHERE complex.c1 = 3", tableName);
assertEquals(
"Should have expected rows", ImmutableList.of(row(2)), sql("SELECT id FROM %s", tableName));

sql("DELETE FROM %s t WHERE t.complex.c1 = 2", tableName);
assertEquals(
"Should have expected rows", ImmutableList.of(), sql("SELECT id FROM %s", tableName));
}

// TODO: multiple stripes for ORC

protected void createAndInitPartitionedTable() {
Expand Down

0 comments on commit ca8a3a4

Please sign in to comment.