From 7e1a4c9fedeb679be85a1921ade7995d5ec2cbec Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Wed, 18 Dec 2024 08:10:15 -0800 Subject: [PATCH] Spark 3.5: Support default values in Parquet reader (#11803) --- .../org/apache/iceberg/spark/SparkUtil.java | 66 +++++ .../spark/data/SparkParquetReaders.java | 26 +- .../iceberg/spark/source/BaseReader.java | 62 +---- .../iceberg/spark/data/TestHelpers.java | 22 +- .../spark/data/TestSparkParquetReader.java | 233 +++++++++++++++++- 5 files changed, 329 insertions(+), 80 deletions(-) diff --git a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/SparkUtil.java b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/SparkUtil.java index de06cceb2677..4bd2e9c21551 100644 --- a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/SparkUtil.java +++ b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/SparkUtil.java @@ -18,6 +18,8 @@ */ package org.apache.iceberg.spark; +import java.math.BigDecimal; +import java.nio.ByteBuffer; import java.sql.Date; import java.sql.Timestamp; import java.util.List; @@ -25,14 +27,20 @@ import java.util.function.BiFunction; import java.util.function.Function; import java.util.stream.Collectors; +import org.apache.avro.generic.GenericData; +import org.apache.avro.util.Utf8; import org.apache.hadoop.conf.Configuration; import org.apache.iceberg.PartitionField; import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.StructLike; import org.apache.iceberg.relocated.com.google.common.base.Joiner; import org.apache.iceberg.relocated.com.google.common.base.Preconditions; import org.apache.iceberg.relocated.com.google.common.collect.Lists; import org.apache.iceberg.transforms.Transform; import org.apache.iceberg.transforms.UnknownTransform; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.ByteBuffers; import org.apache.iceberg.util.Pair; import org.apache.spark.SparkEnv; import org.apache.spark.scheduler.ExecutorCacheTaskLocation; @@ -40,14 +48,17 @@ import org.apache.spark.sql.catalyst.expressions.BoundReference; import org.apache.spark.sql.catalyst.expressions.EqualTo; import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; import org.apache.spark.sql.catalyst.expressions.Literal; import org.apache.spark.sql.connector.expressions.NamedReference; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Decimal; import org.apache.spark.sql.types.StructType; import org.apache.spark.storage.BlockManager; import org.apache.spark.storage.BlockManagerId; import org.apache.spark.storage.BlockManagerMaster; +import org.apache.spark.unsafe.types.UTF8String; import org.joda.time.DateTime; import scala.collection.JavaConverters; import scala.collection.Seq; @@ -268,4 +279,59 @@ private static List toJavaList(Seq seq) { private static String toExecutorLocation(BlockManagerId id) { return ExecutorCacheTaskLocation.apply(id.host(), id.executorId()).toString(); } + + /** + * Converts a value to pass into Spark from Iceberg's internal object model. + * + * @param type an Iceberg type + * @param value a value that is an instance of {@link Type.TypeID#javaClass()} + * @return the value converted for Spark + */ + public static Object convertConstant(Type type, Object value) { + if (value == null) { + return null; + } + + switch (type.typeId()) { + case DECIMAL: + return Decimal.apply((BigDecimal) value); + case STRING: + if (value instanceof Utf8) { + Utf8 utf8 = (Utf8) value; + return UTF8String.fromBytes(utf8.getBytes(), 0, utf8.getByteLength()); + } + return UTF8String.fromString(value.toString()); + case FIXED: + if (value instanceof byte[]) { + return value; + } else if (value instanceof GenericData.Fixed) { + return ((GenericData.Fixed) value).bytes(); + } + return ByteBuffers.toByteArray((ByteBuffer) value); + case BINARY: + return ByteBuffers.toByteArray((ByteBuffer) value); + case STRUCT: + Types.StructType structType = (Types.StructType) type; + + if (structType.fields().isEmpty()) { + return new GenericInternalRow(); + } + + List fields = structType.fields(); + Object[] values = new Object[fields.size()]; + StructLike struct = (StructLike) value; + + for (int index = 0; index < fields.size(); index++) { + Types.NestedField field = fields.get(index); + Type fieldType = field.type(); + values[index] = + convertConstant(fieldType, struct.get(index, fieldType.typeId().javaClass())); + } + + return new GenericInternalRow(values); + default: + } + + return value; + } } diff --git a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/data/SparkParquetReaders.java b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/data/SparkParquetReaders.java index af16d9bbc290..65e5843e39b3 100644 --- a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/data/SparkParquetReaders.java +++ b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/data/SparkParquetReaders.java @@ -44,6 +44,7 @@ import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; import org.apache.iceberg.relocated.com.google.common.collect.Lists; import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.SparkUtil; import org.apache.iceberg.types.Type.TypeID; import org.apache.iceberg.types.Types; import org.apache.iceberg.util.UUIDUtil; @@ -165,6 +166,7 @@ public ParquetValueReader struct( int defaultMaxDefinitionLevel = type.getMaxDefinitionLevel(currentPath()); for (Types.NestedField field : expectedFields) { int id = field.fieldId(); + ParquetValueReader reader = readersById.get(id); if (idToConstant.containsKey(id)) { // containsKey is used because the constant may be null int fieldMaxDefinitionLevel = @@ -178,15 +180,21 @@ public ParquetValueReader struct( } else if (id == MetadataColumns.IS_DELETED.fieldId()) { reorderedFields.add(ParquetValueReaders.constant(false)); types.add(null); + } else if (reader != null) { + reorderedFields.add(reader); + types.add(typesById.get(id)); + } else if (field.initialDefault() != null) { + reorderedFields.add( + ParquetValueReaders.constant( + SparkUtil.convertConstant(field.type(), field.initialDefault()), + maxDefinitionLevelsById.getOrDefault(id, defaultMaxDefinitionLevel))); + types.add(typesById.get(id)); + } else if (field.isOptional()) { + reorderedFields.add(ParquetValueReaders.nulls()); + types.add(null); } else { - ParquetValueReader reader = readersById.get(id); - if (reader != null) { - reorderedFields.add(reader); - types.add(typesById.get(id)); - } else { - reorderedFields.add(ParquetValueReaders.nulls()); - types.add(null); - } + throw new IllegalArgumentException( + String.format("Missing required field: %s", field.name())); } } @@ -250,7 +258,7 @@ public ParquetValueReader primitive( if (expected != null && expected.typeId() == Types.LongType.get().typeId()) { return new IntAsLongReader(desc); } else { - return new UnboxedReader(desc); + return new UnboxedReader<>(desc); } case DATE: case INT_64: diff --git a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/source/BaseReader.java b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/source/BaseReader.java index f8e8a1f1dd6b..6f0ee1d2e2a0 100644 --- a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/source/BaseReader.java +++ b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/source/BaseReader.java @@ -20,8 +20,6 @@ import java.io.Closeable; import java.io.IOException; -import java.math.BigDecimal; -import java.nio.ByteBuffer; import java.util.Iterator; import java.util.List; import java.util.Map; @@ -29,8 +27,6 @@ import java.util.function.Supplier; import java.util.stream.Collectors; import java.util.stream.Stream; -import org.apache.avro.generic.GenericData; -import org.apache.avro.util.Utf8; import org.apache.iceberg.ContentFile; import org.apache.iceberg.ContentScanTask; import org.apache.iceberg.DeleteFile; @@ -53,16 +49,11 @@ import org.apache.iceberg.mapping.NameMappingParser; import org.apache.iceberg.spark.SparkExecutorCache; import org.apache.iceberg.spark.SparkSchemaUtil; -import org.apache.iceberg.types.Type; -import org.apache.iceberg.types.Types.NestedField; +import org.apache.iceberg.spark.SparkUtil; import org.apache.iceberg.types.Types.StructType; -import org.apache.iceberg.util.ByteBuffers; import org.apache.iceberg.util.PartitionUtil; import org.apache.spark.rdd.InputFileBlockHolder; import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; -import org.apache.spark.sql.types.Decimal; -import org.apache.spark.unsafe.types.UTF8String; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -193,59 +184,12 @@ private Map inputFiles() { protected Map constantsMap(ContentScanTask task, Schema readSchema) { if (readSchema.findField(MetadataColumns.PARTITION_COLUMN_ID) != null) { StructType partitionType = Partitioning.partitionType(table); - return PartitionUtil.constantsMap(task, partitionType, BaseReader::convertConstant); + return PartitionUtil.constantsMap(task, partitionType, SparkUtil::convertConstant); } else { - return PartitionUtil.constantsMap(task, BaseReader::convertConstant); + return PartitionUtil.constantsMap(task, SparkUtil::convertConstant); } } - protected static Object convertConstant(Type type, Object value) { - if (value == null) { - return null; - } - - switch (type.typeId()) { - case DECIMAL: - return Decimal.apply((BigDecimal) value); - case STRING: - if (value instanceof Utf8) { - Utf8 utf8 = (Utf8) value; - return UTF8String.fromBytes(utf8.getBytes(), 0, utf8.getByteLength()); - } - return UTF8String.fromString(value.toString()); - case FIXED: - if (value instanceof byte[]) { - return value; - } else if (value instanceof GenericData.Fixed) { - return ((GenericData.Fixed) value).bytes(); - } - return ByteBuffers.toByteArray((ByteBuffer) value); - case BINARY: - return ByteBuffers.toByteArray((ByteBuffer) value); - case STRUCT: - StructType structType = (StructType) type; - - if (structType.fields().isEmpty()) { - return new GenericInternalRow(); - } - - List fields = structType.fields(); - Object[] values = new Object[fields.size()]; - StructLike struct = (StructLike) value; - - for (int index = 0; index < fields.size(); index++) { - NestedField field = fields.get(index); - Type fieldType = field.type(); - values[index] = - convertConstant(fieldType, struct.get(index, fieldType.typeId().javaClass())); - } - - return new GenericInternalRow(values); - default: - } - return value; - } - protected class SparkDeleteFilter extends DeleteFilter { private final InternalRowWrapper asStructLike; diff --git a/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/data/TestHelpers.java b/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/data/TestHelpers.java index f9f24834546f..5511ce24337e 100644 --- a/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/data/TestHelpers.java +++ b/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/data/TestHelpers.java @@ -42,6 +42,7 @@ import java.util.stream.Collectors; import java.util.stream.Stream; import java.util.stream.StreamSupport; +import org.apache.avro.Schema.Field; import org.apache.avro.generic.GenericData; import org.apache.avro.generic.GenericData.Record; import org.apache.iceberg.DataFile; @@ -246,7 +247,7 @@ private static void assertEqualsSafe(Type type, Object expected, Object actual) assertThat(expected).as("Should expect a Collection").isInstanceOf(Collection.class); assertThat(actual).as("Should be a Seq").isInstanceOf(Seq.class); List asList = seqAsJavaListConverter((Seq) actual).asJava(); - assertEqualsSafe(type.asNestedType().asListType(), (Collection) expected, asList); + assertEqualsSafe(type.asNestedType().asListType(), (Collection) expected, asList); break; case MAP: assertThat(expected).as("Should expect a Collection").isInstanceOf(Map.class); @@ -263,11 +264,20 @@ private static void assertEqualsSafe(Type type, Object expected, Object actual) public static void assertEqualsUnsafe(Types.StructType struct, Record rec, InternalRow row) { List fields = struct.fields(); - for (int i = 0; i < fields.size(); i += 1) { - Type fieldType = fields.get(i).type(); - - Object expectedValue = rec.get(i); - Object actualValue = row.isNullAt(i) ? null : row.get(i, convert(fieldType)); + for (int readPos = 0; readPos < fields.size(); readPos += 1) { + Types.NestedField field = fields.get(readPos); + Field writeField = rec.getSchema().getField(field.name()); + + Type fieldType = field.type(); + Object actualValue = row.isNullAt(readPos) ? null : row.get(readPos, convert(fieldType)); + + Object expectedValue; + if (writeField != null) { + int writePos = writeField.pos(); + expectedValue = rec.get(writePos); + } else { + expectedValue = field.initialDefault(); + } assertEqualsUnsafe(fieldType, expectedValue, actualValue); } diff --git a/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkParquetReader.java b/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkParquetReader.java index ab0d45c3b7ca..0ac0bb530c77 100644 --- a/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkParquetReader.java +++ b/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkParquetReader.java @@ -21,6 +21,7 @@ import static org.apache.iceberg.spark.data.TestHelpers.assertEqualsUnsafe; import static org.apache.iceberg.types.Types.NestedField.required; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.assertj.core.api.Assumptions.assumeThat; import java.io.File; @@ -63,32 +64,36 @@ public class TestSparkParquetReader extends AvroDataTest { @Override protected void writeAndValidate(Schema schema) throws IOException { + writeAndValidate(schema, schema); + } + + protected void writeAndValidate(Schema writeSchema, Schema expectedSchema) throws IOException { assumeThat( TypeUtil.find( - schema, + writeSchema, type -> type.isMapType() && type.asMapType().keyType() != Types.StringType.get())) .as("Parquet Avro cannot write non-string map keys") .isNull(); - List expected = RandomData.generateList(schema, 100, 0L); + List expected = RandomData.generateList(writeSchema, 100, 0L); File testFile = File.createTempFile("junit", null, temp.toFile()); assertThat(testFile.delete()).as("Delete should succeed").isTrue(); try (FileAppender writer = - Parquet.write(Files.localOutput(testFile)).schema(schema).named("test").build()) { + Parquet.write(Files.localOutput(testFile)).schema(writeSchema).named("test").build()) { writer.addAll(expected); } try (CloseableIterable reader = Parquet.read(Files.localInput(testFile)) - .project(schema) - .createReaderFunc(type -> SparkParquetReaders.buildReader(schema, type)) + .project(expectedSchema) + .createReaderFunc(type -> SparkParquetReaders.buildReader(expectedSchema, type)) .build()) { Iterator rows = reader.iterator(); for (GenericData.Record record : expected) { assertThat(rows).as("Should have expected number of rows").hasNext(); - assertEqualsUnsafe(schema.asStruct(), record, rows.next()); + assertEqualsUnsafe(expectedSchema.asStruct(), record, rows.next()); } assertThat(rows).as("Should not have extra rows").isExhausted(); } @@ -202,4 +207,220 @@ protected WriteSupport getWriteSupport(Configuration configuration) return new org.apache.spark.sql.execution.datasources.parquet.ParquetWriteSupport(); } } + + @Test + public void testMissingRequiredWithoutDefault() { + Schema writeSchema = new Schema(required(1, "id", Types.LongType.get())); + + Schema expectedSchema = + new Schema( + required(1, "id", Types.LongType.get()), + Types.NestedField.required("missing_str") + .withId(6) + .ofType(Types.StringType.get()) + .withDoc("Missing required field with no default") + .build()); + + assertThatThrownBy(() -> writeAndValidate(writeSchema, expectedSchema)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Missing required field: missing_str"); + } + + @Test + public void testDefaultValues() throws IOException { + Schema writeSchema = + new Schema( + required(1, "id", Types.LongType.get()), + Types.NestedField.optional("data") + .withId(2) + .ofType(Types.StringType.get()) + .withInitialDefault("wrong!") + .withDoc("Should not produce default value") + .build()); + + Schema expectedSchema = + new Schema( + required(1, "id", Types.LongType.get()), + Types.NestedField.optional("data") + .withId(2) + .ofType(Types.StringType.get()) + .withInitialDefault("wrong!") + .build(), + Types.NestedField.required("missing_str") + .withId(6) + .ofType(Types.StringType.get()) + .withInitialDefault("orange") + .build(), + Types.NestedField.optional("missing_int") + .withId(7) + .ofType(Types.IntegerType.get()) + .withInitialDefault(34) + .build()); + + writeAndValidate(writeSchema, expectedSchema); + } + + @Test + public void testNullDefaultValue() throws IOException { + Schema writeSchema = + new Schema( + required(1, "id", Types.LongType.get()), + Types.NestedField.optional("data") + .withId(2) + .ofType(Types.StringType.get()) + .withInitialDefault("wrong!") + .withDoc("Should not produce default value") + .build()); + + Schema expectedSchema = + new Schema( + required(1, "id", Types.LongType.get()), + Types.NestedField.optional("data") + .withId(2) + .ofType(Types.StringType.get()) + .withInitialDefault("wrong!") + .build(), + Types.NestedField.optional("missing_date") + .withId(3) + .ofType(Types.DateType.get()) + .build()); + + writeAndValidate(writeSchema, expectedSchema); + } + + @Test + public void testNestedDefaultValue() throws IOException { + Schema writeSchema = + new Schema( + required(1, "id", Types.LongType.get()), + Types.NestedField.optional("data") + .withId(2) + .ofType(Types.StringType.get()) + .withInitialDefault("wrong!") + .withDoc("Should not produce default value") + .build(), + Types.NestedField.optional("nested") + .withId(3) + .ofType(Types.StructType.of(required(4, "inner", Types.StringType.get()))) + .withDoc("Used to test nested field defaults") + .build()); + + Schema expectedSchema = + new Schema( + required(1, "id", Types.LongType.get()), + Types.NestedField.optional("data") + .withId(2) + .ofType(Types.StringType.get()) + .withInitialDefault("wrong!") + .build(), + Types.NestedField.optional("nested") + .withId(3) + .ofType( + Types.StructType.of( + required(4, "inner", Types.StringType.get()), + Types.NestedField.optional("missing_inner_float") + .withId(5) + .ofType(Types.FloatType.get()) + .withInitialDefault(-0.0F) + .build())) + .withDoc("Used to test nested field defaults") + .build()); + + writeAndValidate(writeSchema, expectedSchema); + } + + @Test + public void testMapNestedDefaultValue() throws IOException { + Schema writeSchema = + new Schema( + required(1, "id", Types.LongType.get()), + Types.NestedField.optional("data") + .withId(2) + .ofType(Types.StringType.get()) + .withInitialDefault("wrong!") + .withDoc("Should not produce default value") + .build(), + Types.NestedField.optional("nested_map") + .withId(3) + .ofType( + Types.MapType.ofOptional( + 4, + 5, + Types.StringType.get(), + Types.StructType.of(required(6, "value_str", Types.StringType.get())))) + .withDoc("Used to test nested map value field defaults") + .build()); + + Schema expectedSchema = + new Schema( + required(1, "id", Types.LongType.get()), + Types.NestedField.optional("data") + .withId(2) + .ofType(Types.StringType.get()) + .withInitialDefault("wrong!") + .build(), + Types.NestedField.optional("nested_map") + .withId(3) + .ofType( + Types.MapType.ofOptional( + 4, + 5, + Types.StringType.get(), + Types.StructType.of( + required(6, "value_str", Types.StringType.get()), + Types.NestedField.optional("value_int") + .withId(7) + .ofType(Types.IntegerType.get()) + .withInitialDefault(34) + .build()))) + .withDoc("Used to test nested field defaults") + .build()); + + writeAndValidate(writeSchema, expectedSchema); + } + + @Test + public void testListNestedDefaultValue() throws IOException { + Schema writeSchema = + new Schema( + required(1, "id", Types.LongType.get()), + Types.NestedField.optional("data") + .withId(2) + .ofType(Types.StringType.get()) + .withInitialDefault("wrong!") + .withDoc("Should not produce default value") + .build(), + Types.NestedField.optional("nested_list") + .withId(3) + .ofType( + Types.ListType.ofOptional( + 4, Types.StructType.of(required(5, "element_str", Types.StringType.get())))) + .withDoc("Used to test nested field defaults") + .build()); + + Schema expectedSchema = + new Schema( + required(1, "id", Types.LongType.get()), + Types.NestedField.optional("data") + .withId(2) + .ofType(Types.StringType.get()) + .withInitialDefault("wrong!") + .build(), + Types.NestedField.optional("nested_list") + .withId(3) + .ofType( + Types.ListType.ofOptional( + 4, + Types.StructType.of( + required(5, "element_str", Types.StringType.get()), + Types.NestedField.optional("element_int") + .withId(7) + .ofType(Types.IntegerType.get()) + .withInitialDefault(34) + .build()))) + .withDoc("Used to test nested field defaults") + .build()); + + writeAndValidate(writeSchema, expectedSchema); + } }