diff --git a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAddFilesProcedure.java b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAddFilesProcedure.java index 44ee2ebdb646..b24ebfcf7977 100644 --- a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAddFilesProcedure.java +++ b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAddFilesProcedure.java @@ -939,6 +939,28 @@ public void testPartitionedImportFromEmptyPartitionDoesNotThrow() { sql("SELECT * FROM %s ORDER BY id", tableName)); } + @Test + public void testAddFilesWithParallelism() { + createUnpartitionedHiveTable(); + + String createIceberg = + "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg"; + + sql(createIceberg, tableName); + + List result = + sql( + "CALL %s.system.add_files(table => '%s', source_table => '%s', parallelism => 2)", + catalogName, tableName, sourceTableName); + + assertEquals("Procedure output must match", ImmutableList.of(row(2L, 1L)), result); + + assertEquals( + "Iceberg table contains correct data", + sql("SELECT * FROM %s ORDER BY id", sourceTableName), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + private static final List EMPTY_QUERY_RESULT = Lists.newArrayList(); private static final StructField[] STRUCT = { diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkTableUtil.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkTableUtil.java index af1e99df71d3..dfd0b58ffbee 100644 --- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkTableUtil.java +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkTableUtil.java @@ -28,6 +28,7 @@ import java.util.List; import java.util.Map; import java.util.UUID; +import java.util.concurrent.ExecutorService; import java.util.stream.Collectors; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.Path; @@ -291,7 +292,7 @@ public static List listPartition( PartitionSpec spec, SerializableConfiguration conf, MetricsConfig metricsConfig) { - return listPartition(partition, spec, conf, metricsConfig, null); + return listPartition(partition, spec, conf, metricsConfig, null, 1); } /** @@ -314,7 +315,8 @@ public static List listPartition( PartitionSpec spec, SerializableConfiguration conf, MetricsConfig metricsConfig, - NameMapping mapping) { + NameMapping mapping, + int parallelism) { return TableMigrationUtil.listPartition( partition.values, partition.uri, @@ -322,7 +324,26 @@ public static List listPartition( spec, conf.get(), metricsConfig, - mapping); + mapping, + parallelism); + } + + private static List listPartition( + SparkPartition partition, + PartitionSpec spec, + SerializableConfiguration conf, + MetricsConfig metricsConfig, + NameMapping mapping, + ExecutorService service) { + return TableMigrationUtil.listPartition( + partition.values, + partition.uri, + partition.format, + spec, + conf.get(), + metricsConfig, + mapping, + service); } private static SparkPartition toSparkPartition( @@ -419,6 +440,114 @@ public static void importSparkTable( String stagingDir, Map partitionFilter, boolean checkDuplicateFiles) { + importSparkTable( + spark, sourceTableIdent, targetTable, stagingDir, partitionFilter, checkDuplicateFiles, 1); + } + + /** + * Import files from an existing Spark table to an Iceberg table. + * + *

The import uses the Spark session to get table metadata. It assumes no operation is going on + * the original and target table and thus is not thread-safe. + * + * @param spark a Spark session + * @param sourceTableIdent an identifier of the source Spark table + * @param targetTable an Iceberg table where to import the data + * @param stagingDir a staging directory to store temporary manifest files + * @param parallelism number of threads to use for file reading + */ + public static void importSparkTable( + SparkSession spark, + TableIdentifier sourceTableIdent, + Table targetTable, + String stagingDir, + int parallelism) { + importSparkTable( + spark, + sourceTableIdent, + targetTable, + stagingDir, + TableMigrationUtil.migrationService(parallelism)); + } + + /** + * Import files from an existing Spark table to an Iceberg table. + * + *

The import uses the Spark session to get table metadata. It assumes no operation is going on + * the original and target table and thus is not thread-safe. + * + * @param spark a Spark session + * @param sourceTableIdent an identifier of the source Spark table + * @param targetTable an Iceberg table where to import the data + * @param stagingDir a staging directory to store temporary manifest files + * @param service executor service to use for file reading + */ + public static void importSparkTable( + SparkSession spark, + TableIdentifier sourceTableIdent, + Table targetTable, + String stagingDir, + ExecutorService service) { + importSparkTable( + spark, sourceTableIdent, targetTable, stagingDir, Collections.emptyMap(), false, service); + } + + /** + * Import files from an existing Spark table to an Iceberg table. + * + *

The import uses the Spark session to get table metadata. It assumes no operation is going on + * the original and target table and thus is not thread-safe. + * + * @param spark a Spark session + * @param sourceTableIdent an identifier of the source Spark table + * @param targetTable an Iceberg table where to import the data + * @param stagingDir a staging directory to store temporary manifest files + * @param partitionFilter only import partitions whose values match those in the map, can be + * partially defined + * @param checkDuplicateFiles if true, throw exception if import results in a duplicate data file + * @param parallelism number of threads to use for file reading + */ + public static void importSparkTable( + SparkSession spark, + TableIdentifier sourceTableIdent, + Table targetTable, + String stagingDir, + Map partitionFilter, + boolean checkDuplicateFiles, + int parallelism) { + importSparkTable( + spark, + sourceTableIdent, + targetTable, + stagingDir, + partitionFilter, + checkDuplicateFiles, + TableMigrationUtil.migrationService(parallelism)); + } + + /** + * Import files from an existing Spark table to an Iceberg table. + * + *

The import uses the Spark session to get table metadata. It assumes no operation is going on + * the original and target table and thus is not thread-safe. + * + * @param spark a Spark session + * @param sourceTableIdent an identifier of the source Spark table + * @param targetTable an Iceberg table where to import the data + * @param stagingDir a staging directory to store temporary manifest files + * @param partitionFilter only import partitions whose values match those in the map, can be + * partially defined + * @param checkDuplicateFiles if true, throw exception if import results in a duplicate data file + * @param service executor service to use for file reading + */ + public static void importSparkTable( + SparkSession spark, + TableIdentifier sourceTableIdent, + Table targetTable, + String stagingDir, + Map partitionFilter, + boolean checkDuplicateFiles, + ExecutorService service) { SessionCatalog catalog = spark.sessionState().catalog(); String db = @@ -439,7 +568,7 @@ public static void importSparkTable( if (Objects.equal(spec, PartitionSpec.unpartitioned())) { importUnpartitionedSparkTable( - spark, sourceTableIdentWithDB, targetTable, checkDuplicateFiles); + spark, sourceTableIdentWithDB, targetTable, checkDuplicateFiles, service); } else { List sourceTablePartitions = getPartitions(spark, sourceTableIdent, partitionFilter); @@ -447,7 +576,13 @@ public static void importSparkTable( targetTable.newAppend().commit(); } else { importSparkPartitions( - spark, sourceTablePartitions, targetTable, spec, stagingDir, checkDuplicateFiles); + spark, + sourceTablePartitions, + targetTable, + spec, + stagingDir, + checkDuplicateFiles, + service); } } } catch (AnalysisException e) { @@ -480,7 +615,8 @@ public static void importSparkTable( targetTable, stagingDir, Collections.emptyMap(), - checkDuplicateFiles); + checkDuplicateFiles, + 1); } /** @@ -497,14 +633,15 @@ public static void importSparkTable( public static void importSparkTable( SparkSession spark, TableIdentifier sourceTableIdent, Table targetTable, String stagingDir) { importSparkTable( - spark, sourceTableIdent, targetTable, stagingDir, Collections.emptyMap(), false); + spark, sourceTableIdent, targetTable, stagingDir, Collections.emptyMap(), false, 1); } private static void importUnpartitionedSparkTable( SparkSession spark, TableIdentifier sourceTableIdent, Table targetTable, - boolean checkDuplicateFiles) { + boolean checkDuplicateFiles, + ExecutorService service) { try { CatalogTable sourceTable = spark.sessionState().catalog().getTableMetadata(sourceTableIdent); Option format = @@ -529,7 +666,8 @@ private static void importUnpartitionedSparkTable( spec, conf, metricsConfig, - nameMapping); + nameMapping, + service); if (checkDuplicateFiles) { Dataset importedFiles = @@ -577,9 +715,60 @@ public static void importSparkPartitions( PartitionSpec spec, String stagingDir, boolean checkDuplicateFiles) { + importSparkPartitions(spark, partitions, targetTable, spec, stagingDir, checkDuplicateFiles, 1); + } + + /** + * Import files from given partitions to an Iceberg table. + * + * @param spark a Spark session + * @param partitions partitions to import + * @param targetTable an Iceberg table where to import the data + * @param spec a partition spec + * @param stagingDir a staging directory to store temporary manifest files + * @param checkDuplicateFiles if true, throw exception if import results in a duplicate data file + * @param parallelism number of threads to use for file reading + */ + public static void importSparkPartitions( + SparkSession spark, + List partitions, + Table targetTable, + PartitionSpec spec, + String stagingDir, + boolean checkDuplicateFiles, + int parallelism) { + importSparkPartitions( + spark, + partitions, + targetTable, + spec, + stagingDir, + checkDuplicateFiles, + TableMigrationUtil.migrationService(parallelism)); + } + + /** + * Import files from given partitions to an Iceberg table. + * + * @param spark a Spark session + * @param partitions partitions to import + * @param targetTable an Iceberg table where to import the data + * @param spec a partition spec + * @param stagingDir a staging directory to store temporary manifest files + * @param checkDuplicateFiles if true, throw exception if import results in a duplicate data file + * @param service executor service to use for file reading + */ + public static void importSparkPartitions( + SparkSession spark, + List partitions, + Table targetTable, + PartitionSpec spec, + String stagingDir, + boolean checkDuplicateFiles, + ExecutorService service) { Configuration conf = spark.sessionState().newHadoopConf(); SerializableConfiguration serializableConf = new SerializableConfiguration(conf); - int parallelism = + int listingParallelism = Math.min( partitions.size(), spark.sessionState().conf().parallelPartitionDiscoveryParallelism()); int numShufflePartitions = spark.sessionState().conf().numShufflePartitions(); @@ -589,7 +778,7 @@ public static void importSparkPartitions( nameMappingString != null ? NameMappingParser.fromJson(nameMappingString) : null; JavaSparkContext sparkContext = JavaSparkContext.fromSparkContext(spark.sparkContext()); - JavaRDD partitionRDD = sparkContext.parallelize(partitions, parallelism); + JavaRDD partitionRDD = sparkContext.parallelize(partitions, listingParallelism); Dataset partitionDS = spark.createDataset(partitionRDD.rdd(), Encoders.javaSerialization(SparkPartition.class)); @@ -599,7 +788,12 @@ public static void importSparkPartitions( (FlatMapFunction) sparkPartition -> listPartition( - sparkPartition, spec, serializableConf, metricsConfig, nameMapping) + sparkPartition, + spec, + serializableConf, + metricsConfig, + nameMapping, + service) .iterator(), Encoders.javaSerialization(DataFile.class)); @@ -672,7 +866,7 @@ public static void importSparkPartitions( Table targetTable, PartitionSpec spec, String stagingDir) { - importSparkPartitions(spark, partitions, targetTable, spec, stagingDir, false); + importSparkPartitions(spark, partitions, targetTable, spec, stagingDir, false, 1); } public static List filterPartitions( diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/actions/MigrateTableSparkAction.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/actions/MigrateTableSparkAction.java index fe8acf0157d3..0eb2a99f2f49 100644 --- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/actions/MigrateTableSparkAction.java +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/actions/MigrateTableSparkAction.java @@ -19,6 +19,7 @@ package org.apache.iceberg.spark.actions; import java.util.Map; +import java.util.concurrent.ExecutorService; import org.apache.iceberg.Snapshot; import org.apache.iceberg.SnapshotSummary; import org.apache.iceberg.Table; @@ -59,6 +60,7 @@ public class MigrateTableSparkAction extends BaseTableCreationSparkAction partitionFilter, - boolean checkDuplicateFiles) { + boolean checkDuplicateFiles, + int parallelism) { return modifyIcebergTable( destIdent, table -> { @@ -153,9 +164,16 @@ private InternalRow[] importToIceberg( Path sourcePath = new Path(sourceIdent.name()); String format = sourceIdent.namespace()[0]; importFileTable( - table, sourcePath, format, partitionFilter, checkDuplicateFiles, table.spec()); + table, + sourcePath, + format, + partitionFilter, + checkDuplicateFiles, + table.spec(), + parallelism); } else { - importCatalogTable(table, sourceIdent, partitionFilter, checkDuplicateFiles); + importCatalogTable( + table, sourceIdent, partitionFilter, checkDuplicateFiles, parallelism); } Snapshot snapshot = table.currentSnapshot(); @@ -178,7 +196,8 @@ private void importFileTable( String format, Map partitionFilter, boolean checkDuplicateFiles, - PartitionSpec spec) { + PartitionSpec spec, + int parallelism) { // List Partitions via Spark InMemory file search interface List partitions = Spark3Util.getPartitions(spark(), tableLocation, format, partitionFilter, spec); @@ -193,11 +212,11 @@ private void importFileTable( // Build a Global Partition for the source SparkPartition partition = new SparkPartition(Collections.emptyMap(), tableLocation.toString(), format); - importPartitions(table, ImmutableList.of(partition), checkDuplicateFiles); + importPartitions(table, ImmutableList.of(partition), checkDuplicateFiles, parallelism); } else { Preconditions.checkArgument( !partitions.isEmpty(), "Cannot find any matching partitions in table %s", table.name()); - importPartitions(table, partitions, checkDuplicateFiles); + importPartitions(table, partitions, checkDuplicateFiles, parallelism); } } @@ -205,7 +224,8 @@ private void importCatalogTable( Table table, Identifier sourceIdent, Map partitionFilter, - boolean checkDuplicateFiles) { + boolean checkDuplicateFiles, + int parallelism) { String stagingLocation = getMetadataLocation(table); TableIdentifier sourceTableIdentifier = Spark3Util.toV1TableIdentifier(sourceIdent); SparkTableUtil.importSparkTable( @@ -214,14 +234,24 @@ private void importCatalogTable( table, stagingLocation, partitionFilter, - checkDuplicateFiles); + checkDuplicateFiles, + parallelism); } private void importPartitions( - Table table, List partitions, boolean checkDuplicateFiles) { + Table table, + List partitions, + boolean checkDuplicateFiles, + int parallelism) { String stagingLocation = getMetadataLocation(table); SparkTableUtil.importSparkPartitions( - spark(), partitions, table, table.spec(), stagingLocation, checkDuplicateFiles); + spark(), + partitions, + table, + table.spec(), + stagingLocation, + checkDuplicateFiles, + parallelism); } private String getMetadataLocation(Table table) { diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/procedures/MigrateTableProcedure.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/procedures/MigrateTableProcedure.java index aaa6d2cb238d..69e4ef20ea50 100644 --- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/procedures/MigrateTableProcedure.java +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/procedures/MigrateTableProcedure.java @@ -39,7 +39,8 @@ class MigrateTableProcedure extends BaseProcedure { new ProcedureParameter[] { ProcedureParameter.required("table", DataTypes.StringType), ProcedureParameter.optional("properties", STRING_MAP), - ProcedureParameter.optional("drop_backup", DataTypes.BooleanType) + ProcedureParameter.optional("drop_backup", DataTypes.BooleanType), + ProcedureParameter.optional("parallelism", DataTypes.IntegerType) }; private static final StructType OUTPUT_TYPE = @@ -95,13 +96,19 @@ public InternalRow[] call(InternalRow args) { MigrateTableSparkAction migrateTableSparkAction = SparkActions.get().migrateTable(tableName).tableProperties(properties); - MigrateTable.Result result; if (dropBackup) { - result = migrateTableSparkAction.dropBackup().execute(); - } else { - result = migrateTableSparkAction.execute(); + migrateTableSparkAction = migrateTableSparkAction.dropBackup(); } + if (!args.isNullAt(3)) { + int parallelism = args.getInt(3); + Preconditions.checkArgument(parallelism > 0, "Parallelism should be larger than 0"); + migrateTableSparkAction = + migrateTableSparkAction.executeWith(executorService(parallelism, "table-migration")); + } + + MigrateTable.Result result = migrateTableSparkAction.execute(); + return new InternalRow[] {newInternalRow(result.migratedDataFilesCount())}; } diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/procedures/ProcedureInput.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/procedures/ProcedureInput.java index 42e4d8ba0603..0be4b38de79c 100644 --- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/procedures/ProcedureInput.java +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/procedures/ProcedureInput.java @@ -68,6 +68,18 @@ public Boolean asBoolean(ProcedureParameter param, Boolean defaultValue) { return args.isNullAt(ordinal) ? defaultValue : (Boolean) args.getBoolean(ordinal); } + public Integer asInt(ProcedureParameter param) { + Integer value = asInt(param, null); + Preconditions.checkArgument(value != null, "Parameter '%s' is not set", param.name()); + return value; + } + + public Integer asInt(ProcedureParameter param, Integer defaultValue) { + validateParamType(param, DataTypes.IntegerType); + int ordinal = ordinal(param); + return args.isNullAt(ordinal) ? defaultValue : (Integer) args.getInt(ordinal); + } + public long asLong(ProcedureParameter param) { Long value = asLong(param, null); Preconditions.checkArgument(value != null, "Parameter '%s' is not set", param.name()); diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/procedures/SnapshotTableProcedure.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/procedures/SnapshotTableProcedure.java index 7a015a51e8ed..f709f64ebf62 100644 --- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/procedures/SnapshotTableProcedure.java +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/procedures/SnapshotTableProcedure.java @@ -38,7 +38,8 @@ class SnapshotTableProcedure extends BaseProcedure { ProcedureParameter.required("source_table", DataTypes.StringType), ProcedureParameter.required("table", DataTypes.StringType), ProcedureParameter.optional("location", DataTypes.StringType), - ProcedureParameter.optional("properties", STRING_MAP) + ProcedureParameter.optional("properties", STRING_MAP), + ProcedureParameter.optional("parallelism", DataTypes.IntegerType) }; private static final StructType OUTPUT_TYPE = @@ -102,6 +103,12 @@ public InternalRow[] call(InternalRow args) { action.tableLocation(snapshotLocation); } + if (!args.isNullAt(4)) { + int parallelism = args.getInt(4); + Preconditions.checkArgument(parallelism > 0, "Parallelism should be larger than 0"); + action = action.executeWith(executorService(parallelism, "table-snapshot")); + } + SnapshotTable.Result result = action.tableProperties(properties).execute(); return new InternalRow[] {newInternalRow(result.importedDataFilesCount())}; } diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/actions/TestMigrateTableAction.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/actions/TestMigrateTableAction.java new file mode 100644 index 000000000000..7bed72b7cc2c --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/actions/TestMigrateTableAction.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.actions; + +import static org.assertj.core.api.Assumptions.assumeThat; + +import java.io.IOException; +import java.util.Map; +import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicInteger; +import org.apache.iceberg.spark.SparkCatalogTestBase; +import org.junit.After; +import org.junit.Assert; +import org.junit.Test; + +public class TestMigrateTableAction extends SparkCatalogTestBase { + + public TestMigrateTableAction( + String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + assumeThat(catalogName).isEqualToIgnoringCase("spark_catalog"); + } + + @After + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + sql("DROP TABLE IF EXISTS %s_BACKUP_", tableName); + } + + @Test + public void testMigrateWithParallelTasks() throws IOException { + String location = temp.newFolder().toURI().toString(); + sql( + "CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet LOCATION '%s'", + tableName, location); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'b')", tableName); + + AtomicInteger migrationThreadsIndex = new AtomicInteger(0); + SparkActions.get() + .migrateTable(tableName) + .executeWith( + Executors.newFixedThreadPool( + 4, + runnable -> { + Thread thread = new Thread(runnable); + thread.setName("table-migration-" + migrationThreadsIndex.getAndIncrement()); + thread.setDaemon(true); + return thread; + })) + .execute(); + Assert.assertEquals(migrationThreadsIndex.get(), 2); + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/actions/TestSnapshotTableAction.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/actions/TestSnapshotTableAction.java new file mode 100644 index 000000000000..8e6358f51bcd --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/actions/TestSnapshotTableAction.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.actions; + +import java.io.IOException; +import java.util.Map; +import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicInteger; +import org.apache.iceberg.spark.SparkCatalogTestBase; +import org.junit.After; +import org.junit.Assert; +import org.junit.Test; + +public class TestSnapshotTableAction extends SparkCatalogTestBase { + private static final String SOURCE_NAME = "spark_catalog.default.source"; + + public TestSnapshotTableAction( + String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @After + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + sql("DROP TABLE IF EXISTS %s PURGE", SOURCE_NAME); + } + + @Test + public void testSnapshotWithParallelTasks() throws IOException { + String location = temp.newFolder().toURI().toString(); + sql( + "CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet LOCATION '%s'", + SOURCE_NAME, location); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", SOURCE_NAME); + sql("INSERT INTO TABLE %s VALUES (2, 'b')", SOURCE_NAME); + + AtomicInteger snapshotThreadsIndex = new AtomicInteger(0); + SparkActions.get() + .snapshotTable(SOURCE_NAME) + .as(tableName) + .executeWith( + Executors.newFixedThreadPool( + 4, + runnable -> { + Thread thread = new Thread(runnable); + thread.setName("table-snapshot-" + snapshotThreadsIndex.getAndIncrement()); + thread.setDaemon(true); + return thread; + })) + .execute(); + Assert.assertEquals(snapshotThreadsIndex.get(), 2); + } +} diff --git a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAddFilesProcedure.java b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAddFilesProcedure.java index 5f995a7776c3..ae2062691d77 100644 --- a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAddFilesProcedure.java +++ b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAddFilesProcedure.java @@ -935,6 +935,28 @@ public void testPartitionedImportFromEmptyPartitionDoesNotThrow() { sql("SELECT * FROM %s ORDER BY id", tableName)); } + @Test + public void testAddFilesWithParallelism() { + createUnpartitionedHiveTable(); + + String createIceberg = + "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg"; + + sql(createIceberg, tableName); + + List result = + sql( + "CALL %s.system.add_files(table => '%s', source_table => '%s', parallelism => 2)", + catalogName, tableName, sourceTableName); + + assertEquals("Procedure output must match", ImmutableList.of(row(2L, 1L)), result); + + assertEquals( + "Iceberg table contains correct data", + sql("SELECT * FROM %s ORDER BY id", sourceTableName), + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + private static final List EMPTY_QUERY_RESULT = Lists.newArrayList(); private static final StructField[] STRUCT = { diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkTableUtil.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkTableUtil.java index 6f57c7ae376c..7a96e97fb98a 100644 --- a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkTableUtil.java +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkTableUtil.java @@ -28,6 +28,7 @@ import java.util.List; import java.util.Map; import java.util.UUID; +import java.util.concurrent.ExecutorService; import java.util.stream.Collectors; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.Path; @@ -277,7 +278,8 @@ private static List listPartition( PartitionSpec spec, SerializableConfiguration conf, MetricsConfig metricsConfig, - NameMapping mapping) { + NameMapping mapping, + int parallelism) { return TableMigrationUtil.listPartition( partition.values, partition.uri, @@ -285,7 +287,26 @@ private static List listPartition( spec, conf.get(), metricsConfig, - mapping); + mapping, + parallelism); + } + + private static List listPartition( + SparkPartition partition, + PartitionSpec spec, + SerializableConfiguration conf, + MetricsConfig metricsConfig, + NameMapping mapping, + ExecutorService service) { + return TableMigrationUtil.listPartition( + partition.values, + partition.uri, + partition.format, + spec, + conf.get(), + metricsConfig, + mapping, + service); } private static SparkPartition toSparkPartition( @@ -382,6 +403,114 @@ public static void importSparkTable( String stagingDir, Map partitionFilter, boolean checkDuplicateFiles) { + importSparkTable( + spark, sourceTableIdent, targetTable, stagingDir, partitionFilter, checkDuplicateFiles, 1); + } + + /** + * Import files from an existing Spark table to an Iceberg table. + * + *

The import uses the Spark session to get table metadata. It assumes no operation is going on + * the original and target table and thus is not thread-safe. + * + * @param spark a Spark session + * @param sourceTableIdent an identifier of the source Spark table + * @param targetTable an Iceberg table where to import the data + * @param stagingDir a staging directory to store temporary manifest files + * @param parallelism number of threads to use for file reading + */ + public static void importSparkTable( + SparkSession spark, + TableIdentifier sourceTableIdent, + Table targetTable, + String stagingDir, + int parallelism) { + importSparkTable( + spark, + sourceTableIdent, + targetTable, + stagingDir, + TableMigrationUtil.migrationService(parallelism)); + } + + /** + * Import files from an existing Spark table to an Iceberg table. + * + *

The import uses the Spark session to get table metadata. It assumes no operation is going on + * the original and target table and thus is not thread-safe. + * + * @param spark a Spark session + * @param sourceTableIdent an identifier of the source Spark table + * @param targetTable an Iceberg table where to import the data + * @param stagingDir a staging directory to store temporary manifest files + * @param service executor service to use for file reading + */ + public static void importSparkTable( + SparkSession spark, + TableIdentifier sourceTableIdent, + Table targetTable, + String stagingDir, + ExecutorService service) { + importSparkTable( + spark, sourceTableIdent, targetTable, stagingDir, Collections.emptyMap(), false, service); + } + + /** + * Import files from an existing Spark table to an Iceberg table. + * + *

The import uses the Spark session to get table metadata. It assumes no operation is going on + * the original and target table and thus is not thread-safe. + * + * @param spark a Spark session + * @param sourceTableIdent an identifier of the source Spark table + * @param targetTable an Iceberg table where to import the data + * @param stagingDir a staging directory to store temporary manifest files + * @param partitionFilter only import partitions whose values match those in the map, can be + * partially defined + * @param checkDuplicateFiles if true, throw exception if import results in a duplicate data file + * @param parallelism number of threads to use for file reading + */ + public static void importSparkTable( + SparkSession spark, + TableIdentifier sourceTableIdent, + Table targetTable, + String stagingDir, + Map partitionFilter, + boolean checkDuplicateFiles, + int parallelism) { + importSparkTable( + spark, + sourceTableIdent, + targetTable, + stagingDir, + partitionFilter, + checkDuplicateFiles, + TableMigrationUtil.migrationService(parallelism)); + } + + /** + * Import files from an existing Spark table to an Iceberg table. + * + *

The import uses the Spark session to get table metadata. It assumes no operation is going on + * the original and target table and thus is not thread-safe. + * + * @param spark a Spark session + * @param sourceTableIdent an identifier of the source Spark table + * @param targetTable an Iceberg table where to import the data + * @param stagingDir a staging directory to store temporary manifest files + * @param partitionFilter only import partitions whose values match those in the map, can be + * partially defined + * @param checkDuplicateFiles if true, throw exception if import results in a duplicate data file + * @param service executor service to use for file reading + */ + public static void importSparkTable( + SparkSession spark, + TableIdentifier sourceTableIdent, + Table targetTable, + String stagingDir, + Map partitionFilter, + boolean checkDuplicateFiles, + ExecutorService service) { SessionCatalog catalog = spark.sessionState().catalog(); String db = @@ -402,7 +531,7 @@ public static void importSparkTable( if (Objects.equal(spec, PartitionSpec.unpartitioned())) { importUnpartitionedSparkTable( - spark, sourceTableIdentWithDB, targetTable, checkDuplicateFiles); + spark, sourceTableIdentWithDB, targetTable, checkDuplicateFiles, service); } else { List sourceTablePartitions = getPartitions(spark, sourceTableIdent, partitionFilter); @@ -410,7 +539,13 @@ public static void importSparkTable( targetTable.newAppend().commit(); } else { importSparkPartitions( - spark, sourceTablePartitions, targetTable, spec, stagingDir, checkDuplicateFiles); + spark, + sourceTablePartitions, + targetTable, + spec, + stagingDir, + checkDuplicateFiles, + service); } } } catch (AnalysisException e) { @@ -443,7 +578,8 @@ public static void importSparkTable( targetTable, stagingDir, Collections.emptyMap(), - checkDuplicateFiles); + checkDuplicateFiles, + 1); } /** @@ -460,14 +596,15 @@ public static void importSparkTable( public static void importSparkTable( SparkSession spark, TableIdentifier sourceTableIdent, Table targetTable, String stagingDir) { importSparkTable( - spark, sourceTableIdent, targetTable, stagingDir, Collections.emptyMap(), false); + spark, sourceTableIdent, targetTable, stagingDir, Collections.emptyMap(), false, 1); } private static void importUnpartitionedSparkTable( SparkSession spark, TableIdentifier sourceTableIdent, Table targetTable, - boolean checkDuplicateFiles) { + boolean checkDuplicateFiles, + ExecutorService service) { try { CatalogTable sourceTable = spark.sessionState().catalog().getTableMetadata(sourceTableIdent); Option format = @@ -492,7 +629,8 @@ private static void importUnpartitionedSparkTable( spec, conf, metricsConfig, - nameMapping); + nameMapping, + service); if (checkDuplicateFiles) { Dataset importedFiles = @@ -540,9 +678,60 @@ public static void importSparkPartitions( PartitionSpec spec, String stagingDir, boolean checkDuplicateFiles) { + importSparkPartitions(spark, partitions, targetTable, spec, stagingDir, checkDuplicateFiles, 1); + } + + /** + * Import files from given partitions to an Iceberg table. + * + * @param spark a Spark session + * @param partitions partitions to import + * @param targetTable an Iceberg table where to import the data + * @param spec a partition spec + * @param stagingDir a staging directory to store temporary manifest files + * @param checkDuplicateFiles if true, throw exception if import results in a duplicate data file + * @param parallelism number of threads to use for file reading + */ + public static void importSparkPartitions( + SparkSession spark, + List partitions, + Table targetTable, + PartitionSpec spec, + String stagingDir, + boolean checkDuplicateFiles, + int parallelism) { + importSparkPartitions( + spark, + partitions, + targetTable, + spec, + stagingDir, + checkDuplicateFiles, + TableMigrationUtil.migrationService(parallelism)); + } + + /** + * Import files from given partitions to an Iceberg table. + * + * @param spark a Spark session + * @param partitions partitions to import + * @param targetTable an Iceberg table where to import the data + * @param spec a partition spec + * @param stagingDir a staging directory to store temporary manifest files + * @param checkDuplicateFiles if true, throw exception if import results in a duplicate data file + * @param service executor service to use for file reading + */ + public static void importSparkPartitions( + SparkSession spark, + List partitions, + Table targetTable, + PartitionSpec spec, + String stagingDir, + boolean checkDuplicateFiles, + ExecutorService service) { Configuration conf = spark.sessionState().newHadoopConf(); SerializableConfiguration serializableConf = new SerializableConfiguration(conf); - int parallelism = + int listingParallelism = Math.min( partitions.size(), spark.sessionState().conf().parallelPartitionDiscoveryParallelism()); int numShufflePartitions = spark.sessionState().conf().numShufflePartitions(); @@ -552,7 +741,7 @@ public static void importSparkPartitions( nameMappingString != null ? NameMappingParser.fromJson(nameMappingString) : null; JavaSparkContext sparkContext = JavaSparkContext.fromSparkContext(spark.sparkContext()); - JavaRDD partitionRDD = sparkContext.parallelize(partitions, parallelism); + JavaRDD partitionRDD = sparkContext.parallelize(partitions, listingParallelism); Dataset partitionDS = spark.createDataset(partitionRDD.rdd(), Encoders.javaSerialization(SparkPartition.class)); @@ -562,7 +751,12 @@ public static void importSparkPartitions( (FlatMapFunction) sparkPartition -> listPartition( - sparkPartition, spec, serializableConf, metricsConfig, nameMapping) + sparkPartition, + spec, + serializableConf, + metricsConfig, + nameMapping, + service) .iterator(), Encoders.javaSerialization(DataFile.class)); @@ -635,7 +829,7 @@ public static void importSparkPartitions( Table targetTable, PartitionSpec spec, String stagingDir) { - importSparkPartitions(spark, partitions, targetTable, spec, stagingDir, false); + importSparkPartitions(spark, partitions, targetTable, spec, stagingDir, false, 1); } public static List filterPartitions( diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/MigrateTableSparkAction.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/MigrateTableSparkAction.java index 5f3cdd3f035c..bdffeb465405 100644 --- a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/MigrateTableSparkAction.java +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/actions/MigrateTableSparkAction.java @@ -19,6 +19,7 @@ package org.apache.iceberg.spark.actions; import java.util.Map; +import java.util.concurrent.ExecutorService; import org.apache.iceberg.Snapshot; import org.apache.iceberg.SnapshotSummary; import org.apache.iceberg.Table; @@ -59,6 +60,7 @@ public class MigrateTableSparkAction extends BaseTableCreationSparkAction partitionFilter, - boolean checkDuplicateFiles) { + boolean checkDuplicateFiles, + int parallelism) { return modifyIcebergTable( destIdent, table -> { @@ -153,9 +164,16 @@ private InternalRow[] importToIceberg( Path sourcePath = new Path(sourceIdent.name()); String format = sourceIdent.namespace()[0]; importFileTable( - table, sourcePath, format, partitionFilter, checkDuplicateFiles, table.spec()); + table, + sourcePath, + format, + partitionFilter, + checkDuplicateFiles, + table.spec(), + parallelism); } else { - importCatalogTable(table, sourceIdent, partitionFilter, checkDuplicateFiles); + importCatalogTable( + table, sourceIdent, partitionFilter, checkDuplicateFiles, parallelism); } Snapshot snapshot = table.currentSnapshot(); @@ -178,7 +196,8 @@ private void importFileTable( String format, Map partitionFilter, boolean checkDuplicateFiles, - PartitionSpec spec) { + PartitionSpec spec, + int parallelism) { // List Partitions via Spark InMemory file search interface List partitions = Spark3Util.getPartitions(spark(), tableLocation, format, partitionFilter, spec); @@ -193,11 +212,11 @@ private void importFileTable( // Build a Global Partition for the source SparkPartition partition = new SparkPartition(Collections.emptyMap(), tableLocation.toString(), format); - importPartitions(table, ImmutableList.of(partition), checkDuplicateFiles); + importPartitions(table, ImmutableList.of(partition), checkDuplicateFiles, parallelism); } else { Preconditions.checkArgument( !partitions.isEmpty(), "Cannot find any matching partitions in table %s", table.name()); - importPartitions(table, partitions, checkDuplicateFiles); + importPartitions(table, partitions, checkDuplicateFiles, parallelism); } } @@ -205,7 +224,8 @@ private void importCatalogTable( Table table, Identifier sourceIdent, Map partitionFilter, - boolean checkDuplicateFiles) { + boolean checkDuplicateFiles, + int parallelism) { String stagingLocation = getMetadataLocation(table); TableIdentifier sourceTableIdentifier = Spark3Util.toV1TableIdentifier(sourceIdent); SparkTableUtil.importSparkTable( @@ -214,14 +234,24 @@ private void importCatalogTable( table, stagingLocation, partitionFilter, - checkDuplicateFiles); + checkDuplicateFiles, + parallelism); } private void importPartitions( - Table table, List partitions, boolean checkDuplicateFiles) { + Table table, + List partitions, + boolean checkDuplicateFiles, + int parallelism) { String stagingLocation = getMetadataLocation(table); SparkTableUtil.importSparkPartitions( - spark(), partitions, table, table.spec(), stagingLocation, checkDuplicateFiles); + spark(), + partitions, + table, + table.spec(), + stagingLocation, + checkDuplicateFiles, + parallelism); } private String getMetadataLocation(Table table) { diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/MigrateTableProcedure.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/MigrateTableProcedure.java index 37b1e3bf0195..a0bd04dd997e 100644 --- a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/MigrateTableProcedure.java +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/MigrateTableProcedure.java @@ -40,7 +40,8 @@ class MigrateTableProcedure extends BaseProcedure { ProcedureParameter.required("table", DataTypes.StringType), ProcedureParameter.optional("properties", STRING_MAP), ProcedureParameter.optional("drop_backup", DataTypes.BooleanType), - ProcedureParameter.optional("backup_table_name", DataTypes.StringType) + ProcedureParameter.optional("backup_table_name", DataTypes.StringType), + ProcedureParameter.optional("parallelism", DataTypes.IntegerType) }; private static final StructType OUTPUT_TYPE = @@ -105,6 +106,13 @@ public InternalRow[] call(InternalRow args) { migrateTableSparkAction = migrateTableSparkAction.backupTableName(backupTableName); } + if (!args.isNullAt(4)) { + int parallelism = args.getInt(4); + Preconditions.checkArgument(parallelism > 0, "Parallelism should be larger than 0"); + migrateTableSparkAction = + migrateTableSparkAction.executeWith(executorService(parallelism, "table-migration")); + } + MigrateTable.Result result = migrateTableSparkAction.execute(); return new InternalRow[] {newInternalRow(result.migratedDataFilesCount())}; } diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/ProcedureInput.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/ProcedureInput.java index 42e4d8ba0603..0be4b38de79c 100644 --- a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/ProcedureInput.java +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/ProcedureInput.java @@ -68,6 +68,18 @@ public Boolean asBoolean(ProcedureParameter param, Boolean defaultValue) { return args.isNullAt(ordinal) ? defaultValue : (Boolean) args.getBoolean(ordinal); } + public Integer asInt(ProcedureParameter param) { + Integer value = asInt(param, null); + Preconditions.checkArgument(value != null, "Parameter '%s' is not set", param.name()); + return value; + } + + public Integer asInt(ProcedureParameter param, Integer defaultValue) { + validateParamType(param, DataTypes.IntegerType); + int ordinal = ordinal(param); + return args.isNullAt(ordinal) ? defaultValue : (Integer) args.getInt(ordinal); + } + public long asLong(ProcedureParameter param) { Long value = asLong(param, null); Preconditions.checkArgument(value != null, "Parameter '%s' is not set", param.name()); diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/SnapshotTableProcedure.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/SnapshotTableProcedure.java index 7a015a51e8ed..f709f64ebf62 100644 --- a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/SnapshotTableProcedure.java +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/SnapshotTableProcedure.java @@ -38,7 +38,8 @@ class SnapshotTableProcedure extends BaseProcedure { ProcedureParameter.required("source_table", DataTypes.StringType), ProcedureParameter.required("table", DataTypes.StringType), ProcedureParameter.optional("location", DataTypes.StringType), - ProcedureParameter.optional("properties", STRING_MAP) + ProcedureParameter.optional("properties", STRING_MAP), + ProcedureParameter.optional("parallelism", DataTypes.IntegerType) }; private static final StructType OUTPUT_TYPE = @@ -102,6 +103,12 @@ public InternalRow[] call(InternalRow args) { action.tableLocation(snapshotLocation); } + if (!args.isNullAt(4)) { + int parallelism = args.getInt(4); + Preconditions.checkArgument(parallelism > 0, "Parallelism should be larger than 0"); + action = action.executeWith(executorService(parallelism, "table-snapshot")); + } + SnapshotTable.Result result = action.tableProperties(properties).execute(); return new InternalRow[] {newInternalRow(result.importedDataFilesCount())}; } diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/actions/TestMigrateTableAction.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/actions/TestMigrateTableAction.java new file mode 100644 index 000000000000..7bed72b7cc2c --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/actions/TestMigrateTableAction.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.actions; + +import static org.assertj.core.api.Assumptions.assumeThat; + +import java.io.IOException; +import java.util.Map; +import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicInteger; +import org.apache.iceberg.spark.SparkCatalogTestBase; +import org.junit.After; +import org.junit.Assert; +import org.junit.Test; + +public class TestMigrateTableAction extends SparkCatalogTestBase { + + public TestMigrateTableAction( + String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + assumeThat(catalogName).isEqualToIgnoringCase("spark_catalog"); + } + + @After + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + sql("DROP TABLE IF EXISTS %s_BACKUP_", tableName); + } + + @Test + public void testMigrateWithParallelTasks() throws IOException { + String location = temp.newFolder().toURI().toString(); + sql( + "CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet LOCATION '%s'", + tableName, location); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + sql("INSERT INTO TABLE %s VALUES (2, 'b')", tableName); + + AtomicInteger migrationThreadsIndex = new AtomicInteger(0); + SparkActions.get() + .migrateTable(tableName) + .executeWith( + Executors.newFixedThreadPool( + 4, + runnable -> { + Thread thread = new Thread(runnable); + thread.setName("table-migration-" + migrationThreadsIndex.getAndIncrement()); + thread.setDaemon(true); + return thread; + })) + .execute(); + Assert.assertEquals(migrationThreadsIndex.get(), 2); + } +} diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/actions/TestSnapshotTableAction.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/actions/TestSnapshotTableAction.java new file mode 100644 index 000000000000..8e6358f51bcd --- /dev/null +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/actions/TestSnapshotTableAction.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.actions; + +import java.io.IOException; +import java.util.Map; +import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicInteger; +import org.apache.iceberg.spark.SparkCatalogTestBase; +import org.junit.After; +import org.junit.Assert; +import org.junit.Test; + +public class TestSnapshotTableAction extends SparkCatalogTestBase { + private static final String SOURCE_NAME = "spark_catalog.default.source"; + + public TestSnapshotTableAction( + String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @After + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + sql("DROP TABLE IF EXISTS %s PURGE", SOURCE_NAME); + } + + @Test + public void testSnapshotWithParallelTasks() throws IOException { + String location = temp.newFolder().toURI().toString(); + sql( + "CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet LOCATION '%s'", + SOURCE_NAME, location); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", SOURCE_NAME); + sql("INSERT INTO TABLE %s VALUES (2, 'b')", SOURCE_NAME); + + AtomicInteger snapshotThreadsIndex = new AtomicInteger(0); + SparkActions.get() + .snapshotTable(SOURCE_NAME) + .as(tableName) + .executeWith( + Executors.newFixedThreadPool( + 4, + runnable -> { + Thread thread = new Thread(runnable); + thread.setName("table-snapshot-" + snapshotThreadsIndex.getAndIncrement()); + thread.setDaemon(true); + return thread; + })) + .execute(); + Assert.assertEquals(snapshotThreadsIndex.get(), 2); + } +}