diff --git a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java index 27ed07fcace8..1454fc534e7d 100644 --- a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java +++ b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java @@ -697,7 +697,12 @@ private long advisoryPartitionSize(long defaultValue) { private long advisoryPartitionSize( long expectedFileSize, FileFormat outputFileFormat, String outputCodec) { double shuffleCompressionRatio = shuffleCompressionRatio(outputFileFormat, outputCodec); - return (long) (expectedFileSize * shuffleCompressionRatio); + long suggestedAdvisoryPartitionSize = (long) (expectedFileSize * shuffleCompressionRatio); + return Math.max(suggestedAdvisoryPartitionSize, sparkAdvisoryPartitionSize()); + } + + private long sparkAdvisoryPartitionSize() { + return (long) spark.sessionState().conf().getConf(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES()); } private double shuffleCompressionRatio(FileFormat outputFileFormat, String outputCodec) { diff --git a/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/TestSparkWriteConf.java b/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/TestSparkWriteConf.java index b7af797d149c..abf40ebd953d 100644 --- a/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/TestSparkWriteConf.java +++ b/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/TestSparkWriteConf.java @@ -44,6 +44,7 @@ import static org.apache.spark.sql.connector.write.RowLevelOperation.Command.DELETE; import static org.apache.spark.sql.connector.write.RowLevelOperation.Command.MERGE; import static org.apache.spark.sql.connector.write.RowLevelOperation.Command.UPDATE; +import static org.assertj.core.api.Assertions.assertThat; import java.util.List; import java.util.Map; @@ -53,6 +54,7 @@ import org.apache.iceberg.UpdateProperties; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.spark.sql.internal.SQLConf; import org.junit.After; import org.junit.Assert; import org.junit.Before; @@ -74,6 +76,24 @@ public void after() { sql("DROP TABLE IF EXISTS %s", tableName); } + @Test + public void testAdvisoryPartitionSize() { + Table table = validationCatalog.loadTable(tableIdent); + + SparkWriteConf writeConf = new SparkWriteConf(spark, table, ImmutableMap.of()); + + long value1 = writeConf.writeRequirements().advisoryPartitionSize(); + assertThat(value1).isGreaterThan(64L * 1024 * 1024).isLessThan(2L * 1024 * 1024 * 1024); + + spark.conf().set(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES().key(), "2GB"); + long value2 = writeConf.writeRequirements().advisoryPartitionSize(); + assertThat(value2).isEqualTo(2L * 1024 * 1024 * 1024); + + spark.conf().set(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES().key(), "10MB"); + long value3 = writeConf.writeRequirements().advisoryPartitionSize(); + assertThat(value3).isGreaterThan(10L * 1024 * 1024); + } + @Test public void testSparkWriteConfDistributionDefault() { Table table = validationCatalog.loadTable(tableIdent);