diff --git a/Cargo.lock b/Cargo.lock index 7a394de..ad38ca2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -216,9 +216,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.84" +version = "1.0.85" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec96c6a92621310b51366f1e28d05ef11489516e93be030060e5fc12024a49d6" +checksum = "22244ce15aa966053a896d1accb3a6e68469b97c7f33f284b99f0d576879fc23" dependencies = [ "unicode-ident", ] @@ -320,16 +320,16 @@ dependencies = [ [[package]] name = "redox_syscall" -version = "0.5.1" +version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "469052894dcb553421e483e4209ee581a45100d31b4018de03e5a7ad86374a7e" +checksum = "c82cf8cff14456045f55ec4241383baeff27af886adb72ffb2162f99911de0fd" dependencies = [ "bitflags 2.5.0", ] [[package]] name = "rlbot-flatbuffers-py" -version = "0.3.7" +version = "0.3.8" dependencies = [ "flatbuffers", "get-size", diff --git a/Cargo.toml b/Cargo.toml index 9ed52c3..eb75adb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rlbot-flatbuffers-py" -version = "0.3.7" +version = "0.3.8" edition = "2021" description = "A Python module implemented in Rust for serializing and deserializing RLBot's flatbuffers" repository = "https://github.com/VirxEC/rlbot_flatbuffers_py" diff --git a/build.rs b/build.rs index 3146a45..9d4bef8 100644 --- a/build.rs +++ b/build.rs @@ -22,11 +22,60 @@ struct PythonBindGenerator { file_contents: Vec>, bind_type: PythonBindType, has_complex_pack: bool, + is_all_base_types: bool, + is_frozen: bool, + frozen_needs_py: bool, +} + +macro_rules! write_str { + ($self:ident, $s:expr) => { + $self.file_contents.push(Cow::Borrowed($s)) + }; +} + +macro_rules! write_fmt { + ($self:ident, $($arg:tt)*) => { + $self.file_contents.push(Cow::Owned(format!($($arg)*))) + }; } impl PythonBindGenerator { const BASE_TYPES: [&'static str; 6] = ["bool", "i32", "u32", "f32", "String", "u8"]; const SPECIAL_BASE_TYPES: [&'static str; 2] = ["FloatT", "BoolT"]; + const FROZEN_TYPES: [&'static str; 24] = [ + "FieldInfo", + "BoostPad", + "GoalInfo", + "GameTickPacket", + "PlayerInfo", + "ScoreInfo", + "BallInfo", + "Touch", + "CollisionShape", + "BoxShape", + "SphereShape", + "CylinderShape", + "BoostPadState", + "GameInfo", + "TeamInfo", + "BallPrediction", + "PredictionSlice", + "Physics", + "MessagePacket", + "GameMessageWrapper", + "GameMessage", + "PlayerInputChange", + "PlayerSpectate", + "PlayerStatEvent", + ]; + const FROZEN_NEEDS_PY: [&'static str; 6] = [ + "GameTickPacket", + "BallInfo", + "CollisionShape", + "MessagePacket", + "GameMessageWrapper", + "GameMessage", + ]; fn new(path: &Path) -> Option { // get the filename without the extension @@ -61,12 +110,21 @@ impl PythonBindGenerator { return None; }; + let is_frozen = Self::FROZEN_TYPES.contains(&struct_name.as_str()); + let frozen_needs_py = Self::FROZEN_NEEDS_PY.contains(&struct_name.as_str()); + let is_all_base_types = types.iter().all(|t| Self::BASE_TYPES.contains(&t[1].as_str())); let has_complex_pack = contents.contains("pub fn pack<'b, A: flatbuffers::Allocator + 'b>("); let mut file_contents = vec![]; file_contents.push(Cow::Borrowed(match bind_type { - PythonBindType::Struct => "use crate::{flat_err_to_py, generated::rlbot::flat, FromGil, IntoGil};", - PythonBindType::Union => "use crate::{generated::rlbot::flat, FromGil, IntoGil};", + PythonBindType::Struct => { + if (is_frozen && !frozen_needs_py) || is_all_base_types || types.is_empty() { + "use crate::{flat_err_to_py, generated::rlbot::flat, FromGil};" + } else { + "use crate::{flat_err_to_py, generated::rlbot::flat, FromGil, IntoGil};" + } + } + PythonBindType::Union => "use crate::{generated::rlbot::flat, FromGil};", PythonBindType::Enum => "use crate::{flat_err_to_py, generated::rlbot::flat};", })); @@ -83,7 +141,13 @@ impl PythonBindGenerator { } file_contents.push(Cow::Borrowed(match bind_type { - PythonBindType::Struct => "use pyo3::{pyclass, pymethods, types::PyBytes, Bound, Py, PyResult, Python};", + PythonBindType::Struct => { + if is_frozen { + "use pyo3::{pyclass, pymethods, types::PyBytes, Bound, PyResult, Python};" + } else { + "use pyo3::{pyclass, pymethods, types::PyBytes, Bound, Py, PyResult, Python};" + } + } PythonBindType::Enum => { "use pyo3::{exceptions::PyValueError, pyclass, pymethods, types::PyBytes, Bound, PyResult, Python};" } @@ -100,6 +164,9 @@ impl PythonBindGenerator { file_contents, bind_type, has_complex_pack, + is_all_base_types, + is_frozen, + frozen_needs_py, }) } @@ -247,46 +314,63 @@ impl PythonBindGenerator { )) } - fn write_str(&mut self, s: &'static str) { - self.file_contents.push(Cow::Borrowed(s)); - } - - fn write_string(&mut self, s: String) { - self.file_contents.push(Cow::Owned(s)); - } - fn generate_struct_definition(&mut self) { - self.write_str("#[pyclass(module = \"rlbot_flatbuffers\", subclass, get_all, set_all)]"); + write_str!( + self, + if self.is_frozen { + "#[pyclass(module = \"rlbot_flatbuffers\", subclass, get_all, frozen)]" + } else if self.types.is_empty() { + "#[pyclass(module = \"rlbot_flatbuffers\", subclass, frozen)]" + } else { + "#[pyclass(module = \"rlbot_flatbuffers\", subclass, get_all, set_all)]" + } + ); if self.types.is_empty() { - self.write_str("#[derive(Debug, Default, Clone)]"); - self.write_string(format!("pub struct {} {{}}", self.struct_name)); - self.write_str(""); + write_str!(self, "#[derive(Debug, Default, Clone)]"); + write_fmt!(self, "pub struct {} {{}}", self.struct_name); + write_str!(self, ""); return; } - self.write_str("#[derive(Debug, Clone)]"); - self.write_string(format!("pub struct {} {{", self.struct_name)); + write_str!( + self, + if self.is_frozen || self.is_all_base_types { + if !self.is_all_base_types || self.types.iter().any(|t| t[1] == "String") { + "#[derive(Debug, Default, Clone)]" + } else { + "#[derive(Debug, Default, Clone, Copy)]" + } + } else { + "#[derive(Debug, Clone)]" + } + ); + write_fmt!(self, "pub struct {} {{", self.struct_name); for variable_info in &self.types { let variable_name = &variable_info[0]; let mut variable_type = variable_info[1].to_string(); if variable_type.starts_with("Vec<") && variable_type.ends_with("T>") { - variable_type = format!( - "Vec>", - variable_type.trim_start_matches("Vec<").trim_end_matches("T>") - ); + let inner_type = variable_type.trim_start_matches("Vec<").trim_end_matches("T>"); + variable_type = if self.is_frozen { + format!("Vec") + } else { + format!("Vec>") + }; } else if variable_type.starts_with("Vec<") && variable_type.ends_with('>') { variable_type = String::from("Py"); } else if variable_type.starts_with("Box<") && variable_type.ends_with('>') { - variable_type = format!( - "Py", - variable_type - .trim_start_matches("Box<") - .trim_end_matches('>') - .trim_end_matches('T') - ); + let inner_type = variable_type + .trim_start_matches("Box<") + .trim_end_matches('>') + .trim_end_matches('T'); + + variable_type = if self.is_frozen { + format!("super::{inner_type}") + } else { + format!("Py") + }; } else if variable_type.starts_with("Option<") && variable_type.ends_with('>') { let inner_type = variable_type .trim_start_matches("Option<") @@ -300,34 +384,47 @@ impl PythonBindGenerator { variable_type = format!("Option>"); } } else if variable_type.ends_with('T') { - variable_type = format!("Py", variable_type.trim_end_matches('T')); + let inner_type = variable_type.trim_end_matches('T'); + + variable_type = if self.is_frozen { + format!("super::{inner_type}") + } else { + format!("Py") + }; } else if !Self::BASE_TYPES.contains(&variable_type.as_str()) { variable_type = format!("super::{variable_type}"); } - self.file_contents - .push(Cow::Owned(format!(" pub {variable_name}: {variable_type},"))); + write_fmt!(self, " pub {variable_name}: {variable_type},"); + } + + write_str!(self, "}"); + write_str!(self, ""); + + if self.is_all_base_types { + return; + } + + if self.is_frozen { + return; } - self.write_str("}"); - self.write_str(""); - self.write_string(format!("impl crate::PyDefault for {} {{", self.struct_name)); - self.write_str(" fn py_default(py: Python) -> Py {"); - self.write_str(" Py::new(py, Self {"); + write_fmt!(self, "impl crate::PyDefault for {} {{", self.struct_name); + write_str!(self, " fn py_default(py: Python) -> Py {"); + write_str!(self, " Py::new(py, Self {"); for variable_info in &self.types { let variable_name = &variable_info[0]; let variable_type = &variable_info[1]; - if variable_type.starts_with("Vec<") { - self.file_contents.push(Cow::Owned(if variable_type == "Vec" { - format!(" {variable_name}: PyBytes::new_bound(py, &[]).unbind(),") + let end = if variable_type.starts_with("Vec<") { + if variable_type == "Vec" { + Cow::Borrowed("PyBytes::new_bound(py, &[]).unbind()") } else { - format!(" {variable_name}: Vec::new(),") - })); + Cow::Borrowed("Vec::new()") + } } else if variable_type.starts_with("Option<") { - self.file_contents - .push(Cow::Owned(format!(" {variable_name}: None,"))); + Cow::Borrowed("None") } else if !Self::BASE_TYPES.contains(&variable_type.as_str()) && (variable_type.starts_with("Box<") || variable_type.ends_with('T')) { @@ -335,28 +432,31 @@ impl PythonBindGenerator { .trim_start_matches("Box<") .trim_end_matches('>') .trim_end_matches('T'); - - self.file_contents.push(Cow::Owned(format!( - " {variable_name}: Py::new(py, super::{inner_type}::py_default(py)).unwrap()," - ))); + Cow::Owned(format!("super::{inner_type}::py_default(py)")) } else { - self.file_contents - .push(Cow::Owned(format!(" {variable_name}: Default::default(),"))); - } + Cow::Borrowed("Default::default()") + }; + + write_fmt!(self, " {variable_name}: {end},"); } - self.write_str(" }).unwrap()"); - self.write_str(" }"); - self.write_str("}"); - self.write_str(""); + if self.is_frozen { + write_str!(self, " }"); + } else { + write_str!(self, " }).unwrap()"); + } + + write_str!(self, " }"); + write_str!(self, "}"); + write_str!(self, ""); } fn generate_enum_definition(&mut self) { - self.write_str("#[allow(non_camel_case_types)]"); - self.write_str("#[pyclass(module = \"rlbot_flatbuffers\", get_all, set_all)]"); - self.write_str("#[derive(Debug, Default, Clone, Copy)]"); - self.write_string(format!("pub enum {} {{", self.struct_name)); - self.write_str(" #[default]"); + write_str!(self, "#[allow(non_camel_case_types)]"); + write_str!(self, "#[pyclass(module = \"rlbot_flatbuffers\", frozen)]"); + write_str!(self, "#[derive(Debug, Default, Clone, Copy)]"); + write_fmt!(self, "pub enum {} {{", self.struct_name); + write_str!(self, " #[default]"); for variable_info in &self.types { let variable_name = &variable_info[0]; @@ -366,8 +466,8 @@ impl PythonBindGenerator { .push(Cow::Owned(format!(" {variable_name} = {variable_value},"))); } - self.write_str("}"); - self.write_str(""); + write_str!(self, "}"); + write_str!(self, ""); } fn generate_definition(&mut self) { @@ -379,24 +479,33 @@ impl PythonBindGenerator { } fn generate_union_definition(&mut self) { - self.write_str("#[derive(Debug, Clone, pyo3::FromPyObject)]"); - self.write_string(format!("pub enum {}Union {{", self.struct_name)); + write_str!(self, "#[derive(Debug, Clone, pyo3::FromPyObject)]"); + write_fmt!(self, "pub enum {}Union {{", self.struct_name); for variable_info in self.types.iter().skip(1) { let variable_name = &variable_info[0]; - self.file_contents - .push(Cow::Owned(format!(" {variable_name}(Py),"))); + write_fmt!(self, " {variable_name}(Py),"); + } + + write_str!(self, "}"); + write_str!(self, ""); + + if self.is_frozen { + write_str!(self, "#[pyclass(module = \"rlbot_flatbuffers\", frozen)]"); + } else { + write_str!(self, "#[pyclass(module = \"rlbot_flatbuffers\")]"); } - self.write_str("}"); - self.write_str(""); - self.write_str("#[pyclass(module = \"rlbot_flatbuffers\")]"); - self.write_str("#[derive(Debug, Default, Clone)]"); - self.write_string(format!("pub struct {} {{", self.struct_name)); - self.write_str(" #[pyo3(set)]"); - self.write_string(format!(" pub item: Option<{}Union>,", self.struct_name)); - self.write_str("}"); - self.write_str(""); + write_str!(self, "#[derive(Debug, Default, Clone)]"); + write_fmt!(self, "pub struct {} {{", self.struct_name); + + if !self.is_frozen { + write_str!(self, " #[pyo3(set)]"); + } + + write_fmt!(self, " pub item: Option<{}Union>,", self.struct_name); + write_str!(self, "}"); + write_str!(self, ""); } fn generate_from_flat_impls(&mut self) { @@ -408,186 +517,216 @@ impl PythonBindGenerator { } fn generate_union_from_flat_impls(&mut self) { - let from_impl_types = [ - format!("flat::{}", self.struct_t_name), - format!("Box", self.struct_t_name), - ]; + write_fmt!(self, "impl FromGil for {} {{", self.struct_t_name, self.struct_name); + write_fmt!( + self, + " fn from_gil(py: Python, flat_t: flat::{}) -> Self {{", + self.struct_t_name + ); - for impl_type in from_impl_types { - self.write_string(format!("impl FromGil<{impl_type}> for Py<{}> {{", self.struct_name)); - self.write_string(format!(" fn from_gil(py: Python, flat_t: {impl_type}) -> Self {{")); + write_str!(self, " match flat_t {"); - if impl_type.starts_with("Box<") { - self.write_str(" Self::from_gil(py, *flat_t)"); - } else { - self.write_str(" Py::new("); - self.write_str(" py,"); - self.write_str(" match flat_t {"); + for variable_info in &self.types { + let variable_name = &variable_info[0]; - for variable_info in &self.types { - let variable_name = &variable_info[0]; + if variable_name == "NONE" { + write_fmt!( + self, + " flat::{}::NONE => {}::default(),", + self.struct_t_name, + self.struct_name + ); + } else { + write_fmt!( + self, + " flat::{}::{variable_name}(item) => {} {{", + self.struct_t_name, + self.struct_name, + ); - if variable_name == "NONE" { - self.file_contents.push(Cow::Owned(format!( - " flat::{}::NONE => {}::default(),", - self.struct_t_name, self.struct_name - ))); - } else { - self.file_contents.push(Cow::Owned(format!( - " flat::{}::{variable_name}(item) => {} {{", - self.struct_t_name, self.struct_name, - ))); - self.file_contents.push(Cow::Owned(format!( - " item: Some({}Union::{variable_name}(item.into_gil(py)))", - self.struct_name, - ))); - self.file_contents.push(Cow::Borrowed(" },")); - } - } + write_fmt!(self, " item: Some({}Union::{variable_name}(", self.struct_name); + write_fmt!( + self, + " Py::new(py, super::{variable_name}::from_gil(py, *item)).unwrap()," + ); + write_fmt!(self, " )),"); - self.write_str(" },"); - self.write_str(" )"); - self.write_str(" .unwrap()"); + write_fmt!(self, " }},"); } - - self.write_str(" }"); - self.write_str("}"); - self.write_str(""); } + + write_str!(self, " }"); + write_str!(self, " }"); + write_str!(self, "}"); + write_str!(self, ""); } fn generate_enum_from_flat_impls(&mut self) { - self.write_string(format!("impl From for {} {{", self.struct_name, self.struct_name)); - self.write_string(format!(" fn from(flat_t: flat::{}) -> Self {{", self.struct_name)); - self.write_str(" match flat_t {"); + write_fmt!(self, "impl From for {} {{", self.struct_name, self.struct_name); + write_fmt!(self, " fn from(flat_t: flat::{}) -> Self {{", self.struct_name); + write_str!(self, " match flat_t {"); for variable_info in &self.types { let variable_name = &variable_info[0]; - self.file_contents.push(Cow::Owned(format!( + write_fmt!( + self, " flat::{}::{variable_name} => Self::{variable_name},", self.struct_name - ))); + ); } - self.write_str(" v => unreachable!(\"Unknown value: {v:?}\"),"); + write_str!(self, " v => unreachable!(\"Unknown value: {v:?}\"),"); - self.write_str(" }"); - self.write_str(" }"); - self.write_str("}"); - self.write_str(""); + write_str!(self, " }"); + write_str!(self, " }"); + write_str!(self, "}"); + write_str!(self, ""); } fn generate_struct_from_flat_impls(&mut self) { - let from_impl_types = [ - format!("flat::{}", self.struct_t_name), - format!("Box", self.struct_t_name), - ]; - - for impl_type in from_impl_types { - self.write_string(format!("impl FromGil<{impl_type}> for Py<{}> {{", self.struct_name)); - - if self.types.is_empty() { - self.write_string(format!(" fn from_gil(py: Python, _: {impl_type}) -> Self {{")); - self.write_string(format!(" Py::new(py, {}::default()).unwrap()", self.struct_name)); - self.write_str(" }"); - self.write_str("}"); - self.write_str(""); - continue; - } + let impl_type = format!("flat::{}", self.struct_t_name); - self.write_string(format!(" fn from_gil(py: Python, flat_t: {impl_type}) -> Self {{")); + if self.types.is_empty() { + write_fmt!(self, "impl From<{impl_type}> for {} {{", self.struct_name); + write_fmt!(self, " fn from(_: {impl_type}) -> Self {{"); + write_fmt!(self, " {} {{}}", self.struct_name); + write_str!(self, " }"); + write_str!(self, "}"); + write_str!(self, ""); + return; + } - if impl_type.starts_with("Box<") { - self.write_str(" Self::from_gil(py, *flat_t)"); + let (trait_name, fn_name, python_arg) = + if (self.is_frozen && !Self::FROZEN_NEEDS_PY.contains(&self.struct_name.as_str())) || self.is_all_base_types { + ("From", "from", "") } else { - self.write_string(format!(" Py::new(py, {} {{", self.struct_name)); + ("FromGil", "from_gil", "py: Python, ") + }; - for variable_info in &self.types { - let variable_name = &variable_info[0]; - let variable_type = variable_info[1].as_str(); + write_fmt!(self, "impl {trait_name}<{impl_type}> for {} {{", self.struct_name); + write_str!(self, " #[allow(unused_variables)]"); + write_fmt!(self, " fn {fn_name}({python_arg}flat_t: {impl_type}) -> Self {{"); + write_fmt!(self, " {} {{", self.struct_name); - if variable_type.starts_with("Vec<") { - self.file_contents - .push(Cow::Owned(if variable_type == "Vec" { - format!(" {variable_name}: PyBytes::new_bound(py, &flat_t.{variable_name}).unbind(),") - } else { - format!( - " {variable_name}: flat_t.{variable_name}.into_iter().map(|x| x.into_gil(py)).collect(),", - ) - })); - } else if variable_type.starts_with("Option<") { - self.file_contents.push(Cow::Owned( - if variable_type.trim_start_matches("Option<").trim_end_matches('>') == "String" { - format!(" {variable_name}: flat_t.{variable_name},") - } else { - format!(" {variable_name}: flat_t.{variable_name}.map(|x| x.into_gil(py)),") - }, - )); - } else if variable_type.starts_with("Box<") || variable_type.ends_with('T') { - self.file_contents.push(Cow::Owned(format!( - " {variable_name}: flat_t.{variable_name}.into_gil(py),", - ))); - } else if Self::BASE_TYPES.contains(&variable_type) { - self.file_contents - .push(Cow::Owned(format!(" {variable_name}: flat_t.{variable_name},"))); + for variable_info in &self.types { + let variable_name = &variable_info[0]; + let variable_type = variable_info[1].as_str(); + + if variable_type.starts_with("Vec<") { + if variable_type == "Vec" { + write_fmt!( + self, + " {variable_name}: PyBytes::new_bound(py, &flat_t.{variable_name}).unbind()," + ) + } else if self.is_frozen { + let inner_type = variable_type + .trim_start_matches("Vec<") + .trim_end_matches('>') + .trim_end_matches('T'); + let map_out = if Self::FROZEN_NEEDS_PY.contains(&inner_type) { + "|x| x.into_gil(py)" + } else { + "Into::into" + }; + + write_fmt!( + self, + " {variable_name}: flat_t.{variable_name}.into_iter().map({map_out}).collect()," + ) + } else { + write_fmt!( + self, + " {variable_name}: flat_t.{variable_name}.into_iter().map(|x| crate::into_py_from(py, x)).collect(),", + ) + }; + } else if variable_type.starts_with("Option<") { + let inner_type = variable_type.trim_start_matches("Option<").trim_end_matches('>'); + let end = if inner_type == "String" { + "," + } else if inner_type.starts_with("Box<") { + ".map(|x| crate::into_py_from(py, *x))," + } else { + ".map(|x| crate::into_py_from(py, x))," + }; + + write_fmt!(self, " {variable_name}: flat_t.{variable_name}{end}"); + } else if variable_type.starts_with("Box<") { + let end = if self.is_frozen { + let inner_type = variable_type + .trim_start_matches("Box<") + .trim_end_matches('>') + .trim_end_matches('T'); + if Self::FROZEN_NEEDS_PY.contains(&inner_type) { + format!("(*flat_t.{variable_name}).into_gil(py)",) } else { - self.file_contents.push(Cow::Owned(format!( - " {variable_name}: flat_t.{variable_name}.into(),", - ))); + format!("(*flat_t.{variable_name}).into()",) } - } + } else { + format!("crate::into_py_from(py, *flat_t.{variable_name})") + }; + write_fmt!(self, " {variable_name}: {end},",); + } else if variable_type.ends_with('T') { + let inner_type = variable_type.trim_end_matches('T'); + let end = if self.is_frozen { + if Self::FROZEN_NEEDS_PY.contains(&inner_type) { + format!("flat_t.{variable_name}.into_gil(py)") + } else { + format!("flat_t.{variable_name}.into()") + } + } else { + format!("crate::into_py_from(py, flat_t.{variable_name})") + }; - self.write_str(" }).unwrap()"); + write_fmt!(self, " {variable_name}: {end},",); + } else if Self::BASE_TYPES.contains(&variable_type) { + write_fmt!(self, " {variable_name}: flat_t.{variable_name},"); + } else { + write_fmt!(self, " {variable_name}: flat_t.{variable_name}.into(),",); } - - self.write_str(" }"); - self.write_str("}"); - self.write_str(""); } + + write_str!(self, " }"); + write_str!(self, " }"); + write_str!(self, "}"); + write_str!(self, ""); } fn generate_union_to_flat_impls(&mut self) { - let from_impl_types = [ - format!("flat::{}", self.struct_t_name), - format!("Box", self.struct_t_name), - ]; - - for impl_type in from_impl_types { - self.write_string(format!("impl FromGil<&Py<{}>> for {impl_type} {{", self.struct_name)); - self.write_string(format!( - " fn from_gil(py: Python, py_type: &Py<{}>) -> Self {{", - self.struct_name - )); - - if impl_type.contains("Box<") { - self.write_string(format!( - " Self::new(flat::{}::from_gil(py, py_type))", - self.struct_t_name - )); - } else { - self.write_str(" match py_type.borrow(py).item.as_ref() {"); - for variable_info in &self.types { - let variable_name = &variable_info[0]; - let variable_value = &variable_info[1]; + write_fmt!( + self, + "impl FromGil<&{}> for flat::{} {{", + self.struct_name, + self.struct_t_name + ); + write_fmt!(self, " fn from_gil(py: Python, py_type: &{}) -> Self {{", self.struct_name); + write_str!(self, " match py_type.item.as_ref() {"); - if variable_value.is_empty() { - self.file_contents.push(Cow::Borrowed(" None => Self::NONE,")); - } else { - self.file_contents.push(Cow::Owned(format!( - " Some({}Union::{variable_value}(item)) => flat::{}::{variable_name}(item.into_gil(py)),", - self.struct_name, self.struct_t_name - ))); - } - } + for variable_info in &self.types { + let variable_name = &variable_info[0]; + let variable_value = &variable_info[1]; - self.write_str(" }"); + if variable_value.is_empty() { + write_str!(self, " None => Self::NONE,"); + } else { + write_fmt!( + self, + " Some({}Union::{variable_value}(item)) => {{", + self.struct_name, + ); + write_fmt!( + self, + " flat::{}::{variable_name}(Box::new(crate::from_py_into(py, item)))", + self.struct_t_name + ); + write_str!(self, " },"); } - - self.write_str(" }"); - self.write_str("}"); - self.write_str(""); } + + write_str!(self, " }"); + write_str!(self, " }"); + write_str!(self, "}"); + write_str!(self, ""); } fn generate_to_flat_impls(&mut self) { @@ -599,119 +738,139 @@ impl PythonBindGenerator { } fn generate_enum_to_flat_impls(&mut self) { - self.write_string(format!("impl From<&{}> for flat::{} {{", self.struct_name, self.struct_name)); - self.write_string(format!(" fn from(py_type: &{}) -> Self {{", self.struct_name)); - self.write_str(" match *py_type {"); + write_fmt!(self, "impl From<&{}> for flat::{} {{", self.struct_name, self.struct_name); + write_fmt!(self, " fn from(py_type: &{}) -> Self {{", self.struct_name); + write_str!(self, " match *py_type {"); for variable_info in &self.types { let variable_name = &variable_info[0]; - self.file_contents.push(Cow::Owned(format!( + write_fmt!( + self, " {}::{variable_name} => Self::{variable_name},", self.struct_name - ))); + ); } - self.write_str(" }"); - self.write_str(" }"); - self.write_str("}"); + write_str!(self, " }"); + write_str!(self, " }"); + write_str!(self, "}"); - self.write_str(""); + write_str!(self, ""); } fn generate_struct_to_flat_impls(&mut self) { - let from_impl_types = [ - format!("flat::{}", self.struct_t_name), - format!("Box", self.struct_t_name), - ]; - - self.write_string(format!( - "impl FromGil<&{}> for flat::{} {{", - self.struct_name, self.struct_t_name - )); + let impl_type = format!("flat::{}", self.struct_t_name); if self.types.is_empty() { - self.write_string(format!(" fn from_gil(_: Python, _: &{}) -> Self {{", self.struct_name)); - self.write_str(" Self::default()"); + write_fmt!(self, "impl From<&{}> for {impl_type} {{", self.struct_name); + write_fmt!(self, " fn from(_: &{}) -> Self {{", self.struct_name); + write_str!(self, " Self {}"); + write_str!(self, " }"); + write_str!(self, "}"); + write_str!(self, ""); + return; + } + + let (trait_name, fn_name, python_arg) = if (self.is_frozen && !self.frozen_needs_py) || self.is_all_base_types { + ("From", "from", "") } else { - self.write_str(" #[allow(unused_variables)]"); - self.write_string(format!( - " fn from_gil(py: Python, py_type: &{}) -> Self {{", - self.struct_name - )); - self.write_str(" Self {"); + ("FromGil", "from_gil", "py: Python, ") + }; - for variable_info in &self.types { - let variable_name = &variable_info[0]; - let variable_type = variable_info[1].as_str(); + write_fmt!(self, "impl {trait_name}<&{}> for {impl_type} {{", self.struct_name); + write_str!(self, " #[allow(unused_variables)]"); + write_fmt!( + self, + " fn {fn_name}({python_arg}py_type: &{}) -> Self {{", + self.struct_name + ); + write_str!(self, " Self {"); - if variable_type.starts_with("Vec<") { - self.file_contents.push(Cow::Owned(if variable_type == "Vec" { - format!(" {variable_name}: py_type.{variable_name}.as_bytes(py).to_vec(),") + for variable_info in &self.types { + let variable_name = &variable_info[0]; + let variable_type = variable_info[1].as_str(); + + if variable_type.starts_with("Vec<") { + if variable_type == "Vec" { + write_fmt!( + self, + " {variable_name}: py_type.{variable_name}.as_bytes(py).to_vec()," + ) + } else if self.is_frozen { + let inner_type = variable_type + .trim_start_matches("Vec<") + .trim_end_matches('>') + .trim_end_matches('T'); + let map_out = if Self::FROZEN_NEEDS_PY.contains(&inner_type) { + "|x| x.into_gil(py)" } else { - format!( - " {variable_name}: py_type.{variable_name}.iter().map(|x| x.into_gil(py)).collect(),", - ) - })); - } else if variable_type.starts_with("Option<") { - self.file_contents.push(Cow::Owned( - if variable_type.trim_start_matches("Option<").trim_end_matches('>') == "String" { - format!(" {variable_name}: py_type.{variable_name}.clone(),") - } else { - format!(" {variable_name}: py_type.{variable_name}.as_ref().map(|x| x.into_gil(py)),") - }, - )); - } else if variable_type == "String" { - self.file_contents.push(Cow::Owned(format!( - " {variable_name}: py_type.{variable_name}.clone(),", - ))); - } else if variable_type.ends_with('T') || variable_type.starts_with("Box<") { - self.file_contents.push(Cow::Owned(format!( - " {variable_name}: (&py_type.{variable_name}).into_gil(py),", - ))); - } else if Self::BASE_TYPES.contains(&variable_type) { - self.file_contents - .push(Cow::Owned(format!(" {variable_name}: py_type.{variable_name},"))); - } else { - self.file_contents.push(Cow::Owned(format!( - " {variable_name}: (&py_type.{variable_name}).into(),", - ))); - } - } + "Into::into" + }; - self.write_str(" }"); - } + write_fmt!( + self, + " {variable_name}: py_type.{variable_name}.iter().map({map_out}).collect(),", + ) + } else { + write_fmt!( + self, + " {variable_name}: py_type.{variable_name}.iter().map(|x| crate::from_py_into(py, x)).collect(),", + ) + }; + } else if variable_type.starts_with("Option<") { + let inner = variable_type.trim_start_matches("Option<").trim_end_matches('>'); + let end = if inner == "String" { + ".clone()" + } else if inner.starts_with("Box<") { + ".as_ref().map(|x| Box::new(crate::from_py_into(py, x)))" + } else { + ".as_ref().map(|x| crate::from_py_into(py, x))" + }; - self.write_str(" }"); - self.write_str("}"); - self.write_str(""); + write_fmt!(self, " {variable_name}: py_type.{variable_name}{end},"); + } else if variable_type == "String" { + write_fmt!(self, " {variable_name}: py_type.{variable_name}.clone(),",); + } else if variable_type.starts_with("Box<") { + let inner_type = variable_type + .trim_start_matches("Box<") + .trim_end_matches('>') + .trim_end_matches('T'); + let var_name = if self.is_frozen { + if Self::FROZEN_NEEDS_PY.contains(&inner_type) { + format!("(&py_type.{variable_name}).into_gil(py)") + } else { + format!("(&py_type.{variable_name}).into()") + } + } else { + format!("crate::from_py_into(py, &py_type.{variable_name})") + }; - for impl_type in from_impl_types { - self.write_string(format!("impl FromGil<&Py<{}>> for {impl_type} {{", self.struct_name)); + write_fmt!(self, " {variable_name}: Box::new({var_name}),",); + } else if variable_type.ends_with('T') { + let inner_type = variable_type.trim_end_matches('T'); + let end = if self.is_frozen { + if Self::FROZEN_NEEDS_PY.contains(&inner_type) { + format!("(&py_type.{variable_name}).into_gil(py)") + } else { + format!("(&py_type.{variable_name}).into()") + } + } else { + format!("crate::from_py_into(py, &py_type.{variable_name})") + }; - if self.types.is_empty() { - self.write_string(format!(" fn from_gil(_: Python, _: &Py<{}>) -> Self {{", self.struct_name)); - self.write_str(" Self::default()"); + write_fmt!(self, " {variable_name}: {end},",); + } else if Self::BASE_TYPES.contains(&variable_type) { + write_fmt!(self, " {variable_name}: py_type.{variable_name},"); } else { - self.write_string(format!( - " fn from_gil(py: Python, py_type: &Py<{}>) -> Self {{", - self.struct_name - )); - - if impl_type.contains("Box<") { - self.write_string(format!( - " Self::new(flat::{}::from_gil(py, py_type))", - self.struct_t_name - )); - } else { - self.write_str(" Self::from_gil(py, &*py_type.borrow(py))"); - } + write_fmt!(self, " {variable_name}: (&py_type.{variable_name}).into(),",); } - - self.write_str(" }"); - self.write_str("}"); - self.write_str(""); } + + write_str!(self, " }"); + write_str!(self, " }"); + write_str!(self, "}"); + write_str!(self, ""); } fn generate_new_method(&mut self) { @@ -725,66 +884,68 @@ impl PythonBindGenerator { fn generate_union_new_method(&mut self) { assert!(u8::try_from(self.types.len()).is_ok()); - self.write_str(" #[new]"); - self.write_str(" #[pyo3(signature = (item = None))]"); - self.write_string(format!(" pub fn new(item: Option<{}Union>) -> Self {{", self.struct_name)); - self.write_str(" Self { item }"); - self.write_str(" }"); - self.write_str(""); - self.write_str(" #[getter(item)]"); - self.write_str(" pub fn get(&self, py: Python) -> Option {"); - self.write_str(" match self.item.as_ref() {"); + write_str!(self, " #[new]"); + write_str!(self, " #[pyo3(signature = (item = None))]"); + write_fmt!(self, " pub fn new(item: Option<{}Union>) -> Self {{", self.struct_name); + write_str!(self, " Self { item }"); + write_str!(self, " }"); + write_str!(self, ""); + write_str!(self, " #[getter(item)]"); + write_str!(self, " pub fn get(&self, py: Python) -> Option {"); + write_str!(self, " match self.item.as_ref() {"); for variable_info in &self.types { let variable_name = &variable_info[0]; if variable_name == "NONE" { - self.file_contents.push(Cow::Borrowed(" None => None,")); + write_str!(self, " None => None,"); } else { - self.file_contents.push(Cow::Owned(format!( + write_fmt!( + self, " Some({}Union::{variable_name}(item)) => Some(item.to_object(py)),", self.struct_name - ))); + ); } } - self.write_str(" }"); - self.write_str(" }"); + write_str!(self, " }"); + write_str!(self, " }"); } fn generate_enum_new_method(&mut self) { - self.write_str(" #[new]"); + write_str!(self, " #[new]"); assert!(u8::try_from(self.types.len()).is_ok()); - self.write_str(" #[pyo3(signature = (value=Default::default()))]"); - self.write_str(" pub fn new(value: u8) -> PyResult {"); - self.write_str(" match value {"); + write_str!(self, " #[pyo3(signature = (value=Default::default()))]"); + write_str!(self, " pub fn new(value: u8) -> PyResult {"); + write_str!(self, " match value {"); for variable_info in &self.types { let variable_name = &variable_info[0]; let variable_value = &variable_info[1]; - self.file_contents.push(Cow::Owned(format!( - " {variable_value} => Ok(Self::{variable_name})," - ))); + write_fmt!(self, " {variable_value} => Ok(Self::{variable_name}),"); } if self.types.len() != usize::from(u8::MAX) { - self.write_str(" v => Err(PyValueError::new_err(format!(\"Unknown value of {v}\"))),"); + write_str!( + self, + " v => Err(PyValueError::new_err(format!(\"Unknown value of {v}\")))," + ); } - self.write_str(" }"); - self.write_str(" }"); + write_str!(self, " }"); + write_str!(self, " }"); } fn generate_struct_new_method(&mut self) { - self.write_str(" #[new]"); - self.write_str(" #[allow(clippy::too_many_arguments)]"); + write_str!(self, " #[new]"); + write_str!(self, " #[allow(clippy::too_many_arguments)]"); if self.types.is_empty() { - self.write_str(" pub fn new() -> Self {"); - self.write_str(" Self::default()"); - self.write_str(" }"); + write_str!(self, " pub fn new() -> Self {"); + write_str!(self, " Self::default()"); + write_str!(self, " }"); return; } @@ -810,7 +971,8 @@ impl PythonBindGenerator { } signature_parts.push(format!("{variable_name}=None")); - } else if !Self::BASE_TYPES.contains(&variable_type.as_str()) + } else if !self.is_frozen + && !Self::BASE_TYPES.contains(&variable_type.as_str()) && (variable_type.starts_with("Box<") || variable_type.ends_with('T')) { signature_parts.push(format!("{variable_name}=crate::get_py_default()")); @@ -821,11 +983,11 @@ impl PythonBindGenerator { } } - self.write_string(format!(" #[pyo3(signature = ({}))]", signature_parts.join(", "))); - self.write_str(" pub fn new("); + write_fmt!(self, " #[pyo3(signature = ({}))]", signature_parts.join(", ")); + write_str!(self, " pub fn new("); if needs_python { - self.write_str(" py: Python,"); + write_str!(self, " py: Python,"); } for variable_info in &self.types { @@ -833,20 +995,26 @@ impl PythonBindGenerator { let mut variable_type = variable_info[1].to_string(); if variable_type.starts_with("Vec<") && variable_type.ends_with("T>") { - variable_type = format!( - "Vec>", - variable_type.trim_start_matches("Vec<").trim_end_matches("T>") - ); + let inner_type = variable_type.trim_start_matches("Vec<").trim_end_matches("T>"); + + variable_type = if self.is_frozen { + format!("Vec",) + } else { + format!("Vec>",) + }; } else if variable_type == "Vec" { variable_type = String::from("Py"); } else if variable_type.starts_with("Box<") && variable_type.ends_with('>') { - variable_type = format!( - "Py", - variable_type - .trim_start_matches("Box<") - .trim_end_matches('>') - .trim_end_matches('T') - ); + let inner_type = variable_type + .trim_start_matches("Box<") + .trim_end_matches('>') + .trim_end_matches('T'); + + variable_type = if self.is_frozen { + format!("super::{inner_type}") + } else { + format!("Py") + }; } else if variable_type.starts_with("Option<") && variable_type.ends_with('>') { let inner_type = variable_type .trim_start_matches("Option<") @@ -854,29 +1022,30 @@ impl PythonBindGenerator { .trim_end_matches('>') .trim_end_matches('T'); - if Self::BASE_TYPES.contains(&inner_type) { - variable_type = format!("Option<{inner_type}>"); + variable_type = if Self::BASE_TYPES.contains(&inner_type) { + format!("Option<{inner_type}>") } else if inner_type == "Float" { - variable_type = String::from("Option"); + String::from("Option") } else if inner_type == "Bool" { - variable_type = String::from("Option"); + String::from("Option") } else { - variable_type = format!("Option>"); - } + format!("Option>") + }; } else if !Self::BASE_TYPES.contains(&variable_type.as_str()) { - if variable_type.ends_with('T') { - variable_type = format!("Py", variable_type.trim_end_matches('T')); + let inner_type = variable_type.trim_end_matches('T'); + + variable_type = if variable_type.ends_with('T') && !self.is_frozen { + format!("Py") } else { - variable_type = format!("super::{variable_type}"); + format!("super::{inner_type}") } } - self.file_contents - .push(Cow::Owned(format!(" {variable_name}: {variable_type},"))); + write_fmt!(self, " {variable_name}: {variable_type},"); } - self.write_str(" ) -> Self {"); - self.write_str(" Self {"); + write_str!(self, " ) -> Self {"); + write_str!(self, " Self {"); for variable_info in &self.types { let variable_name = &variable_info[0]; @@ -886,22 +1055,20 @@ impl PythonBindGenerator { .trim_end_matches('>'); if Self::SPECIAL_BASE_TYPES.contains(&variable_type) { - self.file_contents.push(Cow::Owned(format!( - " {variable_name}: {variable_name}.map(|x| x.into_gil(py))," - ))); + write_fmt!(self, " {variable_name}: {variable_name}.map(|x| x.into_gil(py)),"); } else { - self.file_contents.push(Cow::Owned(format!(" {variable_name},"))); + write_fmt!(self, " {variable_name},"); } } - self.write_str(" }"); - self.write_str(" }"); + write_str!(self, " }"); + write_str!(self, " }"); } fn generate_str_method(&mut self) { - self.write_str(" pub fn __str__(&self) -> String {"); - self.write_str(" format!(\"{self:?}\")"); - self.write_str(" }"); + write_str!(self, " pub fn __str__(&self) -> String {"); + write_str!(self, " format!(\"{self:?}\")"); + write_str!(self, " }"); } fn generate_repr_method(&mut self) { @@ -913,47 +1080,44 @@ impl PythonBindGenerator { } fn generate_union_repr_method(&mut self) { - self.write_str(" pub fn __repr__(&self, py: Python) -> String {"); - self.write_str(" match self.item.as_ref() {"); + write_str!(self, " pub fn __repr__(&self, py: Python) -> String {"); + write_str!(self, " match self.item.as_ref() {"); for variable_info in &self.types { let variable_name = &variable_info[0]; let variable_type = &variable_info[1]; if variable_type.is_empty() { - self.file_contents.push(Cow::Owned(format!( - " None => String::from(\"{}()\"),", - self.struct_name - ))); + write_fmt!(self, " None => String::from(\"{}()\"),", self.struct_name); } else { - self.file_contents.push(Cow::Owned(format!( + write_fmt!(self, " Some({}Union::{variable_name}(item)) => format!(\"{}({{}})\", item.borrow(py).__repr__(py)),", self.struct_name, self.struct_name - ))); + ); } } - self.write_str(" }"); - self.write_str(" }"); + write_str!(self, " }"); + write_str!(self, " }"); } fn generate_enum_repr_method(&mut self) { - self.write_str(" pub fn __repr__(&self) -> String {"); - self.write_string(format!(" format!(\"{}(value={{}})\", *self as u8)", self.struct_name)); - self.write_str(" }"); + write_str!(self, " pub fn __repr__(&self) -> String {"); + write_fmt!(self, " format!(\"{}(value={{}})\", *self as u8)", self.struct_name); + write_str!(self, " }"); } fn generate_struct_repr_method(&mut self) { if self.types.is_empty() { - self.write_str(" pub fn __repr__(&self, _py: Python) -> String {"); - self.write_string(format!(" String::from(\"{}()\")", self.struct_name)); - self.write_str(" }"); + write_str!(self, " pub fn __repr__(&self, _py: Python) -> String {"); + write_fmt!(self, " String::from(\"{}()\")", self.struct_name); + write_str!(self, " }"); return; } - self.write_str(" #[allow(unused_variables)]"); - self.write_str(" pub fn __repr__(&self, py: Python) -> String {"); - self.write_str(" format!("); + write_str!(self, " #[allow(unused_variables)]"); + write_str!(self, " pub fn __repr__(&self, py: Python) -> String {"); + write_str!(self, " format!("); let repr_signature = self .types @@ -974,63 +1138,67 @@ impl PythonBindGenerator { }) .collect::>() .join(", "); - self.write_string(format!(" \"{}({repr_signature})\",", self.struct_name)); + write_fmt!(self, " \"{}({repr_signature})\",", self.struct_name); for variable_info in &self.types { let variable_name = &variable_info[0]; let variable_type = variable_info[1].as_str(); if variable_type == "bool" { - self.file_contents - .push(Cow::Owned(format!(" crate::bool_to_str(self.{variable_name}),"))); + write_fmt!(self, " crate::bool_to_str(self.{variable_name}),"); } else if Self::BASE_TYPES.contains(&variable_type) { - self.file_contents - .push(Cow::Owned(format!(" self.{variable_name},"))); + write_fmt!(self, " self.{variable_name},"); } else if variable_type.starts_with("Option<") { - self.file_contents - .push(Cow::Owned(format!(" self.{variable_name}"))); - self.file_contents.push(Cow::Borrowed(" .as_ref()")); + write_fmt!(self, " self.{variable_name}"); + write_str!(self, " .as_ref()"); + if Self::BASE_TYPES.into_iter().any(|t| variable_type.contains(t)) { - self.file_contents - .push(Cow::Borrowed(" .map(|i| format!(\"{i:?}\"))")); + write_str!(self, " .map(|i| format!(\"{i:?}\"))"); } else { - self.file_contents - .push(Cow::Borrowed(" .map(|x| x.borrow(py).__repr__(py))")); + write_str!(self, " .map(|x| x.borrow(py).__repr__(py))"); } - self.file_contents - .push(Cow::Borrowed(" .unwrap_or_else(crate::none_str),")); + + write_str!(self, " .unwrap_or_else(crate::none_str),"); } else if variable_type.starts_with("Vec<") { - self.file_contents - .push(Cow::Owned(format!(" self.{variable_name}"))); + write_fmt!(self, " self.{variable_name}"); + if Self::BASE_TYPES.into_iter().any(|t| variable_type.contains(t)) { - self.file_contents.push(Cow::Borrowed(" .as_bytes(py)")); - self.file_contents.push(Cow::Borrowed(" .iter()")); - self.file_contents - .push(Cow::Borrowed(" .map(ToString::to_string)")); + write_str!(self, " .as_bytes(py)"); + write_str!(self, " .iter()"); + write_str!(self, " .map(ToString::to_string)"); } else { - self.file_contents.push(Cow::Borrowed(" .iter()")); - self.file_contents - .push(Cow::Borrowed(" .map(|x| x.borrow(py).__repr__(py))")); + write_str!(self, " .iter()"); + write_str!( + self, + if self.is_frozen { + " .map(|x| x.__repr__(py))" + } else { + " .map(|x| x.borrow(py).__repr__(py))" + } + ); } - self.file_contents - .push(Cow::Borrowed(" .collect::>()")); - self.file_contents.push(Cow::Borrowed(" .join(\", \"),")); + + write_str!(self, " .collect::>()"); + write_str!(self, " .join(\", \"),"); } else if variable_type.ends_with('T') || variable_type.starts_with("Box<") { - self.file_contents.push(Cow::Owned(format!( - " self.{variable_name}.borrow(py).__repr__(py)," - ))); + let repr_str = if self.is_frozen { + ".__repr__(py)" + } else { + ".borrow(py).__repr__(py)" + }; + + write_fmt!(self, " self.{variable_name}{repr_str},"); } else { - self.file_contents - .push(Cow::Owned(format!(" self.{variable_name}.__repr__(),"))); + write_fmt!(self, " self.{variable_name}.__repr__(),"); } } - self.write_str(" )"); - self.write_str(" }"); + write_str!(self, " )"); + write_str!(self, " }"); } fn generate_pack_method(&mut self) { - self.write_str(" fn pack<'py>(&self, py: Python<'py>) -> Bound<'py, PyBytes> {"); + write_str!(self, " fn pack<'py>(&self, py: Python<'py>) -> Bound<'py, PyBytes> {"); let name = if self.bind_type == PythonBindType::Enum { &self.struct_name @@ -1039,80 +1207,81 @@ impl PythonBindGenerator { }; if self.bind_type == PythonBindType::Enum { - self.write_string(format!(" let flat_t = flat::{name}::from(self);")); + write_fmt!(self, " let flat_t = flat::{name}::from(self);"); } else { - self.write_string(format!(" let flat_t = flat::{name}::from_gil(py, self);")); + write_fmt!(self, " let flat_t = flat::{name}::from_gil(py, self);"); } if self.has_complex_pack { - self.write_str(" let size = flat_t.get_size().next_power_of_two();"); - self.write_str(""); - self.write_str(" let mut builder = FlatBufferBuilder::with_capacity(size);"); - self.write_str(" let offset = flat_t.pack(&mut builder);"); - self.write_str(" builder.finish(offset, None);"); - self.write_str(""); - self.write_str(" PyBytes::new_bound(py, builder.finished_data())"); + write_str!(self, " let size = flat_t.get_size().next_power_of_two();"); + write_str!(self, ""); + write_str!(self, " let mut builder = FlatBufferBuilder::with_capacity(size);"); + write_str!(self, " let offset = flat_t.pack(&mut builder);"); + write_str!(self, " builder.finish(offset, None);"); + write_str!(self, ""); + write_str!(self, " PyBytes::new_bound(py, builder.finished_data())"); } else if self.bind_type == PythonBindType::Enum { - self.write_str(" PyBytes::new_bound(py, &[flat_t.0])"); + write_str!(self, " PyBytes::new_bound(py, &[flat_t.0])"); } else { - self.write_str(" let item = flat_t.pack();"); - self.write_str(""); - self.write_str(" PyBytes::new_bound(py, &item.0)"); + write_str!(self, " let item = flat_t.pack();"); + write_str!(self, ""); + write_str!(self, " PyBytes::new_bound(py, &item.0)"); } - self.write_str(" }"); + write_str!(self, " }"); } fn generate_unpack_method(&mut self) { - self.write_str(" #[staticmethod]"); - - if self.bind_type == PythonBindType::Enum { - self.write_str(" fn unpack(data: &[u8]) -> PyResult {"); - self.write_string(format!(" match root::(data) {{", self.struct_name)); - self.write_str(" Ok(flat_t) => Ok(flat_t.into()),"); - self.write_str(" Err(e) => Err(flat_err_to_py(e)),"); - self.write_str(" }"); + write_str!(self, " #[staticmethod]"); + + let (py_arg, return_val, out_map) = if self.bind_type == PythonBindType::Enum { + ("", "Self", "flat_t.into()") + } else if Self::FROZEN_NEEDS_PY.contains(&self.struct_name.as_str()) { + ("py: Python, ", "Self", "flat_t.unpack().into_gil(py)") + } else if self.is_frozen { + ("", "Self", "flat_t.unpack().into()") } else { - self.write_str(" fn unpack(py: Python, data: &[u8]) -> PyResult> {"); - self.write_string(format!(" match root::(data) {{", self.struct_name)); - self.write_str(" Ok(flat_t) => Ok(flat_t.unpack().into_gil(py)),"); - self.write_str(" Err(e) => Err(flat_err_to_py(e)),"); - self.write_str(" }"); - } + ("py: Python, ", "Py", "crate::into_py_from(py, flat_t.unpack())") + }; - self.write_str(" }"); + write_fmt!(self, " fn unpack({py_arg}data: &[u8]) -> PyResult<{return_val}> {{"); + write_fmt!(self, " match root::(data) {{", self.struct_name); + write_fmt!(self, " Ok(flat_t) => Ok({out_map}),"); + write_str!(self, " Err(e) => Err(flat_err_to_py(e)),"); + write_str!(self, " }"); + write_str!(self, " }"); } fn generate_enum_hash_method(&mut self) { - self.write_str(" pub fn __hash__(&self) -> u64 {"); - self.write_str(" crate::hash_u8(*self as u8)"); - self.write_str(" }"); + write_str!(self, " pub fn __hash__(&self) -> u64 {"); + write_str!(self, " crate::hash_u8(*self as u8)"); + write_str!(self, " }"); } fn generate_py_methods(&mut self) { - self.write_str("#[pymethods]"); - self.write_string(format!("impl {} {{", self.struct_name)); + write_str!(self, "#[pymethods]"); + write_fmt!(self, "impl {} {{", self.struct_name); self.generate_new_method(); - self.write_str(""); + write_str!(self, ""); self.generate_str_method(); - self.write_str(""); + write_str!(self, ""); self.generate_repr_method(); if self.bind_type == PythonBindType::Enum { - self.write_str(""); + write_str!(self, ""); self.generate_enum_hash_method(); } if self.bind_type != PythonBindType::Union { - self.write_str(""); + write_str!(self, ""); self.generate_pack_method(); - self.write_str(""); + write_str!(self, ""); self.generate_unpack_method(); } - self.write_str("}"); - self.write_str(""); + write_str!(self, "}"); + write_str!(self, ""); } fn finish(self) -> io::Result<(String, String, Vec>)> { @@ -1275,7 +1444,7 @@ fn pyi_generator(type_data: &[(String, String, Vec>)]) -> io::Result .trim_end_matches('>') .trim_end_matches('T'); - let python_type = if type_name == "bool" { + let mut python_type = if type_name == "bool" { "bool" } else if type_name == "i32" || type_name == "u32" { "int" @@ -1283,16 +1452,19 @@ fn pyi_generator(type_data: &[(String, String, Vec>)]) -> io::Result "float" } else if type_name == "String" { "str" - } else if type_name == "Float" { - "Float | float" - } else if type_name == "Bool" { - "Bool | bool" } else { type_name }; - python_types.push(format!("Optional[{python_type}]")); file_contents.push(Cow::Owned(format!(" {variable_name}: Optional[{python_type}]"))); + + if type_name == "Float" { + python_type = "Float | float"; + } else if type_name == "Bool" { + python_type = "Bool | bool"; + } + + python_types.push(format!("Optional[{python_type}]")); } else if variable_type.starts_with("Box<") && variable_type.ends_with("T>") { let type_name = variable_type.trim_start_matches("Box<").trim_end_matches("T>"); diff --git a/pybench.py b/pybench.py new file mode 100644 index 0000000..2e14cdf --- /dev/null +++ b/pybench.py @@ -0,0 +1,52 @@ +from time import time_ns + +import rlbot_flatbuffers as flat + + +def test_gtp(): + times = [] + + gtp = flat.GameTickPacket( + ball=flat.BallInfo( + shape=flat.CollisionShape(flat.SphereShape()), + ), + players=[flat.PlayerInfo() for _ in range(128)], + boost_pad_states=[flat.BoostPadState() for _ in range(128)], + teams=[flat.TeamInfo() for _ in range(2)], + ) + + for _ in range(20_000): + start = time_ns() + + packed = gtp.pack() + flat.GameTickPacket.unpack(packed) + + times.append(time_ns() - start) + + print(f"Total time: {sum(times) / 1_000_000_000:.3f}s") + avg_time_ns = sum(times) / len(times) + print(f"Average time per: {avg_time_ns / 1000:.1f}us") + + +def test_ballpred(): + times = [] + + ballPred = flat.BallPrediction([flat.PredictionSlice() for _ in range(720)]) + + for _ in range(10_000): + start = time_ns() + + ballpred_bytes = ballPred.pack() + flat.BallPrediction.unpack(ballpred_bytes) + + times.append(time_ns() - start) + + print(f"Total time: {sum(times) / 1_000_000_000:.3f}s") + avg_time_ns = sum(times) / len(times) + print(f"Average time per: {avg_time_ns / 1000:.1f}us") + + +if __name__ == "__main__": + test_gtp() + print() + test_ballpred() diff --git a/pytest.py b/pytest.py index 04f5846..6429877 100644 --- a/pytest.py +++ b/pytest.py @@ -62,7 +62,7 @@ def random_script_config(): print(repr(RenderType())) - render_type = RenderType(Line3D(Vector3(0, 0, 0), Vector3(1, 1, 1), Color(255))) + render_type = RenderType(Line3D(MyVector(0, 0, 0), Vector3(1, 1, 1), Color(255))) if isinstance(render_type.item, Line3D): render_type.item.color.a = 150 else: @@ -119,7 +119,7 @@ def random_script_config(): print("Running quick benchmark...") - num_trials = 100_000 + num_trials = 60_000 total_make_time = 0 total_pack_time = 0 @@ -145,7 +145,7 @@ def random_script_config(): ), 100, ) - for _ in range(8) + for _ in range(16) ], game_info_state=DesiredGameInfoState( game_speed=1, world_gravity_z=-650, end_match=True @@ -162,8 +162,8 @@ def random_script_config(): DesiredGameState.unpack(packed_bytes) total_unpack_time += time_ns() - start - print(f"Average time to make: {round(total_make_time / num_trials, 2)}ns") - print(f"Average time to pack: {round(total_pack_time / num_trials, 2)}ns") - print(f"Average time to unpack: {round(total_unpack_time / num_trials, 2)}ns") + print(f"Average time to make: {total_make_time / num_trials / 1_000:.2f}us") + print(f"Average time to pack: {total_pack_time / num_trials / 1_000:.2f}us") + print(f"Average time to unpack: {total_unpack_time / num_trials / 1_000:.2f}us") - print(f"Total time: {round((total_pack_time + total_unpack_time) / 1000000, 2)}ms") + print(f"Total time: {(total_pack_time + total_unpack_time) / 1_000_000_000:.3f}s") diff --git a/src/lib.rs b/src/lib.rs index 99de163..0899e8b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,7 +9,7 @@ )] pub mod generated; -#[allow(clippy::enum_variant_names)] +#[allow(clippy::enum_variant_names, unused_imports)] mod python; use pyo3::{create_exception, exceptions::PyValueError, prelude::*, types::PyBytes, PyClass}; @@ -38,6 +38,16 @@ pub trait FromGil { fn from_gil(py: Python, obj: T) -> Self; } +impl FromGil for U +where + U: From, +{ + #[inline] + fn from_gil(_py: Python, obj: T) -> Self { + Self::from(obj) + } +} + pub trait IntoGil: Sized { fn into_gil(self, py: Python) -> T; } @@ -52,6 +62,22 @@ where } } +fn into_py_from(py: Python, obj: T) -> Py +where + T: IntoGil, + U: pyo3::PyClass + Into>, +{ + Py::new(py, obj.into_gil(py)).unwrap() +} + +fn from_py_into(py: Python, obj: &Py) -> U +where + T: PyClass, + U: for<'a> FromGil<&'a T>, +{ + (&*obj.borrow(py)).into_gil(py) +} + pub trait PyDefault: Sized + PyClass { fn py_default(py: Python) -> Py; }