Skip to content

Commit

Permalink
Spark 3.3, Spark 3.4: Parallelize reading files in snapshot and migra…
Browse files Browse the repository at this point in the history
…te procedures

Back-port of #10037
  • Loading branch information
manuzhang committed Aug 29, 2024
1 parent 2512315 commit b5f81c4
Show file tree
Hide file tree
Showing 14 changed files with 635 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ExecutorService;
import java.util.stream.Collectors;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
Expand Down Expand Up @@ -327,6 +328,24 @@ public static List<DataFile> listPartition(
parallelism);
}

private static List<DataFile> listPartition(
SparkPartition partition,
PartitionSpec spec,
SerializableConfiguration conf,
MetricsConfig metricsConfig,
NameMapping mapping,
ExecutorService service) {
return TableMigrationUtil.listPartition(
partition.values,
partition.uri,
partition.format,
spec,
conf.get(),
metricsConfig,
mapping,
service);
}

private static SparkPartition toSparkPartition(
CatalogTablePartition partition, CatalogTable table) {
Option<URI> locationUri = partition.storage().locationUri();
Expand Down Expand Up @@ -425,6 +444,54 @@ public static void importSparkTable(
spark, sourceTableIdent, targetTable, stagingDir, partitionFilter, checkDuplicateFiles, 1);
}

/**
* Import files from an existing Spark table to an Iceberg table.
*
* <p>The import uses the Spark session to get table metadata. It assumes no operation is going on
* the original and target table and thus is not thread-safe.
*
* @param spark a Spark session
* @param sourceTableIdent an identifier of the source Spark table
* @param targetTable an Iceberg table where to import the data
* @param stagingDir a staging directory to store temporary manifest files
* @param parallelism number of threads to use for file reading
*/
public static void importSparkTable(
SparkSession spark,
TableIdentifier sourceTableIdent,
Table targetTable,
String stagingDir,
int parallelism) {
importSparkTable(
spark,
sourceTableIdent,
targetTable,
stagingDir,
TableMigrationUtil.migrationService(parallelism));
}

/**
* Import files from an existing Spark table to an Iceberg table.
*
* <p>The import uses the Spark session to get table metadata. It assumes no operation is going on
* the original and target table and thus is not thread-safe.
*
* @param spark a Spark session
* @param sourceTableIdent an identifier of the source Spark table
* @param targetTable an Iceberg table where to import the data
* @param stagingDir a staging directory to store temporary manifest files
* @param service executor service to use for file reading
*/
public static void importSparkTable(
SparkSession spark,
TableIdentifier sourceTableIdent,
Table targetTable,
String stagingDir,
ExecutorService service) {
importSparkTable(
spark, sourceTableIdent, targetTable, stagingDir, Collections.emptyMap(), false, service);
}

