From c5dceea0a8b306ffc4119ab2ea2cf86a612454d8 Mon Sep 17 00:00:00 2001 From: Millione Date: Wed, 20 Nov 2024 16:32:32 +0800 Subject: [PATCH] feat: support generate BTreeSet/BTreeMap for thrift set/map --- Cargo.lock | 4 +- pilota-build/Cargo.toml | 2 +- pilota-build/src/codegen/mod.rs | 6 +- pilota-build/src/codegen/thrift/ty.rs | 319 +++++--- pilota-build/src/middle/context.rs | 76 +- pilota-build/src/middle/ty.rs | 80 +- pilota-build/src/resolve.rs | 54 +- pilota-build/test_data/thrift/btree.rs | 847 +++++++++++++++++++++ pilota-build/test_data/thrift/btree.thrift | 25 + pilota/Cargo.toml | 2 +- pilota/src/thrift/mod.rs | 285 ++++--- 11 files changed, 1390 insertions(+), 310 deletions(-) create mode 100644 pilota-build/test_data/thrift/btree.rs create mode 100644 pilota-build/test_data/thrift/btree.thrift diff --git a/Cargo.lock b/Cargo.lock index 40ee21ce..c40179fb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -841,7 +841,7 @@ dependencies = [ [[package]] name = "pilota" -version = "0.11.7" +version = "0.11.8" dependencies = [ "ahash", "anyhow", @@ -865,7 +865,7 @@ dependencies = [ [[package]] name = "pilota-build" -version = "0.11.25" +version = "0.11.26" dependencies = [ "ahash", "anyhow", diff --git a/pilota-build/Cargo.toml b/pilota-build/Cargo.toml index f363468d..eba818f5 100644 --- a/pilota-build/Cargo.toml +++ b/pilota-build/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pilota-build" -version = "0.11.25" +version = "0.11.26" edition = "2021" description = "Compile thrift and protobuf idl into rust code at compile-time." documentation = "https://docs.rs/pilota-build" diff --git a/pilota-build/src/codegen/mod.rs b/pilota-build/src/codegen/mod.rs index 3ffdb689..9a0feb03 100644 --- a/pilota-build/src/codegen/mod.rs +++ b/pilota-build/src/codegen/mod.rs @@ -592,12 +592,12 @@ where To avoid problems when generating files for services with similar names, e.g. testService and TestService, such names are de-duplicated by adding a number to their nam5e */ - fn generate_unique_name(existing_names: &AHashSet, simple_name: &String) -> String { + fn generate_unique_name(existing_names: &AHashSet, simple_name: &str) -> String { let mut counter = 1; - let mut name = simple_name.clone(); + let mut name = simple_name.to_string(); while existing_names.contains(name.to_ascii_lowercase().as_str()) { counter += 1; - name = format!("{}_{}", simple_name.clone(), counter) + name = format!("{}_{}", simple_name, counter) } name } diff --git a/pilota-build/src/codegen/thrift/ty.rs b/pilota-build/src/codegen/thrift/ty.rs index b8376443..22ebff8c 100644 --- a/pilota-build/src/codegen/thrift/ty.rs +++ b/pilota-build/src/codegen/thrift/ty.rs @@ -23,8 +23,8 @@ impl ThriftBackend { ty::F64 | ty::OrderedF64 => "::pilota::thrift::TType::Double".into(), ty::Uuid => "::pilota::thrift::TType::Uuid".into(), ty::Vec(_) => "::pilota::thrift::TType::List".into(), - ty::Set(_) => "::pilota::thrift::TType::Set".into(), - ty::Map(_, _) => "::pilota::thrift::TType::Map".into(), + ty::Set(_) | ty::BTreeSet(_) => "::pilota::thrift::TType::Set".into(), + ty::Map(_, _) | ty::BTreeMap(_, _) => "::pilota::thrift::TType::Map".into(), ty::Path(path) => { let item = self.expect_item(path.did); match &*item { @@ -73,34 +73,17 @@ impl ThriftBackend { } .into() } - ty::Set(ty) => { - let write_el = self.codegen_encode_ty(ty, "val".into()); - let el_ttype = self.ttype(ty); - - format! { - r#"__protocol.write_set({el_ttype}, &{ident}, |__protocol, val| {{ - {write_el} - ::std::result::Result::Ok(()) - }})?;"# - } - .into() + ty::Set(k) => { + self.encode_set(k, ident, "set") + } + ty::BTreeSet(k) => { + self.encode_set(k, ident, "btree_set") } ty::Map(k, v) => { - let key_ttype = self.ttype(k); - let val_ttype = self.ttype(v); - let write_key = self.codegen_encode_ty(k, "key".into()); - let write_val = self.codegen_encode_ty(v, "val".into()); - - format! { - r#"__protocol.write_map({key_ttype}, {val_ttype}, &{ident}, |__protocol, key| {{ - {write_key} - ::std::result::Result::Ok(()) - }}, |__protocol, val| {{ - {write_val} - ::std::result::Result::Ok(()) - }})?;"# - } - .into() + self.encode_map(k, v, ident, "map") + } + ty::BTreeMap(k, v) => { + self.encode_map(k, v, ident, "btree_map") } ty::Path(_) => format!("__protocol.write_struct({ident})?;").into(), ty::Arc(ty) => self.codegen_encode_ty(ty, ident), @@ -108,6 +91,38 @@ impl ThriftBackend { } } + #[inline] + fn encode_set(&self, ty: &Ty, ident: FastStr, name: &str) -> FastStr { + let write_el = self.codegen_encode_ty(ty, "val".into()); + let el_ttype = self.ttype(ty); + format! { + r#"__protocol.write_{name}({el_ttype}, &{ident}, |__protocol, val| {{ + {write_el} + ::std::result::Result::Ok(()) + }})?;"# + } + .into() + } + + #[inline] + fn encode_map(&self, k: &Ty, v: &Ty, ident: FastStr, name: &str) -> FastStr { + let key_ttype = self.ttype(k); + let val_ttype = self.ttype(v); + let write_key = self.codegen_encode_ty(k, "key".into()); + let write_val = self.codegen_encode_ty(v, "val".into()); + + format! { + r#"__protocol.write_{name}({key_ttype}, {val_ttype}, &{ident}, |__protocol, key| {{ + {write_key} + ::std::result::Result::Ok(()) + }}, |__protocol, val| {{ + {write_val} + ::std::result::Result::Ok(()) + }})?;"# + } + .into() + } + fn is_i32_enum(&self, def_id: DefId) -> bool { let item = self.expect_item(def_id); match &*item { @@ -148,34 +163,17 @@ impl ThriftBackend { } .into() } - ty::Set(ty) => { - let write_el = self.codegen_encode_ty(ty, "val".into()); - let el_ttype = self.ttype(ty); - - format! { - r#"__protocol.write_set_field({id}, {el_ttype}, &{ident}, |__protocol, val| {{ - {write_el} - ::std::result::Result::Ok(()) - }})?;"# - } - .into() + ty::Set(k) => { + self.encode_set_field(k, id, ident, "set") + } + ty::BTreeSet(k) => { + self.encode_set_field(k, id, ident, "btree_set") } ty::Map(k, v) => { - let key_ttype = self.ttype(k); - let val_ttype = self.ttype(v); - let write_key = self.codegen_encode_ty(k, "key".into()); - let write_val = self.codegen_encode_ty(v, "val".into()); - - format! { - r#"__protocol.write_map_field({id}, {key_ttype}, {val_ttype}, &{ident}, |__protocol, key| {{ - {write_key} - ::std::result::Result::Ok(()) - }}, |__protocol, val| {{ - {write_val} - ::std::result::Result::Ok(()) - }})?;"# - } - .into() + self.encode_map_field(k, v, id, ident, "map") + } + ty::BTreeMap(k, v) => { + self.encode_map_field(k, v, id, ident, "btree_map") } ty::Path(p) if self.is_i32_enum(p.did) => { format!("__protocol.write_i32_field({id}, ({ident}).inner())?;").into() @@ -195,6 +193,38 @@ impl ThriftBackend { } } + #[inline] + fn encode_set_field(&self, ty: &Ty, id: i16, ident: FastStr, name: &str) -> FastStr { + let write_el = self.codegen_encode_ty(ty, "val".into()); + let el_ttype = self.ttype(ty); + format! { + r#"__protocol.write_{name}_field({id}, {el_ttype}, &{ident}, |__protocol, val| {{ + {write_el} + ::std::result::Result::Ok(()) + }})?;"# + } + .into() + } + + #[inline] + fn encode_map_field(&self, k: &Ty, v: &Ty, id: i16, ident: FastStr, name: &str) -> FastStr { + let key_ttype = self.ttype(k); + let val_ttype = self.ttype(v); + let write_key = self.codegen_encode_ty(k, "key".into()); + let write_val = self.codegen_encode_ty(v, "val".into()); + + format! { + r#"__protocol.write_{name}_field({id}, {key_ttype}, {val_ttype}, &{ident}, |__protocol, key| {{ + {write_key} + ::std::result::Result::Ok(()) + }}, |__protocol, val| {{ + {write_val} + ::std::result::Result::Ok(()) + }})?;"# + } + .into() + } + pub(crate) fn codegen_ty_size(&self, ty: &Ty, ident: FastStr) -> FastStr { match &ty.kind { ty::String => format!("__protocol.string_len({ident})").into(), @@ -221,37 +251,45 @@ impl ThriftBackend { } .into() } - ty::Set(el) => { - let add_el = self.codegen_ty_size(el, "el".into()); - let el_ttype = self.ttype(el); - format! { - r#"__protocol.set_len({el_ttype}, {ident}, |__protocol, el| {{ - {add_el} - }})"# - } - .into() - } - ty::Map(k, v) => { - let add_key = self.codegen_ty_size(k, "key".into()); - let add_val = self.codegen_ty_size(v, "val".into()); - let k_ttype = self.ttype(k); - let v_ttype = self.ttype(v); - - format! { - r#"__protocol.map_len({k_ttype}, {v_ttype}, {ident}, |__protocol, key| {{ - {add_key} - }}, |__protocol, val| {{ - {add_val} - }})"# - } - .into() - } + ty::Set(k) => self.set_size(k, ident, "set"), + ty::BTreeSet(k) => self.set_size(k, ident, "btree_set"), + ty::Map(k, v) => self.map_size(k, v, ident, "map"), + ty::BTreeMap(k, v) => self.map_size(k, v, ident, "btree_map"), ty::Path(_) => format!("__protocol.struct_len({ident})").into(), ty::Arc(ty) => self.codegen_ty_size(ty, ident), _ => unimplemented!(), } } + #[inline] + fn set_size(&self, ty: &Ty, ident: FastStr, name: &str) -> FastStr { + let add_el = self.codegen_ty_size(ty, "el".into()); + let el_ttype = self.ttype(ty); + format! { + r#"__protocol.{name}_len({el_ttype}, {ident}, |__protocol, el| {{ + {add_el} + }})"# + } + .into() + } + + #[inline] + fn map_size(&self, k: &Ty, v: &Ty, ident: FastStr, name: &str) -> FastStr { + let add_key = self.codegen_ty_size(k, "key".into()); + let add_val = self.codegen_ty_size(v, "val".into()); + let k_ttype = self.ttype(k); + let v_ttype = self.ttype(v); + + format! { + r#"__protocol.{name}_len({k_ttype}, {v_ttype}, {ident}, |__protocol, key| {{ + {add_key} + }}, |__protocol, val| {{ + {add_val} + }})"# + } + .into() + } + pub(crate) fn codegen_field_size(&self, ty: &Ty, id: i16, ident: FastStr) -> FastStr { match &ty.kind { ty::String => format!("__protocol.string_field_len(Some({id}), &{ident})").into(), @@ -278,24 +316,10 @@ impl ThriftBackend { } .into() } - ty::Set(el) => { - let add_el = self.codegen_ty_size(el, "el".into()); - let el_ttype = self.ttype(el); - format! { - r#"__protocol.set_field_len(Some({id}), {el_ttype}, {ident}, |__protocol, el| {{ - {add_el} - }})"# - } - .into() - } - ty::Map(k, v) => { - let add_key = self.codegen_ty_size(k, "key".into()); - let add_val = self.codegen_ty_size(v, "val".into()); - let k_ttype = self.ttype(k); - let v_ttype = self.ttype(v); - - format!("__protocol.map_field_len(Some({id}), {k_ttype}, {v_ttype}, {ident}, |__protocol, key| {{ {add_key} }}, |__protocol, val| {{ {add_val} }})").into() - } + ty::Set(k) => self.set_field_size(k, id, ident, "set"), + ty::BTreeSet(k) => self.set_field_size(k, id, ident, "btree_set"), + ty::Map(k, v) => self.map_field_size(k, v, id, ident, "map"), + ty::BTreeMap(k, v) => self.map_field_size(k, v, id, ident, "btree_map"), ty::Path(p) if self.is_i32_enum(p.did) => { format!("__protocol.i32_field_len(Some({id}), ({ident}).inner())").into() } @@ -305,6 +329,35 @@ impl ThriftBackend { } } + #[inline] + fn set_field_size(&self, ty: &Ty, id: i16, ident: FastStr, name: &str) -> FastStr { + let add_el = self.codegen_ty_size(ty, "el".into()); + let el_ttype = self.ttype(ty); + format! { + r#"__protocol.{name}_field_len(Some({id}), {el_ttype}, {ident}, |__protocol, el| {{ + {add_el} + }})"# + } + .into() + } + + #[inline] + fn map_field_size(&self, k: &Ty, v: &Ty, id: i16, ident: FastStr, name: &str) -> FastStr { + let add_key = self.codegen_ty_size(k, "key".into()); + let add_val = self.codegen_ty_size(v, "val".into()); + let k_ttype = self.ttype(k); + let v_ttype = self.ttype(v); + + format! { + r#"__protocol.{name}_field_len(Some({id}), {k_ttype}, {v_ttype}, {ident}, |__protocol, key| {{ + {add_key} + }}, |__protocol, val| {{ + {add_val} + }})"# + } + .into() + } + pub(crate) fn codegen_decode_ty(&self, helper: &DecodeHelper, ty: &Ty) -> FastStr { match &ty.kind { ty::String => helper.codegen_read_string(), @@ -369,46 +422,66 @@ impl ThriftBackend { .into() } } - ty::Set(ty) => { - let read_set_begin = helper.codegen_read_set_begin(); - let read_set_end = helper.codegen_read_set_end(); - let read_el = self.codegen_decode_ty(helper, ty); - format! {r#"{{let list_ident = {read_set_begin}; - let mut val = ::pilota::AHashSet::with_capacity(list_ident.size); + ty::Set(ty) => self.decode_set( + ty, + helper, + "::pilota::AHashSet::with_capacity(list_ident.size)", + ), + ty::BTreeSet(ty) => self.decode_set(ty, helper, "::std::collections::BTreeSet::new()"), + ty::Map(key_ty, val_ty) => self.decode_map( + key_ty, + val_ty, + helper, + "::pilota::AHashMap::with_capacity(map_ident.size)", + ), + ty::BTreeMap(key_ty, val_ty) => self.decode_map( + key_ty, + val_ty, + helper, + "::std::collections::BTreeMap::new()", + ), + ty::Path(_) => helper + .codegen_item_decode(format!("{}", self.codegen_item_ty(ty.kind.clone())).into()), + ty::Arc(ty) => { + let inner = self.codegen_decode_ty(helper, ty); + format!("::std::sync::Arc::new({inner})").into() + } + _ => unimplemented!(), + } + } + + #[inline] + fn decode_set(&self, ty: &Ty, helper: &DecodeHelper, new: &str) -> FastStr { + let read_set_begin = helper.codegen_read_set_begin(); + let read_set_end = helper.codegen_read_set_end(); + let read_el = self.codegen_decode_ty(helper, ty); + format! {r#"{{let list_ident = {read_set_begin}; + let mut val = {new}; for _ in 0..list_ident.size {{ val.insert({read_el}); }}; {read_set_end}; val}}"#} - .into() - } - ty::Map(key_ty, val_ty) => { - let read_el_key = self.codegen_decode_ty(helper, key_ty); - let read_el_val = self.codegen_decode_ty(helper, val_ty); - - let read_map_begin = helper.codegen_read_map_begin(); - let read_map_end = helper.codegen_read_map_end(); + .into() + } - format! { + #[inline] + fn decode_map(&self, key_ty: &Ty, val_ty: &Ty, helper: &DecodeHelper, new: &str) -> FastStr { + let read_el_key = self.codegen_decode_ty(helper, key_ty); + let read_el_val = self.codegen_decode_ty(helper, val_ty); + let read_map_begin = helper.codegen_read_map_begin(); + let read_map_end = helper.codegen_read_map_end(); + format! { r#"{{ let map_ident = {read_map_begin}; - let mut val = ::pilota::AHashMap::with_capacity(map_ident.size); + let mut val = {new}; for _ in 0..map_ident.size {{ val.insert({read_el_key}, {read_el_val}); }} {read_map_end}; val }}"# - } - .into() - } - ty::Path(_) => helper - .codegen_item_decode(format!("{}", self.codegen_item_ty(ty.kind.clone())).into()), - ty::Arc(ty) => { - let inner = self.codegen_decode_ty(helper, ty); - format!("::std::sync::Arc::new({inner})").into() - } - _ => unimplemented!(), } + .into() } } diff --git a/pilota-build/src/middle/context.rs b/pilota-build/src/middle/context.rs index a2521ce2..28a16dd1 100644 --- a/pilota-build/src/middle/context.rs +++ b/pilota-build/src/middle/context.rs @@ -468,7 +468,10 @@ impl Context { lit: &Literal, ty: &CodegenTy, ) -> anyhow::Result<(FastStr, bool /* const? */)> { - let mk_map = |m: &Vec<(Literal, Literal)>, k_ty: &Arc, v_ty: &Arc| { + let mk_map = |m: &Vec<(Literal, Literal)>, + k_ty: &Arc, + v_ty: &Arc, + btree: bool| { let k_ty = &**k_ty; let v_ty = &**v_ty; let len = m.len(); @@ -481,9 +484,14 @@ impl Context { }) .try_collect::<_, Vec<_>, _>()? .join(""); + let new = if btree { + "::std::collections::BTreeMap::new()".to_string() + } else { + format!("::pilota::AHashMap::with_capacity({len})") + }; anyhow::Ok( format! {r#"{{ - let mut map = ::pilota::AHashMap::with_capacity({len}); + let mut map = {new}; {kvs} map }}"#} @@ -493,14 +501,32 @@ impl Context { anyhow::Ok(match (lit, ty) { (Literal::Map(m), CodegenTy::LazyStaticRef(map)) => match &**map { - CodegenTy::Map(k_ty, v_ty) => (mk_map(m, k_ty, v_ty)?, false), + CodegenTy::Map(k_ty, v_ty) => (mk_map(m, k_ty, v_ty, false)?, false), + CodegenTy::BTreeMap(k_ty, v_ty) => (mk_map(m, k_ty, v_ty, true)?, false), _ => panic!("invalid map type {:?}", map), }, - (Literal::Map(m), CodegenTy::Map(k_ty, v_ty)) => (mk_map(m, k_ty, v_ty)?, false), - (Literal::List(m), CodegenTy::Map(_, _) | CodegenTy::LazyStaticRef(_)) => { - assert!(m.is_empty()); + (Literal::Map(m), CodegenTy::Map(k_ty, v_ty)) => (mk_map(m, k_ty, v_ty, false)?, false), + (Literal::Map(m), CodegenTy::BTreeMap(k_ty, v_ty)) => { + (mk_map(m, k_ty, v_ty, true)?, false) + } + (Literal::List(l), CodegenTy::LazyStaticRef(map)) => { + assert!(l.is_empty()); + match &**map { + CodegenTy::Map(_, _) => ("::pilota::AHashMap::new()".into(), false), + CodegenTy::BTreeMap(_, _) => { + ("::std::collections::BTreeMap::new()".into(), false) + } + _ => panic!("invalid map type {:?}", map), + } + } + (Literal::List(l), CodegenTy::Map(_, _)) => { + assert!(l.is_empty()); ("::pilota::AHashMap::new()".into(), false) } + (Literal::List(l), CodegenTy::BTreeMap(_, _)) => { + assert!(l.is_empty()); + ("::std::collections::BTreeMap::new()".into(), false) + } _ => self.lit_into_ty(lit, ty)?, }) } @@ -637,7 +663,7 @@ impl Context { (format! { "{ident}({stream})" }.into(), is_const) } (Literal::Map(_), CodegenTy::StaticRef(map)) => match &**map { - CodegenTy::Map(_, _) => { + CodegenTy::Map(_, _) | CodegenTy::BTreeMap(_, _) => { let lazy_map = self.def_lit("INNER_MAP", lit, &mut CodegenTy::LazyStaticRef(map.clone()))?; let stream = format! { @@ -662,30 +688,23 @@ impl Context { (format! {"[{stream}]" }.into(), is_const) } (Literal::List(els), CodegenTy::Vec(inner)) => { - let stream = els - .iter() - .map(|el| self.lit_into_ty(el, inner)) - .try_collect::<_, Vec<_>, _>()? - .into_iter() - .map(|(s, _)| s) - .join(","); - + let stream = self.list_stream(els, inner)?; (format! { "::std::vec![{stream}]" }.into(), false) } (Literal::List(els), CodegenTy::Set(inner)) => { - let stream = els - .iter() - .map(|el| self.lit_into_ty(el, inner)) - .try_collect::<_, Vec<_>, _>()? - .into_iter() - .map(|(s, _)| s) - .join(","); - + let stream = self.list_stream(els, inner)?; ( format! { "::pilota::AHashSet::from([{stream}])" }.into(), false, ) } + (Literal::List(els), CodegenTy::BTreeSet(inner)) => { + let stream = self.list_stream(els, inner)?; + ( + format! { "::std::collections::BTreeSet::from([{stream}])" }.into(), + false, + ) + } (Literal::Bool(b), CodegenTy::Bool) => (format! { "{b}" }.into(), true), (Literal::Int(i), CodegenTy::Bool) => { let b = *i != 0; @@ -763,6 +782,17 @@ impl Context { }) } + #[inline] + fn list_stream(&self, els: &[Literal], inner: &Arc) -> anyhow::Result { + Ok(els + .iter() + .map(|el| self.lit_into_ty(el, inner)) + .try_collect::<_, Vec<_>, _>()? + .into_iter() + .map(|(s, _)| s) + .join(",")) + } + pub(crate) fn def_lit( &self, name: &str, diff --git a/pilota-build/src/middle/ty.rs b/pilota-build/src/middle/ty.rs index eeb094a4..8a080c08 100644 --- a/pilota-build/src/middle/ty.rs +++ b/pilota-build/src/middle/ty.rs @@ -29,7 +29,9 @@ pub enum TyKind { Uuid, Vec(Arc), Set(Arc), + BTreeSet(Arc), Map(Arc, Arc), + BTreeMap(Arc, Arc), Arc(Arc), Path(Path), } @@ -77,7 +79,9 @@ pub enum CodegenTy { Vec(Arc), Array(Arc, usize), Set(Arc), + BTreeSet(Arc), Map(Arc, Arc), + BTreeMap(Arc, Arc), Adt(AdtDef), Arc(Arc), } @@ -89,7 +93,8 @@ impl CodegenTy { | CodegenTy::LazyStaticRef(_) | CodegenTy::StaticRef(_) | CodegenTy::Vec(_) - | CodegenTy::Map(_, _) => true, + | CodegenTy::Map(_, _) + | CodegenTy::BTreeMap(_, _) => true, CodegenTy::Adt(AdtDef { did: _, kind: AdtKind::NewType(inner), @@ -133,6 +138,14 @@ impl CodegenTy { let ty = &**ty; format!("::pilota::AHashSet<{}>", ty.global_path(adt_prefix)).into() } + CodegenTy::BTreeSet(ty) => { + let ty = &**ty; + format!( + "::std::collections::BTreeSet<{}>", + ty.global_path(adt_prefix) + ) + .into() + } CodegenTy::Map(k, v) => { let k = &**k; let v = &**v; @@ -143,6 +156,16 @@ impl CodegenTy { ) .into() } + CodegenTy::BTreeMap(k, v) => { + let k = &**k; + let v = &**v; + format!( + "::std::collections::BTreeMap<{}, {}>", + k.global_path(adt_prefix), + v.global_path(adt_prefix) + ) + .into() + } CodegenTy::Adt(def) => with_cx(|cx| { let path = cx .item_path(def.did) @@ -197,11 +220,20 @@ impl Display for CodegenTy { let ty = &**ty; write!(f, "::pilota::AHashSet<{ty}>") } + CodegenTy::BTreeSet(ty) => { + let ty = &**ty; + write!(f, "::std::collections::BTreeSet<{ty}>") + } CodegenTy::Map(k, v) => { let k = &**k; let v = &**v; write!(f, "::pilota::AHashMap<{k}, {v}>") } + CodegenTy::BTreeMap(k, v) => { + let k = &**k; + let v = &**v; + write!(f, "::std::collections::BTreeMap<{k}, {v}>") + } CodegenTy::Adt(def) => with_cx(|cx| { let path = cx.cur_related_item_path(def.did); @@ -330,6 +362,11 @@ pub trait TyTransformer { CodegenTy::Set(Arc::from(self.codegen_item_ty(&ty.kind))) } + #[inline] + fn btree_set(&self, ty: &Ty) -> CodegenTy { + CodegenTy::BTreeSet(Arc::from(self.codegen_item_ty(&ty.kind))) + } + #[inline] fn map(&self, key: &Ty, value: &Ty) -> CodegenTy { let key = self.codegen_item_ty(&key.kind); @@ -337,6 +374,13 @@ pub trait TyTransformer { CodegenTy::Map(Arc::from(key), Arc::from(value)) } + #[inline] + fn btree_map(&self, key: &Ty, value: &Ty) -> CodegenTy { + let key = self.codegen_item_ty(&key.kind); + let value = self.codegen_item_ty(&value.kind); + CodegenTy::BTreeMap(Arc::from(key), Arc::from(value)) + } + #[inline] fn path(&self, path: &Path) -> CodegenTy { let did = path.did; @@ -368,7 +412,9 @@ pub trait TyTransformer { Uuid => self.uuid(), Vec(ty) => self.vec(ty), Set(ty) => self.set(ty), + BTreeSet(ty) => self.btree_set(ty), Map(k, v) => self.map(k, v), + BTreeMap(k, v) => self.btree_map(k, v), Path(path) => self.path(path), UInt32 => self.uint32(), UInt64 => self.uint64(), @@ -423,6 +469,13 @@ impl TyTransformer for ConstTyTransformer<'_> { )))) } + #[inline] + fn btree_set(&self, ty: &Ty) -> CodegenTy { + CodegenTy::StaticRef(Arc::from(CodegenTy::BTreeSet(Arc::from( + self.dyn_codegen_item_ty(&ty.kind), + )))) + } + #[inline] fn map(&self, key: &Ty, value: &Ty) -> CodegenTy { let key = self.dyn_codegen_item_ty(&key.kind); @@ -430,6 +483,16 @@ impl TyTransformer for ConstTyTransformer<'_> { CodegenTy::StaticRef(Arc::from(CodegenTy::Map(Arc::from(key), Arc::from(value)))) } + #[inline] + fn btree_map(&self, key: &Ty, value: &Ty) -> CodegenTy { + let key = self.dyn_codegen_item_ty(&key.kind); + let value = self.dyn_codegen_item_ty(&value.kind); + CodegenTy::StaticRef(Arc::from(CodegenTy::BTreeMap( + Arc::from(key), + Arc::from(value), + ))) + } + fn get_db(&self) -> &dyn RirDatabase { self.0 } @@ -446,11 +509,20 @@ pub(crate) trait Visitor: Sized { self.visit(el) } + fn visit_btree_set(&mut self, el: &Ty) { + self.visit(el) + } + fn visit_map(&mut self, k: &Ty, v: &Ty) { self.visit(k); self.visit(v); } + fn visit_btree_map(&mut self, k: &Ty, v: &Ty) { + self.visit(k); + self.visit(v); + } + fn visit(&mut self, ty: &Ty) { walk_ty(self, ty) } @@ -480,7 +552,9 @@ pub(crate) fn fold_ty(f: &mut F, ty: &Ty) -> Ty { Uuid => TyKind::Uuid, Vec(ty) => TyKind::Vec(f.fold_ty(ty).into()), Set(ty) => TyKind::Set(f.fold_ty(ty).into()), + BTreeSet(ty) => TyKind::BTreeSet(f.fold_ty(ty).into()), Map(k, v) => TyKind::Map(fold_ty(f, k).into(), fold_ty(f, v).into()), + BTreeMap(k, v) => TyKind::BTreeMap(fold_ty(f, k).into(), fold_ty(f, v).into()), Path(path) => TyKind::Path(path.clone()), UInt32 => TyKind::UInt32, UInt64 => TyKind::UInt64, @@ -498,7 +572,9 @@ pub(crate) fn walk_ty(v: &mut V, ty: &Ty) { match &ty.kind { Vec(el) => v.visit_vec(el), Set(el) => v.visit_set(el), + BTreeSet(el) => v.visit_btree_set(el), Map(key, value) => v.visit_map(key, value), + BTreeMap(key, value) => v.visit_btree_map(key, value), Path(p) => v.visit_path(p), Arc(p) => v.visit(p), _ => {} @@ -510,7 +586,7 @@ mod tests { #[test] fn test_global_path() { use super::CodegenTy::*; - let ty = Vec(std::sync::Arc::new(U8).into()); + let ty = Vec(std::sync::Arc::new(U8)); assert_eq!(ty.global_path("adt_prefix"), "::std::vec::Vec"); let ty = Set(std::sync::Arc::new(U8)); diff --git a/pilota-build/src/resolve.rs b/pilota-build/src/resolve.rs index 2e008baa..00375f84 100644 --- a/pilota-build/src/resolve.rs +++ b/pilota-build/src/resolve.rs @@ -46,7 +46,7 @@ pub struct CollectDef<'a> { } impl<'a> CollectDef<'a> { - pub fn new(resolver: &'a mut Resolver) -> CollectDef { + pub fn new(resolver: &'a mut Resolver) -> Self { CollectDef { resolver, parent: None, @@ -295,6 +295,34 @@ impl Resolver { _ => {} } + if let Some(repr) = tags.get::() { + if repr == "btree" { + struct BTreeFolder<'a>(&'a mut Resolver); + impl Folder for BTreeFolder<'_> { + fn fold_ty(&mut self, ty: &Ty) -> Ty { + let kind = match &ty.kind { + TyKind::Vec(inner) => { + TyKind::Vec(Arc::new(self.fold_ty(inner.as_ref()))) + } + TyKind::Set(inner) => { + TyKind::BTreeSet(Arc::new(self.fold_ty(inner.as_ref()))) + } + TyKind::Map(k, v) => TyKind::BTreeMap( + Arc::new(self.fold_ty(k.as_ref())), + Arc::new(self.fold_ty(v.as_ref())), + ), + kind => kind.clone(), + }; + Ty { + kind, + tags_id: self.0.tags_id_counter.inc_one(), + } + } + } + ty = BTreeFolder(self).fold_ty(&ty); + } + }; + if let Some(RustWrapperArc(true)) = tags.get::() { struct ArcFolder<'a>(&'a mut Resolver); impl Folder for ArcFolder<'_> { @@ -302,9 +330,15 @@ impl Resolver { let kind = match &ty.kind { TyKind::Vec(inner) => TyKind::Vec(Arc::new(self.fold_ty(inner.as_ref()))), TyKind::Set(inner) => TyKind::Set(Arc::new(self.fold_ty(inner.as_ref()))), + TyKind::BTreeSet(inner) => { + TyKind::BTreeSet(Arc::new(self.fold_ty(inner.as_ref()))) + } TyKind::Map(k, v) => { TyKind::Map(k.clone(), Arc::new(self.fold_ty(v.as_ref()))) } + TyKind::BTreeMap(k, v) => { + TyKind::BTreeMap(k.clone(), Arc::new(self.fold_ty(v.as_ref()))) + } TyKind::Path(_) | TyKind::String | TyKind::BytesVec => { TyKind::Arc(Arc::new(ty.clone())) } @@ -660,10 +694,13 @@ impl Resolver { } } - fn lower_type_alias(&mut self, t: &ir::NewType) -> NewType { + fn lower_type_alias(&mut self, t: &ir::NewType, tags: &Tags) -> NewType { NewType { name: t.name.clone(), - ty: self.lower_type(&t.ty, false), + ty: { + let ty = self.lower_type(&t.ty, false); + self.modify_ty_by_tags(ty, tags) + }, } } @@ -683,10 +720,13 @@ impl Resolver { } } - fn lower_const(&mut self, c: &ir::Const) -> Const { + fn lower_const(&mut self, c: &ir::Const, tags: &Tags) -> Const { Const { name: c.name.clone(), - ty: self.lower_type(&c.ty, false), + ty: { + let ty = self.lower_type(&c.ty, false); + self.modify_ty_by_tags(ty, tags) + }, lit: self.lower_lit(&c.lit), } } @@ -734,8 +774,8 @@ impl Resolver { ir::ItemKind::Message(s) => Item::Message(self.lower_message(s)), ir::ItemKind::Enum(e) => Item::Enum(self.lower_enum(e)), ir::ItemKind::Service(s) => Item::Service(self.lower_service(s)), - ir::ItemKind::NewType(t) => Item::NewType(self.lower_type_alias(t)), - ir::ItemKind::Const(c) => Item::Const(self.lower_const(c)), + ir::ItemKind::NewType(t) => Item::NewType(self.lower_type_alias(t, tags)), + ir::ItemKind::Const(c) => Item::Const(self.lower_const(c, tags)), ir::ItemKind::Mod(m) => Item::Mod(self.lower_mod(m, def_id)), ir::ItemKind::Use(_) => unreachable!(), }); diff --git a/pilota-build/test_data/thrift/btree.rs b/pilota-build/test_data/thrift/btree.rs new file mode 100644 index 00000000..ed23456f --- /dev/null +++ b/pilota-build/test_data/thrift/btree.rs @@ -0,0 +1,847 @@ +pub mod btree { + #![allow(warnings, clippy::all)] + + pub mod btree { + #[derive(PartialOrd, Hash, Eq, Ord, Debug, Default, Clone, PartialEq)] + pub struct A {} + impl ::pilota::thrift::Message for A { + fn encode( + &self, + __protocol: &mut T, + ) -> ::std::result::Result<(), ::pilota::thrift::ThriftException> { + #[allow(unused_imports)] + use ::pilota::thrift::TOutputProtocolExt; + let struct_ident = ::pilota::thrift::TStructIdentifier { name: "A" }; + + __protocol.write_struct_begin(&struct_ident)?; + + __protocol.write_field_stop()?; + __protocol.write_struct_end()?; + ::std::result::Result::Ok(()) + } + + fn decode( + __protocol: &mut T, + ) -> ::std::result::Result { + #[allow(unused_imports)] + use ::pilota::{thrift::TLengthProtocolExt, Buf}; + + let mut __pilota_decoding_field_id = None; + + __protocol.read_struct_begin()?; + if let ::std::result::Result::Err(mut err) = (|| { + loop { + let field_ident = __protocol.read_field_begin()?; + if field_ident.field_type == ::pilota::thrift::TType::Stop { + __protocol.field_stop_len(); + break; + } else { + __protocol.field_begin_len(field_ident.field_type, field_ident.id); + } + __pilota_decoding_field_id = field_ident.id; + match field_ident.id { + _ => { + __protocol.skip(field_ident.field_type)?; + } + } + + __protocol.read_field_end()?; + __protocol.field_end_len(); + } + ::std::result::Result::Ok::<_, ::pilota::thrift::ThriftException>(()) + })() { + if let Some(field_id) = __pilota_decoding_field_id { + err.prepend_msg(&format!( + "decode struct `A` field(#{}) failed, caused by: ", + field_id + )); + } + return ::std::result::Result::Err(err); + }; + __protocol.read_struct_end()?; + + let data = Self {}; + ::std::result::Result::Ok(data) + } + + fn decode_async<'a, T: ::pilota::thrift::TAsyncInputProtocol>( + __protocol: &'a mut T, + ) -> ::std::pin::Pin< + ::std::boxed::Box< + dyn ::std::future::Future< + Output = ::std::result::Result, + > + Send + + 'a, + >, + > { + ::std::boxed::Box::pin(async move { + let mut __pilota_decoding_field_id = None; + + __protocol.read_struct_begin().await?; + if let ::std::result::Result::Err(mut err) = async { + loop { + let field_ident = __protocol.read_field_begin().await?; + if field_ident.field_type == ::pilota::thrift::TType::Stop { + break; + } else { + } + __pilota_decoding_field_id = field_ident.id; + match field_ident.id { + _ => { + __protocol.skip(field_ident.field_type).await?; + } + } + + __protocol.read_field_end().await?; + } + ::std::result::Result::Ok::<_, ::pilota::thrift::ThriftException>(()) + } + .await + { + if let Some(field_id) = __pilota_decoding_field_id { + err.prepend_msg(&format!( + "decode struct `A` field(#{}) failed, caused by: ", + field_id + )); + } + return ::std::result::Result::Err(err); + }; + __protocol.read_struct_end().await?; + + let data = Self {}; + ::std::result::Result::Ok(data) + }) + } + + fn size(&self, __protocol: &mut T) -> usize { + #[allow(unused_imports)] + use ::pilota::thrift::TLengthProtocolExt; + __protocol.struct_begin_len(&::pilota::thrift::TStructIdentifier { name: "A" }) + + __protocol.field_stop_len() + + __protocol.struct_end_len() + } + } + #[derive(PartialOrd, Hash, Eq, Ord, Debug, Default, Clone, PartialEq)] + pub struct TypeA( + pub ::std::collections::BTreeMap<::std::collections::BTreeSet, ::pilota::FastStr>, + ); + + impl ::std::ops::Deref for TypeA { + type Target = + ::std::collections::BTreeMap<::std::collections::BTreeSet, ::pilota::FastStr>; + + fn deref(&self) -> &Self::Target { + &self.0 + } + } + + impl + From<::std::collections::BTreeMap<::std::collections::BTreeSet, ::pilota::FastStr>> + for TypeA + { + fn from( + v: ::std::collections::BTreeMap< + ::std::collections::BTreeSet, + ::pilota::FastStr, + >, + ) -> Self { + Self(v) + } + } + + impl ::pilota::thrift::Message for TypeA { + fn encode( + &self, + __protocol: &mut T, + ) -> ::std::result::Result<(), ::pilota::thrift::ThriftException> { + #[allow(unused_imports)] + use ::pilota::thrift::TOutputProtocolExt; + __protocol.write_btree_map( + ::pilota::thrift::TType::Set, + ::pilota::thrift::TType::Binary, + &(&**self), + |__protocol, key| { + __protocol.write_btree_set( + ::pilota::thrift::TType::I32, + &key, + |__protocol, val| { + __protocol.write_i32(*val)?; + ::std::result::Result::Ok(()) + }, + )?; + ::std::result::Result::Ok(()) + }, + |__protocol, val| { + __protocol.write_faststr((val).clone())?; + ::std::result::Result::Ok(()) + }, + )?; + ::std::result::Result::Ok(()) + } + + fn decode( + __protocol: &mut T, + ) -> ::std::result::Result { + #[allow(unused_imports)] + use ::pilota::{thrift::TLengthProtocolExt, Buf}; + ::std::result::Result::Ok(TypeA({ + let map_ident = __protocol.read_map_begin()?; + let mut val = ::std::collections::BTreeMap::new(); + for _ in 0..map_ident.size { + val.insert( + { + let list_ident = __protocol.read_set_begin()?; + let mut val = ::std::collections::BTreeSet::new(); + for _ in 0..list_ident.size { + val.insert(__protocol.read_i32()?); + } + __protocol.read_set_end()?; + val + }, + __protocol.read_faststr()?, + ); + } + __protocol.read_map_end()?; + val + })) + } + + fn decode_async<'a, T: ::pilota::thrift::TAsyncInputProtocol>( + __protocol: &'a mut T, + ) -> ::std::pin::Pin< + ::std::boxed::Box< + dyn ::std::future::Future< + Output = ::std::result::Result, + > + Send + + 'a, + >, + > { + ::std::boxed::Box::pin(async move { + ::std::result::Result::Ok(TypeA({ + let map_ident = __protocol.read_map_begin().await?; + let mut val = ::std::collections::BTreeMap::new(); + for _ in 0..map_ident.size { + val.insert( + { + let list_ident = __protocol.read_set_begin().await?; + let mut val = ::std::collections::BTreeSet::new(); + for _ in 0..list_ident.size { + val.insert(__protocol.read_i32().await?); + } + __protocol.read_set_end().await?; + val + }, + __protocol.read_faststr().await?, + ); + } + __protocol.read_map_end().await?; + val + })) + }) + } + + fn size(&self, __protocol: &mut T) -> usize { + #[allow(unused_imports)] + use ::pilota::thrift::TLengthProtocolExt; + __protocol.btree_map_len( + ::pilota::thrift::TType::Set, + ::pilota::thrift::TType::Binary, + &**self, + |__protocol, key| { + __protocol.btree_set_len( + ::pilota::thrift::TType::I32, + key, + |__protocol, el| __protocol.i32_len(*el), + ) + }, + |__protocol, val| __protocol.faststr_len(val), + ) + } + } + #[derive(PartialOrd, Hash, Eq, Ord, Debug, Default, Clone, PartialEq)] + pub struct B { + pub m: ::std::collections::BTreeMap>>, + + pub s: ::std::collections::BTreeSet, + + pub m2: ::std::collections::BTreeMap< + ::std::vec::Vec< + ::std::collections::BTreeMap<::std::collections::BTreeSet, i32>, + >, + ::std::collections::BTreeSet, + >, + } + impl ::pilota::thrift::Message for B { + fn encode( + &self, + __protocol: &mut T, + ) -> ::std::result::Result<(), ::pilota::thrift::ThriftException> { + #[allow(unused_imports)] + use ::pilota::thrift::TOutputProtocolExt; + let struct_ident = ::pilota::thrift::TStructIdentifier { name: "B" }; + + __protocol.write_struct_begin(&struct_ident)?; + __protocol.write_btree_map_field( + 1, + ::pilota::thrift::TType::I32, + ::pilota::thrift::TType::List, + &&self.m, + |__protocol, key| { + __protocol.write_i32(*key)?; + ::std::result::Result::Ok(()) + }, + |__protocol, val| { + __protocol.write_list( + ::pilota::thrift::TType::Struct, + &val, + |__protocol, val| { + __protocol.write_struct(val)?; + ::std::result::Result::Ok(()) + }, + )?; + ::std::result::Result::Ok(()) + }, + )?; + __protocol.write_btree_set_field( + 2, + ::pilota::thrift::TType::I32, + &&self.s, + |__protocol, val| { + __protocol.write_i32(*val)?; + ::std::result::Result::Ok(()) + }, + )?; + __protocol.write_btree_map_field( + 3, + ::pilota::thrift::TType::List, + ::pilota::thrift::TType::Set, + &&self.m2, + |__protocol, key| { + __protocol.write_list( + ::pilota::thrift::TType::Map, + &key, + |__protocol, val| { + __protocol.write_btree_map( + ::pilota::thrift::TType::Set, + ::pilota::thrift::TType::I32, + &val, + |__protocol, key| { + __protocol.write_btree_set( + ::pilota::thrift::TType::I32, + &key, + |__protocol, val| { + __protocol.write_i32(*val)?; + ::std::result::Result::Ok(()) + }, + )?; + ::std::result::Result::Ok(()) + }, + |__protocol, val| { + __protocol.write_i32(*val)?; + ::std::result::Result::Ok(()) + }, + )?; + ::std::result::Result::Ok(()) + }, + )?; + ::std::result::Result::Ok(()) + }, + |__protocol, val| { + __protocol.write_btree_set( + ::pilota::thrift::TType::I32, + &val, + |__protocol, val| { + __protocol.write_i32(*val)?; + ::std::result::Result::Ok(()) + }, + )?; + ::std::result::Result::Ok(()) + }, + )?; + __protocol.write_field_stop()?; + __protocol.write_struct_end()?; + ::std::result::Result::Ok(()) + } + + fn decode( + __protocol: &mut T, + ) -> ::std::result::Result { + #[allow(unused_imports)] + use ::pilota::{thrift::TLengthProtocolExt, Buf}; + + let mut var_1 = None; + let mut var_2 = None; + let mut var_3 = None; + + let mut __pilota_decoding_field_id = None; + + __protocol.read_struct_begin()?; + if let ::std::result::Result::Err(mut err) = (|| { + loop { + let field_ident = __protocol.read_field_begin()?; + if field_ident.field_type == ::pilota::thrift::TType::Stop { + __protocol.field_stop_len(); + break; + } else { + __protocol.field_begin_len(field_ident.field_type, field_ident.id); + } + __pilota_decoding_field_id = field_ident.id; + match field_ident.id { + Some(1) if field_ident.field_type == ::pilota::thrift::TType::Map => { + var_1 = Some({ + let map_ident = __protocol.read_map_begin()?; + let mut val = ::std::collections::BTreeMap::new(); + for _ in 0..map_ident.size { + val.insert(__protocol.read_i32()?, unsafe { + let list_ident = __protocol.read_list_begin()?; + let mut val: Vec<::std::sync::Arc> = + Vec::with_capacity(list_ident.size); + for i in 0..list_ident.size { + val.as_mut_ptr().offset(i as isize).write( + ::std::sync::Arc::new( + ::pilota::thrift::Message::decode( + __protocol, + )?, + ), + ); + } + val.set_len(list_ident.size); + __protocol.read_list_end()?; + val + }); + } + __protocol.read_map_end()?; + val + }); + } + Some(2) if field_ident.field_type == ::pilota::thrift::TType::Set => { + var_2 = Some({ + let list_ident = __protocol.read_set_begin()?; + let mut val = ::std::collections::BTreeSet::new(); + for _ in 0..list_ident.size { + val.insert(__protocol.read_i32()?); + } + __protocol.read_set_end()?; + val + }); + } + Some(3) if field_ident.field_type == ::pilota::thrift::TType::Map => { + var_3 = Some({ + let map_ident = __protocol.read_map_begin()?; + let mut val = ::std::collections::BTreeMap::new(); + for _ in 0..map_ident.size { + val.insert( + unsafe { + let list_ident = __protocol.read_list_begin()?; + let mut val: Vec< + ::std::collections::BTreeMap< + ::std::collections::BTreeSet, + i32, + >, + > = Vec::with_capacity(list_ident.size); + for i in 0..list_ident.size { + val.as_mut_ptr().offset(i as isize).write({ + let map_ident = __protocol.read_map_begin()?; + let mut val = ::std::collections::BTreeMap::new(); + for _ in 0..map_ident.size { + val.insert({let list_ident = __protocol.read_set_begin()?; + let mut val = ::std::collections::BTreeSet::new(); + for _ in 0..list_ident.size { + val.insert(__protocol.read_i32()?); + }; + __protocol.read_set_end()?; + val}, __protocol.read_i32()?); + } + __protocol.read_map_end()?; + val + }); + } + val.set_len(list_ident.size); + __protocol.read_list_end()?; + val + }, + { + let list_ident = __protocol.read_set_begin()?; + let mut val = ::std::collections::BTreeSet::new(); + for _ in 0..list_ident.size { + val.insert(__protocol.read_i32()?); + } + __protocol.read_set_end()?; + val + }, + ); + } + __protocol.read_map_end()?; + val + }); + } + _ => { + __protocol.skip(field_ident.field_type)?; + } + } + + __protocol.read_field_end()?; + __protocol.field_end_len(); + } + ::std::result::Result::Ok::<_, ::pilota::thrift::ThriftException>(()) + })() { + if let Some(field_id) = __pilota_decoding_field_id { + err.prepend_msg(&format!( + "decode struct `B` field(#{}) failed, caused by: ", + field_id + )); + } + return ::std::result::Result::Err(err); + }; + __protocol.read_struct_end()?; + + let Some(var_1) = var_1 else { + return ::std::result::Result::Err(::pilota::thrift::new_protocol_exception( + ::pilota::thrift::ProtocolExceptionKind::InvalidData, + "field m is required".to_string(), + )); + }; + let Some(var_2) = var_2 else { + return ::std::result::Result::Err(::pilota::thrift::new_protocol_exception( + ::pilota::thrift::ProtocolExceptionKind::InvalidData, + "field s is required".to_string(), + )); + }; + let Some(var_3) = var_3 else { + return ::std::result::Result::Err(::pilota::thrift::new_protocol_exception( + ::pilota::thrift::ProtocolExceptionKind::InvalidData, + "field m2 is required".to_string(), + )); + }; + + let data = Self { + m: var_1, + s: var_2, + m2: var_3, + }; + ::std::result::Result::Ok(data) + } + + fn decode_async<'a, T: ::pilota::thrift::TAsyncInputProtocol>( + __protocol: &'a mut T, + ) -> ::std::pin::Pin< + ::std::boxed::Box< + dyn ::std::future::Future< + Output = ::std::result::Result, + > + Send + + 'a, + >, + > { + ::std::boxed::Box::pin(async move { + let mut var_1 = None; + let mut var_2 = None; + let mut var_3 = None; + + let mut __pilota_decoding_field_id = None; + + __protocol.read_struct_begin().await?; + if let ::std::result::Result::Err(mut err) = async { + loop { + + + let field_ident = __protocol.read_field_begin().await?; + if field_ident.field_type == ::pilota::thrift::TType::Stop { + + break; + } else { + + } + __pilota_decoding_field_id = field_ident.id; + match field_ident.id { + Some(1) if field_ident.field_type == ::pilota::thrift::TType::Map => { + var_1 = Some({ + let map_ident = __protocol.read_map_begin().await?; + let mut val = ::std::collections::BTreeMap::new(); + for _ in 0..map_ident.size { + val.insert(__protocol.read_i32().await?, { + let list_ident = __protocol.read_list_begin().await?; + let mut val = Vec::with_capacity(list_ident.size); + for _ in 0..list_ident.size { + val.push(::std::sync::Arc::new(::decode_async(__protocol).await?)); + }; + __protocol.read_list_end().await?; + val + }); + } + __protocol.read_map_end().await?; + val + }); + + },Some(2) if field_ident.field_type == ::pilota::thrift::TType::Set => { + var_2 = Some({let list_ident = __protocol.read_set_begin().await?; + let mut val = ::std::collections::BTreeSet::new(); + for _ in 0..list_ident.size { + val.insert(__protocol.read_i32().await?); + }; + __protocol.read_set_end().await?; + val}); + + },Some(3) if field_ident.field_type == ::pilota::thrift::TType::Map => { + var_3 = Some({ + let map_ident = __protocol.read_map_begin().await?; + let mut val = ::std::collections::BTreeMap::new(); + for _ in 0..map_ident.size { + val.insert({ + let list_ident = __protocol.read_list_begin().await?; + let mut val = Vec::with_capacity(list_ident.size); + for _ in 0..list_ident.size { + val.push({ + let map_ident = __protocol.read_map_begin().await?; + let mut val = ::std::collections::BTreeMap::new(); + for _ in 0..map_ident.size { + val.insert({let list_ident = __protocol.read_set_begin().await?; + let mut val = ::std::collections::BTreeSet::new(); + for _ in 0..list_ident.size { + val.insert(__protocol.read_i32().await?); + }; + __protocol.read_set_end().await?; + val}, __protocol.read_i32().await?); + } + __protocol.read_map_end().await?; + val + }); + }; + __protocol.read_list_end().await?; + val + }, {let list_ident = __protocol.read_set_begin().await?; + let mut val = ::std::collections::BTreeSet::new(); + for _ in 0..list_ident.size { + val.insert(__protocol.read_i32().await?); + }; + __protocol.read_set_end().await?; + val}); + } + __protocol.read_map_end().await?; + val + }); + + }, + _ => { + __protocol.skip(field_ident.field_type).await?; + + }, + } + + __protocol.read_field_end().await?; + + + }; + ::std::result::Result::Ok::<_, ::pilota::thrift::ThriftException>(()) + }.await { + if let Some(field_id) = __pilota_decoding_field_id { + err.prepend_msg(&format!("decode struct `B` field(#{}) failed, caused by: ", field_id)); + } + return ::std::result::Result::Err(err); + }; + __protocol.read_struct_end().await?; + + let Some(var_1) = var_1 else { + return ::std::result::Result::Err( + ::pilota::thrift::new_protocol_exception( + ::pilota::thrift::ProtocolExceptionKind::InvalidData, + "field m is required".to_string(), + ), + ); + }; + let Some(var_2) = var_2 else { + return ::std::result::Result::Err( + ::pilota::thrift::new_protocol_exception( + ::pilota::thrift::ProtocolExceptionKind::InvalidData, + "field s is required".to_string(), + ), + ); + }; + let Some(var_3) = var_3 else { + return ::std::result::Result::Err( + ::pilota::thrift::new_protocol_exception( + ::pilota::thrift::ProtocolExceptionKind::InvalidData, + "field m2 is required".to_string(), + ), + ); + }; + + let data = Self { + m: var_1, + s: var_2, + m2: var_3, + }; + ::std::result::Result::Ok(data) + }) + } + + fn size(&self, __protocol: &mut T) -> usize { + #[allow(unused_imports)] + use ::pilota::thrift::TLengthProtocolExt; + __protocol.struct_begin_len(&::pilota::thrift::TStructIdentifier { name: "B" }) + + __protocol.btree_map_field_len( + Some(1), + ::pilota::thrift::TType::I32, + ::pilota::thrift::TType::List, + &self.m, + |__protocol, key| __protocol.i32_len(*key), + |__protocol, val| { + __protocol.list_len( + ::pilota::thrift::TType::Struct, + val, + |__protocol, el| __protocol.struct_len(el), + ) + }, + ) + + __protocol.btree_set_field_len( + Some(2), + ::pilota::thrift::TType::I32, + &self.s, + |__protocol, el| __protocol.i32_len(*el), + ) + + __protocol.btree_map_field_len( + Some(3), + ::pilota::thrift::TType::List, + ::pilota::thrift::TType::Set, + &self.m2, + |__protocol, key| { + __protocol.list_len( + ::pilota::thrift::TType::Map, + key, + |__protocol, el| { + __protocol.btree_map_len( + ::pilota::thrift::TType::Set, + ::pilota::thrift::TType::I32, + el, + |__protocol, key| { + __protocol.btree_set_len( + ::pilota::thrift::TType::I32, + key, + |__protocol, el| __protocol.i32_len(*el), + ) + }, + |__protocol, val| __protocol.i32_len(*val), + ) + }, + ) + }, + |__protocol, val| { + __protocol.btree_set_len( + ::pilota::thrift::TType::I32, + val, + |__protocol, el| __protocol.i32_len(*el), + ) + }, + ) + + __protocol.field_stop_len() + + __protocol.struct_end_len() + } + } + pub static TEST_MAP_LIST: ::std::sync::LazyLock< + ::std::collections::BTreeMap>, + > = ::std::sync::LazyLock::new(|| { + let mut map = ::std::collections::BTreeMap::new(); + map.insert(1i32, ::std::vec!["hello"]); + map + }); + + pub static TEST_MAP: ::std::sync::LazyLock< + ::std::collections::BTreeMap, + > = ::std::sync::LazyLock::new(|| { + let mut map = ::std::collections::BTreeMap::new(); + map.insert(Index::A, "hello"); + map.insert(Index::B, "world"); + map + }); + #[derive(PartialOrd, Hash, Eq, Ord, Debug, ::pilota::derivative::Derivative)] + #[derivative(Default)] + #[derive(Clone, PartialEq, Copy)] + #[repr(transparent)] + pub struct Index(i32); + + impl Index { + pub const A: Self = Self(0); + pub const B: Self = Self(1); + + pub fn inner(&self) -> i32 { + self.0 + } + + pub fn to_string(&self) -> ::std::string::String { + match self { + Self(0) => ::std::string::String::from("A"), + Self(1) => ::std::string::String::from("B"), + Self(val) => val.to_string(), + } + } + } + + impl ::std::convert::From for Index { + fn from(value: i32) -> Self { + Self(value) + } + } + + impl ::std::convert::From for i32 { + fn from(value: Index) -> i32 { + value.0 + } + } + + impl ::pilota::thrift::Message for Index { + fn encode( + &self, + __protocol: &mut T, + ) -> ::std::result::Result<(), ::pilota::thrift::ThriftException> { + #[allow(unused_imports)] + use ::pilota::thrift::TOutputProtocolExt; + __protocol.write_i32(self.inner())?; + ::std::result::Result::Ok(()) + } + + fn decode( + __protocol: &mut T, + ) -> ::std::result::Result { + #[allow(unused_imports)] + use ::pilota::{thrift::TLengthProtocolExt, Buf}; + let value = __protocol.read_i32()?; + ::std::result::Result::Ok(::std::convert::TryFrom::try_from(value).map_err( + |err| { + ::pilota::thrift::new_protocol_exception( + ::pilota::thrift::ProtocolExceptionKind::InvalidData, + format!("invalid enum value for Index, value: {}", value), + ) + }, + )?) + } + + fn decode_async<'a, T: ::pilota::thrift::TAsyncInputProtocol>( + __protocol: &'a mut T, + ) -> ::std::pin::Pin< + ::std::boxed::Box< + dyn ::std::future::Future< + Output = ::std::result::Result, + > + Send + + 'a, + >, + > { + ::std::boxed::Box::pin(async move { + let value = __protocol.read_i32().await?; + ::std::result::Result::Ok(::std::convert::TryFrom::try_from(value).map_err( + |err| { + ::pilota::thrift::new_protocol_exception( + ::pilota::thrift::ProtocolExceptionKind::InvalidData, + format!("invalid enum value for Index, value: {}", value), + ) + }, + )?) + }) + } + + fn size(&self, __protocol: &mut T) -> usize { + #[allow(unused_imports)] + use ::pilota::thrift::TLengthProtocolExt; + __protocol.i32_len(self.inner()) + } + } + } +} diff --git a/pilota-build/test_data/thrift/btree.thrift b/pilota-build/test_data/thrift/btree.thrift new file mode 100644 index 00000000..6ca9953b --- /dev/null +++ b/pilota-build/test_data/thrift/btree.thrift @@ -0,0 +1,25 @@ +struct A { + +} + +struct B { + 1: required map> m(pilota.rust_type = "btree", pilota.rust_wrapper_arc = "true"), + 2: required set s(pilota.rust_type = "btree"), + 3: required map, i32>>, set> m2(pilota.rust_type = "btree"), +} + +const map> TEST_MAP_LIST = { + 1: ["hello"] +}(pilota.rust_type = "btree") + +enum Index { + A = 0, + B = 1, +} + +const map TEST_MAP = { + Index.A: "hello", + Index.B: "world", +}(pilota.rust_type = "btree") + +typedef map, string> TypeA(pilota.rust_type = "btree") \ No newline at end of file diff --git a/pilota/Cargo.toml b/pilota/Cargo.toml index 33a73341..608da0c3 100644 --- a/pilota/Cargo.toml +++ b/pilota/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pilota" -version = "0.11.7" +version = "0.11.8" edition = "2021" description = "Pilota is a thrift and protobuf implementation in pure rust with high performance and extensibility." documentation = "https://docs.rs/pilota" diff --git a/pilota/src/thrift/mod.rs b/pilota/src/thrift/mod.rs index b1f187b7..e888c2df 100644 --- a/pilota/src/thrift/mod.rs +++ b/pilota/src/thrift/mod.rs @@ -7,7 +7,12 @@ pub mod rw_ext; pub mod unknown; pub mod varint_ext; -use std::{future::Future, ops::Deref, sync::Arc}; +use std::{ + collections::{BTreeMap, BTreeSet}, + future::Future, + ops::Deref, + sync::Arc, +}; use bytes::{Buf, BufMut, Bytes}; pub use error::*; @@ -278,6 +283,63 @@ macro_rules! field_len { }; } +macro_rules! set_field_len { + ($name:ident($t:ty)) => { + paste::paste! { + #[inline] + fn [<$name _field_len>](&mut self, id: Option, el_ttype: TType, els: &$t, el_len: F,) -> usize + where + F: Fn(&mut Self, &T) -> usize, + { + self.field_begin_len(TType::Set, id) + + self.[<$name _len>](el_ttype, els, el_len) + + self.field_end_len() + } + + #[inline] + fn [<$name _len>](&mut self, el_ttype: TType, els: &$t, el_len: F) -> usize + where + F: Fn(&mut Self, &T) -> usize, + { + self.set_begin_len(TSetIdentifier { + element_type: el_ttype, + size: els.len(), + }) + els.iter().map(|el| el_len(self, el)).sum::() + self.set_end_len() + } + } + }; +} + +macro_rules! map_field_len { + ($name:ident($t:ty)) => { + paste::paste! { + #[inline] + fn [<$name _field_len>](&mut self, id: Option, key_ttype: TType, val_ttype: TType, els: &$t, key_len: FK, val_len: FV,) -> usize + where + FK: Fn(&mut Self, &K) -> usize, + FV: Fn(&mut Self, &V) -> usize, + { + self.field_begin_len(TType::Map, id) + + self.[<$name _len>](key_ttype, val_ttype, els, key_len, val_len) + + self.field_end_len() + } + + #[inline] + fn [<$name _len>](&mut self, key_ttype: TType, val_ttype: TType, els: &$t, key_len: FK, val_len: FV,) -> usize + where + FK: Fn(&mut Self, &K) -> usize, + FV: Fn(&mut Self, &V) -> usize, + { + self.map_begin_len(TMapIdentifier { + key_type: key_ttype, + value_type: val_ttype, + size: els.len(), + }) + els.iter().map(|(k, v)| key_len(self, k) + val_len(self, v)).sum::() + self.map_end_len() + } + } + }; +} + pub trait TLengthProtocolExt: TLengthProtocol + Sized { field_len!(TType::Bool, bool(b: bool)); field_len!(TType::I8, i8(i: i8)); @@ -321,81 +383,16 @@ pub trait TLengthProtocolExt: TLengthProtocol + Sized { + self.list_end_len() } - #[inline] - fn set_field_len( - &mut self, - id: Option, - el_ttype: TType, - els: &AHashSet, - el_len: F, - ) -> usize - where - F: Fn(&mut Self, &T) -> usize, - { - self.field_begin_len(TType::Set, id) - + self.set_len(el_ttype, els, el_len) - + self.field_end_len() - } - - #[inline] - fn set_len(&mut self, el_ttype: TType, els: &AHashSet, el_len: F) -> usize - where - F: Fn(&mut Self, &T) -> usize, - { - self.set_begin_len(TSetIdentifier { - element_type: el_ttype, - size: els.len(), - }) + els.iter().map(|el| el_len(self, el)).sum::() - + self.set_end_len() - } + set_field_len!(set(AHashSet)); + set_field_len!(btree_set(BTreeSet)); #[inline] fn message_len(&mut self, id: Option, m: &M) -> usize { self.field_begin_len(TType::Struct, id) + m.size(self) + self.field_end_len() } - #[inline] - fn map_field_len( - &mut self, - id: Option, - key_ttype: TType, - val_ttype: TType, - els: &AHashMap, - key_len: FK, - val_len: FV, - ) -> usize - where - FK: Fn(&mut Self, &K) -> usize, - FV: Fn(&mut Self, &V) -> usize, - { - self.field_begin_len(TType::Map, id) - + self.map_len(key_ttype, val_ttype, els, key_len, val_len) - + self.field_end_len() - } - - #[inline] - fn map_len( - &mut self, - key_ttype: TType, - val_ttype: TType, - els: &AHashMap, - key_len: FK, - val_len: FV, - ) -> usize - where - FK: Fn(&mut Self, &K) -> usize, - FV: Fn(&mut Self, &V) -> usize, - { - self.map_begin_len(TMapIdentifier { - key_type: key_ttype, - value_type: val_ttype, - size: els.len(), - }) + els - .iter() - .map(|(k, v)| key_len(self, k) + val_len(self, v)) - .sum::() - + self.map_end_len() - } + map_field_len!(map(AHashMap)); + map_field_len!(btree_map(BTreeMap)); #[inline] fn void_len(&mut self) -> usize { @@ -491,6 +488,72 @@ macro_rules! write_field { }; } +macro_rules! write_set_field { + ($name:ident($t:ty)) => { + paste::paste! { + #[inline] + fn [](&mut self, id: i16, el_ttype: TType, els: &$t, encode: F,) -> Result<(), ThriftException> + where + F: Fn(&mut Self, &T) -> Result<(), ThriftException>, + { + self.write_field_begin(TType::Set, id)?; + self.[](el_ttype, els, encode)?; + self.write_field_end() + } + + #[inline] + fn [](&mut self, el_ttype: TType, els: &$t, encode: F,) -> Result<(), ThriftException> + where + F: Fn(&mut Self, &T) -> Result<(), ThriftException>, + { + self.write_set_begin(TSetIdentifier { + element_type: el_ttype, + size: els.len(), + })?; + for el in els { + encode(self, el)?; + } + self.write_set_end() + } + } + }; +} + +macro_rules! write_map_field { + ($name:ident($t:ty)) => { + paste::paste! { + #[inline] + fn [](&mut self, id: i16, key_ttype: TType, val_ttype: TType, els: &$t, key_encode: FK, val_encode: FV,) -> Result<(), ThriftException> + where + FK: Fn(&mut Self, &K) -> Result<(), ThriftException>, + FV: Fn(&mut Self, &V) -> Result<(), ThriftException>, + { + self.write_field_begin(TType::Map, id)?; + self.[](key_ttype, val_ttype, els, key_encode, val_encode)?; + self.write_field_end() + } + + #[inline] + fn [](&mut self, key_ttype: TType, val_ttype: TType, els: &$t, key_encode: FK, val_encode: FV,) -> Result<(), ThriftException> + where + FK: Fn(&mut Self, &K) -> Result<(), ThriftException>, + FV: Fn(&mut Self, &V) -> Result<(), ThriftException>, + { + self.write_map_begin(TMapIdentifier { + key_type: key_ttype, + value_type: val_ttype, + size: els.len(), + })?; + for (k, v) in els { + key_encode(self, k)?; + val_encode(self, v)?; + } + self.write_map_end() + } + } + }; +} + pub trait TOutputProtocolExt: TOutputProtocol + Sized { write_field!(TType::Bool, bool(b: bool)); write_field!(TType::I8, i8(i: i8)); @@ -542,41 +605,8 @@ pub trait TOutputProtocolExt: TOutputProtocol + Sized { self.write_list_end() } - #[inline] - fn write_set_field( - &mut self, - id: i16, - el_ttype: TType, - els: &AHashSet, - encode: F, - ) -> Result<(), ThriftException> - where - F: Fn(&mut Self, &T) -> Result<(), ThriftException>, - { - self.write_field_begin(TType::Set, id)?; - self.write_set(el_ttype, els, encode)?; - self.write_field_end() - } - - #[inline] - fn write_set( - &mut self, - el_ttype: TType, - els: &AHashSet, - encode: F, - ) -> Result<(), ThriftException> - where - F: Fn(&mut Self, &T) -> Result<(), ThriftException>, - { - self.write_set_begin(TSetIdentifier { - element_type: el_ttype, - size: els.len(), - })?; - for el in els { - encode(self, el)? - } - self.write_set_end() - } + write_set_field!(set(AHashSet)); + write_set_field!(btree_set(BTreeSet)); #[inline] fn write_struct_field( @@ -595,49 +625,8 @@ pub trait TOutputProtocolExt: TOutputProtocol + Sized { m.encode(self) } - #[inline] - fn write_map_field( - &mut self, - id: i16, - key_ttype: TType, - val_ttype: TType, - els: &AHashMap, - key_encode: FK, - val_encode: FV, - ) -> Result<(), ThriftException> - where - FK: Fn(&mut Self, &K) -> Result<(), ThriftException>, - FV: Fn(&mut Self, &V) -> Result<(), ThriftException>, - { - self.write_field_begin(TType::Map, id)?; - self.write_map(key_ttype, val_ttype, els, key_encode, val_encode)?; - self.write_field_end() - } - - #[inline] - fn write_map( - &mut self, - key_ttype: TType, - val_ttype: TType, - els: &AHashMap, - key_encode: FK, - val_encode: FV, - ) -> Result<(), ThriftException> - where - FK: Fn(&mut Self, &K) -> Result<(), ThriftException>, - FV: Fn(&mut Self, &V) -> Result<(), ThriftException>, - { - self.write_map_begin(TMapIdentifier { - key_type: key_ttype, - value_type: val_ttype, - size: els.len(), - })?; - for (k, v) in els { - key_encode(self, k)?; - val_encode(self, v)?; - } - self.write_map_end() - } + write_map_field!(map(AHashMap)); + write_map_field!(btree_map(BTreeMap)); #[inline] fn write_void(&mut self) -> Result<(), ThriftException> {