Skip to content

Commit

Permalink
Spark 3.5: Support default values in Parquet reader (#11803)
Browse files Browse the repository at this point in the history
  • Loading branch information
rdblue authored Dec 18, 2024
1 parent 204a49c commit 7e1a4c9
Show file tree
Hide file tree
Showing 5 changed files with 329 additions and 80 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,36 +18,47 @@
*/
package org.apache.iceberg.spark;

import java.math.BigDecimal;
import java.nio.ByteBuffer;
import java.sql.Date;
import java.sql.Timestamp;
import java.util.List;
import java.util.Map;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.avro.generic.GenericData;
import org.apache.avro.util.Utf8;
import org.apache.hadoop.conf.Configuration;
import org.apache.iceberg.PartitionField;
import org.apache.iceberg.PartitionSpec;
import org.apache.iceberg.StructLike;
import org.apache.iceberg.relocated.com.google.common.base.Joiner;
import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
import org.apache.iceberg.relocated.com.google.common.collect.Lists;
import org.apache.iceberg.transforms.Transform;
import org.apache.iceberg.transforms.UnknownTransform;
import org.apache.iceberg.types.Type;
import org.apache.iceberg.types.Types;
import org.apache.iceberg.util.ByteBuffers;
import org.apache.iceberg.util.Pair;
import org.apache.spark.SparkEnv;
import org.apache.spark.scheduler.ExecutorCacheTaskLocation;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.catalyst.expressions.BoundReference;
import org.apache.spark.sql.catalyst.expressions.EqualTo;
import org.apache.spark.sql.catalyst.expressions.Expression;
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow;
import org.apache.spark.sql.catalyst.expressions.Literal;
import org.apache.spark.sql.connector.expressions.NamedReference;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Decimal;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.storage.BlockManager;
import org.apache.spark.storage.BlockManagerId;
import org.apache.spark.storage.BlockManagerMaster;
import org.apache.spark.unsafe.types.UTF8String;
import org.joda.time.DateTime;
import scala.collection.JavaConverters;
import scala.collection.Seq;
Expand Down Expand Up @@ -268,4 +279,59 @@ private static <T> List<T> toJavaList(Seq<T> seq) {
private static String toExecutorLocation(BlockManagerId id) {
return ExecutorCacheTaskLocation.apply(id.host(), id.executorId()).toString();
}

/**
* Converts a value to pass into Spark from Iceberg's internal object model.
*
* @param type an Iceberg type
* @param value a value that is an instance of {@link Type.TypeID#javaClass()}
* @return the value converted for Spark
*/
public static Object convertConstant(Type type, Object value) {
if (value == null) {
return null;
}

switch (type.typeId()) {
case DECIMAL:
return Decimal.apply((BigDecimal) value);
case STRING:
if (value instanceof Utf8) {
Utf8 utf8 = (Utf8) value;
return UTF8String.fromBytes(utf8.getBytes(), 0, utf8.getByteLength());
}
return UTF8String.fromString(value.toString());
case FIXED:
if (value instanceof byte[]) {
return value;
} else if (value instanceof GenericData.Fixed) {
return ((GenericData.Fixed) value).bytes();
}
return ByteBuffers.toByteArray((ByteBuffer) value);
case BINARY:
return ByteBuffers.toByteArray((ByteBuffer) value);
case STRUCT:
Types.StructType structType = (Types.StructType) type;

if (structType.fields().isEmpty()) {
return new GenericInternalRow();
}

List<Types.NestedField> fields = structType.fields();
Object[] values = new Object[fields.size()];
StructLike struct = (StructLike) value;

for (int index = 0; index < fields.size(); index++) {
Types.NestedField field = fields.get(index);
Type fieldType = field.type();
values[index] =
convertConstant(fieldType, struct.get(index, fieldType.typeId().javaClass()));
}

return new GenericInternalRow(values);
default:
}

return value;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
import org.apache.iceberg.relocated.com.google.common.collect.Lists;
import org.apache.iceberg.relocated.com.google.common.collect.Maps;
import org.apache.iceberg.spark.SparkUtil;
import org.apache.iceberg.types.Type.TypeID;
import org.apache.iceberg.types.Types;
import org.apache.iceberg.util.UUIDUtil;
Expand Down Expand Up @@ -165,6 +166,7 @@ public ParquetValueReader<?> struct(
int defaultMaxDefinitionLevel = type.getMaxDefinitionLevel(currentPath());
for (Types.NestedField field : expectedFields) {
int id = field.fieldId();
ParquetValueReader<?> reader = readersById.get(id);
if (idToConstant.containsKey(id)) {
// containsKey is used because the constant may be null
int fieldMaxDefinitionLevel =
Expand All @@ -178,15 +180,21 @@ public ParquetValueReader<?> struct(
} else if (id == MetadataColumns.IS_DELETED.fieldId()) {
reorderedFields.add(ParquetValueReaders.constant(false));
types.add(null);
} else if (reader != null) {
reorderedFields.add(reader);
types.add(typesById.get(id));
} else if (field.initialDefault() != null) {
reorderedFields.add(
ParquetValueReaders.constant(
SparkUtil.convertConstant(field.type(), field.initialDefault()),
maxDefinitionLevelsById.getOrDefault(id, defaultMaxDefinitionLevel)));
types.add(typesById.get(id));
} else if (field.isOptional()) {
reorderedFields.add(ParquetValueReaders.nulls());
types.add(null);
} else {
ParquetValueReader<?> reader = readersById.get(id);
if (reader != null) {
reorderedFields.add(reader);
types.add(typesById.get(id));
} else {
reorderedFields.add(ParquetValueReaders.nulls());
types.add(null);
}
throw new IllegalArgumentException(
String.format("Missing required field: %s", field.name()));
}
}

Expand Down Expand Up @@ -250,7 +258,7 @@ public ParquetValueReader<?> primitive(
if (expected != null && expected.typeId() == Types.LongType.get().typeId()) {
return new IntAsLongReader(desc);
} else {
return new UnboxedReader(desc);
return new UnboxedReader<>(desc);
}
case DATE:
case INT_64:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,13 @@

import java.io.Closeable;
import java.io.IOException;
import java.math.BigDecimal;
import java.nio.ByteBuffer;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.avro.generic.GenericData;
import org.apache.avro.util.Utf8;
import org.apache.iceberg.ContentFile;
import org.apache.iceberg.ContentScanTask;
import org.apache.iceberg.DeleteFile;
Expand All @@ -53,16 +49,11 @@
import org.apache.iceberg.mapping.NameMappingParser;
import org.apache.iceberg.spark.SparkExecutorCache;
import org.apache.iceberg.spark.SparkSchemaUtil;
import org.apache.iceberg.types.Type;
import org.apache.iceberg.types.Types.NestedField;
import org.apache.iceberg.spark.SparkUtil;
import org.apache.iceberg.types.Types.StructType;
import org.apache.iceberg.util.ByteBuffers;
import org.apache.iceberg.util.PartitionUtil;
import org.apache.spark.rdd.InputFileBlockHolder;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow;
import org.apache.spark.sql.types.Decimal;
import org.apache.spark.unsafe.types.UTF8String;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -193,59 +184,12 @@ private Map<String, InputFile> inputFiles() {
protected Map<Integer, ?> constantsMap(ContentScanTask<?> task, Schema readSchema) {
if (readSchema.findField(MetadataColumns.PARTITION_COLUMN_ID) != null) {
StructType partitionType = Partitioning.partitionType(table);
return PartitionUtil.constantsMap(task, partitionType, BaseReader::convertConstant);
return PartitionUtil.constantsMap(task, partitionType, SparkUtil::convertConstant);
} else {
return PartitionUtil.constantsMap(task, BaseReader::convertConstant);
return PartitionUtil.constantsMap(task, SparkUtil::convertConstant);
}
}

protected static Object convertConstant(Type type, Object value) {
if (value == null) {
return null;
}

switch (type.typeId()) {
case DECIMAL:
return Decimal.apply((BigDecimal) value);
case STRING:
if (value instanceof Utf8) {
Utf8 utf8 = (Utf8) value;
return UTF8String.fromBytes(utf8.getBytes(), 0, utf8.getByteLength());
}
return UTF8String.fromString(value.toString());
case FIXED:
if (value instanceof byte[]) {
return value;
} else if (value instanceof GenericData.Fixed) {
return ((GenericData.Fixed) value).bytes();
}
return ByteBuffers.toByteArray((ByteBuffer) value);
case BINARY:
return ByteBuffers.toByteArray((ByteBuffer) value);
case STRUCT:
StructType structType = (StructType) type;

if (structType.fields().isEmpty()) {
return new GenericInternalRow();
}

List<NestedField> fields = structType.fields();
Object[] values = new Object[fields.size()];
StructLike struct = (StructLike) value;

for (int index = 0; index < fields.size(); index++) {
NestedField field = fields.get(index);
Type fieldType = field.type();
values[index] =
convertConstant(fieldType, struct.get(index, fieldType.typeId().javaClass()));
}

return new GenericInternalRow(values);
default:
}
return value;
}

protected class SparkDeleteFilter extends DeleteFilter<InternalRow> {
private final InternalRowWrapper asStructLike;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import java.util.stream.Collectors;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;
import org.apache.avro.Schema.Field;
import org.apache.avro.generic.GenericData;
import org.apache.avro.generic.GenericData.Record;
import org.apache.iceberg.DataFile;
Expand Down Expand Up @@ -246,7 +247,7 @@ private static void assertEqualsSafe(Type type, Object expected, Object actual)
assertThat(expected).as("Should expect a Collection").isInstanceOf(Collection.class);
assertThat(actual).as("Should be a Seq").isInstanceOf(Seq.class);
List<?> asList = seqAsJavaListConverter((Seq<?>) actual).asJava();
assertEqualsSafe(type.asNestedType().asListType(), (Collection) expected, asList);
assertEqualsSafe(type.asNestedType().asListType(), (Collection<?>) expected, asList);
break;
case MAP:
assertThat(expected).as("Should expect a Collection").isInstanceOf(Map.class);
Expand All @@ -263,11 +264,20 @@ private static void assertEqualsSafe(Type type, Object expected, Object actual)

public static void assertEqualsUnsafe(Types.StructType struct, Record rec, InternalRow row) {
List<Types.NestedField> fields = struct.fields();
for (int i = 0; i < fields.size(); i += 1) {
Type fieldType = fields.get(i).type();

Object expectedValue = rec.get(i);
Object actualValue = row.isNullAt(i) ? null : row.get(i, convert(fieldType));
for (int readPos = 0; readPos < fields.size(); readPos += 1) {
Types.NestedField field = fields.get(readPos);
Field writeField = rec.getSchema().getField(field.name());

Type fieldType = field.type();
Object actualValue = row.isNullAt(readPos) ? null : row.get(readPos, convert(fieldType));

Object expectedValue;
if (writeField != null) {
int writePos = writeField.pos();
expectedValue = rec.get(writePos);
} else {
expectedValue = field.initialDefault();
}

assertEqualsUnsafe(fieldType, expectedValue, actualValue);
}
Expand Down
Loading

0 comments on commit 7e1a4c9

Please sign in to comment.