/**
* Import files from an existing Spark table to an Iceberg table.
*
Expand All @@ -448,6 +515,39 @@ public static void importSparkTable(
Map<String, String> partitionFilter,
boolean checkDuplicateFiles,
int parallelism) {
importSparkTable(
spark,
sourceTableIdent,
targetTable,
stagingDir,
partitionFilter,
checkDuplicateFiles,
TableMigrationUtil.migrationService(parallelism));
}

/**
* Import files from an existing Spark table to an Iceberg table.
*
* <p>The import uses the Spark session to get table metadata. It assumes no operation is going on
* the original and target table and thus is not thread-safe.
*
* @param spark a Spark session
* @param sourceTableIdent an identifier of the source Spark table
* @param targetTable an Iceberg table where to import the data
* @param stagingDir a staging directory to store temporary manifest files
* @param partitionFilter only import partitions whose values match those in the map, can be
* partially defined
* @param checkDuplicateFiles if true, throw exception if import results in a duplicate data file
* @param service executor service to use for file reading
*/
public static void importSparkTable(
SparkSession spark,
TableIdentifier sourceTableIdent,
Table targetTable,
String stagingDir,
Map<String, String> partitionFilter,
boolean checkDuplicateFiles,
ExecutorService service) {
SessionCatalog catalog = spark.sessionState().catalog();

String db =
Expand All @@ -468,7 +568,7 @@ public static void importSparkTable(

if (Objects.equal(spec, PartitionSpec.unpartitioned())) {
importUnpartitionedSparkTable(
spark, sourceTableIdentWithDB, targetTable, checkDuplicateFiles, parallelism);
spark, sourceTableIdentWithDB, targetTable, checkDuplicateFiles, service);
} else {
List<SparkPartition> sourceTablePartitions =
getPartitions(spark, sourceTableIdent, partitionFilter);
Expand All @@ -482,7 +582,7 @@ public static void importSparkTable(
spec,
stagingDir,
checkDuplicateFiles,
parallelism);
service);
}
}
} catch (AnalysisException e) {
Expand Down Expand Up @@ -541,7 +641,7 @@ private static void importUnpartitionedSparkTable(
TableIdentifier sourceTableIdent,
Table targetTable,
boolean checkDuplicateFiles,
int parallelism) {
ExecutorService service) {
try {
CatalogTable sourceTable = spark.sessionState().catalog().getTableMetadata(sourceTableIdent);
Option<String> format =
Expand All @@ -567,7 +667,7 @@ private static void importUnpartitionedSparkTable(
conf,
metricsConfig,
nameMapping,
parallelism);
service);

if (checkDuplicateFiles) {
Dataset<Row> importedFiles =
Expand Down Expand Up @@ -637,6 +737,35 @@ public static void importSparkPartitions(
String stagingDir,
boolean checkDuplicateFiles,
int parallelism) {
importSparkPartitions(
spark,
partitions,
targetTable,
spec,
stagingDir,
checkDuplicateFiles,
TableMigrationUtil.migrationService(parallelism));
}

/**
* Import files from given partitions to an Iceberg table.
*
* @param spark a Spark session
* @param partitions partitions to import
* @param targetTable an Iceberg table where to import the data
* @param spec a partition spec
* @param stagingDir a staging directory to store temporary manifest files
* @param checkDuplicateFiles if true, throw exception if import results in a duplicate data file
* @param service executor service to use for file reading
*/
public static void importSparkPartitions(
SparkSession spark,
List<SparkPartition> partitions,
Table targetTable,
PartitionSpec spec,
String stagingDir,
boolean checkDuplicateFiles,
ExecutorService service) {
Configuration conf = spark.sessionState().newHadoopConf();
SerializableConfiguration serializableConf = new SerializableConfiguration(conf);
int listingParallelism =
Expand Down Expand Up @@ -664,7 +793,7 @@ public static void importSparkPartitions(
serializableConf,
metricsConfig,
nameMapping,
parallelism)
service)
.iterator(),
Encoders.javaSerialization(DataFile.class));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package org.apache.iceberg.spark.actions;

import java.util.Map;
import java.util.concurrent.ExecutorService;
import org.apache.iceberg.Snapshot;
import org.apache.iceberg.SnapshotSummary;
import org.apache.iceberg.Table;
Expand Down Expand Up @@ -59,6 +60,7 @@ public class MigrateTableSparkAction extends BaseTableCreationSparkAction<Migrat
private final Identifier backupIdent;

private boolean dropBackup = false;
private ExecutorService executorService;

MigrateTableSparkAction(
SparkSession spark, CatalogPlugin sourceCatalog, Identifier sourceTableIdent) {
Expand Down Expand Up @@ -102,6 +104,12 @@ public MigrateTableSparkAction dropBackup() {
return this;
}

@Override
public MigrateTableSparkAction executeWith(ExecutorService service) {
this.executorService = service;
return this;
}

@Override
public MigrateTable.Result execute() {
String desc = String.format("Migrating table %s", destTableIdent().toString());
Expand Down Expand Up @@ -131,7 +139,8 @@ private MigrateTable.Result doExecute() {
TableIdentifier v1BackupIdent = new TableIdentifier(backupIdent.name(), backupNamespace);
String stagingLocation = getMetadataLocation(icebergTable);
LOG.info("Generating Iceberg metadata for {} in {}", destTableIdent(), stagingLocation);
SparkTableUtil.importSparkTable(spark(), v1BackupIdent, icebergTable, stagingLocation);
SparkTableUtil.importSparkTable(
spark(), v1BackupIdent, icebergTable, stagingLocation, executorService);

LOG.info("Committing staged changes to {}", destTableIdent());
stagedTable.commitStagedChanges();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package org.apache.iceberg.spark.actions;

import java.util.Map;
import java.util.concurrent.ExecutorService;
import org.apache.iceberg.Snapshot;
import org.apache.iceberg.SnapshotSummary;
import org.apache.iceberg.Table;
Expand Down Expand Up @@ -54,6 +55,7 @@ public class SnapshotTableSparkAction extends BaseTableCreationSparkAction<Snaps
private StagingTableCatalog destCatalog;
private Identifier destTableIdent;
private String destTableLocation = null;
private ExecutorService executorService;

SnapshotTableSparkAction(
SparkSession spark, CatalogPlugin sourceCatalog, Identifier sourceTableIdent) {
Expand Down Expand Up @@ -98,6 +100,12 @@ public SnapshotTableSparkAction tableProperty(String property, String value) {
return this;
}

@Override
public SnapshotTableSparkAction executeWith(ExecutorService service) {
this.executorService = service;
return this;
}

@Override
public SnapshotTable.Result execute() {
String desc = String.format("Snapshotting table %s as %s", sourceTableIdent(), destTableIdent);
Expand Down Expand Up @@ -126,7 +134,8 @@ private SnapshotTable.Result doExecute() {
TableIdentifier v1TableIdent = v1SourceTable().identifier();
String stagingLocation = getMetadataLocation(icebergTable);
LOG.info("Generating Iceberg metadata for {} in {}", destTableIdent(), stagingLocation);
SparkTableUtil.importSparkTable(spark(), v1TableIdent, icebergTable, stagingLocation);
SparkTableUtil.importSparkTable(
spark(), v1TableIdent, icebergTable, stagingLocation, executorService);

LOG.info("Committing staged changes to {}", destTableIdent());
stagedTable.commitStagedChanges();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ class MigrateTableProcedure extends BaseProcedure {
new ProcedureParameter[] {
ProcedureParameter.required("table", DataTypes.StringType),
ProcedureParameter.optional("properties", STRING_MAP),
ProcedureParameter.optional("drop_backup", DataTypes.BooleanType)
ProcedureParameter.optional("drop_backup", DataTypes.BooleanType),
ProcedureParameter.optional("parallelism", DataTypes.IntegerType)
};

private static final StructType OUTPUT_TYPE =
Expand Down Expand Up @@ -95,13 +96,19 @@ public InternalRow[] call(InternalRow args) {
MigrateTableSparkAction migrateTableSparkAction =
SparkActions.get().migrateTable(tableName).tableProperties(properties);

MigrateTable.Result result;
if (dropBackup) {
result = migrateTableSparkAction.dropBackup().execute();
} else {
result = migrateTableSparkAction.execute();
migrateTableSparkAction = migrateTableSparkAction.dropBackup();
}

if (!args.isNullAt(4)) {
int parallelism = args.getInt(4);
Preconditions.checkArgument(parallelism > 0, "Parallelism should be larger than 0");
migrateTableSparkAction =
migrateTableSparkAction.executeWith(executorService(parallelism, "table-migration"));
}

MigrateTable.Result result = migrateTableSparkAction.execute();

return new InternalRow[] {newInternalRow(result.migratedDataFilesCount())};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ class SnapshotTableProcedure extends BaseProcedure {
ProcedureParameter.required("source_table", DataTypes.StringType),
ProcedureParameter.required("table", DataTypes.StringType),
ProcedureParameter.optional("location", DataTypes.StringType),
ProcedureParameter.optional("properties", STRING_MAP)
ProcedureParameter.optional("properties", STRING_MAP),
ProcedureParameter.optional("parallelism", DataTypes.IntegerType)
};

private static final StructType OUTPUT_TYPE =
Expand Down Expand Up @@ -102,6 +103,12 @@ public InternalRow[] call(InternalRow args) {
action.tableLocation(snapshotLocation);
}

if (!args.isNullAt(4)) {
int parallelism = args.getInt(4);
Preconditions.checkArgument(parallelism > 0, "Parallelism should be larger than 0");
action = action.executeWith(executorService(parallelism, "table-snapshot"));
}

SnapshotTable.Result result = action.tableProperties(properties).execute();
return new InternalRow[] {newInternalRow(result.importedDataFilesCount())};
}
Expand Down
Loading

0 comments on commit b5f81c4

Please sign in to comment.