Skip to content

Commit

Permalink
Spark 3.3, 3.4: Parallelize reading files in migrate procedures (apac…
Browse files Browse the repository at this point in the history
…he#11043)

Back-port of apache#9274
Back-port of apache#10037
  • Loading branch information
manuzhang authored and zachdisc committed Dec 12, 2024
1 parent 7dbd29e commit ecb3e4e
Show file tree
Hide file tree
Showing 20 changed files with 920 additions and 61 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -939,6 +939,28 @@ public void testPartitionedImportFromEmptyPartitionDoesNotThrow() {
sql("SELECT * FROM %s ORDER BY id", tableName));
}

@Test
public void testAddFilesWithParallelism() {
createUnpartitionedHiveTable();

String createIceberg =
"CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg";

sql(createIceberg, tableName);

List<Object[]> result =
sql(
"CALL %s.system.add_files(table => '%s', source_table => '%s', parallelism => 2)",
catalogName, tableName, sourceTableName);

assertEquals("Procedure output must match", ImmutableList.of(row(2L, 1L)), result);

assertEquals(
"Iceberg table contains correct data",
sql("SELECT * FROM %s ORDER BY id", sourceTableName),
sql("SELECT * FROM %s ORDER BY id", tableName));
}

private static final List<Object[]> EMPTY_QUERY_RESULT = Lists.newArrayList();

private static final StructField[] STRUCT = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ExecutorService;
import java.util.stream.Collectors;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
Expand Down Expand Up @@ -291,7 +292,7 @@ public static List<DataFile> listPartition(
PartitionSpec spec,
SerializableConfiguration conf,
MetricsConfig metricsConfig) {
return listPartition(partition, spec, conf, metricsConfig, null);
return listPartition(partition, spec, conf, metricsConfig, null, 1);
}

/**
Expand All @@ -314,15 +315,35 @@ public static List<DataFile> listPartition(
PartitionSpec spec,
SerializableConfiguration conf,
MetricsConfig metricsConfig,
NameMapping mapping) {
NameMapping mapping,
int parallelism) {
return TableMigrationUtil.listPartition(
partition.values,
partition.uri,
partition.format,
spec,
conf.get(),
metricsConfig,
mapping);
mapping,
parallelism);
}

private static List<DataFile> listPartition(
SparkPartition partition,
PartitionSpec spec,
SerializableConfiguration conf,
MetricsConfig metricsConfig,
NameMapping mapping,
ExecutorService service) {
return TableMigrationUtil.listPartition(
partition.values,
partition.uri,
partition.format,
spec,
conf.get(),
metricsConfig,
mapping,
service);
}

private static SparkPartition toSparkPartition(
Expand Down Expand Up @@ -419,6 +440,114 @@ public static void importSparkTable(
String stagingDir,
Map<String, String> partitionFilter,
boolean checkDuplicateFiles) {
importSparkTable(
spark, sourceTableIdent, targetTable, stagingDir, partitionFilter, checkDuplicateFiles, 1);
}

/**
* Import files from an existing Spark table to an Iceberg table.
*
* <p>The import uses the Spark session to get table metadata. It assumes no operation is going on
* the original and target table and thus is not thread-safe.
*
* @param spark a Spark session
* @param sourceTableIdent an identifier of the source Spark table
* @param targetTable an Iceberg table where to import the data
* @param stagingDir a staging directory to store temporary manifest files
* @param parallelism number of threads to use for file reading
*/
public static void importSparkTable(
SparkSession spark,
TableIdentifier sourceTableIdent,
Table targetTable,
String stagingDir,
int parallelism) {
importSparkTable(
spark,
sourceTableIdent,
targetTable,
stagingDir,
TableMigrationUtil.migrationService(parallelism));
}

/**
* Import files from an existing Spark table to an Iceberg table.
*
* <p>The import uses the Spark session to get table metadata. It assumes no operation is going on
* the original and target table and thus is not thread-safe.
*
* @param spark a Spark session
* @param sourceTableIdent an identifier of the source Spark table
* @param targetTable an Iceberg table where to import the data
* @param stagingDir a staging directory to store temporary manifest files
* @param service executor service to use for file reading
*/
public static void importSparkTable(
SparkSession spark,
TableIdentifier sourceTableIdent,
Table targetTable,
String stagingDir,
ExecutorService service) {
importSparkTable(
spark, sourceTableIdent, targetTable, stagingDir, Collections.emptyMap(), false, service);
}

