Skip to content

Commit

Permalink
Flink: Backport #11662 Fix range distribution npe when value is null …
Browse files Browse the repository at this point in the history
…to Flink 1.18 and 1.19 (#11745)
  • Loading branch information
Guosmilesmile authored Jan 3, 2025
1 parent 3b00043 commit 4d35682
Show file tree
Hide file tree
Showing 22 changed files with 651 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -108,4 +108,21 @@ boolean isEmpty() {
return keyFrequency().isEmpty();
}
}

boolean isValid() {
if (type == StatisticsType.Sketch) {
if (null == keySamples) {
return false;
}
} else {
if (null == keyFrequency()) {
return false;
}
if (keyFrequency().values().contains(null)) {
return false;
}
}

return true;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,18 @@ class CompletedStatisticsSerializer extends TypeSerializer<CompletedStatistics>
this.keySamplesSerializer = new ListSerializer<>(sortKeySerializer);
}

public void changeSortKeySerializerVersion(int version) {
if (sortKeySerializer instanceof SortKeySerializer) {
((SortKeySerializer) sortKeySerializer).setVersion(version);
}
}

public void changeSortKeySerializerVersionLatest() {
if (sortKeySerializer instanceof SortKeySerializer) {
((SortKeySerializer) sortKeySerializer).restoreToLatestVersion();
}
}

@Override
public boolean isImmutableType() {
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,8 @@ public void resetToCheckpoint(long checkpointId, byte[] checkpointData) {
"Restoring data statistic coordinator {} from checkpoint {}", operatorName, checkpointId);
this.completedStatistics =
StatisticsUtil.deserializeCompletedStatistics(
checkpointData, completedStatisticsSerializer);
checkpointData, (CompletedStatisticsSerializer) completedStatisticsSerializer);

