Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Spark 3.5: Support default values in vectorized reads #11815

Merged
merged 3 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import java.util.List;
import java.util.Map;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.stream.IntStream;
import org.apache.arrow.memory.BufferAllocator;
Expand Down Expand Up @@ -47,13 +48,30 @@ public class VectorizedReaderBuilder extends TypeWithSchemaVisitor<VectorizedRea
private final Map<Integer, ?> idToConstant;
private final boolean setArrowValidityVector;
private final Function<List<VectorizedReader<?>>, VectorizedReader<?>> readerFactory;
private final BiFunction<org.apache.iceberg.types.Type, Object, Object> convert;

public VectorizedReaderBuilder(
Schema expectedSchema,
MessageType parquetSchema,
boolean setArrowValidityVector,
Map<Integer, ?> idToConstant,
Function<List<VectorizedReader<?>>, VectorizedReader<?>> readerFactory) {
this(
expectedSchema,
parquetSchema,
setArrowValidityVector,
idToConstant,
readerFactory,
(type, value) -> value);
}

protected VectorizedReaderBuilder(
Schema expectedSchema,
MessageType parquetSchema,
boolean setArrowValidityVector,
Map<Integer, ?> idToConstant,
Function<List<VectorizedReader<?>>, VectorizedReader<?>> readerFactory,
BiFunction<org.apache.iceberg.types.Type, Object, Object> convert) {
this.parquetSchema = parquetSchema;
this.icebergSchema = expectedSchema;
this.rootAllocator =
Expand All @@ -62,6 +80,7 @@ public VectorizedReaderBuilder(
this.setArrowValidityVector = setArrowValidityVector;
this.idToConstant = idToConstant;
this.readerFactory = readerFactory;
this.convert = convert;
}

@Override
Expand All @@ -85,7 +104,7 @@ public VectorizedReader<?> message(
int id = field.fieldId();
VectorizedReader<?> reader = readersById.get(id);
if (idToConstant.containsKey(id)) {
reorderedFields.add(new ConstantVectorReader<>(field, idToConstant.get(id)));
reorderedFields.add(constantReader(field, idToConstant.get(id)));
} else if (id == MetadataColumns.ROW_POSITION.fieldId()) {
if (setArrowValidityVector) {
reorderedFields.add(VectorizedArrowReader.positionsWithSetArrowValidityVector());
Expand All @@ -96,13 +115,23 @@ public VectorizedReader<?> message(
reorderedFields.add(new DeletedVectorReader());
} else if (reader != null) {
reorderedFields.add(reader);
} else {
} else if (field.initialDefault() != null) {
reorderedFields.add(
constantReader(field, convert.apply(field.type(), field.initialDefault())));
} else if (field.isOptional()) {
reorderedFields.add(VectorizedArrowReader.nulls());
} else {
throw new IllegalArgumentException(
String.format("Missing required field: %s", field.name()));
}
}
return vectorizedReader(reorderedFields);
}

private <T> ConstantVectorReader<T> constantReader(Types.NestedField field, T constant) {
return new ConstantVectorReader<>(field, constant);
}

protected VectorizedReader<?> vectorizedReader(List<VectorizedReader<?>> reorderedFields) {
return readerFactory.apply(reorderedFields);
}
Expand All @@ -120,21 +149,24 @@ public VectorizedReader<?> struct(
@Override
public VectorizedReader<?> primitive(
org.apache.iceberg.types.Type.PrimitiveType expected, PrimitiveType primitive) {

rdblue marked this conversation as resolved.
Show resolved Hide resolved
// Create arrow vector for this field
if (primitive.getId() == null) {
return null;
}

int parquetFieldId = primitive.getId().intValue();

ColumnDescriptor desc = parquetSchema.getColumnDescription(currentPath());
// Nested types not yet supported for vectorized reads
if (desc.getMaxRepetitionLevel() > 0) {
return null;
}

Types.NestedField icebergField = icebergSchema.findField(parquetFieldId);
if (icebergField == null) {
return null;
}

// Set the validity buffer if null checking is enabled in arrow
return new VectorizedArrowReader(desc, icebergField, rootAllocator, setArrowValidityVector);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.apache.iceberg.data.DeleteFilter;
import org.apache.iceberg.parquet.TypeWithSchemaVisitor;
import org.apache.iceberg.parquet.VectorizedReader;
import org.apache.iceberg.spark.SparkUtil;
import org.apache.parquet.schema.MessageType;
import org.apache.spark.sql.catalyst.InternalRow;
import org.slf4j.Logger;
Expand Down Expand Up @@ -112,7 +113,13 @@ private static class ReaderBuilder extends VectorizedReaderBuilder {
Map<Integer, ?> idToConstant,
Function<List<VectorizedReader<?>>, VectorizedReader<?>> readerFactory,
DeleteFilter<InternalRow> deleteFilter) {
super(expectedSchema, parquetSchema, setArrowValidityVector, idToConstant, readerFactory);
super(
expectedSchema,
parquetSchema,
setArrowValidityVector,
idToConstant,
readerFactory,
SparkUtil::internalToSpark);
this.deleteFilter = deleteFilter;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ protected boolean supportsDefaultValues() {
return false;
}

protected boolean supportsNestedTypes() {
return true;
}

protected static final StructType SUPPORTED_PRIMITIVES =
StructType.of(
required(100, "id", LongType.get()),
Expand All @@ -74,6 +78,7 @@ protected boolean supportsDefaultValues() {
required(106, "d", Types.DoubleType.get()),
optional(107, "date", Types.DateType.get()),
required(108, "ts", Types.TimestampType.withZone()),
required(109, "ts_without_zone", Types.TimestampType.withoutZone()),
required(110, "s", Types.StringType.get()),
required(111, "uuid", Types.UUIDType.get()),
required(112, "fixed", Types.FixedType.ofLength(7)),
Expand Down Expand Up @@ -109,12 +114,16 @@ public void testStructWithOptionalFields() throws IOException {

@Test
public void testNestedStruct() throws IOException {
Assumptions.assumeThat(supportsNestedTypes()).isTrue();

writeAndValidate(
TypeUtil.assignIncreasingFreshIds(new Schema(required(1, "struct", SUPPORTED_PRIMITIVES))));
}

@Test
public void testArray() throws IOException {
Assumptions.assumeThat(supportsNestedTypes()).isTrue();

Schema schema =
new Schema(
required(0, "id", LongType.get()),
Expand All @@ -125,6 +134,8 @@ public void testArray() throws IOException {

@Test
public void testArrayOfStructs() throws IOException {
Assumptions.assumeThat(supportsNestedTypes()).isTrue();

Schema schema =
TypeUtil.assignIncreasingFreshIds(
new Schema(
Expand All @@ -136,6 +147,8 @@ public void testArrayOfStructs() throws IOException {

@Test
public void testMap() throws IOException {
Assumptions.assumeThat(supportsNestedTypes()).isTrue();

Schema schema =
new Schema(
required(0, "id", LongType.get()),
Expand All @@ -149,6 +162,8 @@ public void testMap() throws IOException {

@Test
public void testNumericMapKey() throws IOException {
Assumptions.assumeThat(supportsNestedTypes()).isTrue();

Schema schema =
new Schema(
required(0, "id", LongType.get()),
Expand All @@ -160,6 +175,8 @@ public void testNumericMapKey() throws IOException {

@Test
public void testComplexMapKey() throws IOException {
Assumptions.assumeThat(supportsNestedTypes()).isTrue();

Schema schema =
new Schema(
required(0, "id", LongType.get()),
Expand All @@ -179,6 +196,8 @@ public void testComplexMapKey() throws IOException {

@Test
public void testMapOfStructs() throws IOException {
Assumptions.assumeThat(supportsNestedTypes()).isTrue();

Schema schema =
TypeUtil.assignIncreasingFreshIds(
new Schema(
Expand All @@ -193,6 +212,8 @@ public void testMapOfStructs() throws IOException {

@Test
public void testMixedTypes() throws IOException {
Assumptions.assumeThat(supportsNestedTypes()).isTrue();

StructType structType =
StructType.of(
required(0, "id", LongType.get()),
Expand Down Expand Up @@ -248,17 +269,6 @@ public void testMixedTypes() throws IOException {
writeAndValidate(schema);
}

@Test
public void testTimestampWithoutZone() throws IOException {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing this test for TimestampNTZ by adding the type to SUPPORTED_PRIMITIVES (so that it is handled like any other primitive) is what broke the ORC tests. It looks like the problem is that Spark 3.5's ColumnarRow doesn't support TimestampNTZType. As a temporary work-around, I've added validation code that checks the value by accessing it as a TimestampType instead.

This isn't a change to read behavior, just how we access the data to validate it. I expect to be able to remove this workaround in the next Spark version.

Copy link
Contributor

@nastra nastra Dec 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I noticed that too and was planning on fixing that in Spark. I've opened https://issues.apache.org/jira/browse/SPARK-50624

Schema schema =
TypeUtil.assignIncreasingFreshIds(
new Schema(
required(0, "id", LongType.get()),
optional(1, "ts_without_zone", Types.TimestampType.withoutZone())));

writeAndValidate(schema);
}

@Test
public void testMissingRequiredWithoutDefault() {
Assumptions.assumeThat(supportsDefaultValues()).isTrue();
Expand Down Expand Up @@ -348,6 +358,7 @@ public void testNullDefaultValue() throws IOException {
@Test
public void testNestedDefaultValue() throws IOException {
Assumptions.assumeThat(supportsDefaultValues()).isTrue();
Assumptions.assumeThat(supportsNestedTypes()).isTrue();

Schema writeSchema =
new Schema(
Expand Down Expand Up @@ -391,6 +402,7 @@ public void testNestedDefaultValue() throws IOException {
@Test
public void testMapNestedDefaultValue() throws IOException {
Assumptions.assumeThat(supportsDefaultValues()).isTrue();
Assumptions.assumeThat(supportsNestedTypes()).isTrue();

Schema writeSchema =
new Schema(
Expand Down Expand Up @@ -443,6 +455,7 @@ public void testMapNestedDefaultValue() throws IOException {
@Test
public void testListNestedDefaultValue() throws IOException {
Assumptions.assumeThat(supportsDefaultValues()).isTrue();
Assumptions.assumeThat(supportsNestedTypes()).isTrue();

Schema writeSchema =
new Schema(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,25 @@ public static void assertEqualsSafe(Types.StructType struct, Record rec, Row row
public static void assertEqualsBatch(
Types.StructType struct, Iterator<Record> expected, ColumnarBatch batch) {
for (int rowId = 0; rowId < batch.numRows(); rowId++) {
List<Types.NestedField> fields = struct.fields();
InternalRow row = batch.getRow(rowId);
Record rec = expected.next();
for (int i = 0; i < fields.size(); i += 1) {
Type fieldType = fields.get(i).type();
Object expectedValue = rec.get(i);
Object actualValue = row.isNullAt(i) ? null : row.get(i, convert(fieldType));

List<Types.NestedField> fields = struct.fields();
for (int readPos = 0; readPos < fields.size(); readPos += 1) {
Types.NestedField field = fields.get(readPos);
Field writeField = rec.getSchema().getField(field.name());

Type fieldType = field.type();
Object actualValue = row.isNullAt(readPos) ? null : row.get(readPos, convert(fieldType));

Object expectedValue;
if (writeField != null) {
int writePos = writeField.pos();
expectedValue = rec.get(writePos);
} else {
expectedValue = field.initialDefault();
}

assertEqualsUnsafe(fieldType, expectedValue, actualValue);
}
}
Expand Down
Loading
Loading