diff --git a/lang/rust/avro/src/decode.rs b/lang/rust/avro/src/decode.rs index bf8477fb70a..48a04f95a0b 100644 --- a/lang/rust/avro/src/decode.rs +++ b/lang/rust/avro/src/decode.rs @@ -196,7 +196,12 @@ pub(crate) fn decode_internal>( items.reserve(len); for _ in 0..len { - items.push(decode_internal(inner, names, enclosing_namespace, reader)?); + items.push(decode_internal( + &inner.items, + names, + enclosing_namespace, + reader, + )?); } } @@ -215,7 +220,8 @@ pub(crate) fn decode_internal>( for _ in 0..len { match decode_internal(&Schema::String, names, enclosing_namespace, reader)? { Value::String(key) => { - let value = decode_internal(inner, names, enclosing_namespace, reader)?; + let value = + decode_internal(&inner.types, names, enclosing_namespace, reader)?; items.insert(key, value); } value => return Err(Error::MapKeyType(value.into())), @@ -321,7 +327,7 @@ mod tests { #[test] fn test_decode_array_without_size() -> TestResult { let mut input: &[u8] = &[6, 2, 4, 6, 0]; - let result = decode(&Schema::Array(Box::new(Schema::Int)), &mut input); + let result = decode(&Schema::array(Schema::Int), &mut input); assert_eq!(Array(vec!(Int(1), Int(2), Int(3))), result?); Ok(()) @@ -330,7 +336,7 @@ mod tests { #[test] fn test_decode_array_with_size() -> TestResult { let mut input: &[u8] = &[5, 6, 2, 4, 6, 0]; - let result = decode(&Schema::Array(Box::new(Schema::Int)), &mut input); + let result = decode(&Schema::array(Schema::Int), &mut input); assert_eq!(Array(vec!(Int(1), Int(2), Int(3))), result?); Ok(()) @@ -339,7 +345,7 @@ mod tests { #[test] fn test_decode_map_without_size() -> TestResult { let mut input: &[u8] = &[0x02, 0x08, 0x74, 0x65, 0x73, 0x74, 0x02, 0x00]; - let result = decode(&Schema::Map(Box::new(Schema::Int)), &mut input); + let result = decode(&Schema::map(Schema::Int), &mut input); let mut expected = HashMap::new(); expected.insert(String::from("test"), Int(1)); assert_eq!(Map(expected), result?); @@ -350,7 +356,7 @@ mod tests { #[test] fn test_decode_map_with_size() -> TestResult { let mut input: &[u8] = &[0x01, 0x0C, 0x08, 0x74, 0x65, 0x73, 0x74, 0x02, 0x00]; - let result = decode(&Schema::Map(Box::new(Schema::Int)), &mut input); + let result = decode(&Schema::map(Schema::Int), &mut input); let mut expected = HashMap::new(); expected.insert(String::from("test"), Int(1)); assert_eq!(Map(expected), result?); diff --git a/lang/rust/avro/src/encode.rs b/lang/rust/avro/src/encode.rs index 23f94664c89..c99f80e27e3 100644 --- a/lang/rust/avro/src/encode.rs +++ b/lang/rust/avro/src/encode.rs @@ -187,7 +187,7 @@ pub(crate) fn encode_internal>( if !items.is_empty() { encode_long(items.len() as i64, buffer); for item in items.iter() { - encode_internal(item, inner, names, enclosing_namespace, buffer)?; + encode_internal(item, &inner.items, names, enclosing_namespace, buffer)?; } } buffer.push(0u8); @@ -205,7 +205,7 @@ pub(crate) fn encode_internal>( encode_long(items.len() as i64, buffer); for (key, value) in items { encode_bytes(key, buffer); - encode_internal(value, inner, names, enclosing_namespace, buffer)?; + encode_internal(value, &inner.types, names, enclosing_namespace, buffer)?; } } buffer.push(0u8); @@ -309,13 +309,10 @@ pub(crate) mod tests { let empty: Vec = Vec::new(); encode( &Value::Array(empty.clone()), - &Schema::Array(Box::new(Schema::Int)), + &Schema::array(Schema::Int), &mut buf, ) - .expect(&success( - &Value::Array(empty), - &Schema::Array(Box::new(Schema::Int)), - )); + .expect(&success(&Value::Array(empty), &Schema::array(Schema::Int))); assert_eq!(vec![0u8], buf); } @@ -325,13 +322,10 @@ pub(crate) mod tests { let empty: HashMap = HashMap::new(); encode( &Value::Map(empty.clone()), - &Schema::Map(Box::new(Schema::Int)), + &Schema::map(Schema::Int), &mut buf, ) - .expect(&success( - &Value::Map(empty), - &Schema::Map(Box::new(Schema::Int)), - )); + .expect(&success(&Value::Map(empty), &Schema::map(Schema::Int))); assert_eq!(vec![0u8], buf); } diff --git a/lang/rust/avro/src/reader.rs b/lang/rust/avro/src/reader.rs index 2ec0b84cb82..9b598315c8a 100644 --- a/lang/rust/avro/src/reader.rs +++ b/lang/rust/avro/src/reader.rs @@ -71,7 +71,7 @@ impl<'r, R: Read> Block<'r, R> { /// Try to read the header and to set the writer `Schema`, the `Codec` and the marker based on /// its content. fn read_header(&mut self) -> AvroResult<()> { - let meta_schema = Schema::Map(Box::new(Schema::Bytes)); + let meta_schema = Schema::map(Schema::Bytes); let mut buf = [0u8; 4]; self.reader diff --git a/lang/rust/avro/src/schema.rs b/lang/rust/avro/src/schema.rs index f4c063df60d..680a54a0208 100644 --- a/lang/rust/avro/src/schema.rs +++ b/lang/rust/avro/src/schema.rs @@ -111,11 +111,11 @@ pub enum Schema { String, /// A `array` Avro schema. Avro arrays are required to have the same type for each element. /// This variant holds the `Schema` for the array element type. - Array(Box), + Array(ArraySchema), /// A `map` Avro schema. /// `Map` holds a pointer to the `Schema` of its values, which must all be the same schema. /// `Map` keys are assumed to be `string`. - Map(Box), + Map(MapSchema), /// A `union` Avro schema. Union(UnionSchema), /// A `record` Avro schema. @@ -159,6 +159,18 @@ pub enum Schema { Ref { name: Name }, } +#[derive(Clone, Debug, PartialEq)] +pub struct MapSchema { + pub types: Box, + pub custom_attributes: BTreeMap, +} + +#[derive(Clone, Debug, PartialEq)] +pub struct ArraySchema { + pub items: Box, + pub custom_attributes: BTreeMap, +} + impl PartialEq for Schema { /// Assess equality of two `Schema` based on [Parsing Canonical Form]. /// @@ -495,8 +507,11 @@ impl<'s> ResolvedSchema<'s> { ) -> AvroResult<()> { for schema in schemata { match schema { - Schema::Array(schema) | Schema::Map(schema) => { - self.resolve(vec![schema], enclosing_namespace, known_schemata)? + Schema::Array(schema) => { + self.resolve(vec![&schema.items], enclosing_namespace, known_schemata)? + } + Schema::Map(schema) => { + self.resolve(vec![&schema.types], enclosing_namespace, known_schemata)? } Schema::Union(UnionSchema { schemas, .. }) => { for schema in schemas { @@ -581,9 +596,8 @@ impl ResolvedOwnedSchema { enclosing_namespace: &Namespace, ) -> AvroResult<()> { match schema { - Schema::Array(schema) | Schema::Map(schema) => { - Self::from_internal(schema, names, enclosing_namespace) - } + Schema::Array(schema) => Self::from_internal(&schema.items, names, enclosing_namespace), + Schema::Map(schema) => Self::from_internal(&schema.types, names, enclosing_namespace), Schema::Union(UnionSchema { schemas, .. }) => { for schema in schemas { Self::from_internal(schema, names, enclosing_namespace)? @@ -1160,6 +1174,41 @@ impl Schema { _ => None, } } + + /// Returns a Schema::Map with the given types. + pub fn map(types: Schema) -> Self { + Schema::Map(MapSchema { + types: Box::new(types), + custom_attributes: Default::default(), + }) + } + + /// Returns a Schema::Map with the given types and custom attributes. + pub fn map_with_attributes(types: Schema, custom_attributes: BTreeMap) -> Self { + Schema::Map(MapSchema { + types: Box::new(types), + custom_attributes, + }) + } + + /// Returns a Schema::Array with the given items. + pub fn array(items: Schema) -> Self { + Schema::Array(ArraySchema { + items: Box::new(items), + custom_attributes: Default::default(), + }) + } + + /// Returns a Schema::Array with the given items and custom attributes. + pub fn array_with_attributes( + items: Schema, + custom_attributes: BTreeMap, + ) -> Self { + Schema::Array(ArraySchema { + items: Box::new(items), + custom_attributes, + }) + } } impl Parser { @@ -1723,7 +1772,7 @@ impl Parser { .get("items") .ok_or(Error::GetArrayItemsField) .and_then(|items| self.parse(items, enclosing_namespace)) - .map(|schema| Schema::Array(Box::new(schema))) + .map(Schema::array) } /// Parse a `serde_json::Value` representing a Avro map type into a @@ -1737,7 +1786,7 @@ impl Parser { .get("values") .ok_or(Error::GetMapValuesField) .and_then(|items| self.parse(items, enclosing_namespace)) - .map(|schema| Schema::Map(Box::new(schema))) + .map(Schema::map) } /// Parse a `serde_json::Value` representing a Avro union type into a @@ -1847,15 +1896,21 @@ impl Serialize for Schema { Schema::Bytes => serializer.serialize_str("bytes"), Schema::String => serializer.serialize_str("string"), Schema::Array(ref inner) => { - let mut map = serializer.serialize_map(Some(2))?; + let mut map = serializer.serialize_map(Some(2 + inner.custom_attributes.len()))?; map.serialize_entry("type", "array")?; - map.serialize_entry("items", &*inner.clone())?; + map.serialize_entry("items", &*inner.items.clone())?; + for attr in &inner.custom_attributes { + map.serialize_entry(attr.0, attr.1)?; + } map.end() } Schema::Map(ref inner) => { - let mut map = serializer.serialize_map(Some(2))?; + let mut map = serializer.serialize_map(Some(2 + inner.custom_attributes.len()))?; map.serialize_entry("type", "map")?; - map.serialize_entry("values", &*inner.clone())?; + map.serialize_entry("values", &*inner.types.clone())?; + for attr in &inner.custom_attributes { + map.serialize_entry(attr.0, attr.1)?; + } map.end() } Schema::Union(ref inner) => { @@ -2270,10 +2325,7 @@ pub mod derive { named_schemas: &mut Names, enclosing_namespace: &Namespace, ) -> Schema { - Schema::Array(Box::new(T::get_schema_in_ctxt( - named_schemas, - enclosing_namespace, - ))) + Schema::array(T::get_schema_in_ctxt(named_schemas, enclosing_namespace)) } } @@ -2305,10 +2357,7 @@ pub mod derive { named_schemas: &mut Names, enclosing_namespace: &Namespace, ) -> Schema { - Schema::Map(Box::new(T::get_schema_in_ctxt( - named_schemas, - enclosing_namespace, - ))) + Schema::map(T::get_schema_in_ctxt(named_schemas, enclosing_namespace)) } } @@ -2320,10 +2369,7 @@ pub mod derive { named_schemas: &mut Names, enclosing_namespace: &Namespace, ) -> Schema { - Schema::Map(Box::new(T::get_schema_in_ctxt( - named_schemas, - enclosing_namespace, - ))) + Schema::map(T::get_schema_in_ctxt(named_schemas, enclosing_namespace)) } } @@ -2387,14 +2433,14 @@ mod tests { #[test] fn test_array_schema() -> TestResult { let schema = Schema::parse_str(r#"{"type": "array", "items": "string"}"#)?; - assert_eq!(Schema::Array(Box::new(Schema::String)), schema); + assert_eq!(Schema::array(Schema::String), schema); Ok(()) } #[test] fn test_map_schema() -> TestResult { let schema = Schema::parse_str(r#"{"type": "map", "values": "double"}"#)?; - assert_eq!(Schema::Map(Box::new(Schema::Double)), schema); + assert_eq!(Schema::map(Schema::Double), schema); Ok(()) } @@ -2748,9 +2794,9 @@ mod tests { doc: None, default: None, aliases: None, - schema: Schema::Array(Box::new(Schema::Ref { + schema: Schema::array(Schema::Ref { name: Name::new("Node")?, - })), + }), order: RecordFieldOrder::Ascending, position: 1, custom_attributes: Default::default(), @@ -4442,7 +4488,7 @@ mod tests { assert_eq!(union.schemas[0], Schema::Null); if let Schema::Array(ref array_schema) = union.schemas[1] { - if let Schema::Long = **array_schema { + if let Schema::Long = *array_schema.items { // OK } else { panic!("Expected a Schema::Array of type Long"); @@ -6529,4 +6575,40 @@ mod tests { Ok(()) } + + #[test] + fn test_avro_3927_serialize_array_with_custom_attributes() -> TestResult { + let expected = Schema::array_with_attributes( + Schema::Long, + BTreeMap::from([("field-id".to_string(), "1".into())]), + ); + + let value = serde_json::to_value(&expected)?; + let serialized = serde_json::to_string(&value)?; + assert_eq!( + r#"{"field-id":"1","items":"long","type":"array"}"#, + &serialized + ); + assert_eq!(expected, Schema::parse_str(&serialized)?); + + Ok(()) + } + + #[test] + fn test_avro_3927_serialize_map_with_custom_attributes() -> TestResult { + let expected = Schema::map_with_attributes( + Schema::Long, + BTreeMap::from([("field-id".to_string(), "1".into())]), + ); + + let value = serde_json::to_value(&expected)?; + let serialized = serde_json::to_string(&value)?; + assert_eq!( + r#"{"field-id":"1","type":"map","values":"long"}"#, + &serialized + ); + assert_eq!(expected, Schema::parse_str(&serialized)?); + + Ok(()) + } } diff --git a/lang/rust/avro/src/schema_compatibility.rs b/lang/rust/avro/src/schema_compatibility.rs index 107a30a3745..09c302036e2 100644 --- a/lang/rust/avro/src/schema_compatibility.rs +++ b/lang/rust/avro/src/schema_compatibility.rs @@ -71,7 +71,7 @@ impl Checker { SchemaKind::Map => { if let Schema::Map(w_m) = writers_schema { match readers_schema { - Schema::Map(r_m) => self.full_match_schemas(w_m, r_m), + Schema::Map(r_m) => self.full_match_schemas(&w_m.types, &r_m.types), _ => Err(CompatibilityError::WrongType { writer_schema_type: format!("{:#?}", writers_schema), reader_schema_type: format!("{:#?}", readers_schema), @@ -87,7 +87,7 @@ impl Checker { SchemaKind::Array => { if let Schema::Array(w_a) = writers_schema { match readers_schema { - Schema::Array(r_a) => self.full_match_schemas(w_a, r_a), + Schema::Array(r_a) => self.full_match_schemas(&w_a.items, &r_a.items), _ => Err(CompatibilityError::WrongType { writer_schema_type: format!("{:#?}", writers_schema), reader_schema_type: format!("{:#?}", readers_schema), @@ -370,7 +370,7 @@ impl SchemaCompatibility { SchemaKind::Map => { if let Schema::Map(w_m) = writers_schema { if let Schema::Map(r_m) = readers_schema { - return SchemaCompatibility::match_schemas(w_m, r_m); + return SchemaCompatibility::match_schemas(&w_m.types, &r_m.types); } else { return Err(CompatibilityError::TypeExpected { schema_type: String::from("readers_schema"), @@ -387,7 +387,7 @@ impl SchemaCompatibility { SchemaKind::Array => { if let Schema::Array(w_a) = writers_schema { if let Schema::Array(r_a) = readers_schema { - return SchemaCompatibility::match_schemas(w_a, r_a); + return SchemaCompatibility::match_schemas(&w_a.items, &r_a.items); } else { return Err(CompatibilityError::TypeExpected { schema_type: String::from("readers_schema"), diff --git a/lang/rust/avro/src/types.rs b/lang/rust/avro/src/types.rs index 97d6b7174f7..62752bbbacf 100644 --- a/lang/rust/avro/src/types.rs +++ b/lang/rust/avro/src/types.rs @@ -523,14 +523,14 @@ impl Value { (Value::Array(items), Schema::Array(inner)) => items.iter().fold(None, |acc, item| { Value::accumulate( acc, - item.validate_internal(inner, names, enclosing_namespace), + item.validate_internal(&inner.items, names, enclosing_namespace), ) }), (Value::Map(items), Schema::Map(inner)) => { items.iter().fold(None, |acc, (_, value)| { Value::accumulate( acc, - value.validate_internal(inner, names, enclosing_namespace), + value.validate_internal(&inner.types, names, enclosing_namespace), ) }) } @@ -681,8 +681,10 @@ impl Value { ref default, .. }) => self.resolve_enum(symbols, default, field_default), - Schema::Array(ref inner) => self.resolve_array(inner, names, enclosing_namespace), - Schema::Map(ref inner) => self.resolve_map(inner, names, enclosing_namespace), + Schema::Array(ref inner) => { + self.resolve_array(&inner.items, names, enclosing_namespace) + } + Schema::Map(ref inner) => self.resolve_map(&inner.types, names, enclosing_namespace), Schema::Record(RecordSchema { ref fields, .. }) => { self.resolve_record(fields, names, enclosing_namespace) } @@ -1265,15 +1267,15 @@ mod tests { ), ( Value::Array(vec![Value::Long(42i64)]), - Schema::Array(Box::new(Schema::Long)), + Schema::array(Schema::Long), true, "", ), ( Value::Array(vec![Value::Boolean(true)]), - Schema::Array(Box::new(Schema::Long)), + Schema::array(Schema::Long), false, - "Invalid value: Array([Boolean(true)]) for schema: Array(Long). Reason: Unsupported value-schema combination", + "Invalid value: Array([Boolean(true)]) for schema: Array(ArraySchema { items: Long, custom_attributes: {} }). Reason: Unsupported value-schema combination", ), (Value::Record(vec![]), Schema::Null, false, "Invalid value: Record([]) for schema: Null. Reason: Unsupported value-schema combination"), ( diff --git a/lang/rust/avro/src/writer.rs b/lang/rust/avro/src/writer.rs index b820885c6e3..446a4c0ef39 100644 --- a/lang/rust/avro/src/writer.rs +++ b/lang/rust/avro/src/writer.rs @@ -376,11 +376,7 @@ impl<'a, W: Write> Writer<'a, W> { let mut header = Vec::new(); header.extend_from_slice(AVRO_OBJECT_HEADER); - encode( - &metadata.into(), - &Schema::Map(Box::new(Schema::Bytes)), - &mut header, - )?; + encode(&metadata.into(), &Schema::map(Schema::Bytes), &mut header)?; header.extend_from_slice(&self.marker); Ok(header) diff --git a/lang/rust/avro_derive/src/lib.rs b/lang/rust/avro_derive/src/lib.rs index 5b36839be4e..bee080ace3f 100644 --- a/lang/rust/avro_derive/src/lib.rs +++ b/lang/rust/avro_derive/src/lib.rs @@ -267,7 +267,7 @@ fn type_to_schema_expr(ty: &Type) -> Result> { Ok(schema) } else if let Type::Array(ta) = ty { let inner_schema_expr = type_to_schema_expr(&ta.elem)?; - Ok(quote! {apache_avro::schema::Schema::Array(Box::new(#inner_schema_expr))}) + Ok(quote! {apache_avro::schema::Schema::array(#inner_schema_expr)}) } else if let Type::Reference(tr) = ty { type_to_schema_expr(&tr.elem) } else {