/**
* Import files from an existing Spark table to an Iceberg table.
*
* <p>The import uses the Spark session to get table metadata. It assumes no operation is going on
* the original and target table and thus is not thread-safe.
*
* @param spark a Spark session
* @param sourceTableIdent an identifier of the source Spark table
* @param targetTable an Iceberg table where to import the data
* @param stagingDir a staging directory to store temporary manifest files
* @param partitionFilter only import partitions whose values match those in the map, can be
* partially defined
* @param checkDuplicateFiles if true, throw exception if import results in a duplicate data file
* @param parallelism number of threads to use for file reading
*/
public static void importSparkTable(
SparkSession spark,
TableIdentifier sourceTableIdent,
Table targetTable,
String stagingDir,
Map<String, String> partitionFilter,
boolean checkDuplicateFiles,
int parallelism) {
importSparkTable(
spark,
sourceTableIdent,
targetTable,
stagingDir,
partitionFilter,
checkDuplicateFiles,
TableMigrationUtil.migrationService(parallelism));
}

/**
* Import files from an existing Spark table to an Iceberg table.
*
* <p>The import uses the Spark session to get table metadata. It assumes no operation is going on
* the original and target table and thus is not thread-safe.
*
* @param spark a Spark session
* @param sourceTableIdent an identifier of the source Spark table
* @param targetTable an Iceberg table where to import the data
* @param stagingDir a staging directory to store temporary manifest files
* @param partitionFilter only import partitions whose values match those in the map, can be
* partially defined
* @param checkDuplicateFiles if true, throw exception if import results in a duplicate data file
* @param service executor service to use for file reading
*/
public static void importSparkTable(
SparkSession spark,
TableIdentifier sourceTableIdent,
Table targetTable,
String stagingDir,
Map<String, String> partitionFilter,
boolean checkDuplicateFiles,
ExecutorService service) {
SessionCatalog catalog = spark.sessionState().catalog();

String db =
Expand All @@ -439,15 +568,21 @@ public static void importSparkTable(

if (Objects.equal(spec, PartitionSpec.unpartitioned())) {
importUnpartitionedSparkTable(
spark, sourceTableIdentWithDB, targetTable, checkDuplicateFiles);
spark, sourceTableIdentWithDB, targetTable, checkDuplicateFiles, service);
} else {
List<SparkPartition> sourceTablePartitions =
getPartitions(spark, sourceTableIdent, partitionFilter);
if (sourceTablePartitions.isEmpty()) {
targetTable.newAppend().commit();
} else {
importSparkPartitions(
spark, sourceTablePartitions, targetTable, spec, stagingDir, checkDuplicateFiles);
spark,
sourceTablePartitions,
targetTable,
spec,
stagingDir,
checkDuplicateFiles,
service);
}
}
} catch (AnalysisException e) {
Expand Down Expand Up @@ -480,7 +615,8 @@ public static void importSparkTable(
targetTable,
stagingDir,
Collections.emptyMap(),
checkDuplicateFiles);
checkDuplicateFiles,
1);
}

/**
Expand All @@ -497,14 +633,15 @@ public static void importSparkTable(
public static void importSparkTable(
SparkSession spark, TableIdentifier sourceTableIdent, Table targetTable, String stagingDir) {
importSparkTable(
spark, sourceTableIdent, targetTable, stagingDir, Collections.emptyMap(), false);
spark, sourceTableIdent, targetTable, stagingDir, Collections.emptyMap(), false, 1);
}

private static void importUnpartitionedSparkTable(
SparkSession spark,
TableIdentifier sourceTableIdent,
Table targetTable,
boolean checkDuplicateFiles) {
boolean checkDuplicateFiles,
ExecutorService service) {
try {
CatalogTable sourceTable = spark.sessionState().catalog().getTableMetadata(sourceTableIdent);
Option<String> format =
Expand All @@ -529,7 +666,8 @@ private static void importUnpartitionedSparkTable(
spec,
conf,
metricsConfig,
nameMapping);
nameMapping,
service);

if (checkDuplicateFiles) {
Dataset<Row> importedFiles =
Expand Down Expand Up @@ -577,9 +715,60 @@ public static void importSparkPartitions(
PartitionSpec spec,
String stagingDir,
boolean checkDuplicateFiles) {
importSparkPartitions(spark, partitions, targetTable, spec, stagingDir, checkDuplicateFiles, 1);
}

/**
* Import files from given partitions to an Iceberg table.
*
* @param spark a Spark session
* @param partitions partitions to import
* @param targetTable an Iceberg table where to import the data
* @param spec a partition spec
* @param stagingDir a staging directory to store temporary manifest files
* @param checkDuplicateFiles if true, throw exception if import results in a duplicate data file
* @param parallelism number of threads to use for file reading
*/
public static void importSparkPartitions(
SparkSession spark,
List<SparkPartition> partitions,
Table targetTable,
PartitionSpec spec,
String stagingDir,
boolean checkDuplicateFiles,
int parallelism) {
importSparkPartitions(
spark,
partitions,
targetTable,
spec,
stagingDir,
checkDuplicateFiles,
TableMigrationUtil.migrationService(parallelism));
}

