From c91b8876b1c21c93ee9021abe6d0d04cf1f50a32 Mon Sep 17 00:00:00 2001 From: Christophe Le Saec <51320496+clesaec@users.noreply.github.com> Date: Mon, 9 Oct 2023 13:32:37 +0200 Subject: [PATCH] AVRO-3779: using rust bigdecimal (#2302) * AVRO-3779: using rust bigdecimal Signed-off-by: Martin Tzvetanov Grigorov Co-authored-by: Martin Tzvetanov Grigorov --- lang/rust/Cargo.lock | 14 ++++++ lang/rust/avro/Cargo.toml | 1 + lang/rust/avro/src/decimal.rs | 85 ++++++++++++++++++++++++++++++++++- lang/rust/avro/src/decode.rs | 12 +++-- lang/rust/avro/src/encode.rs | 9 +++- lang/rust/avro/src/error.rs | 12 +++++ lang/rust/avro/src/schema.rs | 39 ++++++++++++++++ lang/rust/avro/src/types.rs | 37 ++++++++++++++- 8 files changed, 200 insertions(+), 9 deletions(-) diff --git a/lang/rust/Cargo.lock b/lang/rust/Cargo.lock index c4fa6e3c60c..d585c8fb403 100644 --- a/lang/rust/Cargo.lock +++ b/lang/rust/Cargo.lock @@ -68,6 +68,7 @@ dependencies = [ "anyhow", "apache-avro-derive", "apache-avro-test-helper", + "bigdecimal", "bzip2", "crc32fast", "criterion", @@ -155,6 +156,19 @@ dependencies = [ "rustc-demangle", ] +[[package]] +name = "bigdecimal" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "454bca3db10617b88b566f205ed190aedb0e0e6dd4cad61d3988a72e8c5594cb" +dependencies = [ + "autocfg", + "libm", + "num-bigint", + "num-integer", + "num-traits", +] + [[package]] name = "bitflags" version = "1.3.2" diff --git a/lang/rust/avro/Cargo.toml b/lang/rust/avro/Cargo.toml index 163bbfe003f..b9610a8b544 100644 --- a/lang/rust/avro/Cargo.toml +++ b/lang/rust/avro/Cargo.toml @@ -73,6 +73,7 @@ typed-builder = { default-features = false, version = "0.16.2" } uuid = { default-features = false, version = "1.4.1", features = ["serde", "std"] } xz2 = { default-features = false, version = "0.1.7", optional = true } zstd = { default-features = false, version = "0.12.4+zstd.1.5.2", optional = true } +bigdecimal = "0.4" [target.'cfg(target_arch = "wasm32")'.dependencies] quad-rand = { default-features = false, version = "0.2.1" } diff --git a/lang/rust/avro/src/decimal.rs b/lang/rust/avro/src/decimal.rs index a06ab45a6ca..da5cb35f00f 100644 --- a/lang/rust/avro/src/decimal.rs +++ b/lang/rust/avro/src/decimal.rs @@ -15,8 +15,15 @@ // specific language governing permissions and limitations // under the License. -use crate::{AvroResult, Error}; +use crate::{ + decode::{decode_len, decode_long}, + encode::{encode_bytes, encode_long}, + types::Value, + AvroResult, Error, +}; +use bigdecimal::BigDecimal; use num_bigint::{BigInt, Sign}; +use std::io::Read; #[derive(Debug, Clone)] pub struct Decimal { @@ -105,12 +112,47 @@ impl> From for Decimal { } } +pub(crate) fn serialize_big_decimal(decimal: &BigDecimal) -> Vec { + let mut buffer: Vec = Vec::new(); + let (big_int, exponent): (BigInt, i64) = decimal.as_bigint_and_exponent(); + let big_endian_value: Vec = big_int.to_signed_bytes_be(); + encode_bytes(&big_endian_value, &mut buffer); + encode_long(exponent, &mut buffer); + + buffer +} + +pub(crate) fn deserialize_big_decimal(bytes: &Vec) -> Result { + let mut bytes: &[u8] = bytes.as_slice(); + let mut big_decimal_buffer = match decode_len(&mut bytes) { + Ok(size) => vec![0u8; size], + Err(_err) => return Err(Error::BigDecimalLen), + }; + + bytes + .read_exact(&mut big_decimal_buffer[..]) + .map_err(Error::ReadDouble)?; + + match decode_long(&mut bytes) { + Ok(Value::Long(scale_value)) => { + let big_int: BigInt = BigInt::from_signed_bytes_be(&big_decimal_buffer); + let decimal = BigDecimal::new(big_int, scale_value); + Ok(decimal) + } + _ => Err(Error::BigDecimalScale), + } +} + #[cfg(test)] mod tests { use super::*; use apache_avro_test_helper::TestResult; + use bigdecimal::{One, Zero}; use pretty_assertions::assert_eq; - use std::convert::TryFrom; + use std::{ + convert::TryFrom, + ops::{Div, Mul}, + }; #[test] fn test_decimal_from_bytes_from_ref_decimal() -> TestResult { @@ -133,4 +175,43 @@ mod tests { Ok(()) } + + #[test] + fn test_avro_3779_bigdecimal_serial() -> TestResult { + let value: bigdecimal::BigDecimal = + bigdecimal::BigDecimal::from(-1421).div(bigdecimal::BigDecimal::from(2)); + let mut current: bigdecimal::BigDecimal = bigdecimal::BigDecimal::one(); + + for iter in 1..180 { + let result: Vec = serialize_big_decimal(¤t); + + let deserialize_big_decimal: Result = + deserialize_big_decimal(&result); + assert!( + deserialize_big_decimal.is_ok(), + "can't deserialize for iter {iter}" + ); + assert_eq!( + current, + deserialize_big_decimal.unwrap(), + "not equals for ${iter}" + ); + current = current.mul(&value); + } + + let result: Vec = serialize_big_decimal(&BigDecimal::zero()); + let deserialize_big_decimal: Result = + deserialize_big_decimal(&result); + assert!( + deserialize_big_decimal.is_ok(), + "can't deserialize for zero" + ); + assert_eq!( + BigDecimal::zero(), + deserialize_big_decimal.unwrap(), + "not equals for zero" + ); + + Ok(()) + } } diff --git a/lang/rust/avro/src/decode.rs b/lang/rust/avro/src/decode.rs index b13c76739b9..7857bbec565 100644 --- a/lang/rust/avro/src/decode.rs +++ b/lang/rust/avro/src/decode.rs @@ -16,7 +16,7 @@ // under the License. use crate::{ - decimal::Decimal, + decimal::{deserialize_big_decimal, Decimal}, duration::Duration, schema::{ DecimalSchema, EnumSchema, FixedSchema, Name, Namespace, RecordSchema, ResolvedSchema, @@ -36,7 +36,7 @@ use std::{ use uuid::Uuid; #[inline] -fn decode_long(reader: &mut R) -> AvroResult { +pub(crate) fn decode_long(reader: &mut R) -> AvroResult { zag_i64(reader).map(Value::Long) } @@ -46,7 +46,7 @@ fn decode_int(reader: &mut R) -> AvroResult { } #[inline] -fn decode_len(reader: &mut R) -> AvroResult { +pub(crate) fn decode_len(reader: &mut R) -> AvroResult { let len = zag_i64(reader)?; safe_len(usize::try_from(len).map_err(|e| Error::ConvertI64ToUsize(e, len))?) } @@ -114,6 +114,12 @@ pub(crate) fn decode_internal>( }, schema => Err(Error::ResolveDecimalSchema(schema.into())), }, + Schema::BigDecimal => { + match decode_internal(&Schema::Bytes, names, enclosing_namespace, reader)? { + Value::Bytes(bytes) => deserialize_big_decimal(&bytes).map(Value::BigDecimal), + value => Err(Error::BytesValue(value.into())), + } + } Schema::Uuid => Ok(Value::Uuid( Uuid::from_str( match decode_internal(&Schema::String, names, enclosing_namespace, reader)? { diff --git a/lang/rust/avro/src/encode.rs b/lang/rust/avro/src/encode.rs index 6e52e0c3b1e..f347767f3cc 100644 --- a/lang/rust/avro/src/encode.rs +++ b/lang/rust/avro/src/encode.rs @@ -16,6 +16,7 @@ // under the License. use crate::{ + decimal::serialize_big_decimal, schema::{ DecimalSchema, EnumSchema, FixedSchema, Name, Namespace, RecordSchema, ResolvedSchema, Schema, SchemaKind, @@ -40,13 +41,13 @@ pub fn encode(value: &Value, schema: &Schema, buffer: &mut Vec) -> AvroResul encode_internal(value, schema, rs.get_names(), &None, buffer) } -fn encode_bytes + ?Sized>(s: &B, buffer: &mut Vec) { +pub(crate) fn encode_bytes + ?Sized>(s: &B, buffer: &mut Vec) { let bytes = s.as_ref(); encode_long(bytes.len() as i64, buffer); buffer.extend_from_slice(bytes); } -fn encode_long(i: i64, buffer: &mut Vec) { +pub(crate) fn encode_long(i: i64, buffer: &mut Vec) { zig_i64(i, buffer) } @@ -116,6 +117,10 @@ pub(crate) fn encode_internal>( &uuid.to_string(), buffer, ), + Value::BigDecimal(bg) => { + let mut buf: Vec = serialize_big_decimal(bg); + buffer.append(&mut buf); + } Value::Bytes(bytes) => match *schema { Schema::Bytes => encode_bytes(bytes, buffer), Schema::Fixed { .. } => buffer.extend(bytes), diff --git a/lang/rust/avro/src/error.rs b/lang/rust/avro/src/error.rs index bf066b8a5ee..1487296a716 100644 --- a/lang/rust/avro/src/error.rs +++ b/lang/rust/avro/src/error.rs @@ -115,6 +115,9 @@ pub enum Error { #[error("expected UUID, got: {0:?}")] GetUuid(ValueKind), + #[error("expected BigDecimal, got: {0:?}")] + GetBigdecimal(ValueKind), + #[error("Fixed bytes of size 12 expected, got Fixed of size {0}")] GetDecimalFixedBytes(usize), @@ -289,6 +292,15 @@ pub enum Error { #[error("The decimal precision ({precision}) must be a positive number")] DecimalPrecisionMuBePositive { precision: usize }, + #[error("Unreadable decimal sign")] + BigDecimalSign, + + #[error("Unreadable length for decimal inner bytes")] + BigDecimalLen, + + #[error("Unreadable decimal scale")] + BigDecimalScale, + #[error("Unexpected `type` {0} variant for `logicalType`")] GetLogicalTypeVariant(serde_json::Value), diff --git a/lang/rust/avro/src/schema.rs b/lang/rust/avro/src/schema.rs index 7c6b7dc9e43..bb914f4a357 100644 --- a/lang/rust/avro/src/schema.rs +++ b/lang/rust/avro/src/schema.rs @@ -112,6 +112,9 @@ pub enum Schema { /// Logical type which represents `Decimal` values. The underlying type is serialized and /// deserialized as `Schema::Bytes` or `Schema::Fixed`. Decimal(DecimalSchema), + /// Logical type which represents `Decimal` values without predefined scale. + /// The underlying type is serialized and deserialized as `Schema::Bytes` + BigDecimal, /// A universally unique identifier, annotating a string. Uuid, /// Logical type which represents the number of days since the unix epoch. @@ -189,6 +192,7 @@ impl From<&types::Value> for SchemaKind { Value::Enum(_, _) => Self::Enum, Value::Fixed(_, _) => Self::Fixed, Value::Decimal { .. } => Self::Decimal, + Value::BigDecimal(_) => Self::BigDecimal, Value::Uuid(_) => Self::Uuid, Value::Date(_) => Self::Date, Value::TimeMillis(_) => Self::TimeMillis, @@ -1359,6 +1363,10 @@ impl Parser { inner, })); } + "big-decimal" => { + logical_verify_type(complex, &[SchemaKind::Bytes], self, enclosing_namespace)?; + return Ok(Schema::BigDecimal); + } "uuid" => { logical_verify_type(complex, &[SchemaKind::String], self, enclosing_namespace)?; return Ok(Schema::Uuid); @@ -1909,6 +1917,12 @@ impl Serialize for Schema { map.serialize_entry("precision", precision)?; map.end() } + Schema::BigDecimal => { + let mut map = serializer.serialize_map(None)?; + map.serialize_entry("type", "bytes")?; + map.serialize_entry("logicalType", "big-decimal")?; + map.end() + } Schema::Uuid => { let mut map = serializer.serialize_map(None)?; map.serialize_entry("type", "string")?; @@ -5155,6 +5169,31 @@ mod tests { Ok(()) } + #[test] + fn test_avro_3779_bigdecimal_schema() -> TestResult { + let schema = json!( + { + "type": "record", + "name": "recordWithDecimal", + "fields": [ + { + "name": "decimal", + "type": "bytes", + "logicalType": "big-decimal" + } + ] + }); + + let parse_result = Schema::parse(&schema); + assert!( + parse_result.is_ok(), + "parse result must be ok, got: {:?}", + parse_result + ); + + Ok(()) + } + #[test] fn test_avro_3820_deny_invalid_field_names() -> TestResult { let schema_str = r#" diff --git a/lang/rust/avro/src/types.rs b/lang/rust/avro/src/types.rs index 9bb60770562..715094f7ff2 100644 --- a/lang/rust/avro/src/types.rs +++ b/lang/rust/avro/src/types.rs @@ -17,7 +17,7 @@ //! Logic handling the intermediate representation of Avro values. use crate::{ - decimal::Decimal, + decimal::{deserialize_big_decimal, serialize_big_decimal, Decimal}, duration::Duration, schema::{ DecimalSchema, EnumSchema, FixedSchema, Name, Namespace, Precision, RecordField, @@ -25,6 +25,7 @@ use crate::{ }, AvroResult, Error, }; +use bigdecimal::BigDecimal; use serde_json::{Number, Value as JsonValue}; use std::{ borrow::Borrow, @@ -100,6 +101,8 @@ pub enum Value { Date(i32), /// An Avro Decimal value. Bytes are in big-endian order, per the Avro spec. Decimal(Decimal), + /// An Avro Decimal value. + BigDecimal(BigDecimal), /// Time in milliseconds. TimeMillis(i32), /// Time in microseconds. @@ -154,6 +157,7 @@ to_value!(String, Value::String); to_value!(Vec, Value::Bytes); to_value!(uuid::Uuid, Value::Uuid); to_value!(Decimal, Value::Decimal); +to_value!(BigDecimal, Value::BigDecimal); to_value!(Duration, Value::Duration); impl From<()> for Value { @@ -327,6 +331,10 @@ impl TryFrom for JsonValue { Value::Date(d) => Ok(Self::Number(d.into())), Value::Decimal(ref d) => >::try_from(d) .map(|vec| Self::Array(vec.into_iter().map(|v| v.into()).collect())), + Value::BigDecimal(ref bg) => { + let vec1: Vec = serialize_big_decimal(bg); + Ok(Self::Array(vec1.into_iter().map(|b| b.into()).collect())) + } Value::TimeMillis(t) => Ok(Self::Number(t.into())), Value::TimeMicros(t) => Ok(Self::Number(t.into())), Value::TimestampMillis(t) => Ok(Self::Number(t.into())), @@ -425,6 +433,7 @@ impl Value { (&Value::TimeMillis(_), &Schema::TimeMillis) => None, (&Value::Date(_), &Schema::Date) => None, (&Value::Decimal(_), &Schema::Decimal { .. }) => None, + (&Value::BigDecimal(_), &Schema::BigDecimal) => None, (&Value::Duration(_), &Schema::Duration) => None, (&Value::Uuid(_), &Schema::Uuid) => None, (&Value::Float(_), &Schema::Float) => None, @@ -634,7 +643,6 @@ impl Value { }; self = v; } - match *schema { Schema::Ref { ref name } => { let name = name.fully_qualified_name(enclosing_namespace); @@ -674,6 +682,7 @@ impl Value { precision, ref inner, }) => self.resolve_decimal(precision, scale, inner), + Schema::BigDecimal => self.resolve_bigdecimal(), Schema::Date => self.resolve_date(), Schema::TimeMillis => self.resolve_time_millis(), Schema::TimeMicros => self.resolve_time_micros(), @@ -696,6 +705,14 @@ impl Value { }) } + fn resolve_bigdecimal(self) -> Result { + Ok(match self { + bg @ Value::BigDecimal(_) => bg, + Value::Bytes(b) => Value::BigDecimal(deserialize_big_decimal(&b).unwrap()), + other => return Err(Error::GetBigdecimal(other.into())), + }) + } + fn resolve_duration(self) -> Result { Ok(match self { duration @ Value::Duration { .. } => duration, @@ -2925,4 +2942,20 @@ Field with name '"b"' is not a member of the map items"#, Ok(()) } + + #[test] + fn test_avro_3779_bigdecimal_resolving() -> TestResult { + let schema = + r#"{"name": "bigDecimalSchema", "logicalType": "big-decimal", "type": "bytes" }"#; + + let avro_value = Value::BigDecimal(BigDecimal::from(12345678u32)); + let schema = Schema::parse_str(schema)?; + let resolve_result: AvroResult = avro_value.resolve(&schema); + assert!( + resolve_result.is_ok(), + "resolve result must be ok, got: {resolve_result:?}" + ); + + Ok(()) + } }