Skip to content

Commit

Permalink
Spark: Move the Writer to a visitor
Browse files Browse the repository at this point in the history
  • Loading branch information
Fokko committed Jan 9, 2024
1 parent 2101ac2 commit 5465ab5
Showing 1 changed file with 113 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Optional;
import java.util.UUID;
import org.apache.iceberg.parquet.ParquetValueReaders.ReusableEntry;
import org.apache.iceberg.parquet.ParquetValueWriter;
Expand All @@ -48,11 +49,9 @@
import org.apache.spark.sql.catalyst.util.ArrayData;
import org.apache.spark.sql.catalyst.util.MapData;
import org.apache.spark.sql.types.ArrayType;
import org.apache.spark.sql.types.ByteType;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.Decimal;
import org.apache.spark.sql.types.MapType;
import org.apache.spark.sql.types.ShortType;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.unsafe.types.UTF8String;
Expand Down Expand Up @@ -136,46 +135,120 @@ private ParquetValueWriter<?> newOption(Type fieldType, ParquetValueWriter<?> wr
return ParquetValueWriters.option(fieldType, maxD, writer);
}

private static class LogicalTypeAnnotationParquetValueWriterVisitor
implements LogicalTypeAnnotation.LogicalTypeAnnotationVisitor<ParquetValueWriter<?>> {

private final ColumnDescriptor desc;
private final PrimitiveType primitive;

public LogicalTypeAnnotationParquetValueWriterVisitor(
ColumnDescriptor desc, PrimitiveType primitive) {
this.desc = desc;
this.primitive = primitive;
}

@Override
public Optional<ParquetValueWriter<?>> visit(
LogicalTypeAnnotation.StringLogicalTypeAnnotation stringLogicalType) {
return Optional.of(utf8Strings(desc));
}

@Override
public Optional<ParquetValueWriter<?>> visit(
LogicalTypeAnnotation.EnumLogicalTypeAnnotation enumLogicalType) {
return Optional.of(utf8Strings(desc));
}

@Override
public Optional<ParquetValueWriter<?>> visit(
LogicalTypeAnnotation.JsonLogicalTypeAnnotation jsonLogicalType) {
return Optional.of(utf8Strings(desc));
}

@Override
public Optional<ParquetValueWriter<?>> visit(
LogicalTypeAnnotation.MapLogicalTypeAnnotation mapLogicalType) {
return LogicalTypeAnnotation.LogicalTypeAnnotationVisitor.super.visit(mapLogicalType);
}

@Override
public Optional<ParquetValueWriter<?>> visit(
LogicalTypeAnnotation.ListLogicalTypeAnnotation listLogicalType) {
return LogicalTypeAnnotation.LogicalTypeAnnotationVisitor.super.visit(listLogicalType);
}

@Override
public Optional<ParquetValueWriter<?>> visit(DecimalLogicalTypeAnnotation decimal) {
switch (primitive.getPrimitiveTypeName()) {
case INT32:
return Optional.of(decimalAsInteger(desc, decimal.getPrecision(), decimal.getScale()));
case INT64:
return Optional.of(decimalAsLong(desc, decimal.getPrecision(), decimal.getScale()));
case BINARY:
case FIXED_LEN_BYTE_ARRAY:
return Optional.of(decimalAsFixed(desc, decimal.getPrecision(), decimal.getScale()));
}
return Optional.empty();
}

@Override
public Optional<ParquetValueWriter<?>> visit(
LogicalTypeAnnotation.DateLogicalTypeAnnotation dateLogicalType) {
return Optional.of(ParquetValueWriters.ints(desc));
}

@Override
public Optional<ParquetValueWriter<?>> visit(
LogicalTypeAnnotation.TimeLogicalTypeAnnotation timeLogicalType) {
if (timeLogicalType.getUnit() == LogicalTypeAnnotation.TimeUnit.MICROS) {
return Optional.of(ParquetValueWriters.longs(desc));
}
return Optional.empty();
}

@Override
public Optional<ParquetValueWriter<?>> visit(
LogicalTypeAnnotation.TimestampLogicalTypeAnnotation timestampLogicalType) {
if (timestampLogicalType.getUnit() == LogicalTypeAnnotation.TimeUnit.MICROS) {
return Optional.of(ParquetValueWriters.longs(desc));
}
return Optional.empty();
}

@Override
public Optional<ParquetValueWriter<?>> visit(
LogicalTypeAnnotation.IntLogicalTypeAnnotation intLogicalType) {
int bitWidth = intLogicalType.getBitWidth();
if (bitWidth <= 8) {
return Optional.of(ParquetValueWriters.tinyints(desc));
} else if (bitWidth <= 16) {
return Optional.of(ParquetValueWriters.shorts(desc));
} else if (bitWidth <= 32) {
return Optional.of(ParquetValueWriters.ints(desc));
} else {
return Optional.of(ParquetValueWriters.longs(desc));
}
}

@Override
public Optional<ParquetValueWriter<?>> visit(
LogicalTypeAnnotation.BsonLogicalTypeAnnotation bsonLogicalType) {
return Optional.of(byteArrays(desc));
}
}

