diff --git a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/data/SparkParquetWriters.java b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/data/SparkParquetWriters.java index af6f65a089b6..34197cabf08a 100644 --- a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/data/SparkParquetWriters.java +++ b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/data/SparkParquetWriters.java @@ -24,6 +24,7 @@ import java.util.List; import java.util.Map; import java.util.NoSuchElementException; +import java.util.Optional; import java.util.UUID; import org.apache.iceberg.parquet.ParquetValueReaders.ReusableEntry; import org.apache.iceberg.parquet.ParquetValueWriter; @@ -48,11 +49,9 @@ import org.apache.spark.sql.catalyst.util.ArrayData; import org.apache.spark.sql.catalyst.util.MapData; import org.apache.spark.sql.types.ArrayType; -import org.apache.spark.sql.types.ByteType; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.Decimal; import org.apache.spark.sql.types.MapType; -import org.apache.spark.sql.types.ShortType; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; import org.apache.spark.unsafe.types.UTF8String; @@ -136,46 +135,120 @@ private ParquetValueWriter newOption(Type fieldType, ParquetValueWriter wr return ParquetValueWriters.option(fieldType, maxD, writer); } + private static class LogicalTypeAnnotationParquetValueWriterVisitor + implements LogicalTypeAnnotation.LogicalTypeAnnotationVisitor> { + + private final ColumnDescriptor desc; + private final PrimitiveType primitive; + + public LogicalTypeAnnotationParquetValueWriterVisitor( + ColumnDescriptor desc, PrimitiveType primitive) { + this.desc = desc; + this.primitive = primitive; + } + + @Override + public Optional> visit( + LogicalTypeAnnotation.StringLogicalTypeAnnotation stringLogicalType) { + return Optional.of(utf8Strings(desc)); + } + + @Override + public Optional> visit( + LogicalTypeAnnotation.EnumLogicalTypeAnnotation enumLogicalType) { + return Optional.of(utf8Strings(desc)); + } + + @Override + public Optional> visit( + LogicalTypeAnnotation.JsonLogicalTypeAnnotation jsonLogicalType) { + return Optional.of(utf8Strings(desc)); + } + + @Override + public Optional> visit( + LogicalTypeAnnotation.MapLogicalTypeAnnotation mapLogicalType) { + return LogicalTypeAnnotation.LogicalTypeAnnotationVisitor.super.visit(mapLogicalType); + } + + @Override + public Optional> visit( + LogicalTypeAnnotation.ListLogicalTypeAnnotation listLogicalType) { + return LogicalTypeAnnotation.LogicalTypeAnnotationVisitor.super.visit(listLogicalType); + } + + @Override + public Optional> visit(DecimalLogicalTypeAnnotation decimal) { + switch (primitive.getPrimitiveTypeName()) { + case INT32: + return Optional.of(decimalAsInteger(desc, decimal.getPrecision(), decimal.getScale())); + case INT64: + return Optional.of(decimalAsLong(desc, decimal.getPrecision(), decimal.getScale())); + case BINARY: + case FIXED_LEN_BYTE_ARRAY: + return Optional.of(decimalAsFixed(desc, decimal.getPrecision(), decimal.getScale())); + } + return Optional.empty(); + } + + @Override + public Optional> visit( + LogicalTypeAnnotation.DateLogicalTypeAnnotation dateLogicalType) { + return Optional.of(ParquetValueWriters.ints(desc)); + } + + @Override + public Optional> visit( + LogicalTypeAnnotation.TimeLogicalTypeAnnotation timeLogicalType) { + if (timeLogicalType.getUnit() == LogicalTypeAnnotation.TimeUnit.MICROS) { + return Optional.of(ParquetValueWriters.longs(desc)); + } + return Optional.empty(); + } + + @Override + public Optional> visit( + LogicalTypeAnnotation.TimestampLogicalTypeAnnotation timestampLogicalType) { + if (timestampLogicalType.getUnit() == LogicalTypeAnnotation.TimeUnit.MICROS) { + return Optional.of(ParquetValueWriters.longs(desc)); + } + return Optional.empty(); + } + + @Override + public Optional> visit( + LogicalTypeAnnotation.IntLogicalTypeAnnotation intLogicalType) { + int bitWidth = intLogicalType.getBitWidth(); + if (bitWidth <= 8) { + return Optional.of(ParquetValueWriters.tinyints(desc)); + } else if (bitWidth <= 16) { + return Optional.of(ParquetValueWriters.shorts(desc)); + } else if (bitWidth <= 32) { + return Optional.of(ParquetValueWriters.ints(desc)); + } else { + return Optional.of(ParquetValueWriters.longs(desc)); + } + } + + @Override + public Optional> visit( + LogicalTypeAnnotation.BsonLogicalTypeAnnotation bsonLogicalType) { + return Optional.of(byteArrays(desc)); + } + } + @Override public ParquetValueWriter primitive(DataType sType, PrimitiveType primitive) { ColumnDescriptor desc = type.getColumnDescription(currentPath()); - - if (primitive.getOriginalType() != null) { - switch (primitive.getOriginalType()) { - case ENUM: - case JSON: - case UTF8: - return utf8Strings(desc); - case DATE: - case INT_8: - case INT_16: - case INT_32: - return ints(sType, desc); - case INT_64: - case TIME_MICROS: - case TIMESTAMP_MICROS: - return ParquetValueWriters.longs(desc); - case DECIMAL: - DecimalLogicalTypeAnnotation decimal = - (DecimalLogicalTypeAnnotation) primitive.getLogicalTypeAnnotation(); - switch (primitive.getPrimitiveTypeName()) { - case INT32: - return decimalAsInteger(desc, decimal.getPrecision(), decimal.getScale()); - case INT64: - return decimalAsLong(desc, decimal.getPrecision(), decimal.getScale()); - case BINARY: - case FIXED_LEN_BYTE_ARRAY: - return decimalAsFixed(desc, decimal.getPrecision(), decimal.getScale()); - default: - throw new UnsupportedOperationException( - "Unsupported base type for decimal: " + primitive.getPrimitiveTypeName()); - } - case BSON: - return byteArrays(desc); - default: - throw new UnsupportedOperationException( - "Unsupported logical type: " + primitive.getOriginalType()); - } + LogicalTypeAnnotation logicalTypeAnnotation = primitive.getLogicalTypeAnnotation(); + + if (logicalTypeAnnotation != null) { + logicalTypeAnnotation + .accept(new LogicalTypeAnnotationParquetValueWriterVisitor(desc, primitive)) + .orElseThrow( + () -> + new UnsupportedOperationException( + "Unsupported logical type: " + primitive.getLogicalTypeAnnotation())); } switch (primitive.getPrimitiveTypeName()) { @@ -188,7 +261,7 @@ public ParquetValueWriter primitive(DataType sType, PrimitiveType primitive) case BOOLEAN: return ParquetValueWriters.booleans(desc); case INT32: - return ints(sType, desc); + return ParquetValueWriters.ints(desc); case INT64: return ParquetValueWriters.longs(desc); case FLOAT: @@ -201,15 +274,6 @@ public ParquetValueWriter primitive(DataType sType, PrimitiveType primitive) } } - private static PrimitiveWriter ints(DataType type, ColumnDescriptor desc) { - if (type instanceof ByteType) { - return ParquetValueWriters.tinyints(desc); - } else if (type instanceof ShortType) { - return ParquetValueWriters.shorts(desc); - } - return ParquetValueWriters.ints(desc); - } - private static PrimitiveWriter utf8Strings(ColumnDescriptor desc) { return new UTF8StringWriter(desc); }