From 6c7a80dbe99c054e03a5abaffde777128d4b00bd Mon Sep 17 00:00:00 2001 From: manuzhang Date: Tue, 19 Dec 2023 12:21:09 +0800 Subject: [PATCH] Spark 3.5: Parallelize reading files in add_files procedure --- .../iceberg/data/TableMigrationUtil.java | 4 +- docs/spark-procedures.md | 1 + .../extensions/TestAddFilesProcedure.java | 22 +++++++ .../apache/iceberg/spark/SparkTableUtil.java | 66 +++++++++++++++---- .../spark/procedures/AddFilesProcedure.java | 54 +++++++++++---- .../spark/procedures/ProcedureInput.java | 12 ++++ 6 files changed, 132 insertions(+), 27 deletions(-) diff --git a/data/src/main/java/org/apache/iceberg/data/TableMigrationUtil.java b/data/src/main/java/org/apache/iceberg/data/TableMigrationUtil.java index 0fb290f947f8..5834a074a105 100644 --- a/data/src/main/java/org/apache/iceberg/data/TableMigrationUtil.java +++ b/data/src/main/java/org/apache/iceberg/data/TableMigrationUtil.java @@ -215,11 +215,11 @@ private static DataFile buildDataFile( .build(); } - private static ExecutorService migrationService(int concurrentDeletes) { + private static ExecutorService migrationService(int parallelism) { return MoreExecutors.getExitingExecutorService( (ThreadPoolExecutor) Executors.newFixedThreadPool( - concurrentDeletes, + parallelism, new ThreadFactoryBuilder().setNameFormat("table-migration-%d").build())); } } diff --git a/docs/spark-procedures.md b/docs/spark-procedures.md index cdc1779a88f9..45a9f80ea633 100644 --- a/docs/spark-procedures.md +++ b/docs/spark-procedures.md @@ -640,6 +640,7 @@ Keep in mind the `add_files` procedure will fetch the Parquet metadata from each | `source_table` | ✔️ | string | Table where files should come from, paths are also possible in the form of \`file_format\`.\`path\` | | `partition_filter` | ️ | map | A map of partitions in the source table to import from | | `check_duplicate_files` | ️ | boolean | Whether to prevent files existing in the table from being added (defaults to true) | +| `parallelism` | | int | number of threads to use for file reading (defaults to 1) | Warning : Schema is not validated, adding files with different schema to the Iceberg table will cause issues. diff --git a/spark/v3.5/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAddFilesProcedure.java b/spark/v3.5/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAddFilesProcedure.java index 3ed99da24947..eaa0a5894c85 100644 --- a/spark/v3.5/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAddFilesProcedure.java +++ b/spark/v3.5/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAddFilesProcedure.java @@ -935,6 +935,28 @@ public void testPartitionedImportFromEmptyPartitionDoesNotThrow() { sql("SELECT * FROM %s ORDER BY id", tableName)); } + @Test + public void testAddFilesWithParallelism() { + createUnpartitionedHiveTable(); + + String createIceberg = + "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg"; + + sql(createIceberg, tableName); + + List result = + sql( + "CALL %s.system.add_files(table => '%s', source_table => '%s', parallelism => 2)", + catalogName, tableName, sourceTableName); + + assertEquals("Procedure output must match", ImmutableList.of(row(2L, 1L)), result); + + assertEquals( + "Iceberg table contains correct data", + sql("SELECT * FROM %s ORDER BY id", sourceTableName), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + private static final List emptyQueryResult = Lists.newArrayList(); private static final StructField[] struct = { diff --git a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/SparkTableUtil.java b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/SparkTableUtil.java index 51df02d56959..4de6a70fc590 100644 --- a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/SparkTableUtil.java +++ b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/SparkTableUtil.java @@ -277,7 +277,8 @@ private static List listPartition( PartitionSpec spec, SerializableConfiguration conf, MetricsConfig metricsConfig, - NameMapping mapping) { + NameMapping mapping, + int parallelism) { return TableMigrationUtil.listPartition( partition.values, partition.uri, @@ -285,7 +286,8 @@ private static List listPartition( spec, conf.get(), metricsConfig, - mapping); + mapping, + parallelism); } private static SparkPartition toSparkPartition( @@ -374,6 +376,7 @@ private static Iterator buildManifest( * @param partitionFilter only import partitions whose values match those in the map, can be * partially defined * @param checkDuplicateFiles if true, throw exception if import results in a duplicate data file + * @param parallelism number of threads to use for file reading */ public static void importSparkTable( SparkSession spark, @@ -381,7 +384,8 @@ public static void importSparkTable( Table targetTable, String stagingDir, Map partitionFilter, - boolean checkDuplicateFiles) { + boolean checkDuplicateFiles, + int parallelism) { SessionCatalog catalog = spark.sessionState().catalog(); String db = @@ -402,7 +406,7 @@ public static void importSparkTable( if (Objects.equal(spec, PartitionSpec.unpartitioned())) { importUnpartitionedSparkTable( - spark, sourceTableIdentWithDB, targetTable, checkDuplicateFiles); + spark, sourceTableIdentWithDB, targetTable, checkDuplicateFiles, parallelism); } else { List sourceTablePartitions = getPartitions(spark, sourceTableIdent, partitionFilter); @@ -410,7 +414,13 @@ public static void importSparkTable( targetTable.newAppend().commit(); } else { importSparkPartitions( - spark, sourceTablePartitions, targetTable, spec, stagingDir, checkDuplicateFiles); + spark, + sourceTablePartitions, + targetTable, + spec, + stagingDir, + checkDuplicateFiles, + parallelism); } } } catch (AnalysisException e) { @@ -443,7 +453,8 @@ public static void importSparkTable( targetTable, stagingDir, Collections.emptyMap(), - checkDuplicateFiles); + checkDuplicateFiles, + 1); } /** @@ -460,14 +471,15 @@ public static void importSparkTable( public static void importSparkTable( SparkSession spark, TableIdentifier sourceTableIdent, Table targetTable, String stagingDir) { importSparkTable( - spark, sourceTableIdent, targetTable, stagingDir, Collections.emptyMap(), false); + spark, sourceTableIdent, targetTable, stagingDir, Collections.emptyMap(), false, 1); } private static void importUnpartitionedSparkTable( SparkSession spark, TableIdentifier sourceTableIdent, Table targetTable, - boolean checkDuplicateFiles) { + boolean checkDuplicateFiles, + int parallelism) { try { CatalogTable sourceTable = spark.sessionState().catalog().getTableMetadata(sourceTableIdent); Option format = @@ -492,7 +504,8 @@ private static void importUnpartitionedSparkTable( spec, conf, metricsConfig, - nameMapping); + nameMapping, + parallelism); if (checkDuplicateFiles) { Dataset importedFiles = @@ -540,9 +553,31 @@ public static void importSparkPartitions( PartitionSpec spec, String stagingDir, boolean checkDuplicateFiles) { + importSparkPartitions(spark, partitions, targetTable, spec, stagingDir, checkDuplicateFiles, 1); + } + + /** + * Import files from given partitions to an Iceberg table. + * + * @param spark a Spark session + * @param partitions partitions to import + * @param targetTable an Iceberg table where to import the data + * @param spec a partition spec + * @param stagingDir a staging directory to store temporary manifest files + * @param checkDuplicateFiles if true, throw exception if import results in a duplicate data file + * @param parallelism number of threads to use for file reading + */ + public static void importSparkPartitions( + SparkSession spark, + List partitions, + Table targetTable, + PartitionSpec spec, + String stagingDir, + boolean checkDuplicateFiles, + int parallelism) { Configuration conf = spark.sessionState().newHadoopConf(); SerializableConfiguration serializableConf = new SerializableConfiguration(conf); - int parallelism = + int listingParallelism = Math.min( partitions.size(), spark.sessionState().conf().parallelPartitionDiscoveryParallelism()); int numShufflePartitions = spark.sessionState().conf().numShufflePartitions(); @@ -552,7 +587,7 @@ public static void importSparkPartitions( nameMappingString != null ? NameMappingParser.fromJson(nameMappingString) : null; JavaSparkContext sparkContext = JavaSparkContext.fromSparkContext(spark.sparkContext()); - JavaRDD partitionRDD = sparkContext.parallelize(partitions, parallelism); + JavaRDD partitionRDD = sparkContext.parallelize(partitions, listingParallelism); Dataset partitionDS = spark.createDataset(partitionRDD.rdd(), Encoders.javaSerialization(SparkPartition.class)); @@ -562,7 +597,12 @@ public static void importSparkPartitions( (FlatMapFunction) sparkPartition -> listPartition( - sparkPartition, spec, serializableConf, metricsConfig, nameMapping) + sparkPartition, + spec, + serializableConf, + metricsConfig, + nameMapping, + parallelism) .iterator(), Encoders.javaSerialization(DataFile.class)); @@ -635,7 +675,7 @@ public static void importSparkPartitions( Table targetTable, PartitionSpec spec, String stagingDir) { - importSparkPartitions(spark, partitions, targetTable, spec, stagingDir, false); + importSparkPartitions(spark, partitions, targetTable, spec, stagingDir, false, 1); } public static List filterPartitions( diff --git a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/procedures/AddFilesProcedure.java b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/procedures/AddFilesProcedure.java index 6a0570677673..40a343b55b80 100644 --- a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/procedures/AddFilesProcedure.java +++ b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/procedures/AddFilesProcedure.java @@ -63,9 +63,16 @@ class AddFilesProcedure extends BaseProcedure { private static final ProcedureParameter CHECK_DUPLICATE_FILES_PARAM = ProcedureParameter.optional("check_duplicate_files", DataTypes.BooleanType); + private static final ProcedureParameter PARALLELISM = + ProcedureParameter.optional("parallelism", DataTypes.IntegerType); + private static final ProcedureParameter[] PARAMETERS = new ProcedureParameter[] { - TABLE_PARAM, SOURCE_TABLE_PARAM, PARTITION_FILTER_PARAM, CHECK_DUPLICATE_FILES_PARAM + TABLE_PARAM, + SOURCE_TABLE_PARAM, + PARTITION_FILTER_PARAM, + CHECK_DUPLICATE_FILES_PARAM, + PARALLELISM }; private static final StructType OUTPUT_TYPE = @@ -112,7 +119,10 @@ public InternalRow[] call(InternalRow args) { boolean checkDuplicateFiles = input.asBoolean(CHECK_DUPLICATE_FILES_PARAM, true); - return importToIceberg(tableIdent, sourceIdent, partitionFilter, checkDuplicateFiles); + int parallelism = input.asInt(PARALLELISM, 1); + + return importToIceberg( + tableIdent, sourceIdent, partitionFilter, checkDuplicateFiles, parallelism); } private InternalRow[] toOutputRows(Snapshot snapshot) { @@ -142,7 +152,8 @@ private InternalRow[] importToIceberg( Identifier destIdent, Identifier sourceIdent, Map partitionFilter, - boolean checkDuplicateFiles) { + boolean checkDuplicateFiles, + int parallelism) { return modifyIcebergTable( destIdent, table -> { @@ -153,9 +164,16 @@ private InternalRow[] importToIceberg( Path sourcePath = new Path(sourceIdent.name()); String format = sourceIdent.namespace()[0]; importFileTable( - table, sourcePath, format, partitionFilter, checkDuplicateFiles, table.spec()); + table, + sourcePath, + format, + partitionFilter, + checkDuplicateFiles, + table.spec(), + parallelism); } else { - importCatalogTable(table, sourceIdent, partitionFilter, checkDuplicateFiles); + importCatalogTable( + table, sourceIdent, partitionFilter, checkDuplicateFiles, parallelism); } Snapshot snapshot = table.currentSnapshot(); @@ -178,7 +196,8 @@ private void importFileTable( String format, Map partitionFilter, boolean checkDuplicateFiles, - PartitionSpec spec) { + PartitionSpec spec, + int parallelism) { // List Partitions via Spark InMemory file search interface List partitions = Spark3Util.getPartitions(spark(), tableLocation, format, partitionFilter, spec); @@ -193,11 +212,11 @@ private void importFileTable( // Build a Global Partition for the source SparkPartition partition = new SparkPartition(Collections.emptyMap(), tableLocation.toString(), format); - importPartitions(table, ImmutableList.of(partition), checkDuplicateFiles); + importPartitions(table, ImmutableList.of(partition), checkDuplicateFiles, parallelism); } else { Preconditions.checkArgument( !partitions.isEmpty(), "Cannot find any matching partitions in table %s", table.name()); - importPartitions(table, partitions, checkDuplicateFiles); + importPartitions(table, partitions, checkDuplicateFiles, parallelism); } } @@ -205,7 +224,8 @@ private void importCatalogTable( Table table, Identifier sourceIdent, Map partitionFilter, - boolean checkDuplicateFiles) { + boolean checkDuplicateFiles, + int parallelism) { String stagingLocation = getMetadataLocation(table); TableIdentifier sourceTableIdentifier = Spark3Util.toV1TableIdentifier(sourceIdent); SparkTableUtil.importSparkTable( @@ -214,14 +234,24 @@ private void importCatalogTable( table, stagingLocation, partitionFilter, - checkDuplicateFiles); + checkDuplicateFiles, + parallelism); } private void importPartitions( - Table table, List partitions, boolean checkDuplicateFiles) { + Table table, + List partitions, + boolean checkDuplicateFiles, + int parallelism) { String stagingLocation = getMetadataLocation(table); SparkTableUtil.importSparkPartitions( - spark(), partitions, table, table.spec(), stagingLocation, checkDuplicateFiles); + spark(), + partitions, + table, + table.spec(), + stagingLocation, + checkDuplicateFiles, + parallelism); } private String getMetadataLocation(Table table) { diff --git a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/procedures/ProcedureInput.java b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/procedures/ProcedureInput.java index 42e4d8ba0603..0be4b38de79c 100644 --- a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/procedures/ProcedureInput.java +++ b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/procedures/ProcedureInput.java @@ -68,6 +68,18 @@ public Boolean asBoolean(ProcedureParameter param, Boolean defaultValue) { return args.isNullAt(ordinal) ? defaultValue : (Boolean) args.getBoolean(ordinal); } + public Integer asInt(ProcedureParameter param) { + Integer value = asInt(param, null); + Preconditions.checkArgument(value != null, "Parameter '%s' is not set", param.name()); + return value; + } + + public Integer asInt(ProcedureParameter param, Integer defaultValue) { + validateParamType(param, DataTypes.IntegerType); + int ordinal = ordinal(param); + return args.isNullAt(ordinal) ? defaultValue : (Integer) args.getInt(ordinal); + } + public long asLong(ProcedureParameter param) { Long value = asLong(param, null); Preconditions.checkArgument(value != null, "Parameter '%s' is not set", param.name());