Skip to content

Commit

Permalink
revert back to table supplier
Browse files Browse the repository at this point in the history
  • Loading branch information
bryanck committed Sep 19, 2023
1 parent dcd6630 commit f417bfd
Show file tree
Hide file tree
Showing 8 changed files with 82 additions and 114 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,6 @@ static TableLoader fromHadoopTable(String location, Configuration hadoopConf) {
return new HadoopTableLoader(location, hadoopConf);
}

static TableLoader immutableFromTable(Table table) {
return new ImmutableTableLoader(table);
}

class HadoopTableLoader implements TableLoader {

private static final long serialVersionUID = 1L;
Expand Down Expand Up @@ -160,35 +156,4 @@ public String toString() {
.toString();
}
}

class ImmutableTableLoader implements TableLoader {

private final Table table;

private ImmutableTableLoader(Table table) {
this.table = table;
}

@Override
public void open() {}

@Override
public boolean isOpen() {
return true;
}

@Override
public Table loadTable() {
return table;
}

@Override
@SuppressWarnings({"checkstyle:NoClone", "checkstyle:SuperClone"})
public TableLoader clone() {
return new ImmutableTableLoader(table);
}

@Override
public void close() throws IOException {}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@
*/
package org.apache.iceberg.flink.sink;

import java.io.IOException;
import java.time.Duration;
import org.apache.flink.util.Preconditions;
import org.apache.iceberg.Table;
import org.apache.iceberg.flink.TableLoader;
import org.apache.iceberg.util.DateTimeUtil;
import org.apache.iceberg.util.SerializableSupplier;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -32,17 +32,17 @@
* table loader should be used carefully when used with writer tasks. It could result in heavy load
* on a catalog for jobs with many writers.
*/
class CachingTableLoader implements TableLoader {
class CachingTableSupplier implements SerializableSupplier<Table> {

private static final Logger LOG = LoggerFactory.getLogger(CachingTableLoader.class);
private static final Logger LOG = LoggerFactory.getLogger(CachingTableSupplier.class);

private final Table initialTable;
private final TableLoader tableLoader;
private final Duration tableRefreshInterval;
private long nextReloadTimeMs;
private transient Table table;

CachingTableLoader(Table initialTable, TableLoader tableLoader, Duration tableRefreshInterval) {
CachingTableSupplier(Table initialTable, TableLoader tableLoader, Duration tableRefreshInterval) {
Preconditions.checkArgument(initialTable != null, "initialTable cannot be null");
Preconditions.checkArgument(tableLoader != null, "tableLoader cannot be null");
Preconditions.checkArgument(
Expand All @@ -55,25 +55,20 @@ class CachingTableLoader implements TableLoader {
}

@Override
public void open() {
if (!tableLoader.isOpen()) {
tableLoader.open();
}
}

@Override
public boolean isOpen() {
return tableLoader.isOpen();
}

@Override
public Table loadTable() {
public Table get() {
if (table == null) {
this.table = initialTable;
}
return table;
}

public void refresh() {
if (System.currentTimeMillis() > nextReloadTimeMs) {
try {
if (!tableLoader.isOpen()) {
tableLoader.open();
}

this.table = tableLoader.loadTable();
nextReloadTimeMs = System.currentTimeMillis() + tableRefreshInterval.toMillis();

Expand All @@ -85,18 +80,5 @@ public Table loadTable() {
LOG.warn("An error occurred reloading table {}, table was not reloaded", table.name(), e);
}
}

return table;
}

@Override
@SuppressWarnings({"checkstyle:NoClone", "checkstyle:SuperClone"})
public TableLoader clone() {
return new CachingTableLoader(initialTable, tableLoader, tableRefreshInterval);
}

@Override
public void close() throws IOException {
tableLoader.close();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
import org.apache.iceberg.relocated.com.google.common.collect.Maps;
import org.apache.iceberg.relocated.com.google.common.collect.Sets;
import org.apache.iceberg.types.TypeUtil;
import org.apache.iceberg.util.SerializableSupplier;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -482,16 +483,17 @@ private SingleOutputStreamOperator<WriteResult> appendWriter(
}
}

TableLoader tableRefreshLoader;
Table serializableTable = SerializableTable.copyOf(table);
SerializableSupplier<Table> tableSupplier;
if (tableRefreshInterval != null) {
tableRefreshLoader = new CachingTableLoader(table, tableLoader, tableRefreshInterval);
tableSupplier =
new CachingTableSupplier(serializableTable, tableLoader, tableRefreshInterval);
} else {
tableRefreshLoader = null;
tableSupplier = () -> serializableTable;
}

IcebergStreamWriter<RowData> streamWriter =
createStreamWriter(
table, tableRefreshLoader, flinkWriteConf, flinkRowType, equalityFieldIds);
createStreamWriter(tableSupplier, flinkWriteConf, flinkRowType, equalityFieldIds);

int parallelism =
flinkWriteConf.writeParallelism() == null
Expand Down Expand Up @@ -608,30 +610,25 @@ static RowType toFlinkRowType(Schema schema, TableSchema requestedSchema) {
}

static IcebergStreamWriter<RowData> createStreamWriter(
Table table,
TableLoader writeTableLoader,
SerializableSupplier<Table> tableSupplier,
FlinkWriteConf flinkWriteConf,
RowType flinkRowType,
List<Integer> equalityFieldIds) {
Preconditions.checkArgument(table != null, "Iceberg table shouldn't be null");
Preconditions.checkArgument(tableSupplier != null, "Iceberg table supplier shouldn't be null");

Table serializableTable = SerializableTable.copyOf(table);
TableLoader tableLoader =
writeTableLoader != null
? writeTableLoader
: TableLoader.immutableFromTable(serializableTable);
Table initTable = tableSupplier.get();
FileFormat format = flinkWriteConf.dataFileFormat();
TaskWriterFactory<RowData> taskWriterFactory =
new RowDataTaskWriterFactory(
serializableTable,
tableSupplier,
flinkRowType,
flinkWriteConf.targetDataFileSize(),
format,
writeProperties(table, format, flinkWriteConf),
writeProperties(initTable, format, flinkWriteConf),
equalityFieldIds,
flinkWriteConf.upsertMode());

return new IcebergStreamWriter<>(table.name(), taskWriterFactory, tableLoader);
return new IcebergStreamWriter<>(initTable.name(), taskWriterFactory, tableSupplier);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@

import java.io.IOException;
import java.util.concurrent.TimeUnit;
import java.util.function.Supplier;
import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
import org.apache.flink.streaming.api.operators.BoundedOneInput;
import org.apache.flink.streaming.api.operators.ChainingStrategy;
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.iceberg.flink.TableLoader;
import org.apache.iceberg.Table;
import org.apache.iceberg.io.TaskWriter;
import org.apache.iceberg.io.WriteResult;
import org.apache.iceberg.relocated.com.google.common.base.MoreObjects;
Expand All @@ -37,18 +38,18 @@ class IcebergStreamWriter<T> extends AbstractStreamOperator<WriteResult>

private final String fullTableName;
private final TaskWriterFactory<T> taskWriterFactory;
private final TableLoader tableLoader;
private final Supplier<Table> tableSupplier;

private transient TaskWriter<T> writer;
private transient int subTaskId;
private transient int attemptId;
private transient IcebergStreamWriterMetrics writerMetrics;

IcebergStreamWriter(
String fullTableName, TaskWriterFactory<T> taskWriterFactory, TableLoader tableLoader) {
String fullTableName, TaskWriterFactory<T> taskWriterFactory, Supplier<Table> tableSupplier) {
this.fullTableName = fullTableName;
this.taskWriterFactory = taskWriterFactory;
this.tableLoader = tableLoader;
this.tableSupplier = tableSupplier;
setChainingStrategy(ChainingStrategy.ALWAYS);
}

Expand All @@ -58,13 +59,13 @@ public void open() {
this.attemptId = getRuntimeContext().getAttemptNumber();
this.writerMetrics = new IcebergStreamWriterMetrics(super.metrics, fullTableName);

// Initialize the task writer factory.
// Initialize the task writer factory before refreshing the table so that the initial
// schema and partition spec are used.
this.taskWriterFactory.initialize(subTaskId, attemptId);

// Refresh the table if needed.
this.tableLoader.open();
if (this.taskWriterFactory instanceof RowDataTaskWriterFactory) {
((RowDataTaskWriterFactory) this.taskWriterFactory).setTable(tableLoader.loadTable());
if (tableSupplier instanceof CachingTableSupplier) {
((CachingTableSupplier) tableSupplier).refresh();
}

// Initialize the task writer.
Expand All @@ -75,8 +76,9 @@ public void open() {
public void prepareSnapshotPreBarrier(long checkpointId) throws Exception {
flush();

if (taskWriterFactory instanceof RowDataTaskWriterFactory) {
((RowDataTaskWriterFactory) taskWriterFactory).setTable(tableLoader.loadTable());
// Refresh the table if needed.
if (tableSupplier instanceof CachingTableSupplier) {
((CachingTableSupplier) tableSupplier).refresh();
}

this.writer = taskWriterFactory.create();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import java.util.List;
import java.util.Map;
import java.util.function.Supplier;
import org.apache.flink.table.data.RowData;
import org.apache.flink.table.types.logical.RowType;
import org.apache.iceberg.FileFormat;
Expand All @@ -38,9 +39,10 @@
import org.apache.iceberg.relocated.com.google.common.collect.Sets;
import org.apache.iceberg.types.TypeUtil;
import org.apache.iceberg.util.ArrayUtil;
import org.apache.iceberg.util.SerializableSupplier;

public class RowDataTaskWriterFactory implements TaskWriterFactory<RowData> {
private Table table;
private final Supplier<Table> tableSupplier;
private final Schema schema;
private final RowType flinkSchema;
private final PartitionSpec spec;
Expand All @@ -60,8 +62,27 @@ public RowDataTaskWriterFactory(
Map<String, String> writeProperties,
List<Integer> equalityFieldIds,
boolean upsert) {
this(
() -> table,
flinkSchema,
targetFileSizeBytes,
format,
writeProperties,
equalityFieldIds,
upsert);
}

public RowDataTaskWriterFactory(
SerializableSupplier<Table> tableSupplier,
RowType flinkSchema,
long targetFileSizeBytes,
FileFormat format,
Map<String, String> writeProperties,
List<Integer> equalityFieldIds,
boolean upsert) {
this.tableSupplier = tableSupplier;
// rely on the initial table metadata for schema, etc., until schema evolution is supported
this.table = table;
Table table = tableSupplier.get();
this.schema = table.schema();
this.flinkSchema = flinkSchema;
this.spec = table.spec();
Expand Down Expand Up @@ -104,16 +125,12 @@ public RowDataTaskWriterFactory(
}
}

void setTable(Table table) {
this.table = table;
}

@Override
public void initialize(int taskId, int attemptId) {
this.outputFileFactory =
OutputFileFactory.builderFor(table, taskId, attemptId)
OutputFileFactory.builderFor(tableSupplier.get(), taskId, attemptId)
.format(format)
.ioSupplier(() -> table.io())
.ioSupplier(() -> tableSupplier.get().io())
.build();
}

Expand All @@ -127,14 +144,19 @@ public TaskWriter<RowData> create() {
// Initialize a task writer to write INSERT only.
if (spec.isUnpartitioned()) {
return new UnpartitionedWriter<>(
spec, format, appenderFactory, outputFileFactory, table.io(), targetFileSizeBytes);
spec,
format,
appenderFactory,
outputFileFactory,
tableSupplier.get().io(),
targetFileSizeBytes);
} else {
return new RowDataPartitionedFanoutWriter(
spec,
format,
appenderFactory,
outputFileFactory,
table.io(),
tableSupplier.get().io(),
targetFileSizeBytes,
schema,
flinkSchema);
Expand All @@ -147,7 +169,7 @@ public TaskWriter<RowData> create() {
format,
appenderFactory,
outputFileFactory,
table.io(),
tableSupplier.get().io(),
targetFileSizeBytes,
schema,
flinkSchema,
Expand All @@ -159,7 +181,7 @@ public TaskWriter<RowData> create() {
format,
appenderFactory,
outputFileFactory,
table.io(),
tableSupplier.get().io(),
targetFileSizeBytes,
schema,
flinkSchema,
Expand Down
Loading

0 comments on commit f417bfd

Please sign in to comment.