// recompute global statistics in case downstream parallelism changed
this.globalStatistics =
globalStatistics(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,12 @@ class SortKeySerializer extends TypeSerializer<SortKey> {
private final int size;
private final Types.NestedField[] transformedFields;

private int version;

private transient SortKey sortKey;

SortKeySerializer(Schema schema, SortOrder sortOrder) {
SortKeySerializer(Schema schema, SortOrder sortOrder, int version) {
this.version = version;
this.schema = schema;
this.sortOrder = sortOrder;
this.size = sortOrder.fields().size();
Expand All @@ -76,6 +79,10 @@ class SortKeySerializer extends TypeSerializer<SortKey> {
}
}

SortKeySerializer(Schema schema, SortOrder sortOrder) {
this(schema, sortOrder, SortKeySerializerSnapshot.CURRENT_VERSION);
}

private SortKey lazySortKey() {
if (sortKey == null) {
this.sortKey = new SortKey(schema, sortOrder);
Expand All @@ -84,6 +91,18 @@ private SortKey lazySortKey() {
return sortKey;
}

public int getLatestVersion() {
return snapshotConfiguration().getCurrentVersion();
}

public void restoreToLatestVersion() {
this.version = snapshotConfiguration().getCurrentVersion();
}

public void setVersion(int version) {
this.version = version;
}

@Override
public boolean isImmutableType() {
return false;
Expand Down Expand Up @@ -125,6 +144,16 @@ public void serialize(SortKey record, DataOutputView target) throws IOException
for (int i = 0; i < size; ++i) {
int fieldId = transformedFields[i].fieldId();
Type.TypeID typeId = transformedFields[i].type().typeId();
if (version > 1) {
Object value = record.get(i, Object.class);
if (value == null) {
target.writeBoolean(true);
continue;
} else {
target.writeBoolean(false);
}
}

switch (typeId) {
case BOOLEAN:
target.writeBoolean(record.get(i, Boolean.class));
Expand Down Expand Up @@ -193,6 +222,14 @@ public SortKey deserialize(SortKey reuse, DataInputView source) throws IOExcepti
reuse.size(),
size);
for (int i = 0; i < size; ++i) {
if (version > 1) {
boolean isNull = source.readBoolean();
if (isNull) {
reuse.set(i, null);
continue;
}
}

int fieldId = transformedFields[i].fieldId();
Type.TypeID typeId = transformedFields[i].type().typeId();
switch (typeId) {
Expand Down Expand Up @@ -277,11 +314,13 @@ public TypeSerializerSnapshot<SortKey> snapshotConfiguration() {
}

public static class SortKeySerializerSnapshot implements TypeSerializerSnapshot<SortKey> {
private static final int CURRENT_VERSION = 1;
private static final int CURRENT_VERSION = 2;

private Schema schema;
private SortOrder sortOrder;

private int version = CURRENT_VERSION;

/** Constructor for read instantiation. */
@SuppressWarnings({"unused", "checkstyle:RedundantModifier"})
public SortKeySerializerSnapshot() {
Expand Down Expand Up @@ -311,10 +350,16 @@ public void writeSnapshot(DataOutputView out) throws IOException {
@Override
public void readSnapshot(int readVersion, DataInputView in, ClassLoader userCodeClassLoader)
throws IOException {
if (readVersion == 1) {
readV1(in);
} else {
throw new IllegalArgumentException("Unknown read version: " + readVersion);
switch (readVersion) {
case 1:
read(in);
this.version = 1;
break;
case 2:
read(in);
break;
default:
throw new IllegalArgumentException("Unknown read version: " + readVersion);
}
}

Expand All @@ -325,9 +370,13 @@ public TypeSerializerSchemaCompatibility<SortKey> resolveSchemaCompatibility(
return TypeSerializerSchemaCompatibility.incompatible();
}

// Sort order should be identical
SortKeySerializerSnapshot newSnapshot =
(SortKeySerializerSnapshot) newSerializer.snapshotConfiguration();
if (newSnapshot.getCurrentVersion() == 1 && this.getCurrentVersion() == 2) {
return TypeSerializerSchemaCompatibility.compatibleAfterMigration();
}

// Sort order should be identical
if (!sortOrder.sameOrder(newSnapshot.sortOrder)) {
return TypeSerializerSchemaCompatibility.incompatible();
}
Expand All @@ -351,10 +400,10 @@ public TypeSerializerSchemaCompatibility<SortKey> resolveSchemaCompatibility(
public TypeSerializer<SortKey> restoreSerializer() {
Preconditions.checkState(schema != null, "Invalid schema: null");
Preconditions.checkState(sortOrder != null, "Invalid sort order: null");
return new SortKeySerializer(schema, sortOrder);
return new SortKeySerializer(schema, sortOrder, version);
}

private void readV1(DataInputView in) throws IOException {
private void read(DataInputView in) throws IOException {
String schemaJson = StringUtils.readString(in);
String sortOrderJson = StringUtils.readString(in);
this.schema = SchemaParser.fromJson(schemaJson);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,29 @@ static byte[] serializeCompletedStatistics(
}

static CompletedStatistics deserializeCompletedStatistics(
byte[] bytes, TypeSerializer<CompletedStatistics> statisticsSerializer) {
byte[] bytes, CompletedStatisticsSerializer statisticsSerializer) {
try {
DataInputDeserializer input = new DataInputDeserializer(bytes);
return statisticsSerializer.deserialize(input);
} catch (IOException e) {
throw new UncheckedIOException("Fail to deserialize aggregated statistics", e);
CompletedStatistics completedStatistics = statisticsSerializer.deserialize(input);
if (!completedStatistics.isValid()) {
throw new RuntimeException("Fail to deserialize aggregated statistics,change to v1");
}

return completedStatistics;
} catch (Exception e) {
try {
// If we restore from a lower version, the new version of SortKeySerializer cannot correctly
// parse the checkpointData, so we need to first switch the version to v1. Once the state
// data is successfully parsed, we need to switch the serialization version to the latest
// version to parse the subsequent data passed from the TM.
statisticsSerializer.changeSortKeySerializerVersion(1);
DataInputDeserializer input = new DataInputDeserializer(bytes);
CompletedStatistics deserialize = statisticsSerializer.deserialize(input);
statisticsSerializer.changeSortKeySerializerVersionLatest();
return deserialize;
} catch (IOException ioException) {
throw new UncheckedIOException("Fail to deserialize aggregated statistics", ioException);
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import org.apache.iceberg.flink.TestFixtures;
import org.apache.iceberg.flink.sink.shuffle.StatisticsType;
import org.apache.iceberg.flink.source.BoundedTestSource;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
import org.apache.iceberg.relocated.com.google.common.collect.Lists;
import org.apache.iceberg.relocated.com.google.common.collect.Maps;
Expand Down Expand Up @@ -252,6 +253,44 @@ public void testRangeDistributionWithoutSortOrderPartitioned() throws Exception
assertThat(snapshots).hasSizeGreaterThanOrEqualTo(numOfCheckpoints);
}

@TestTemplate
public void testRangeDistributionWithNullValue() throws Exception {
assumeThat(partitioned).isTrue();

table
.updateProperties()
.set(TableProperties.WRITE_DISTRIBUTION_MODE, DistributionMode.RANGE.modeName())
.commit();

int numOfCheckpoints = 6;
List<List<Row>> charRows = createCharRows(numOfCheckpoints, 10);
charRows.add(ImmutableList.of(Row.of(1, null)));
DataStream<Row> dataStream =
env.addSource(createRangeDistributionBoundedSource(charRows), ROW_TYPE_INFO);
FlinkSink.Builder builder =
FlinkSink.forRow(dataStream, SimpleDataUtil.FLINK_SCHEMA)
.table(table)
.tableLoader(tableLoader)
.writeParallelism(parallelism);

// sort based on partition columns
builder.append();
env.execute(getClass().getSimpleName());

table.refresh();
// ordered in reverse timeline from the newest snapshot to the oldest snapshot
List<Snapshot> snapshots = Lists.newArrayList(table.snapshots().iterator());
// only keep the snapshots with added data files
snapshots =
snapshots.stream()
.filter(snapshot -> snapshot.addedDataFiles(table.io()).iterator().hasNext())
.collect(Collectors.toList());

// Sometimes we will have more checkpoints than the bounded source if we pass the
// auto checkpoint interval. Thus producing multiple snapshots.
assertThat(snapshots).hasSizeGreaterThanOrEqualTo(numOfCheckpoints);
}

@TestTemplate
public void testRangeDistributionWithSortOrder() throws Exception {
table
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,15 @@
package org.apache.iceberg.flink.sink.shuffle;

import static org.apache.iceberg.flink.sink.shuffle.Fixtures.CHAR_KEYS;
import static org.assertj.core.api.AssertionsForClassTypes.assertThat;

import org.apache.flink.api.common.typeutils.SerializerTestBase;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.core.memory.DataInputDeserializer;
import org.apache.flink.core.memory.DataOutputSerializer;
import org.apache.iceberg.SortKey;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
import org.junit.jupiter.api.Test;

public class TestCompletedStatisticsSerializer extends SerializerTestBase<CompletedStatistics> {

Expand Down Expand Up @@ -51,4 +55,49 @@ protected CompletedStatistics[] getTestData() {
CompletedStatistics.fromKeySamples(2L, new SortKey[] {CHAR_KEYS.get("a"), CHAR_KEYS.get("b")})
};
}

@Test
public void testSerializer() throws Exception {
TypeSerializer<CompletedStatistics> completedStatisticsTypeSerializer = createSerializer();
CompletedStatistics[] data = getTestData();
DataOutputSerializer output = new DataOutputSerializer(1024);
completedStatisticsTypeSerializer.serialize(data[0], output);
byte[] serializedBytes = output.getCopyOfBuffer();

DataInputDeserializer input = new DataInputDeserializer(serializedBytes);
CompletedStatistics deserialized = completedStatisticsTypeSerializer.deserialize(input);
assertThat(deserialized).isEqualTo(data[0]);
}

@Test
public void testRestoreOldVersionSerializer() throws Exception {
CompletedStatisticsSerializer completedStatisticsTypeSerializer =
(CompletedStatisticsSerializer) createSerializer();
completedStatisticsTypeSerializer.changeSortKeySerializerVersion(1);
CompletedStatistics[] data = getTestData();
DataOutputSerializer output = new DataOutputSerializer(1024);
completedStatisticsTypeSerializer.serialize(data[0], output);
byte[] serializedBytes = output.getCopyOfBuffer();

completedStatisticsTypeSerializer.changeSortKeySerializerVersionLatest();
CompletedStatistics completedStatistics =
StatisticsUtil.deserializeCompletedStatistics(
serializedBytes, completedStatisticsTypeSerializer);
assertThat(completedStatistics).isEqualTo(data[0]);
}

@Test
public void testRestoreNewSerializer() throws Exception {
CompletedStatisticsSerializer completedStatisticsTypeSerializer =
(CompletedStatisticsSerializer) createSerializer();
CompletedStatistics[] data = getTestData();
DataOutputSerializer output = new DataOutputSerializer(1024);
completedStatisticsTypeSerializer.serialize(data[0], output);
byte[] serializedBytes = output.getCopyOfBuffer();

CompletedStatistics completedStatistics =
StatisticsUtil.deserializeCompletedStatistics(
serializedBytes, completedStatisticsTypeSerializer);
assertThat(completedStatistics).isEqualTo(data[0]);
}
}
Loading

0 comments on commit 4d35682

Please sign in to comment.