@Override
public ParquetValueWriter<?> primitive(DataType sType, PrimitiveType primitive) {
ColumnDescriptor desc = type.getColumnDescription(currentPath());

if (primitive.getOriginalType() != null) {
switch (primitive.getOriginalType()) {
case ENUM:
case JSON:
case UTF8:
return utf8Strings(desc);
case DATE:
case INT_8:
case INT_16:
case INT_32:
return ints(sType, desc);
case INT_64:
case TIME_MICROS:
case TIMESTAMP_MICROS:
return ParquetValueWriters.longs(desc);
case DECIMAL:
DecimalLogicalTypeAnnotation decimal =
(DecimalLogicalTypeAnnotation) primitive.getLogicalTypeAnnotation();
switch (primitive.getPrimitiveTypeName()) {
case INT32:
return decimalAsInteger(desc, decimal.getPrecision(), decimal.getScale());
case INT64:
return decimalAsLong(desc, decimal.getPrecision(), decimal.getScale());
case BINARY:
case FIXED_LEN_BYTE_ARRAY:
return decimalAsFixed(desc, decimal.getPrecision(), decimal.getScale());
default:
throw new UnsupportedOperationException(
"Unsupported base type for decimal: " + primitive.getPrimitiveTypeName());
}
case BSON:
return byteArrays(desc);
default:
throw new UnsupportedOperationException(
"Unsupported logical type: " + primitive.getOriginalType());
}
LogicalTypeAnnotation logicalTypeAnnotation = primitive.getLogicalTypeAnnotation();

if (logicalTypeAnnotation != null) {
logicalTypeAnnotation
.accept(new LogicalTypeAnnotationParquetValueWriterVisitor(desc, primitive))
.orElseThrow(
() ->
new UnsupportedOperationException(
"Unsupported logical type: " + primitive.getLogicalTypeAnnotation()));
}

switch (primitive.getPrimitiveTypeName()) {
Expand All @@ -188,7 +261,7 @@ public ParquetValueWriter<?> primitive(DataType sType, PrimitiveType primitive)
case BOOLEAN:
return ParquetValueWriters.booleans(desc);
case INT32:
return ints(sType, desc);
return ParquetValueWriters.ints(desc);
case INT64:
return ParquetValueWriters.longs(desc);
case FLOAT:
Expand All @@ -201,15 +274,6 @@ public ParquetValueWriter<?> primitive(DataType sType, PrimitiveType primitive)
}
}

private static PrimitiveWriter<?> ints(DataType type, ColumnDescriptor desc) {
if (type instanceof ByteType) {
return ParquetValueWriters.tinyints(desc);
} else if (type instanceof ShortType) {
return ParquetValueWriters.shorts(desc);
}
return ParquetValueWriters.ints(desc);
}

private static PrimitiveWriter<UTF8String> utf8Strings(ColumnDescriptor desc) {
return new UTF8StringWriter(desc);
}
Expand Down

0 comments on commit 5465ab5

Please sign in to comment.