From 3ea027ac26b41edc0f08fd8c042ee918c523bdac Mon Sep 17 00:00:00 2001 From: Christophe Le Saec <51320496+clesaec@users.noreply.github.com> Date: Mon, 9 Oct 2023 13:31:39 +0200 Subject: [PATCH] AVRO-3779: [java] any big decimal conversion (#2282) * AVRO-3779: any big decimal conversion --- .../docs/++version++/Specification/_index.md | 14 +++ .../java/org/apache/avro/Conversions.java | 65 ++++++++++++++ .../java/org/apache/avro/LogicalTypes.java | 25 ++++++ .../apache/avro/TestBigDecimalConversion.java | 85 +++++++++++++++++++ .../specific/TestRecordWithLogicalTypes.java | 33 ++++++- .../specific/TestSpecificToFromByteArray.java | 5 +- .../resources/TestRecordWithLogicalTypes.avsc | 8 +- .../org/apache/avro/compiler/idl/idl.jj | 3 + 8 files changed, 231 insertions(+), 7 deletions(-) create mode 100644 lang/java/avro/src/test/java/org/apache/avro/TestBigDecimalConversion.java diff --git a/doc/content/en/docs/++version++/Specification/_index.md b/doc/content/en/docs/++version++/Specification/_index.md index e43120c2e43..63ac2c9cc00 100755 --- a/doc/content/en/docs/++version++/Specification/_index.md +++ b/doc/content/en/docs/++version++/Specification/_index.md @@ -810,6 +810,20 @@ Scale must be zero or a positive integer less than or equal to the precision. For the purposes of schema resolution, two schemas that are `decimal` logical types _match_ if their scales and precisions match. +**alternative** + +As it's not always possible to fix scale and precision in advance for a decimal field, `big-decimal` is another `decimal` logical type restrict to Avro _bytes_. + +_only available in Java_ + +```json +{ + "type": "bytes", + "logicalType": "big-decimal" +} +``` +Here, as scale property is stored in value itself it needs more bytes than preceding `decimal` type, but it allows more flexibility. + ### UUID The `uuid` logical type represents a random generated universally unique identifier (UUID). diff --git a/lang/java/avro/src/main/java/org/apache/avro/Conversions.java b/lang/java/avro/src/main/java/org/apache/avro/Conversions.java index 1a1754226b4..043ddfa0725 100644 --- a/lang/java/avro/src/main/java/org/apache/avro/Conversions.java +++ b/lang/java/avro/src/main/java/org/apache/avro/Conversions.java @@ -22,8 +22,14 @@ import org.apache.avro.generic.GenericEnumSymbol; import org.apache.avro.generic.GenericFixed; import org.apache.avro.generic.IndexedRecord; +import org.apache.avro.io.BinaryDecoder; +import org.apache.avro.io.BinaryEncoder; +import org.apache.avro.io.DecoderFactory; +import org.apache.avro.io.EncoderFactory; import org.apache.avro.util.TimePeriod; +import java.io.ByteArrayOutputStream; +import java.io.IOException; import java.math.BigDecimal; import java.math.BigInteger; import java.math.RoundingMode; @@ -150,6 +156,65 @@ private static BigDecimal validate(final LogicalTypes.Decimal decimal, BigDecima } } + public static class BigDecimalConversion extends Conversion { + + @Override + public Class getConvertedType() { + return BigDecimal.class; + } + + @Override + public String getLogicalTypeName() { + return "big-decimal"; + } + + @Override + public BigDecimal fromBytes(final ByteBuffer value, final Schema schema, final LogicalType type) { + BinaryDecoder decoder = DecoderFactory.get().binaryDecoder(value.array(), null); + + try { + BigInteger bg = null; + ByteBuffer buffer = decoder.readBytes(null); + byte[] array = buffer.array(); + if (array != null && array.length > 0) { + bg = new BigInteger(array); + } + + int scale = decoder.readInt(); + return new BigDecimal(bg, scale); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public ByteBuffer toBytes(final BigDecimal value, final Schema schema, final LogicalType type) { + try { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + BinaryEncoder encoder = EncoderFactory.get().binaryEncoder(out, null); + + BigInteger unscaledValue = value.unscaledValue(); + if (unscaledValue != null) { + encoder.writeBytes(unscaledValue.toByteArray()); + } else { + encoder.writeBytes(new byte[] {}); + } + encoder.writeInt(value.scale()); + encoder.flush(); + return ByteBuffer.wrap(out.toByteArray()); + + } catch (IOException e) { + throw new RuntimeException(e); + } + + } + + @Override + public Schema getRecommendedSchema() { + return LogicalTypes.bigDecimal().addToSchema(Schema.create(Schema.Type.BYTES)); + } + } + public static class DurationConversion extends Conversion { @Override public Class getConvertedType() { diff --git a/lang/java/avro/src/main/java/org/apache/avro/LogicalTypes.java b/lang/java/avro/src/main/java/org/apache/avro/LogicalTypes.java index 4292756a2d2..dbf1a1fd867 100644 --- a/lang/java/avro/src/main/java/org/apache/avro/LogicalTypes.java +++ b/lang/java/avro/src/main/java/org/apache/avro/LogicalTypes.java @@ -137,6 +137,9 @@ private static LogicalType fromSchemaImpl(Schema schema, boolean throwErrors) { case DECIMAL: logicalType = new Decimal(schema); break; + case BIG_DECIMAL: + logicalType = BIG_DECIMAL_TYPE; + break; case UUID: logicalType = UUID_TYPE; break; @@ -182,6 +185,7 @@ private static LogicalType fromSchemaImpl(Schema schema, boolean throwErrors) { } private static final String DECIMAL = "decimal"; + private static final String BIG_DECIMAL = "big-decimal"; private static final String DURATION = "duration"; private static final String UUID = "uuid"; private static final String DATE = "date"; @@ -202,6 +206,13 @@ public static Decimal decimal(int precision, int scale) { return new Decimal(precision, scale); } + private static final BigDecimal BIG_DECIMAL_TYPE = new BigDecimal(); + + /** Create a Big Decimal LogicalType that can accept any precision and scale */ + public static BigDecimal bigDecimal() { + return BIG_DECIMAL_TYPE; + } + private static final LogicalType UUID_TYPE = new Uuid(); public static LogicalType uuid() { @@ -402,6 +413,20 @@ public int hashCode() { } } + public static class BigDecimal extends LogicalType { + private BigDecimal() { + super(BIG_DECIMAL); + } + + @Override + public void validate(final Schema schema) { + super.validate(schema); + if (schema.getType() != Schema.Type.BYTES) { + throw new IllegalArgumentException("BigDecimal can only be used with an underlying bytes type"); + } + } + } + /** Date represents a date without a time */ public static class Date extends LogicalType { private Date() { diff --git a/lang/java/avro/src/test/java/org/apache/avro/TestBigDecimalConversion.java b/lang/java/avro/src/test/java/org/apache/avro/TestBigDecimalConversion.java new file mode 100644 index 00000000000..e781fe07bd9 --- /dev/null +++ b/lang/java/avro/src/test/java/org/apache/avro/TestBigDecimalConversion.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package org.apache.avro; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +import java.math.BigDecimal; +import java.nio.ByteBuffer; +import java.util.Iterator; +import java.util.NoSuchElementException; +import java.util.Spliterator; +import java.util.Spliterators; +import java.util.stream.Stream; +import java.util.stream.StreamSupport; + +public class TestBigDecimalConversion { + + private Conversion conversion = new Conversions.BigDecimalConversion(); + + private final LogicalType bigDecimal = LogicalTypes.bigDecimal(); + + private Schema bytesSchema = conversion.getRecommendedSchema(); + + @ParameterizedTest + @MethodSource("listBigDecimal") + void bigdec(BigDecimal d1) { + ByteBuffer d1bytes = conversion.toBytes(d1, bytesSchema, bigDecimal); + BigDecimal decimal1 = conversion.fromBytes(d1bytes, bytesSchema, bigDecimal); + Assertions.assertEquals(decimal1, d1); + } + + static Stream listBigDecimal() { + Iterator iterator = new Iterator() { + int index = 0; + + BigDecimal step = new BigDecimal(-2.7d); + + BigDecimal current = new BigDecimal(1.0d); + + @Override + public boolean hasNext() { + if (index == 50) { + // test small bigdecimal + current = new BigDecimal(1.0d); + step = new BigDecimal(-0.71d); + } + return index < 100; + } + + @Override + public BigDecimal next() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + index++; + current = current.multiply(step); + return current; + } + }; + return StreamSupport.stream(Spliterators.spliteratorUnknownSize(iterator, Spliterator.ORDERED), false) + .map(Arguments::of); + + } + +} diff --git a/lang/java/avro/src/test/java/org/apache/avro/specific/TestRecordWithLogicalTypes.java b/lang/java/avro/src/test/java/org/apache/avro/specific/TestRecordWithLogicalTypes.java index b7a89db6e59..c2e1ebd384c 100644 --- a/lang/java/avro/src/test/java/org/apache/avro/specific/TestRecordWithLogicalTypes.java +++ b/lang/java/avro/src/test/java/org/apache/avro/specific/TestRecordWithLogicalTypes.java @@ -5,6 +5,7 @@ */ package org.apache.avro.specific; +import org.apache.avro.Conversions; import org.apache.avro.data.TimeConversions; import org.apache.avro.message.BinaryMessageDecoder; import org.apache.avro.message.BinaryMessageEncoder; @@ -16,7 +17,7 @@ public class TestRecordWithLogicalTypes extends org.apache.avro.specific.Specifi implements org.apache.avro.specific.SpecificRecord { private static final long serialVersionUID = 3313339903648295220L; public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse( - "{\"type\":\"record\",\"name\":\"TestRecordWithLogicalTypes\",\"namespace\":\"org.apache.avro.specific\",\"fields\":[{\"name\":\"b\",\"type\":\"boolean\"},{\"name\":\"i32\",\"type\":\"int\"},{\"name\":\"i64\",\"type\":\"long\"},{\"name\":\"f32\",\"type\":\"float\"},{\"name\":\"f64\",\"type\":\"double\"},{\"name\":\"s\",\"type\":[\"null\",\"string\"],\"default\":null},{\"name\":\"d\",\"type\":{\"type\":\"int\",\"logicalType\":\"date\"}},{\"name\":\"t\",\"type\":{\"type\":\"int\",\"logicalType\":\"time-millis\"}},{\"name\":\"ts\",\"type\":{\"type\":\"long\",\"logicalType\":\"timestamp-millis\"}},{\"name\":\"dec\",\"type\":{\"type\":\"bytes\",\"logicalType\":\"decimal\",\"precision\":9,\"scale\":2}}]}"); + "{\"type\":\"record\",\"name\":\"TestRecordWithLogicalTypes\",\"namespace\":\"org.apache.avro.specific\",\"fields\":[{\"name\":\"b\",\"type\":\"boolean\"},{\"name\":\"i32\",\"type\":\"int\"},{\"name\":\"i64\",\"type\":\"long\"},{\"name\":\"f32\",\"type\":\"float\"},{\"name\":\"f64\",\"type\":\"double\"},{\"name\":\"s\",\"type\":[\"null\",\"string\"],\"default\":null},{\"name\":\"d\",\"type\":{\"type\":\"int\",\"logicalType\":\"date\"}},{\"name\":\"t\",\"type\":{\"type\":\"int\",\"logicalType\":\"time-millis\"}},{\"name\":\"ts\",\"type\":{\"type\":\"long\",\"logicalType\":\"timestamp-millis\"}},{\"name\":\"dec\",\"type\":{\"type\":\"bytes\",\"logicalType\":\"decimal\",\"precision\":9,\"scale\":2}},{\"name\":\"bd\",\"type\":{\"type\":\"bytes\",\"logicalType\":\"big-decimal\"}}]}"); public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; @@ -77,6 +78,8 @@ public static TestRecordWithLogicalTypes fromByteBuffer(java.nio.ByteBuffer b) t public java.time.Instant ts; @Deprecated public java.math.BigDecimal dec; + @Deprecated + public java.math.BigDecimal bd; /** * Default constructor. Note that this does not initialize fields to their @@ -99,10 +102,11 @@ public TestRecordWithLogicalTypes() { * @param t The new value for t * @param ts The new value for ts * @param dec The new value for dec + * @param bd The new value for bd */ public TestRecordWithLogicalTypes(java.lang.Boolean b, java.lang.Integer i32, java.lang.Long i64, java.lang.Float f32, java.lang.Double f64, java.lang.CharSequence s, java.time.LocalDate d, java.time.LocalTime t, - java.time.Instant ts, java.math.BigDecimal dec) { + java.time.Instant ts, java.math.BigDecimal dec, java.math.BigDecimal bd) { this.b = b; this.i32 = i32; this.i64 = i64; @@ -113,6 +117,7 @@ public TestRecordWithLogicalTypes(java.lang.Boolean b, java.lang.Integer i32, ja this.t = t; this.ts = ts; this.dec = dec; + this.bd = bd; } @Override @@ -144,18 +149,24 @@ public java.lang.Object get(int field$) { return ts; case 9: return dec; + case 10: + return bd; default: - throw new org.apache.avro.AvroRuntimeException("Bad index"); + throw new org.apache.avro.AvroRuntimeException("Bad index " + field$); } } protected static final org.apache.avro.Conversions.DecimalConversion DECIMAL_CONVERSION = new org.apache.avro.Conversions.DecimalConversion(); + + protected static final Conversions.BigDecimalConversion BIG_DECIMAL_CONVERSION = new org.apache.avro.Conversions.BigDecimalConversion(); + protected static final TimeConversions.DateConversion DATE_CONVERSION = new TimeConversions.DateConversion(); protected static final TimeConversions.TimeMillisConversion TIME_CONVERSION = new TimeConversions.TimeMillisConversion(); protected static final TimeConversions.TimestampMillisConversion TIMESTAMP_CONVERSION = new TimeConversions.TimestampMillisConversion(); private static final org.apache.avro.Conversion[] conversions = new org.apache.avro.Conversion[] { null, null, - null, null, null, null, DATE_CONVERSION, TIME_CONVERSION, TIMESTAMP_CONVERSION, DECIMAL_CONVERSION, null }; + null, null, null, null, DATE_CONVERSION, TIME_CONVERSION, TIMESTAMP_CONVERSION, DECIMAL_CONVERSION, + BIG_DECIMAL_CONVERSION }; @Override public org.apache.avro.Conversion getConversion(int field) { @@ -197,6 +208,9 @@ public void put(int field$, java.lang.Object value$) { case 9: dec = (java.math.BigDecimal) value$; break; + case 10: + bd = (java.math.BigDecimal) value$; + break; default: throw new org.apache.avro.AvroRuntimeException("Bad index"); } @@ -438,6 +452,8 @@ public static class Builder extends org.apache.avro.specific.SpecificRecordBuild private java.time.Instant ts; private java.math.BigDecimal dec; + private java.math.BigDecimal bd; + /** Creates a new Builder */ private Builder() { super(SCHEMA$); @@ -490,6 +506,10 @@ private Builder(TestRecordWithLogicalTypes.Builder other) { this.dec = data().deepCopy(fields()[9].schema(), other.dec); fieldSetFlags()[9] = other.fieldSetFlags()[9]; } + if (isValidValue(fields()[10], other.bd)) { + this.bd = data().deepCopy(fields()[10].schema(), other.bd); + fieldSetFlags()[10] = other.fieldSetFlags()[10]; + } } /** @@ -539,6 +559,10 @@ private Builder(TestRecordWithLogicalTypes other) { this.dec = data().deepCopy(fields()[9].schema(), other.dec); fieldSetFlags()[9] = true; } + if (isValidValue(fields()[10], other.bd)) { + this.bd = data().deepCopy(fields()[10].schema(), other.bd); + fieldSetFlags()[10] = true; + } } /** @@ -968,6 +992,7 @@ public TestRecordWithLogicalTypes build() { record.t = fieldSetFlags()[7] ? this.t : (java.time.LocalTime) defaultValue(fields()[7]); record.ts = fieldSetFlags()[8] ? this.ts : (java.time.Instant) defaultValue(fields()[8]); record.dec = fieldSetFlags()[9] ? this.dec : (java.math.BigDecimal) defaultValue(fields()[9]); + record.bd = fieldSetFlags()[10] ? this.dec : (java.math.BigDecimal) defaultValue(fields()[10]); return record; } catch (java.lang.Exception e) { throw new org.apache.avro.AvroRuntimeException(e); diff --git a/lang/java/avro/src/test/java/org/apache/avro/specific/TestSpecificToFromByteArray.java b/lang/java/avro/src/test/java/org/apache/avro/specific/TestSpecificToFromByteArray.java index a94a3e91e7e..f81dde37407 100644 --- a/lang/java/avro/src/test/java/org/apache/avro/specific/TestSpecificToFromByteArray.java +++ b/lang/java/avro/src/test/java/org/apache/avro/specific/TestSpecificToFromByteArray.java @@ -44,7 +44,7 @@ void specificToFromByteBufferWithLogicalTypes() throws IOException { Instant instant = Instant.now().truncatedTo(ChronoUnit.MILLIS); final TestRecordWithLogicalTypes record = new TestRecordWithLogicalTypes(true, 34, 35L, 3.14F, 3019.34, null, - LocalDate.now(), t, instant, new BigDecimal("123.45")); + LocalDate.now(), t, instant, new BigDecimal("123.45"), new BigDecimal(-23.456562323)); final ByteBuffer b = record.toByteBuffer(); final TestRecordWithLogicalTypes copy = TestRecordWithLogicalTypes.fromByteBuffer(b); @@ -81,7 +81,8 @@ void specificByteArrayIncompatibleWithLogicalTypes() throws IOException { void specificByteArrayIncompatibleWithoutLogicalTypes() throws IOException { assertThrows(MissingSchemaException.class, () -> { final TestRecordWithLogicalTypes withLogicalTypes = new TestRecordWithLogicalTypes(true, 34, 35L, 3.14F, 3019.34, - null, LocalDate.now(), LocalTime.now(), Instant.now(), new BigDecimal("123.45")); + null, LocalDate.now(), LocalTime.now(), Instant.now(), new BigDecimal("123.45"), + new BigDecimal(-23.456562323)); final ByteBuffer b = withLogicalTypes.toByteBuffer(); TestRecordWithoutLogicalTypes.fromByteBuffer(b); diff --git a/lang/java/avro/src/test/resources/TestRecordWithLogicalTypes.avsc b/lang/java/avro/src/test/resources/TestRecordWithLogicalTypes.avsc index f5d212917f4..5f5e870f9c7 100644 --- a/lang/java/avro/src/test/resources/TestRecordWithLogicalTypes.avsc +++ b/lang/java/avro/src/test/resources/TestRecordWithLogicalTypes.avsc @@ -40,6 +40,12 @@ "type" : "long", "logicalType" : "timestamp-millis" } - } ] + }, { + "name" : "bd", + "type" : { + "type" : "bytes", + "logicalType" : "big-decimal" + } + } ] } diff --git a/lang/java/compiler/src/main/javacc/org/apache/avro/compiler/idl/idl.jj b/lang/java/compiler/src/main/javacc/org/apache/avro/compiler/idl/idl.jj index 117764497e3..af2480ce992 100644 --- a/lang/java/compiler/src/main/javacc/org/apache/avro/compiler/idl/idl.jj +++ b/lang/java/compiler/src/main/javacc/org/apache/avro/compiler/idl/idl.jj @@ -292,6 +292,7 @@ TOKEN : | < TIME: "time_ms" > | < TIMESTAMP: "timestamp_ms" > | < DECIMAL: "decimal" > +| < BIG_DECIMAL: "big_decimal" > | < LOCAL_TIMESTAMP: "local_timestamp_ms" > | < UUID: "uuid" > } @@ -1587,6 +1588,7 @@ Schema PrimitiveType(): | "timestamp_ms" { return LogicalTypes.timestampMillis().addToSchema(Schema.create(Type.LONG)); } | "local_timestamp_ms" { return LogicalTypes.localTimestampMillis().addToSchema(Schema.create(Type.LONG)); } | "decimal" s = DecimalTypeProperties() { return s; } +| "big_decimal" { return LogicalTypes.bigDecimal().addToSchema(Schema.create(Type.BYTES)); } | "uuid" {return LogicalTypes.uuid().addToSchema(Schema.create(Type.STRING));} } @@ -1677,6 +1679,7 @@ Token AnyIdentifier(): t = | t = | t = | + t = | t = ) { return t;