Skip to content

Commit

Permalink
Spark 3.5: Parallelize reading files in add_files procedure
Browse files Browse the repository at this point in the history
  • Loading branch information
manuzhang committed Dec 19, 2023
1 parent d247b20 commit b8603a9
Show file tree
Hide file tree
Showing 6 changed files with 141 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ public static List<DataFile> listPartition(
* @param conf a Hadoop conf
* @param metricsSpec a metrics conf
* @param mapping a name mapping
* @param parallelism number of threads to use for file reading
* @param readFilesParallelism number of threads to use when reading files
* @return a List of DataFile
*/
public static List<DataFile> listPartition(
Expand All @@ -111,7 +111,7 @@ public static List<DataFile> listPartition(
Configuration conf,
MetricsConfig metricsSpec,
NameMapping mapping,
int parallelism) {
int readFilesParallelism) {
ExecutorService service = null;
try {
List<String> partitionValues =
Expand All @@ -130,8 +130,8 @@ public static List<DataFile> listPartition(
Tasks.Builder<Integer> task =
Tasks.range(fileStatus.size()).stopOnFailure().throwFailureWhenFinished();

if (parallelism > 1) {
service = migrationService(parallelism);
if (readFilesParallelism > 1) {
service = migrationService(readFilesParallelism);
task.executeWith(service);
}

Expand Down Expand Up @@ -215,11 +215,11 @@ private static DataFile buildDataFile(
.build();
}

private static ExecutorService migrationService(int concurrentDeletes) {
private static ExecutorService migrationService(int parallelism) {
return MoreExecutors.getExitingExecutorService(
(ThreadPoolExecutor)
Executors.newFixedThreadPool(
concurrentDeletes,
parallelism,
new ThreadFactoryBuilder().setNameFormat("table-migration-%d").build()));
}
}
13 changes: 7 additions & 6 deletions docs/spark-procedures.md
Original file line number Diff line number Diff line change
Expand Up @@ -633,12 +633,13 @@ Keep in mind the `add_files` procedure will fetch the Parquet metadata from each

#### Usage

| Argument Name | Required? | Type | Description |
|-------------------------|-----------|---------------------|-----------------------------------------------------------------------------------------------------|
| `table` | ✔️ | string | Table which will have files added to |
| `source_table` | ✔️ | string | Table where files should come from, paths are also possible in the form of \`file_format\`.\`path\` |
| `partition_filter` || map<string, string> | A map of partitions in the source table to import from |
| `check_duplicate_files` || boolean | Whether to prevent files existing in the table from being added (defaults to true) |
| Argument Name | Required? | Type | Description |
|--------------------------|-----------|---------------------|-----------------------------------------------------------------------------------------------------|
| `table` | ✔️ | string | Table which will have files added to |
| `source_table` | ✔️ | string | Table where files should come from, paths are also possible in the form of \`file_format\`.\`path\` |
| `partition_filter` || map<string, string> | A map of partitions in the source table to import from |
| `check_duplicate_files` || boolean | Whether to prevent files existing in the table from being added (defaults to true) |
| `read_files_parallelism` | | int | Number of threads to use when reading files (defaults to 1) |

Warning : Schema is not validated, adding files with different schema to the Iceberg table will cause issues.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -974,6 +974,28 @@ public void testPartitionedImportFromEmptyPartitionDoesNotThrow() {
sql("SELECT * FROM %s ORDER BY id", tableName));
}

@Test
public void testAddFilesWithReadFilesParallelism() {
createUnpartitionedHiveTable();

String createIceberg =
"CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg";

sql(createIceberg, tableName);

List<Object[]> result =
sql(
"CALL %s.system.add_files(table => '%s', source_table => '%s', read_files_parallelism => 2)",
catalogName, tableName, sourceTableName);

assertEquals("Procedure output must match", ImmutableList.of(row(2L, 1L)), result);

assertEquals(
"Iceberg table contains correct data",
sql("SELECT * FROM %s ORDER BY id", sourceTableName),
sql("SELECT * FROM %s ORDER BY id", tableName));
}

private static final List<Object[]> emptyQueryResult = Lists.newArrayList();

private static final StructField[] struct = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -275,15 +275,17 @@ private static List<DataFile> listPartition(
PartitionSpec spec,
SerializableConfiguration conf,
MetricsConfig metricsConfig,
NameMapping mapping) {
NameMapping mapping,
int readFilesParallelism) {
return TableMigrationUtil.listPartition(
partition.values,
partition.uri,
partition.format,
spec,
conf.get(),
metricsConfig,
mapping);
mapping,
readFilesParallelism);
}

private static SparkPartition toSparkPartition(
Expand Down Expand Up @@ -372,14 +374,16 @@ private static Iterator<ManifestFile> buildManifest(
* @param partitionFilter only import partitions whose values match those in the map, can be
* partially defined
* @param checkDuplicateFiles if true, throw exception if import results in a duplicate data file
* @param readFilesParallelism number of threads to use when reading files
*/
public static void importSparkTable(
SparkSession spark,
TableIdentifier sourceTableIdent,
Table targetTable,
String stagingDir,
Map<String, String> partitionFilter,
boolean checkDuplicateFiles) {
boolean checkDuplicateFiles,
int readFilesParallelism) {
SessionCatalog catalog = spark.sessionState().catalog();

String db =
Expand All @@ -400,15 +404,21 @@ public static void importSparkTable(

if (Objects.equal(spec, PartitionSpec.unpartitioned())) {
importUnpartitionedSparkTable(
spark, sourceTableIdentWithDB, targetTable, checkDuplicateFiles);
spark, sourceTableIdentWithDB, targetTable, checkDuplicateFiles, readFilesParallelism);
} else {
List<SparkPartition> sourceTablePartitions =
getPartitions(spark, sourceTableIdent, partitionFilter);
if (sourceTablePartitions.isEmpty()) {
targetTable.newAppend().commit();
} else {
importSparkPartitions(
spark, sourceTablePartitions, targetTable, spec, stagingDir, checkDuplicateFiles);
spark,
sourceTablePartitions,
targetTable,
spec,
stagingDir,
checkDuplicateFiles,
readFilesParallelism);
}
}
} catch (AnalysisException e) {
Expand Down Expand Up @@ -441,7 +451,8 @@ public static void importSparkTable(
targetTable,
stagingDir,
Collections.emptyMap(),
checkDuplicateFiles);
checkDuplicateFiles,
1);
}

/**
Expand All @@ -458,14 +469,15 @@ public static void importSparkTable(
public static void importSparkTable(
SparkSession spark, TableIdentifier sourceTableIdent, Table targetTable, String stagingDir) {
importSparkTable(
spark, sourceTableIdent, targetTable, stagingDir, Collections.emptyMap(), false);
spark, sourceTableIdent, targetTable, stagingDir, Collections.emptyMap(), false, 1);
}

private static void importUnpartitionedSparkTable(
SparkSession spark,
TableIdentifier sourceTableIdent,
Table targetTable,
boolean checkDuplicateFiles) {
boolean checkDuplicateFiles,
int readFilesParallelism) {
try {
CatalogTable sourceTable = spark.sessionState().catalog().getTableMetadata(sourceTableIdent);
Option<String> format =
Expand All @@ -490,7 +502,8 @@ private static void importUnpartitionedSparkTable(
spec,
conf,
metricsConfig,
nameMapping);
nameMapping,
readFilesParallelism);

if (checkDuplicateFiles) {
Dataset<Row> importedFiles =
Expand Down Expand Up @@ -538,6 +551,28 @@ public static void importSparkPartitions(
PartitionSpec spec,
String stagingDir,
boolean checkDuplicateFiles) {
importSparkPartitions(spark, partitions, targetTable, spec, stagingDir, checkDuplicateFiles, 1);
}

/**
* Import files from given partitions to an Iceberg table.
*
* @param spark a Spark session
* @param partitions partitions to import
* @param targetTable an Iceberg table where to import the data
* @param spec a partition spec
* @param stagingDir a staging directory to store temporary manifest files
* @param checkDuplicateFiles if true, throw exception if import results in a duplicate data file
* @param readFilesParallelism number of threads to use when reading files
*/
public static void importSparkPartitions(
SparkSession spark,
List<SparkPartition> partitions,
Table targetTable,
PartitionSpec spec,
String stagingDir,
boolean checkDuplicateFiles,
int readFilesParallelism) {
Configuration conf = spark.sessionState().newHadoopConf();
SerializableConfiguration serializableConf = new SerializableConfiguration(conf);
int parallelism =
Expand All @@ -560,7 +595,12 @@ public static void importSparkPartitions(
(FlatMapFunction<SparkPartition, DataFile>)
sparkPartition ->
listPartition(
sparkPartition, spec, serializableConf, metricsConfig, nameMapping)
sparkPartition,
spec,
serializableConf,
metricsConfig,
nameMapping,
readFilesParallelism)
.iterator(),
Encoders.javaSerialization(DataFile.class));

Expand Down Expand Up @@ -631,7 +671,7 @@ public static void importSparkPartitions(
Table targetTable,
PartitionSpec spec,
String stagingDir) {
importSparkPartitions(spark, partitions, targetTable, spec, stagingDir, false);
importSparkPartitions(spark, partitions, targetTable, spec, stagingDir, false, 1);
}

public static List<SparkPartition> filterPartitions(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,16 @@ class AddFilesProcedure extends BaseProcedure {
private static final ProcedureParameter CHECK_DUPLICATE_FILES_PARAM =
ProcedureParameter.optional("check_duplicate_files", DataTypes.BooleanType);

private static final ProcedureParameter READ_FILES_PARALLELISM_PARAM =
ProcedureParameter.optional("read_files_parallelism", DataTypes.IntegerType);

private static final ProcedureParameter[] PARAMETERS =
new ProcedureParameter[] {
TABLE_PARAM, SOURCE_TABLE_PARAM, PARTITION_FILTER_PARAM, CHECK_DUPLICATE_FILES_PARAM
TABLE_PARAM,
SOURCE_TABLE_PARAM,
PARTITION_FILTER_PARAM,
CHECK_DUPLICATE_FILES_PARAM,
READ_FILES_PARALLELISM_PARAM
};

private static final StructType OUTPUT_TYPE =
Expand Down Expand Up @@ -111,7 +118,10 @@ public InternalRow[] call(InternalRow args) {

boolean checkDuplicateFiles = input.asBoolean(CHECK_DUPLICATE_FILES_PARAM, true);

return importToIceberg(tableIdent, sourceIdent, partitionFilter, checkDuplicateFiles);
int readFilesParallelism = input.asInt(READ_FILES_PARALLELISM_PARAM, 1);

return importToIceberg(
tableIdent, sourceIdent, partitionFilter, checkDuplicateFiles, readFilesParallelism);
}

private InternalRow[] toOutputRows(Snapshot snapshot) {
Expand All @@ -135,7 +145,8 @@ private InternalRow[] importToIceberg(
Identifier destIdent,
Identifier sourceIdent,
Map<String, String> partitionFilter,
boolean checkDuplicateFiles) {
boolean checkDuplicateFiles,
int readFilesParallelism) {
return modifyIcebergTable(
destIdent,
table -> {
Expand All @@ -146,9 +157,16 @@ private InternalRow[] importToIceberg(
Path sourcePath = new Path(sourceIdent.name());
String format = sourceIdent.namespace()[0];
importFileTable(
table, sourcePath, format, partitionFilter, checkDuplicateFiles, table.spec());
table,
sourcePath,
format,
partitionFilter,
checkDuplicateFiles,
table.spec(),
readFilesParallelism);
} else {
importCatalogTable(table, sourceIdent, partitionFilter, checkDuplicateFiles);
importCatalogTable(
table, sourceIdent, partitionFilter, checkDuplicateFiles, readFilesParallelism);
}

Snapshot snapshot = table.currentSnapshot();
Expand All @@ -171,7 +189,8 @@ private void importFileTable(
String format,
Map<String, String> partitionFilter,
boolean checkDuplicateFiles,
PartitionSpec spec) {
PartitionSpec spec,
int readFilesParallelism) {
// List Partitions via Spark InMemory file search interface
List<SparkPartition> partitions =
Spark3Util.getPartitions(spark(), tableLocation, format, partitionFilter, spec);
Expand All @@ -186,19 +205,21 @@ private void importFileTable(
// Build a Global Partition for the source
SparkPartition partition =
new SparkPartition(Collections.emptyMap(), tableLocation.toString(), format);
importPartitions(table, ImmutableList.of(partition), checkDuplicateFiles);
importPartitions(
table, ImmutableList.of(partition), checkDuplicateFiles, readFilesParallelism);
} else {
Preconditions.checkArgument(
!partitions.isEmpty(), "Cannot find any matching partitions in table %s", partitions);
importPartitions(table, partitions, checkDuplicateFiles);
importPartitions(table, partitions, checkDuplicateFiles, readFilesParallelism);
}
}

private void importCatalogTable(
Table table,
Identifier sourceIdent,
Map<String, String> partitionFilter,
boolean checkDuplicateFiles) {
boolean checkDuplicateFiles,
int readFilesParallelism) {
String stagingLocation = getMetadataLocation(table);
TableIdentifier sourceTableIdentifier = Spark3Util.toV1TableIdentifier(sourceIdent);
SparkTableUtil.importSparkTable(
Expand All @@ -207,14 +228,24 @@ private void importCatalogTable(
table,
stagingLocation,
partitionFilter,
checkDuplicateFiles);
checkDuplicateFiles,
readFilesParallelism);
}

private void importPartitions(
Table table, List<SparkTableUtil.SparkPartition> partitions, boolean checkDuplicateFiles) {
Table table,
List<SparkTableUtil.SparkPartition> partitions,
boolean checkDuplicateFiles,
int readFilesParallelism) {
String stagingLocation = getMetadataLocation(table);
SparkTableUtil.importSparkPartitions(
spark(), partitions, table, table.spec(), stagingLocation, checkDuplicateFiles);
spark(),
partitions,
table,
table.spec(),
stagingLocation,
checkDuplicateFiles,
readFilesParallelism);
}

private String getMetadataLocation(Table table) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,18 @@ public Boolean asBoolean(ProcedureParameter param, Boolean defaultValue) {
return args.isNullAt(ordinal) ? defaultValue : (Boolean) args.getBoolean(ordinal);
}

public Integer asInt(ProcedureParameter param) {
Integer value = asInt(param, null);
Preconditions.checkArgument(value != null, "Parameter '%s' is not set", param.name());
return value;
}

public Integer asInt(ProcedureParameter param, Integer defaultValue) {
validateParamType(param, DataTypes.IntegerType);
int ordinal = ordinal(param);
return args.isNullAt(ordinal) ? defaultValue : (Integer) args.getInt(ordinal);
}

public long asLong(ProcedureParameter param) {
Long value = asLong(param, null);
Preconditions.checkArgument(value != null, "Parameter '%s' is not set", param.name());
Expand Down

0 comments on commit b8603a9

Please sign in to comment.