Skip to content

Commit

Permalink
Spark 3.5: Migrate tests to JUnit5 (#9417)
Browse files Browse the repository at this point in the history
  • Loading branch information
chinmay-bhat authored Jan 5, 2024
1 parent c416c29 commit 4602824
Show file tree
Hide file tree
Showing 25 changed files with 685 additions and 689 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@
*/
package org.apache.iceberg;

import static org.assertj.core.api.Assertions.assertThat;

import java.util.Comparator;
import java.util.List;
import java.util.stream.Collectors;
import org.junit.Assert;

public final class TaskCheckHelper {
private TaskCheckHelper() {}
Expand All @@ -31,8 +32,9 @@ public static void assertEquals(
List<FileScanTask> expectedTasks = getFileScanTasksInFilePathOrder(expected);
List<FileScanTask> actualTasks = getFileScanTasksInFilePathOrder(actual);

Assert.assertEquals(
"The number of file scan tasks should match", expectedTasks.size(), actualTasks.size());
assertThat(actualTasks)
.as("The number of file scan tasks should match")
.hasSameSizeAs(expectedTasks);

for (int i = 0; i < expectedTasks.size(); i++) {
FileScanTask expectedTask = expectedTasks.get(i);
Expand All @@ -45,60 +47,57 @@ public static void assertEquals(FileScanTask expected, FileScanTask actual) {
assertEquals(expected.file(), actual.file());

// PartitionSpec implements its own equals method
Assert.assertEquals("PartitionSpec doesn't match", expected.spec(), actual.spec());
assertThat(actual.spec()).as("PartitionSpec doesn't match").isEqualTo(expected.spec());

Assert.assertEquals("starting position doesn't match", expected.start(), actual.start());
assertThat(actual.start()).as("starting position doesn't match").isEqualTo(expected.start());

Assert.assertEquals(
"the number of bytes to scan doesn't match", expected.start(), actual.start());
assertThat(actual.start())
.as("the number of bytes to scan doesn't match")
.isEqualTo(expected.start());

// simplify comparison on residual expression via comparing toString
Assert.assertEquals(
"Residual expression doesn't match",
expected.residual().toString(),
actual.residual().toString());
assertThat(actual.residual().toString())
.as("Residual expression doesn't match")
.isEqualTo(expected.residual().toString());
}

public static void assertEquals(DataFile expected, DataFile actual) {
Assert.assertEquals("Should match the serialized record path", expected.path(), actual.path());
Assert.assertEquals(
"Should match the serialized record format", expected.format(), actual.format());
Assert.assertEquals(
"Should match the serialized record partition",
expected.partition().get(0, Object.class),
actual.partition().get(0, Object.class));
Assert.assertEquals(
"Should match the serialized record count", expected.recordCount(), actual.recordCount());
Assert.assertEquals(
"Should match the serialized record size",
expected.fileSizeInBytes(),
actual.fileSizeInBytes());
Assert.assertEquals(
"Should match the serialized record value counts",
expected.valueCounts(),
actual.valueCounts());
Assert.assertEquals(
"Should match the serialized record null value counts",
expected.nullValueCounts(),
actual.nullValueCounts());
Assert.assertEquals(
"Should match the serialized record lower bounds",
expected.lowerBounds(),
actual.lowerBounds());
Assert.assertEquals(
"Should match the serialized record upper bounds",
expected.upperBounds(),
actual.upperBounds());
Assert.assertEquals(
"Should match the serialized record key metadata",
expected.keyMetadata(),
actual.keyMetadata());
Assert.assertEquals(
"Should match the serialized record offsets",
expected.splitOffsets(),
actual.splitOffsets());
Assert.assertEquals(
"Should match the serialized record offsets", expected.keyMetadata(), actual.keyMetadata());
assertThat(actual.path())
.as("Should match the serialized record path")
.isEqualTo(expected.path());
assertThat(actual.format())
.as("Should match the serialized record format")
.isEqualTo(expected.format());
assertThat(actual.partition().get(0, Object.class))
.as("Should match the serialized record partition")
.isEqualTo(expected.partition().get(0, Object.class));
assertThat(actual.recordCount())
.as("Should match the serialized record count")
.isEqualTo(expected.recordCount());
assertThat(actual.fileSizeInBytes())
.as("Should match the serialized record size")
.isEqualTo(expected.fileSizeInBytes());
assertThat(actual.valueCounts())
.as("Should match the serialized record value counts")
.isEqualTo(expected.valueCounts());
assertThat(actual.nullValueCounts())
.as("Should match the serialized record null value counts")
.isEqualTo(expected.nullValueCounts());
assertThat(actual.lowerBounds())
.as("Should match the serialized record lower bounds")
.isEqualTo(expected.lowerBounds());
assertThat(actual.upperBounds())
.as("Should match the serialized record upper bounds")
.isEqualTo(expected.upperBounds());
assertThat(actual.keyMetadata())
.as("Should match the serialized record key metadata")
.isEqualTo(expected.keyMetadata());
assertThat(actual.splitOffsets())
.as("Should match the serialized record offsets")
.isEqualTo(expected.splitOffsets());
assertThat(actual.keyMetadata())
.as("Should match the serialized record offsets")
.isEqualTo(expected.keyMetadata());
}

private static List<FileScanTask> getFileScanTasksInFilePathOrder(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import static org.apache.iceberg.TaskCheckHelper.assertEquals;
import static org.apache.iceberg.types.Types.NestedField.optional;
import static org.apache.iceberg.types.Types.NestedField.required;
import static org.assertj.core.api.Assertions.assertThat;

import com.esotericsoftware.kryo.Kryo;
import com.esotericsoftware.kryo.io.Input;
Expand All @@ -35,6 +36,7 @@
import java.io.ObjectOutputStream;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.file.Path;
import java.util.Map;
import java.util.UUID;
import org.apache.iceberg.io.FileAppender;
Expand All @@ -48,11 +50,8 @@
import org.apache.spark.SparkConf;
import org.apache.spark.serializer.KryoSerializer;
import org.apache.spark.sql.catalyst.InternalRow;
import org.assertj.core.api.Assertions;
import org.junit.Assert;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;

public class TestDataFileSerialization {

Expand Down Expand Up @@ -102,12 +101,12 @@ public class TestDataFileSerialization {
.withSortOrder(SortOrder.unsorted())
.build();

@Rule public TemporaryFolder temp = new TemporaryFolder();
@TempDir private Path temp;

@Test
public void testDataFileKryoSerialization() throws Exception {
File data = temp.newFile();
Assert.assertTrue(data.delete());
File data = File.createTempFile("junit", null, temp.toFile());
assertThat(data.delete()).isTrue();
Kryo kryo = new KryoSerializer(new SparkConf()).newKryo();

try (Output out = new Output(new FileOutputStream(data))) {
Expand All @@ -118,7 +117,7 @@ public void testDataFileKryoSerialization() throws Exception {
try (Input in = new Input(new FileInputStream(data))) {
for (int i = 0; i < 2; i += 1) {
Object obj = kryo.readClassAndObject(in);
Assertions.assertThat(obj).as("Should be a DataFile").isInstanceOf(DataFile.class);
assertThat(obj).as("Should be a DataFile").isInstanceOf(DataFile.class);
assertEquals(DATA_FILE, (DataFile) obj);
}
}
Expand All @@ -136,7 +135,7 @@ public void testDataFileJavaSerialization() throws Exception {
new ObjectInputStream(new ByteArrayInputStream(bytes.toByteArray()))) {
for (int i = 0; i < 2; i += 1) {
Object obj = in.readObject();
Assertions.assertThat(obj).as("Should be a DataFile").isInstanceOf(DataFile.class);
assertThat(obj).as("Should be a DataFile").isInstanceOf(DataFile.class);
assertEquals(DATA_FILE, (DataFile) obj);
}
}
Expand All @@ -146,7 +145,7 @@ public void testDataFileJavaSerialization() throws Exception {
public void testParquetWriterSplitOffsets() throws IOException {
Iterable<InternalRow> records = RandomData.generateSpark(DATE_SCHEMA, 1, 33L);
File parquetFile =
new File(temp.getRoot(), FileFormat.PARQUET.addExtension(UUID.randomUUID().toString()));
new File(temp.toFile(), FileFormat.PARQUET.addExtension(UUID.randomUUID().toString()));
FileAppender<InternalRow> writer =
Parquet.write(Files.localOutput(parquetFile))
.schema(DATE_SCHEMA)
Expand All @@ -161,7 +160,7 @@ public void testParquetWriterSplitOffsets() throws IOException {
}

Kryo kryo = new KryoSerializer(new SparkConf()).newKryo();
File dataFile = temp.newFile();
File dataFile = File.createTempFile("junit", null, temp.toFile());
try (Output out = new Output(new FileOutputStream(dataFile))) {
kryo.writeClassAndObject(out, writer.splitOffsets());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,12 @@

import static org.apache.iceberg.types.Types.NestedField.optional;
import static org.apache.iceberg.types.Types.NestedField.required;
import static org.assertj.core.api.Assertions.assertThat;

import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Map;
import org.apache.hadoop.conf.Configuration;
import org.apache.iceberg.hadoop.HadoopFileIO;
Expand All @@ -32,11 +35,9 @@
import org.apache.iceberg.relocated.com.google.common.collect.Maps;
import org.apache.iceberg.spark.source.SerializableTableWithSize;
import org.apache.iceberg.types.Types;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;

public class TestFileIOSerialization {

Expand All @@ -60,15 +61,15 @@ public class TestFileIOSerialization {
CONF.set("k2", "v2");
}

@Rule public TemporaryFolder temp = new TemporaryFolder();
@TempDir private Path temp;
private Table table;

@Before
@BeforeEach
public void initTable() throws IOException {
Map<String, String> props = ImmutableMap.of("k1", "v1", "k2", "v2");

File tableLocation = temp.newFolder();
Assert.assertTrue(tableLocation.delete());
File tableLocation = Files.createTempDirectory(temp, "junit").toFile();
assertThat(tableLocation.delete()).isTrue();

this.table = TABLES.create(SCHEMA, SPEC, SORT_ORDER, props, tableLocation.toString());
}
Expand All @@ -82,9 +83,9 @@ public void testHadoopFileIOKryoSerialization() throws IOException {
FileIO deserializedIO = KryoHelpers.roundTripSerialize(serializableTable.io());
Configuration actualConf = ((HadoopFileIO) deserializedIO).conf();

Assert.assertEquals("Conf pairs must match", toMap(expectedConf), toMap(actualConf));
Assert.assertEquals("Conf values must be present", "v1", actualConf.get("k1"));
Assert.assertEquals("Conf values must be present", "v2", actualConf.get("k2"));
assertThat(toMap(actualConf)).as("Conf pairs must match").isEqualTo(toMap(expectedConf));
assertThat(actualConf.get("k1")).as("Conf values must be present").isEqualTo("v1");
assertThat(actualConf.get("k2")).as("Conf values must be present").isEqualTo("v2");
}

@Test
Expand All @@ -96,9 +97,9 @@ public void testHadoopFileIOJavaSerialization() throws IOException, ClassNotFoun
FileIO deserializedIO = TestHelpers.roundTripSerialize(serializableTable.io());
Configuration actualConf = ((HadoopFileIO) deserializedIO).conf();

Assert.assertEquals("Conf pairs must match", toMap(expectedConf), toMap(actualConf));
Assert.assertEquals("Conf values must be present", "v1", actualConf.get("k1"));
Assert.assertEquals("Conf values must be present", "v2", actualConf.get("k2"));
assertThat(toMap(actualConf)).as("Conf pairs must match").isEqualTo(toMap(expectedConf));
assertThat(actualConf.get("k1")).as("Conf values must be present").isEqualTo("v1");
assertThat(actualConf.get("k2")).as("Conf values must be present").isEqualTo("v2");
}

private Map<String, String> toMap(Configuration conf) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import org.apache.iceberg.io.FileIOMetricsContext;
import org.apache.iceberg.metrics.MetricsContext;
import org.apache.iceberg.relocated.com.google.common.collect.Maps;
import org.junit.Test;
import org.junit.jupiter.api.Test;

public class TestHadoopMetricsContextSerialization {

Expand Down
Loading

0 comments on commit 4602824

Please sign in to comment.