Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
ajantha-bhat committed Jan 8, 2025
1 parent 15b943f commit 3b95d75
Show file tree
Hide file tree
Showing 8 changed files with 249 additions and 193 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@
import org.apache.avro.Schema;
import org.apache.iceberg.relocated.com.google.common.base.Preconditions;

abstract class BaseAvroSchemaVisitor extends AvroSchemaVisitor<ValueWriter<?>> {
abstract class BaseWriteBuilder extends AvroSchemaVisitor<ValueWriter<?>> {

protected abstract ValueWriter<?> createRecordWriter(List<ValueWriter<?>> fields);

protected abstract ValueWriter<?> fixedWriter(int size);
protected abstract ValueWriter<?> fixedWriter(int length);

@Override
public ValueWriter<?> record(Schema record, List<String> names, List<ValueWriter<?>> fields) {
Expand All @@ -37,16 +37,15 @@ public ValueWriter<?> record(Schema record, List<String> names, List<ValueWriter

@Override
public ValueWriter<?> union(Schema union, List<ValueWriter<?>> options) {
Preconditions.checkArgument(
options.contains(ValueWriters.nulls()),
"Cannot create writer for non-option union: %s",
union);
Preconditions.checkArgument(
options.size() == 2, "Cannot create writer for non-option union: %s", union);
if (union.getTypes().get(0).getType() == Schema.Type.NULL) {
return ValueWriters.option(0, options.get(1));
} else {
} else if (union.getTypes().get(1).getType() == Schema.Type.NULL) {
return ValueWriters.option(1, options.get(0));
} else {
throw new IllegalArgumentException(
String.format("Cannot create writer for non-option union: %s", union));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public static <D> GenericAvroWriter<D> create(Schema schema) {
@Override
@SuppressWarnings("unchecked")
public void setSchema(Schema schema) {
this.writer = (ValueWriter<T>) AvroSchemaVisitor.visit(schema, new GenericAvroSchemaVisitor());
this.writer = (ValueWriter<T>) AvroSchemaVisitor.visit(schema, new WriteBuilder());
}

@Override
Expand All @@ -52,16 +52,16 @@ public Stream<FieldMetrics> metrics() {
return writer.metrics();
}

private static class GenericAvroSchemaVisitor extends BaseAvroSchemaVisitor {
private static class WriteBuilder extends BaseWriteBuilder {

@Override
protected ValueWriter<?> createRecordWriter(List<ValueWriter<?>> fields) {
return ValueWriters.record(fields);
}

@Override
protected ValueWriter<?> fixedWriter(int size) {
return ValueWriters.genericFixed(size);
protected ValueWriter<?> fixedWriter(int length) {
return ValueWriters.genericFixed(length);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public static <D> InternalWriter<D> create(Schema schema) {
@Override
@SuppressWarnings("unchecked")
public void setSchema(Schema schema) {
this.writer = (ValueWriter<T>) AvroSchemaVisitor.visit(schema, new GenericAvroSchemaVisitor());
this.writer = (ValueWriter<T>) AvroSchemaVisitor.visit(schema, new WriteBuilder());
}

@Override
Expand All @@ -59,16 +59,16 @@ public Stream<FieldMetrics> metrics() {
return writer.metrics();
}

private static class GenericAvroSchemaVisitor extends BaseAvroSchemaVisitor {
private static class WriteBuilder extends BaseWriteBuilder {

@Override
protected ValueWriter<?> createRecordWriter(List<ValueWriter<?>> fields) {
return ValueWriters.struct(fields);
}

@Override
protected ValueWriter<?> fixedWriter(int size) {
return ValueWriters.byteBuffers();
protected ValueWriter<?> fixedWriter(int length) {
return ValueWriters.fixedBuffers(length);
}
}
}
23 changes: 22 additions & 1 deletion core/src/main/java/org/apache/iceberg/avro/ValueWriters.java
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ public static ValueWriter<GenericData.Fixed> genericFixed(int length) {
return new GenericFixedWriter(length);
}

public static ValueWriter<ByteBuffer> fixedBuffers(int length) {
return new FixedByteBufferWriter(length);
}

public static ValueWriter<byte[]> bytes() {
return BytesWriter.INSTANCE;
}
Expand Down Expand Up @@ -332,6 +336,24 @@ public void write(ByteBuffer bytes, Encoder encoder) throws IOException {
}
}

private static class FixedByteBufferWriter implements ValueWriter<ByteBuffer> {
private final int length;

private FixedByteBufferWriter(int length) {
this.length = length;
}

@Override
public void write(ByteBuffer bytes, Encoder encoder) throws IOException {
Preconditions.checkArgument(
bytes.remaining() == length,
"Cannot write byte buffer of length %s as fixed[%s]",
bytes.remaining(),
length);
encoder.writeBytes(bytes);
}
}

private static class DecimalWriter implements ValueWriter<BigDecimal> {
private final int precision;
private final int scale;
Expand Down Expand Up @@ -491,7 +513,6 @@ protected Object get(IndexedRecord struct, int pos) {
}

private static class StructLikeWriter extends StructWriter<StructLike> {
@SuppressWarnings("unchecked")
private StructLikeWriter(List<ValueWriter<?>> writers) {
super(writers);
}
Expand Down
31 changes: 28 additions & 3 deletions core/src/test/java/org/apache/iceberg/avro/AvroTestHelpers.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
import org.apache.avro.JsonProperties;
import org.apache.avro.Schema;
import org.apache.avro.generic.GenericData.Record;
import org.apache.iceberg.StructLike;
import org.apache.iceberg.data.GenericRecord;
import org.apache.iceberg.types.Type;
import org.apache.iceberg.types.Types;

Expand Down Expand Up @@ -78,6 +80,18 @@ static void assertEquals(Types.StructType struct, Record expected, Record actual
}
}

static void assertEquals(Types.StructType struct, StructLike expected, StructLike actual) {
List<Types.NestedField> fields = struct.fields();
for (int i = 0; i < fields.size(); i += 1) {
Type fieldType = fields.get(i).type();

Object expectedValue = expected.get(i, Object.class);
Object actualValue = actual.get(i, Object.class);

assertEquals(fieldType, expectedValue, actualValue);
}
}

static void assertEquals(Types.ListType list, List<?> expected, List<?> actual) {
Type elementType = list.elementType();

Expand Down Expand Up @@ -126,9 +140,20 @@ private static void assertEquals(Type type, Object expected, Object actual) {
assertThat(actual).as("Primitive value should be equal to expected").isEqualTo(expected);
break;
case STRUCT:
assertThat(expected).as("Expected should be a Record").isInstanceOf(Record.class);
assertThat(actual).as("Actual should be a Record").isInstanceOf(Record.class);
assertEquals(type.asStructType(), (Record) expected, (Record) actual);
assertThat(expected)
.as("Expected should be a Record or GenericRecord")
.isInstanceOfAny(Record.class, GenericRecord.class);

if (expected instanceof Record) {
assertThat(actual).as("Actual should be a Record").isInstanceOf(Record.class);
assertEquals(type.asStructType(), (Record) expected, (Record) actual);
} else {
assertThat(actual)
.as("Actual should be a GenericRecord")
.isInstanceOf(GenericRecord.class);
assertEquals(type.asStructType(), (GenericRecord) expected, (GenericRecord) actual);
}

break;
case LIST:
assertThat(expected).as("Expected should be a List").isInstanceOf(List.class);
Expand Down
163 changes: 122 additions & 41 deletions core/src/test/java/org/apache/iceberg/avro/RandomAvroData.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
import org.apache.avro.generic.GenericData.Record;
import org.apache.avro.util.Utf8;
import org.apache.iceberg.Schema;
import org.apache.iceberg.StructLike;
import org.apache.iceberg.data.GenericRecord;
import org.apache.iceberg.relocated.com.google.common.collect.Lists;
import org.apache.iceberg.relocated.com.google.common.collect.Maps;
import org.apache.iceberg.relocated.com.google.common.collect.Sets;
Expand All @@ -51,6 +53,66 @@ public static List<Record> generate(Schema schema, int numRecords, long seed) {
return records;
}

public static List<StructLike> generateStructLike(Schema schema, int numRecords, long seed) {
RandomInternalDataGenerator generator = new RandomInternalDataGenerator(seed);
List<StructLike> records = Lists.newArrayListWithExpectedSize(numRecords);
for (int i = 0; i < numRecords; i += 1) {
records.add((StructLike) TypeUtil.visit(schema, generator));
}

return records;
}

private static List<Object> generateList(
Random random, Types.ListType list, Supplier<Object> elementResult) {
int numElements = random.nextInt(20);

List<Object> result = Lists.newArrayListWithExpectedSize(numElements);
for (int i = 0; i < numElements; i += 1) {
// return null 5% of the time when the value is optional
if (list.isElementOptional() && random.nextInt(20) == 1) {
result.add(null);
} else {
result.add(elementResult.get());
}
}

return result;
}

private static Map<Object, Object> generateMap(
Random random, Types.MapType map, Supplier<Object> keyResult, Supplier<Object> valueResult) {
int numEntries = random.nextInt(20);

Map<Object, Object> result = Maps.newLinkedHashMap();
Supplier<Object> keyFunc;
if (map.keyType() == Types.StringType.get()) {
keyFunc = () -> keyResult.get().toString();
} else {
keyFunc = keyResult;
}

Set<Object> keySet = Sets.newHashSet();
for (int i = 0; i < numEntries; i += 1) {
Object key = keyFunc.get();
// ensure no collisions
while (keySet.contains(key)) {
key = keyFunc.get();
}

keySet.add(key);

// return null 5% of the time when the value is optional
if (map.isValueOptional() && random.nextInt(20) == 1) {
result.put(key, null);
} else {
result.put(key, valueResult.get());
}
}

return result;
}

private static class RandomDataGenerator extends TypeUtil.CustomOrderSchemaVisitor<Object> {
private final Map<Type, org.apache.avro.Schema> typeToSchema;
private final Random random;
Expand Down Expand Up @@ -88,64 +150,83 @@ public Object field(Types.NestedField field, Supplier<Object> fieldResult) {

@Override
public Object list(Types.ListType list, Supplier<Object> elementResult) {
int numElements = random.nextInt(20);

List<Object> result = Lists.newArrayListWithExpectedSize(numElements);
for (int i = 0; i < numElements; i += 1) {
// return null 5% of the time when the value is optional
if (list.isElementOptional() && random.nextInt(20) == 1) {
result.add(null);
} else {
result.add(elementResult.get());
}
return generateList(random, list, elementResult);
}

@Override
public Object map(Types.MapType map, Supplier<Object> keyResult, Supplier<Object> valueResult) {
return generateMap(random, map, keyResult, valueResult);
}

@Override
public Object primitive(Type.PrimitiveType primitive) {
Object result = RandomUtil.generatePrimitive(primitive, random);
// For the primitives that Avro needs a different type than Spark, fix
// them here.
switch (primitive.typeId()) {
case STRING:
return new Utf8((String) result);
case FIXED:
return new GenericData.Fixed(typeToSchema.get(primitive), (byte[]) result);
case BINARY:
return ByteBuffer.wrap((byte[]) result);
case UUID:
return UUID.nameUUIDFromBytes((byte[]) result);
default:
return result;
}
}
}

return result;
private static class RandomInternalDataGenerator
extends TypeUtil.CustomOrderSchemaVisitor<Object> {
private final Random random;

private RandomInternalDataGenerator(long seed) {
this.random = new Random(seed);
}

@Override
public Object map(Types.MapType map, Supplier<Object> keyResult, Supplier<Object> valueResult) {
int numEntries = random.nextInt(20);
public StructLike schema(Schema schema, Supplier<Object> structResult) {
return (StructLike) structResult.get();
}

Map<Object, Object> result = Maps.newLinkedHashMap();
Supplier<Object> keyFunc;
if (map.keyType() == Types.StringType.get()) {
keyFunc = () -> keyResult.get().toString();
} else {
keyFunc = keyResult;
@Override
public StructLike struct(Types.StructType struct, Iterable<Object> fieldResults) {
StructLike rec = GenericRecord.create(struct);
List<Object> values = Lists.newArrayList(fieldResults);
for (int i = 0; i < values.size(); i += 1) {
rec.set(i, values.get(i));
}

Set<Object> keySet = Sets.newHashSet();
for (int i = 0; i < numEntries; i += 1) {
Object key = keyFunc.get();
// ensure no collisions
while (keySet.contains(key)) {
key = keyFunc.get();
}

keySet.add(key);

// return null 5% of the time when the value is optional
if (map.isValueOptional() && random.nextInt(20) == 1) {
result.put(key, null);
} else {
result.put(key, valueResult.get());
}
return rec;
}

@Override
public Object field(Types.NestedField field, Supplier<Object> fieldResult) {
// return null 5% of the time when the value is optional
if (field.isOptional() && random.nextInt(20) == 1) {
return null;
}
return fieldResult.get();
}

@Override
public Object list(Types.ListType list, Supplier<Object> elementResult) {
return generateList(random, list, elementResult);
}

return result;
@Override
public Object map(Types.MapType map, Supplier<Object> keyResult, Supplier<Object> valueResult) {
return generateMap(random, map, keyResult, valueResult);
}

@Override
public Object primitive(Type.PrimitiveType primitive) {
Object result = RandomUtil.generatePrimitive(primitive, random);
// For the primitives that Avro needs a different type than Spark, fix
// them here.

switch (primitive.typeId()) {
case STRING:
return new Utf8((String) result);
case FIXED:
return new GenericData.Fixed(typeToSchema.get(primitive), (byte[]) result);
case BINARY:
return ByteBuffer.wrap((byte[]) result);
case UUID:
Expand Down
Loading

0 comments on commit 3b95d75

Please sign in to comment.