Skip to content

Commit

Permalink
Spark 3.5: Parallelize reading files in add_files procedure (apache#9274
Browse files Browse the repository at this point in the history
)
  • Loading branch information
manuzhang authored and geruh committed Jan 25, 2024
1 parent b11a591 commit 10f9c31
Show file tree
Hide file tree
Showing 6 changed files with 156 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -215,11 +215,11 @@ private static DataFile buildDataFile(
.build();
}

private static ExecutorService migrationService(int concurrentDeletes) {
private static ExecutorService migrationService(int parallelism) {
return MoreExecutors.getExitingExecutorService(
(ThreadPoolExecutor)
Executors.newFixedThreadPool(
concurrentDeletes,
parallelism,
new ThreadFactoryBuilder().setNameFormat("table-migration-%d").build()));
}
}
1 change: 1 addition & 0 deletions docs/spark-procedures.md
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,7 @@ Keep in mind the `add_files` procedure will fetch the Parquet metadata from each
| `source_table` | ✔️ | string | Table where files should come from, paths are also possible in the form of \`file_format\`.\`path\` |
| `partition_filter` || map<string, string> | A map of partitions in the source table to import from |
| `check_duplicate_files` || boolean | Whether to prevent files existing in the table from being added (defaults to true) |
| `parallelism` | | int | number of threads to use for file reading (defaults to 1) |

Warning : Schema is not validated, adding files with different schema to the Iceberg table will cause issues.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -935,6 +935,28 @@ public void testPartitionedImportFromEmptyPartitionDoesNotThrow() {
sql("SELECT * FROM %s ORDER BY id", tableName));
}

@Test
public void testAddFilesWithParallelism() {
createUnpartitionedHiveTable();

String createIceberg =
"CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg";

sql(createIceberg, tableName);

List<Object[]> result =
sql(
"CALL %s.system.add_files(table => '%s', source_table => '%s', parallelism => 2)",
catalogName, tableName, sourceTableName);

assertEquals("Procedure output must match", ImmutableList.of(row(2L, 1L)), result);

assertEquals(
"Iceberg table contains correct data",
sql("SELECT * FROM %s ORDER BY id", sourceTableName),
sql("SELECT * FROM %s ORDER BY id", tableName));
}

private static final List<Object[]> emptyQueryResult = Lists.newArrayList();

private static final StructField[] struct = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -277,15 +277,17 @@ private static List<DataFile> listPartition(
PartitionSpec spec,
SerializableConfiguration conf,
MetricsConfig metricsConfig,
NameMapping mapping) {
NameMapping mapping,
int parallelism) {
return TableMigrationUtil.listPartition(
partition.values,
partition.uri,
partition.format,
spec,
conf.get(),
metricsConfig,
mapping);
mapping,
parallelism);
}

private static SparkPartition toSparkPartition(
Expand Down Expand Up @@ -382,6 +384,33 @@ public static void importSparkTable(
String stagingDir,
Map<String, String> partitionFilter,
boolean checkDuplicateFiles) {
importSparkTable(
spark, sourceTableIdent, targetTable, stagingDir, partitionFilter, checkDuplicateFiles, 1);
}

/**
* Import files from an existing Spark table to an Iceberg table.
*
* <p>The import uses the Spark session to get table metadata. It assumes no operation is going on
* the original and target table and thus is not thread-safe.
*
* @param spark a Spark session
* @param sourceTableIdent an identifier of the source Spark table
* @param targetTable an Iceberg table where to import the data
* @param stagingDir a staging directory to store temporary manifest files
* @param partitionFilter only import partitions whose values match those in the map, can be
* partially defined
* @param checkDuplicateFiles if true, throw exception if import results in a duplicate data file
* @param parallelism number of threads to use for file reading
*/
public static void importSparkTable(
SparkSession spark,
TableIdentifier sourceTableIdent,
Table targetTable,
String stagingDir,
Map<String, String> partitionFilter,
boolean checkDuplicateFiles,
int parallelism) {
SessionCatalog catalog = spark.sessionState().catalog();

String db =
Expand All @@ -402,15 +431,21 @@ public static void importSparkTable(

if (Objects.equal(spec, PartitionSpec.unpartitioned())) {
importUnpartitionedSparkTable(
spark, sourceTableIdentWithDB, targetTable, checkDuplicateFiles);
spark, sourceTableIdentWithDB, targetTable, checkDuplicateFiles, parallelism);
} else {
List<SparkPartition> sourceTablePartitions =
getPartitions(spark, sourceTableIdent, partitionFilter);
if (sourceTablePartitions.isEmpty()) {
targetTable.newAppend().commit();
} else {
importSparkPartitions(
spark, sourceTablePartitions, targetTable, spec, stagingDir, checkDuplicateFiles);
spark,
sourceTablePartitions,
targetTable,
spec,
stagingDir,
checkDuplicateFiles,
parallelism);
}
}
} catch (AnalysisException e) {
Expand Down Expand Up @@ -443,7 +478,8 @@ public static void importSparkTable(
targetTable,
stagingDir,
Collections.emptyMap(),
checkDuplicateFiles);
checkDuplicateFiles,
1);
}

/**
Expand All @@ -460,14 +496,15 @@ public static void importSparkTable(
public static void importSparkTable(
SparkSession spark, TableIdentifier sourceTableIdent, Table targetTable, String stagingDir) {
importSparkTable(
spark, sourceTableIdent, targetTable, stagingDir, Collections.emptyMap(), false);
spark, sourceTableIdent, targetTable, stagingDir, Collections.emptyMap(), false, 1);
}

private static void importUnpartitionedSparkTable(
SparkSession spark,
TableIdentifier sourceTableIdent,
Table targetTable,
boolean checkDuplicateFiles) {
boolean checkDuplicateFiles,
int parallelism) {
try {
CatalogTable sourceTable = spark.sessionState().catalog().getTableMetadata(sourceTableIdent);
Option<String> format =
Expand All @@ -492,7 +529,8 @@ private static void importUnpartitionedSparkTable(
spec,
conf,
metricsConfig,
nameMapping);
nameMapping,
parallelism);

if (checkDuplicateFiles) {
Dataset<Row> importedFiles =
Expand Down Expand Up @@ -540,9 +578,31 @@ public static void importSparkPartitions(
PartitionSpec spec,
String stagingDir,
boolean checkDuplicateFiles) {
importSparkPartitions(spark, partitions, targetTable, spec, stagingDir, checkDuplicateFiles, 1);
}

/**
* Import files from given partitions to an Iceberg table.
*
* @param spark a Spark session
* @param partitions partitions to import
* @param targetTable an Iceberg table where to import the data
* @param spec a partition spec
* @param stagingDir a staging directory to store temporary manifest files
* @param checkDuplicateFiles if true, throw exception if import results in a duplicate data file
* @param parallelism number of threads to use for file reading
*/
public static void importSparkPartitions(
SparkSession spark,
List<SparkPartition> partitions,
Table targetTable,
PartitionSpec spec,
String stagingDir,
boolean checkDuplicateFiles,
int parallelism) {
Configuration conf = spark.sessionState().newHadoopConf();
SerializableConfiguration serializableConf = new SerializableConfiguration(conf);
int parallelism =
int listingParallelism =
Math.min(
partitions.size(), spark.sessionState().conf().parallelPartitionDiscoveryParallelism());
int numShufflePartitions = spark.sessionState().conf().numShufflePartitions();
Expand All @@ -552,7 +612,7 @@ public static void importSparkPartitions(
nameMappingString != null ? NameMappingParser.fromJson(nameMappingString) : null;

JavaSparkContext sparkContext = JavaSparkContext.fromSparkContext(spark.sparkContext());
JavaRDD<SparkPartition> partitionRDD = sparkContext.parallelize(partitions, parallelism);
JavaRDD<SparkPartition> partitionRDD = sparkContext.parallelize(partitions, listingParallelism);

Dataset<SparkPartition> partitionDS =
spark.createDataset(partitionRDD.rdd(), Encoders.javaSerialization(SparkPartition.class));
Expand All @@ -562,7 +622,12 @@ public static void importSparkPartitions(
(FlatMapFunction<SparkPartition, DataFile>)
sparkPartition ->
listPartition(
sparkPartition, spec, serializableConf, metricsConfig, nameMapping)
sparkPartition,
spec,
serializableConf,
metricsConfig,
nameMapping,
parallelism)
.iterator(),
Encoders.javaSerialization(DataFile.class));

Expand Down Expand Up @@ -635,7 +700,7 @@ public static void importSparkPartitions(
Table targetTable,
PartitionSpec spec,
String stagingDir) {
importSparkPartitions(spark, partitions, targetTable, spec, stagingDir, false);
importSparkPartitions(spark, partitions, targetTable, spec, stagingDir, false, 1);
}

public static List<SparkPartition> filterPartitions(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,16 @@ class AddFilesProcedure extends BaseProcedure {
private static final ProcedureParameter CHECK_DUPLICATE_FILES_PARAM =
ProcedureParameter.optional("check_duplicate_files", DataTypes.BooleanType);

private static final ProcedureParameter PARALLELISM =
ProcedureParameter.optional("parallelism", DataTypes.IntegerType);

private static final ProcedureParameter[] PARAMETERS =
new ProcedureParameter[] {
TABLE_PARAM, SOURCE_TABLE_PARAM, PARTITION_FILTER_PARAM, CHECK_DUPLICATE_FILES_PARAM
TABLE_PARAM,
SOURCE_TABLE_PARAM,
PARTITION_FILTER_PARAM,
CHECK_DUPLICATE_FILES_PARAM,
PARALLELISM
};

private static final StructType OUTPUT_TYPE =
Expand Down Expand Up @@ -112,7 +119,10 @@ public InternalRow[] call(InternalRow args) {

boolean checkDuplicateFiles = input.asBoolean(CHECK_DUPLICATE_FILES_PARAM, true);

return importToIceberg(tableIdent, sourceIdent, partitionFilter, checkDuplicateFiles);
int parallelism = input.asInt(PARALLELISM, 1);

return importToIceberg(
tableIdent, sourceIdent, partitionFilter, checkDuplicateFiles, parallelism);
}

private InternalRow[] toOutputRows(Snapshot snapshot) {
Expand Down Expand Up @@ -142,7 +152,8 @@ private InternalRow[] importToIceberg(
Identifier destIdent,
Identifier sourceIdent,
Map<String, String> partitionFilter,
boolean checkDuplicateFiles) {
boolean checkDuplicateFiles,
int parallelism) {
return modifyIcebergTable(
destIdent,
table -> {
Expand All @@ -153,9 +164,16 @@ private InternalRow[] importToIceberg(
Path sourcePath = new Path(sourceIdent.name());
String format = sourceIdent.namespace()[0];
importFileTable(
table, sourcePath, format, partitionFilter, checkDuplicateFiles, table.spec());
table,
sourcePath,
format,
partitionFilter,
checkDuplicateFiles,
table.spec(),
parallelism);
} else {
importCatalogTable(table, sourceIdent, partitionFilter, checkDuplicateFiles);
importCatalogTable(
table, sourceIdent, partitionFilter, checkDuplicateFiles, parallelism);
}

Snapshot snapshot = table.currentSnapshot();
Expand All @@ -178,7 +196,8 @@ private void importFileTable(
String format,
Map<String, String> partitionFilter,
boolean checkDuplicateFiles,
PartitionSpec spec) {
PartitionSpec spec,
int parallelism) {
// List Partitions via Spark InMemory file search interface
List<SparkPartition> partitions =
Spark3Util.getPartitions(spark(), tableLocation, format, partitionFilter, spec);
Expand All @@ -193,19 +212,20 @@ private void importFileTable(
// Build a Global Partition for the source
SparkPartition partition =
new SparkPartition(Collections.emptyMap(), tableLocation.toString(), format);
importPartitions(table, ImmutableList.of(partition), checkDuplicateFiles);
importPartitions(table, ImmutableList.of(partition), checkDuplicateFiles, parallelism);
} else {
Preconditions.checkArgument(
!partitions.isEmpty(), "Cannot find any matching partitions in table %s", table.name());
importPartitions(table, partitions, checkDuplicateFiles);
importPartitions(table, partitions, checkDuplicateFiles, parallelism);
}
}

private void importCatalogTable(
Table table,
Identifier sourceIdent,
Map<String, String> partitionFilter,
boolean checkDuplicateFiles) {
boolean checkDuplicateFiles,
int parallelism) {
String stagingLocation = getMetadataLocation(table);
TableIdentifier sourceTableIdentifier = Spark3Util.toV1TableIdentifier(sourceIdent);
SparkTableUtil.importSparkTable(
Expand All @@ -214,14 +234,24 @@ private void importCatalogTable(
table,
stagingLocation,
partitionFilter,
checkDuplicateFiles);
checkDuplicateFiles,
parallelism);
}

private void importPartitions(
Table table, List<SparkTableUtil.SparkPartition> partitions, boolean checkDuplicateFiles) {
Table table,
List<SparkTableUtil.SparkPartition> partitions,
boolean checkDuplicateFiles,
int parallelism) {
String stagingLocation = getMetadataLocation(table);
SparkTableUtil.importSparkPartitions(
spark(), partitions, table, table.spec(), stagingLocation, checkDuplicateFiles);
spark(),
partitions,
table,
table.spec(),
stagingLocation,
checkDuplicateFiles,
parallelism);
}

private String getMetadataLocation(Table table) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,18 @@ public Boolean asBoolean(ProcedureParameter param, Boolean defaultValue) {
return args.isNullAt(ordinal) ? defaultValue : (Boolean) args.getBoolean(ordinal);
}

public Integer asInt(ProcedureParameter param) {
Integer value = asInt(param, null);
Preconditions.checkArgument(value != null, "Parameter '%s' is not set", param.name());
return value;
}

public Integer asInt(ProcedureParameter param, Integer defaultValue) {
validateParamType(param, DataTypes.IntegerType);
int ordinal = ordinal(param);
return args.isNullAt(ordinal) ? defaultValue : (Integer) args.getInt(ordinal);
}

public long asLong(ProcedureParameter param) {
Long value = asLong(param, null);
Preconditions.checkArgument(value != null, "Parameter '%s' is not set", param.name());
Expand Down

0 comments on commit 10f9c31

Please sign in to comment.