/**
* Import files from given partitions to an Iceberg table.
*
* @param spark a Spark session
* @param partitions partitions to import
* @param targetTable an Iceberg table where to import the data
* @param spec a partition spec
* @param stagingDir a staging directory to store temporary manifest files
* @param checkDuplicateFiles if true, throw exception if import results in a duplicate data file
* @param service executor service to use for file reading
*/
public static void importSparkPartitions(
SparkSession spark,
List<SparkPartition> partitions,
Table targetTable,
PartitionSpec spec,
String stagingDir,
boolean checkDuplicateFiles,
ExecutorService service) {
Configuration conf = spark.sessionState().newHadoopConf();
SerializableConfiguration serializableConf = new SerializableConfiguration(conf);
int parallelism =
int listingParallelism =
Math.min(
partitions.size(), spark.sessionState().conf().parallelPartitionDiscoveryParallelism());
int numShufflePartitions = spark.sessionState().conf().numShufflePartitions();
Expand All @@ -589,7 +778,7 @@ public static void importSparkPartitions(
nameMappingString != null ? NameMappingParser.fromJson(nameMappingString) : null;

JavaSparkContext sparkContext = JavaSparkContext.fromSparkContext(spark.sparkContext());
JavaRDD<SparkPartition> partitionRDD = sparkContext.parallelize(partitions, parallelism);
JavaRDD<SparkPartition> partitionRDD = sparkContext.parallelize(partitions, listingParallelism);

Dataset<SparkPartition> partitionDS =
spark.createDataset(partitionRDD.rdd(), Encoders.javaSerialization(SparkPartition.class));
Expand All @@ -599,7 +788,12 @@ public static void importSparkPartitions(
(FlatMapFunction<SparkPartition, DataFile>)
sparkPartition ->
listPartition(
sparkPartition, spec, serializableConf, metricsConfig, nameMapping)
sparkPartition,
spec,
serializableConf,
metricsConfig,
nameMapping,
service)
.iterator(),
Encoders.javaSerialization(DataFile.class));

Expand Down Expand Up @@ -672,7 +866,7 @@ public static void importSparkPartitions(
Table targetTable,
PartitionSpec spec,
String stagingDir) {
importSparkPartitions(spark, partitions, targetTable, spec, stagingDir, false);
importSparkPartitions(spark, partitions, targetTable, spec, stagingDir, false, 1);
}

public static List<SparkPartition> filterPartitions(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package org.apache.iceberg.spark.actions;

import java.util.Map;
import java.util.concurrent.ExecutorService;
import org.apache.iceberg.Snapshot;
import org.apache.iceberg.SnapshotSummary;
import org.apache.iceberg.Table;
Expand Down Expand Up @@ -59,6 +60,7 @@ public class MigrateTableSparkAction extends BaseTableCreationSparkAction<Migrat
private final Identifier backupIdent;

private boolean dropBackup = false;
private ExecutorService executorService;

MigrateTableSparkAction(
SparkSession spark, CatalogPlugin sourceCatalog, Identifier sourceTableIdent) {
Expand Down Expand Up @@ -102,6 +104,12 @@ public MigrateTableSparkAction dropBackup() {
return this;
}

@Override
public MigrateTableSparkAction executeWith(ExecutorService service) {
this.executorService = service;
return this;
}

@Override
public MigrateTable.Result execute() {
String desc = String.format("Migrating table %s", destTableIdent().toString());
Expand Down Expand Up @@ -131,7 +139,8 @@ private MigrateTable.Result doExecute() {
TableIdentifier v1BackupIdent = new TableIdentifier(backupIdent.name(), backupNamespace);
String stagingLocation = getMetadataLocation(icebergTable);
LOG.info("Generating Iceberg metadata for {} in {}", destTableIdent(), stagingLocation);
SparkTableUtil.importSparkTable(spark(), v1BackupIdent, icebergTable, stagingLocation);
SparkTableUtil.importSparkTable(
spark(), v1BackupIdent, icebergTable, stagingLocation, executorService);

LOG.info("Committing staged changes to {}", destTableIdent());
stagedTable.commitStagedChanges();
Expand Down
Loading

0 comments on commit ecb3e4e

Please sign in to comment.