Skip to content

Commit

Permalink
Spark: Test reading default values in Spark (#11832)
Browse files Browse the repository at this point in the history
  • Loading branch information
rdblue authored Dec 21, 2024
1 parent dea2fd1 commit cd187c5
Show file tree
Hide file tree
Showing 18 changed files with 644 additions and 938 deletions.
6 changes: 1 addition & 5 deletions api/src/main/java/org/apache/iceberg/types/ReassignIds.java
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,7 @@ public Type struct(Types.StructType struct, Iterable<Type> fieldTypes) {
for (int i = 0; i < length; i += 1) {
Types.NestedField field = fields.get(i);
int fieldId = id(sourceStruct, field.name());
if (field.isRequired()) {
newFields.add(Types.NestedField.required(fieldId, field.name(), types.get(i), field.doc()));
} else {
newFields.add(Types.NestedField.optional(fieldId, field.name(), types.get(i), field.doc()));
}
newFields.add(Types.NestedField.from(field).withId(fieldId).ofType(types.get(i)).build());
}

return Types.StructType.of(newFields);
Expand Down
8 changes: 7 additions & 1 deletion api/src/main/java/org/apache/iceberg/types/Types.java
Original file line number Diff line number Diff line change
Expand Up @@ -711,8 +711,14 @@ public boolean equals(Object o) {
return false;
} else if (!Objects.equals(doc, that.doc)) {
return false;
} else if (!type.equals(that.type)) {
return false;
} else if (!Objects.equals(initialDefault, that.initialDefault)) {
return false;
} else if (!Objects.equals(writeDefault, that.writeDefault)) {
return false;
}
return type.equals(that.type);
return true;
}

@Override
Expand Down
20 changes: 16 additions & 4 deletions api/src/test/java/org/apache/iceberg/types/TestTypeUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -595,21 +595,33 @@ public void testReassignOrRefreshIds() {
new Schema(
Lists.newArrayList(
required(10, "a", Types.IntegerType.get()),
required(11, "c", Types.IntegerType.get()),
Types.NestedField.required("c")
.withId(11)
.ofType(Types.IntegerType.get())
.withInitialDefault(23)
.withWriteDefault(34)
.build(),
required(12, "B", Types.IntegerType.get())),
Sets.newHashSet(10));
Schema sourceSchema =
new Schema(
Lists.newArrayList(
required(1, "a", Types.IntegerType.get()),
required(15, "B", Types.IntegerType.get())));
final Schema actualSchema = TypeUtil.reassignOrRefreshIds(schema, sourceSchema);
final Schema expectedSchema =

Schema actualSchema = TypeUtil.reassignOrRefreshIds(schema, sourceSchema);
Schema expectedSchema =
new Schema(
Lists.newArrayList(
required(1, "a", Types.IntegerType.get()),
required(16, "c", Types.IntegerType.get()),
Types.NestedField.required("c")
.withId(16)
.ofType(Types.IntegerType.get())
.withInitialDefault(23)
.withWriteDefault(34)
.build(),
required(15, "B", Types.IntegerType.get())));

assertThat(actualSchema.asStruct()).isEqualTo(expectedSchema.asStruct());
}

Expand Down
45 changes: 40 additions & 5 deletions core/src/main/java/org/apache/iceberg/SchemaParser.java
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ private SchemaParser() {}
private static final String DOC = "doc";
private static final String NAME = "name";
private static final String ID = "id";
private static final String INITIAL_DEFAULT = "initial-default";
private static final String WRITE_DEFAULT = "write-default";
private static final String ELEMENT_ID = "element-id";
private static final String KEY_ID = "key-id";
private static final String VALUE_ID = "value-id";
Expand Down Expand Up @@ -88,6 +90,17 @@ private static void toJson(
if (field.doc() != null) {
generator.writeStringField(DOC, field.doc());
}

if (field.initialDefault() != null) {
generator.writeFieldName(INITIAL_DEFAULT);
SingleValueParser.toJson(field.type(), field.initialDefault(), generator);
}

if (field.writeDefault() != null) {
generator.writeFieldName(WRITE_DEFAULT);
SingleValueParser.toJson(field.type(), field.writeDefault(), generator);
}

generator.writeEndObject();
}
generator.writeEndArray();
Expand Down Expand Up @@ -184,6 +197,22 @@ private static Type typeFromJson(JsonNode json) {
throw new IllegalArgumentException("Cannot parse type from json: " + json);
}

private static Object defaultFromJson(String defaultField, Type type, JsonNode json) {
if (json.has(defaultField)) {
return SingleValueParser.fromJson(type, json.get(defaultField));
}

return null;
}

private static Types.NestedField.Builder fieldBuilder(boolean isRequired, String name) {
if (isRequired) {
return Types.NestedField.required(name);
} else {
return Types.NestedField.optional(name);
}
}

private static Types.StructType structFromJson(JsonNode json) {
JsonNode fieldArray = JsonUtil.get(FIELDS, json);
Preconditions.checkArgument(
Expand All @@ -200,13 +229,19 @@ private static Types.StructType structFromJson(JsonNode json) {
String name = JsonUtil.getString(NAME, field);
Type type = typeFromJson(JsonUtil.get(TYPE, field));

Object initialDefault = defaultFromJson(INITIAL_DEFAULT, type, field);
Object writeDefault = defaultFromJson(WRITE_DEFAULT, type, field);

String doc = JsonUtil.getStringOrNull(DOC, field);
boolean isRequired = JsonUtil.getBool(REQUIRED, field);
if (isRequired) {
fields.add(Types.NestedField.required(id, name, type, doc));
} else {
fields.add(Types.NestedField.optional(id, name, type, doc));
}
fields.add(
fieldBuilder(isRequired, name)
.withId(id)
.ofType(type)
.withDoc(doc)
.withInitialDefault(initialDefault)
.withWriteDefault(writeDefault)
.build());
}

return Types.StructType.of(fields);
Expand Down
126 changes: 126 additions & 0 deletions core/src/test/java/org/apache/iceberg/TestSchemaParser.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.iceberg;

import static org.apache.iceberg.types.Types.NestedField.optional;
import static org.apache.iceberg.types.Types.NestedField.required;
import static org.assertj.core.api.Assertions.assertThat;

import java.io.IOException;
import java.math.BigDecimal;
import java.nio.ByteBuffer;
import java.util.UUID;
import java.util.stream.Stream;
import org.apache.iceberg.avro.AvroDataTest;
import org.apache.iceberg.relocated.com.google.common.collect.Lists;
import org.apache.iceberg.relocated.com.google.common.collect.Sets;
import org.apache.iceberg.types.Type;
import org.apache.iceberg.types.Types;
import org.apache.iceberg.util.DateTimeUtil;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;

public class TestSchemaParser extends AvroDataTest {
@Override
protected void writeAndValidate(Schema schema) throws IOException {
Schema serialized = SchemaParser.fromJson(SchemaParser.toJson(schema));
assertThat(serialized.asStruct()).isEqualTo(schema.asStruct());
}

@Test
public void testSchemaId() {
Schema schema = new Schema(34, required(1, "id", Types.LongType.get()));

Schema serialized = SchemaParser.fromJson(SchemaParser.toJson(schema));
assertThat(serialized.schemaId()).isEqualTo(schema.schemaId());
}

@Test
public void testIdentifierColumns() {
Schema schema =
new Schema(
Lists.newArrayList(
required(1, "id-1", Types.LongType.get()),
required(2, "id-2", Types.LongType.get()),
optional(3, "data", Types.StringType.get())),
Sets.newHashSet(1, 2));

Schema serialized = SchemaParser.fromJson(SchemaParser.toJson(schema));
assertThat(serialized.identifierFieldIds()).isEqualTo(Sets.newHashSet(1, 2));
}

@Test
public void testDocStrings() {
Schema schema =
new Schema(
required(1, "id", Types.LongType.get(), "unique identifier"),
Types.NestedField.optional("data")
.withId(2)
.ofType(Types.StringType.get())
.withDoc("payload")
.build());

Schema serialized = SchemaParser.fromJson(SchemaParser.toJson(schema));
assertThat(serialized.findField("id").doc()).isEqualTo("unique identifier");
assertThat(serialized.findField("data").doc()).isEqualTo("payload");
}

private static Stream<Arguments> primitiveTypesAndDefaults() {
return Stream.of(
Arguments.of(Types.BooleanType.get(), false),
Arguments.of(Types.IntegerType.get(), 34),
Arguments.of(Types.LongType.get(), 4900000000L),
Arguments.of(Types.FloatType.get(), 12.21F),
Arguments.of(Types.DoubleType.get(), -0.0D),
Arguments.of(Types.DateType.get(), DateTimeUtil.isoDateToDays("2024-12-17")),
// Arguments.of(Types.TimeType.get(), DateTimeUtil.isoTimeToMicros("23:59:59.999999")),
Arguments.of(
Types.TimestampType.withZone(),
DateTimeUtil.isoTimestamptzToMicros("2024-12-17T23:59:59.999999+00:00")),
Arguments.of(
Types.TimestampType.withoutZone(),
DateTimeUtil.isoTimestampToMicros("2024-12-17T23:59:59.999999")),
Arguments.of(Types.StringType.get(), "iceberg"),
Arguments.of(Types.UUIDType.get(), UUID.randomUUID()),
Arguments.of(
Types.FixedType.ofLength(4), ByteBuffer.wrap(new byte[] {0x0a, 0x0b, 0x0c, 0x0d})),
Arguments.of(Types.BinaryType.get(), ByteBuffer.wrap(new byte[] {0x0a, 0x0b})),
Arguments.of(Types.DecimalType.of(9, 2), new BigDecimal("12.34")));
}

@ParameterizedTest
@MethodSource("primitiveTypesAndDefaults")
public void testPrimitiveTypeDefaultValues(Type.PrimitiveType type, Object defaultValue) {
Schema schema =
new Schema(
required(1, "id", Types.LongType.get()),
Types.NestedField.required("col_with_default")
.withId(2)
.ofType(type)
.withInitialDefault(defaultValue)
.withWriteDefault(defaultValue)
.build());

Schema serialized = SchemaParser.fromJson(SchemaParser.toJson(schema));
assertThat(serialized.findField("col_with_default").initialDefault()).isEqualTo(defaultValue);
assertThat(serialized.findField("col_with_default").writeDefault()).isEqualTo(defaultValue);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ private CloseableIterable<InternalRow> newAvroIterable(
.reuseContainers()
.project(projection)
.split(start, length)
.createReaderFunc(readSchema -> SparkPlannedAvroReader.create(projection, idToConstant))
.createResolvingReader(schema -> SparkPlannedAvroReader.create(schema, idToConstant))
.withNameMapping(nameMapping())
.build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,11 @@
import java.math.BigDecimal;
import java.nio.ByteBuffer;
import java.nio.file.Path;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Stream;
import org.apache.iceberg.Schema;
import org.apache.iceberg.relocated.com.google.common.collect.Lists;
import org.apache.iceberg.relocated.com.google.common.collect.Maps;
import org.apache.iceberg.types.Type;
import org.apache.iceberg.types.TypeUtil;
import org.apache.iceberg.types.Types;
Expand All @@ -42,8 +40,8 @@
import org.apache.iceberg.types.Types.MapType;
import org.apache.iceberg.types.Types.StructType;
import org.apache.iceberg.util.DateTimeUtil;
import org.apache.spark.sql.internal.SQLConf;
import org.assertj.core.api.Assumptions;
import org.assertj.core.api.Condition;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.junit.jupiter.params.ParameterizedTest;
Expand Down Expand Up @@ -285,8 +283,13 @@ public void testMissingRequiredWithoutDefault() {
.build());

assertThatThrownBy(() -> writeAndValidate(writeSchema, expectedSchema))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("Missing required field: missing_str");
.has(
new Condition<>(
t ->
IllegalArgumentException.class.isInstance(t)
|| IllegalArgumentException.class.isInstance(t.getCause()),
"Expecting a throwable or cause that is an instance of IllegalArgumentException"))
.hasMessageContaining("Missing required field: missing_str");
}

@Test
Expand Down Expand Up @@ -542,44 +545,4 @@ public void testPrimitiveTypeDefaultValues(Type.PrimitiveType type, Object defau

writeAndValidate(writeSchema, readSchema);
}

protected void withSQLConf(Map<String, String> conf, Action action) throws IOException {
SQLConf sqlConf = SQLConf.get();

Map<String, String> currentConfValues = Maps.newHashMap();
conf.keySet()
.forEach(
confKey -> {
if (sqlConf.contains(confKey)) {
String currentConfValue = sqlConf.getConfString(confKey);
currentConfValues.put(confKey, currentConfValue);
}
});

conf.forEach(
(confKey, confValue) -> {
if (SQLConf.isStaticConfigKey(confKey)) {
throw new RuntimeException("Cannot modify the value of a static config: " + confKey);
}
sqlConf.setConfString(confKey, confValue);
});

try {
action.invoke();
} finally {
conf.forEach(
(confKey, confValue) -> {
if (currentConfValues.containsKey(confKey)) {
sqlConf.setConfString(confKey, currentConfValues.get(confKey));
} else {
sqlConf.unsetConf(confKey);
}
});
}
}

@FunctionalInterface
protected interface Action {
void invoke() throws IOException;
}
}
Loading

0 comments on commit cd187c5

Please sign in to comment.