Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Api, Spark: Make StrictMetricsEvaluator not fail on nested column predicates #11261

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,7 @@
import org.apache.iceberg.DataFile;
import org.apache.iceberg.Schema;
import org.apache.iceberg.expressions.ExpressionVisitors.BoundExpressionVisitor;
import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
import org.apache.iceberg.types.Conversions;
import org.apache.iceberg.types.Types;
import org.apache.iceberg.types.Types.StructType;
import org.apache.iceberg.util.NaNUtil;

Expand All @@ -51,7 +49,6 @@
* checks for NaN is necessary in order to not include files that may contain rows that don't match.
*/
public class StrictMetricsEvaluator {
private final Schema schema;
private final StructType struct;
private final Expression expr;

Expand All @@ -60,7 +57,6 @@ public StrictMetricsEvaluator(Schema schema, Expression unbound) {
}

public StrictMetricsEvaluator(Schema schema, Expression unbound, boolean caseSensitive) {
this.schema = schema;
amogh-jahagirdar marked this conversation as resolved.
Show resolved Hide resolved
this.struct = schema.asStruct();
this.expr = Binder.bind(struct, rewriteNot(unbound), caseSensitive);
}
Expand Down Expand Up @@ -144,8 +140,9 @@ public <T> Boolean isNull(BoundReference<T> ref) {
// no need to check whether the field is required because binding evaluates that case
// if the column has any non-null values, the expression does not match
int id = ref.fieldId();
Preconditions.checkNotNull(
struct.field(id), "Cannot filter by nested column: %s", schema.findField(id));
if (isNestedColumn(id)) {
return ROWS_MIGHT_NOT_MATCH;
}

if (containsNullsOnly(id)) {
return ROWS_MUST_MATCH;
Expand All @@ -159,8 +156,9 @@ public <T> Boolean notNull(BoundReference<T> ref) {
// no need to check whether the field is required because binding evaluates that case
// if the column has any null values, the expression does not match
int id = ref.fieldId();
Preconditions.checkNotNull(
struct.field(id), "Cannot filter by nested column: %s", schema.findField(id));
if (isNestedColumn(id)) {
return ROWS_MIGHT_NOT_MATCH;
}

if (nullCounts != null && nullCounts.containsKey(id) && nullCounts.get(id) == 0) {
return ROWS_MUST_MATCH;
Expand Down Expand Up @@ -199,15 +197,16 @@ public <T> Boolean notNaN(BoundReference<T> ref) {
public <T> Boolean lt(BoundReference<T> ref, Literal<T> lit) {
// Rows must match when: <----------Min----Max---X------->
Integer id = ref.fieldId();
Types.NestedField field = struct.field(id);
Preconditions.checkNotNull(field, "Cannot filter by nested column: %s", schema.findField(id));
if (isNestedColumn(id)) {
return ROWS_MIGHT_NOT_MATCH;
}

if (canContainNulls(id) || canContainNaNs(id)) {
return ROWS_MIGHT_NOT_MATCH;
}

if (upperBounds != null && upperBounds.containsKey(id)) {
T upper = Conversions.fromByteBuffer(field.type(), upperBounds.get(id));
T upper = Conversions.fromByteBuffer(ref.type(), upperBounds.get(id));

int cmp = lit.comparator().compare(upper, lit.value());
if (cmp < 0) {
Expand All @@ -222,15 +221,16 @@ public <T> Boolean lt(BoundReference<T> ref, Literal<T> lit) {
public <T> Boolean ltEq(BoundReference<T> ref, Literal<T> lit) {
// Rows must match when: <----------Min----Max---X------->
Integer id = ref.fieldId();
Types.NestedField field = struct.field(id);
Preconditions.checkNotNull(field, "Cannot filter by nested column: %s", schema.findField(id));
if (isNestedColumn(id)) {
return ROWS_MIGHT_NOT_MATCH;
}

if (canContainNulls(id) || canContainNaNs(id)) {
return ROWS_MIGHT_NOT_MATCH;
}

if (upperBounds != null && upperBounds.containsKey(id)) {
T upper = Conversions.fromByteBuffer(field.type(), upperBounds.get(id));
T upper = Conversions.fromByteBuffer(ref.type(), upperBounds.get(id));

int cmp = lit.comparator().compare(upper, lit.value());
if (cmp <= 0) {
Expand All @@ -245,15 +245,16 @@ public <T> Boolean ltEq(BoundReference<T> ref, Literal<T> lit) {
public <T> Boolean gt(BoundReference<T> ref, Literal<T> lit) {
// Rows must match when: <-------X---Min----Max---------->
Integer id = ref.fieldId();
Types.NestedField field = struct.field(id);
Preconditions.checkNotNull(field, "Cannot filter by nested column: %s", schema.findField(id));
if (isNestedColumn(id)) {
return ROWS_MIGHT_NOT_MATCH;
}

if (canContainNulls(id) || canContainNaNs(id)) {
return ROWS_MIGHT_NOT_MATCH;
}

if (lowerBounds != null && lowerBounds.containsKey(id)) {
T lower = Conversions.fromByteBuffer(field.type(), lowerBounds.get(id));
T lower = Conversions.fromByteBuffer(ref.type(), lowerBounds.get(id));

if (NaNUtil.isNaN(lower)) {
// NaN indicates unreliable bounds. See the StrictMetricsEvaluator docs for more.
Expand All @@ -273,15 +274,16 @@ public <T> Boolean gt(BoundReference<T> ref, Literal<T> lit) {
public <T> Boolean gtEq(BoundReference<T> ref, Literal<T> lit) {
// Rows must match when: <-------X---Min----Max---------->
Integer id = ref.fieldId();
Types.NestedField field = struct.field(id);
Preconditions.checkNotNull(field, "Cannot filter by nested column: %s", schema.findField(id));
if (isNestedColumn(id)) {
return ROWS_MIGHT_NOT_MATCH;
}

if (canContainNulls(id) || canContainNaNs(id)) {
return ROWS_MIGHT_NOT_MATCH;
}

if (lowerBounds != null && lowerBounds.containsKey(id)) {
T lower = Conversions.fromByteBuffer(field.type(), lowerBounds.get(id));
T lower = Conversions.fromByteBuffer(ref.type(), lowerBounds.get(id));

if (NaNUtil.isNaN(lower)) {
// NaN indicates unreliable bounds. See the StrictMetricsEvaluator docs for more.
Expand All @@ -301,8 +303,9 @@ public <T> Boolean gtEq(BoundReference<T> ref, Literal<T> lit) {
public <T> Boolean eq(BoundReference<T> ref, Literal<T> lit) {
// Rows must match when Min == X == Max
Integer id = ref.fieldId();
Types.NestedField field = struct.field(id);
Preconditions.checkNotNull(field, "Cannot filter by nested column: %s", schema.findField(id));
if (isNestedColumn(id)) {
return ROWS_MIGHT_NOT_MATCH;
}

if (canContainNulls(id) || canContainNaNs(id)) {
return ROWS_MIGHT_NOT_MATCH;
Expand All @@ -319,7 +322,7 @@ public <T> Boolean eq(BoundReference<T> ref, Literal<T> lit) {
return ROWS_MIGHT_NOT_MATCH;
}

T upper = Conversions.fromByteBuffer(field.type(), upperBounds.get(id));
T upper = Conversions.fromByteBuffer(ref.type(), upperBounds.get(id));

cmp = lit.comparator().compare(upper, lit.value());
if (cmp != 0) {
Expand All @@ -336,8 +339,9 @@ public <T> Boolean eq(BoundReference<T> ref, Literal<T> lit) {
public <T> Boolean notEq(BoundReference<T> ref, Literal<T> lit) {
// Rows must match when X < Min or Max < X because it is not in the range
Integer id = ref.fieldId();
Types.NestedField field = struct.field(id);
Preconditions.checkNotNull(field, "Cannot filter by nested column: %s", schema.findField(id));
if (isNestedColumn(id)) {
return ROWS_MIGHT_NOT_MATCH;
}

if (containsNullsOnly(id) || containsNaNsOnly(id)) {
return ROWS_MUST_MATCH;
Expand All @@ -358,7 +362,7 @@ public <T> Boolean notEq(BoundReference<T> ref, Literal<T> lit) {
}

if (upperBounds != null && upperBounds.containsKey(id)) {
T upper = Conversions.fromByteBuffer(field.type(), upperBounds.get(id));
T upper = Conversions.fromByteBuffer(ref.type(), upperBounds.get(id));

int cmp = lit.comparator().compare(upper, lit.value());
if (cmp < 0) {
Expand All @@ -372,8 +376,9 @@ public <T> Boolean notEq(BoundReference<T> ref, Literal<T> lit) {
@Override
public <T> Boolean in(BoundReference<T> ref, Set<T> literalSet) {
Integer id = ref.fieldId();
Types.NestedField field = struct.field(id);
Preconditions.checkNotNull(field, "Cannot filter by nested column: %s", schema.findField(id));
if (isNestedColumn(id)) {
return ROWS_MIGHT_NOT_MATCH;
}

if (canContainNulls(id) || canContainNaNs(id)) {
return ROWS_MIGHT_NOT_MATCH;
Expand All @@ -390,7 +395,7 @@ public <T> Boolean in(BoundReference<T> ref, Set<T> literalSet) {
}

// check if the upper bound is in the set
T upper = Conversions.fromByteBuffer(field.type(), upperBounds.get(id));
T upper = Conversions.fromByteBuffer(ref.type(), upperBounds.get(id));
if (!literalSet.contains(upper)) {
return ROWS_MIGHT_NOT_MATCH;
}
Expand All @@ -411,8 +416,9 @@ public <T> Boolean in(BoundReference<T> ref, Set<T> literalSet) {
@Override
public <T> Boolean notIn(BoundReference<T> ref, Set<T> literalSet) {
Integer id = ref.fieldId();
Types.NestedField field = struct.field(id);
Preconditions.checkNotNull(field, "Cannot filter by nested column: %s", schema.findField(id));
if (isNestedColumn(id)) {
return ROWS_MIGHT_NOT_MATCH;
}

if (containsNullsOnly(id) || containsNaNsOnly(id)) {
return ROWS_MUST_MATCH;
Expand All @@ -439,7 +445,7 @@ public <T> Boolean notIn(BoundReference<T> ref, Set<T> literalSet) {
}

if (upperBounds != null && upperBounds.containsKey(id)) {
T upper = Conversions.fromByteBuffer(field.type(), upperBounds.get(id));
T upper = Conversions.fromByteBuffer(ref.type(), upperBounds.get(id));
literals =
literals.stream()
.filter(v -> ref.comparator().compare(upper, v) >= 0)
Expand All @@ -466,6 +472,10 @@ public <T> Boolean notStartsWith(BoundReference<T> ref, Literal<T> lit) {
return ROWS_MIGHT_NOT_MATCH;
}

private boolean isNestedColumn(int id) {
return struct.field(id) == null;
}

private boolean canContainNulls(Integer id) {
return nullCounts == null || (nullCounts.containsKey(id) && nullCounts.get(id) > 0);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,14 @@ public class TestStrictMetricsEvaluator {
optional(11, "all_nulls_double", Types.DoubleType.get()),
optional(12, "all_nans_v1_stats", Types.FloatType.get()),
optional(13, "nan_and_null_only", Types.DoubleType.get()),
optional(14, "no_nan_stats", Types.DoubleType.get()));
optional(14, "no_nan_stats", Types.DoubleType.get()),
optional(
15,
"struct",
Types.StructType.of(
Types.NestedField.optional(16, "nested_col_no_stats", Types.IntegerType.get()),
Types.NestedField.optional(
17, "nested_col_with_stats", Types.IntegerType.get()))));

private static final int INT_MIN_VALUE = 30;
private static final int INT_MAX_VALUE = 79;
Expand All @@ -88,6 +95,7 @@ public class TestStrictMetricsEvaluator {
.put(12, 50L)
.put(13, 50L)
.put(14, 50L)
.put(17, 50L)
.buildOrThrow(),
// null value counts
ImmutableMap.<Integer, Long>builder()
Expand All @@ -97,6 +105,7 @@ public class TestStrictMetricsEvaluator {
.put(11, 50L)
.put(12, 0L)
.put(13, 1L)
.put(17, 0L)
.buildOrThrow(),
// nan value counts
ImmutableMap.of(
Expand All @@ -108,13 +117,15 @@ public class TestStrictMetricsEvaluator {
1, toByteBuffer(IntegerType.get(), INT_MIN_VALUE),
7, toByteBuffer(IntegerType.get(), 5),
12, toByteBuffer(Types.FloatType.get(), Float.NaN),
13, toByteBuffer(Types.DoubleType.get(), Double.NaN)),
13, toByteBuffer(Types.DoubleType.get(), Double.NaN),
17, toByteBuffer(Types.IntegerType.get(), INT_MIN_VALUE)),
// upper bounds
ImmutableMap.of(
1, toByteBuffer(IntegerType.get(), INT_MAX_VALUE),
7, toByteBuffer(IntegerType.get(), 5),
12, toByteBuffer(Types.FloatType.get(), Float.NaN),
13, toByteBuffer(Types.DoubleType.get(), Double.NaN)));
13, toByteBuffer(Types.DoubleType.get(), Double.NaN),
17, toByteBuffer(IntegerType.get(), INT_MAX_VALUE)));

private static final DataFile FILE_2 =
new TestDataFile(
Expand Down Expand Up @@ -627,4 +638,50 @@ public void testIntegerNotIn() {
shouldRead = new StrictMetricsEvaluator(SCHEMA, notIn("no_nulls", "abc", "def")).eval(FILE);
assertThat(shouldRead).as("Should not match: no_nulls field does not have bounds").isFalse();
}

@Test
public void testEvaluateOnNestedColumnWithoutStats() {
boolean shouldRead =
new StrictMetricsEvaluator(
SCHEMA, greaterThanOrEqual("struct.nested_col_no_stats", INT_MIN_VALUE))
.eval(FILE);
assertThat(shouldRead).as("greaterThanOrEqual nested column should not match").isFalse();

shouldRead =
new StrictMetricsEvaluator(
SCHEMA, lessThanOrEqual("struct.nested_col_no_stats", INT_MAX_VALUE))
.eval(FILE);
assertThat(shouldRead).as("lessThanOrEqual nested column should not match").isFalse();

shouldRead =
new StrictMetricsEvaluator(SCHEMA, isNull("struct.nested_col_no_stats")).eval(FILE);
assertThat(shouldRead).as("isNull nested column should not match").isFalse();

shouldRead =
new StrictMetricsEvaluator(SCHEMA, notNull("struct.nested_col_no_stats")).eval(FILE);
assertThat(shouldRead).as("notNull nested column should not match").isFalse();
}

@Test
public void testEvaluateOnNestedColumnWithStats() {
boolean shouldRead =
new StrictMetricsEvaluator(
SCHEMA, greaterThanOrEqual("struct.nested_col_with_stats", INT_MIN_VALUE))
.eval(FILE);
assertThat(shouldRead).as("greaterThanOrEqual nested column should not match").isFalse();

shouldRead =
new StrictMetricsEvaluator(
SCHEMA, lessThanOrEqual("struct.nested_col_with_stats", INT_MAX_VALUE))
.eval(FILE);
assertThat(shouldRead).as("lessThanOrEqual nested column should not match").isFalse();

shouldRead =
new StrictMetricsEvaluator(SCHEMA, isNull("struct.nested_col_with_stats")).eval(FILE);
assertThat(shouldRead).as("isNull nested column should not match").isFalse();

shouldRead =
new StrictMetricsEvaluator(SCHEMA, notNull("struct.nested_col_with_stats")).eval(FILE);
assertThat(shouldRead).as("notNull nested column should not match").isFalse();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1401,6 +1401,28 @@ public void testDeleteToCustomWapBranchWithoutWhereClause() throws NoSuchTableEx
});
}

@TestTemplate
public void testDeleteWithFilterOnNestedColumn() {
createAndInitNestedColumnsTable();

sql("INSERT INTO TABLE %s VALUES (1, named_struct(\"c1\", 3, \"c2\", \"v1\"))", tableName);
sql("INSERT INTO TABLE %s VALUES (2, named_struct(\"c1\", 2, \"c2\", \"v2\"))", tableName);

sql("DELETE FROM %s WHERE complex.c1 > 3", tableName);
assertEquals(
"Should have expected rows",
ImmutableList.of(row(1), row(2)),
sql("SELECT id FROM %s order by id", tableName));

sql("DELETE FROM %s WHERE complex.c1 = 3", tableName);
assertEquals(
"Should have expected rows", ImmutableList.of(row(2)), sql("SELECT id FROM %s", tableName));

sql("DELETE FROM %s t WHERE t.complex.c1 = 2", tableName);
assertEquals(
"Should have expected rows", ImmutableList.of(), sql("SELECT id FROM %s", tableName));
amogh-jahagirdar marked this conversation as resolved.
Show resolved Hide resolved
}

// TODO: multiple stripes for ORC

protected void createAndInitPartitionedTable() {
Expand Down