Skip to content

Commit

Permalink
Add __match_args__ for structs
Browse files Browse the repository at this point in the history
  • Loading branch information
VirxEC committed Jun 28, 2024
1 parent 0b9cb3b commit db80588
Show file tree
Hide file tree
Showing 6 changed files with 88 additions and 13 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
[package]
name = "rlbot-flatbuffers-py"
version = "0.4.0"
version = "0.4.1"
edition = "2021"
description = "A Python module implemented in Rust for serializing and deserializing RLBot's flatbuffers"
repository = "https://github.com/VirxEC/rlbot_flatbuffers_py"
build = "codegen/main.rs"
license = "MIT"
readme = "README.md"
exclude = [".github", "pytest.py", "rustfmt.toml", ".gitignore", ".gitmodules", "flatc_mac"]
exclude = [".github", "pytest.py", "pybench.py", "rustfmt.toml", ".gitignore", ".gitmodules", "flatc_mac"]
publish = false

[lints.clippy]
Expand Down
2 changes: 0 additions & 2 deletions codegen/enums.rs
Original file line number Diff line number Diff line change
Expand Up @@ -324,8 +324,6 @@ impl Generator for EnumBindGenerator {
write_str!(self, "");

self.generate_repr_method();
write_str!(self, "");

write_str!(self, "}");
write_str!(self, "");
}
Expand Down
10 changes: 10 additions & 0 deletions codegen/pyi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,16 @@ pub fn generator(type_data: &[PythonBindType]) -> io::Result<()> {
}
}

if !gen.types.is_empty() {
write_str!(file, "");
write_str!(file, " __match_args__ = (");

for variable_info in &gen.types {
write_fmt!(file, " \"{}\",", variable_info.name);
}
write_str!(file, " )");
}

if gen.types.is_empty() {
write_str!(file, " def __init__(self): ...");
} else {
Expand Down
51 changes: 49 additions & 2 deletions codegen/structs.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::{generator::Generator, PythonBindType};
use std::{borrow::Cow, fs, path::Path};
use std::{borrow::Cow, fs, iter::repeat, path::Path};

#[derive(Debug, PartialEq, Eq)]
pub enum InnerVecType {
Expand Down Expand Up @@ -388,7 +388,10 @@ impl StructBindGenerator {
return;
}

write_str!(self, " #[allow(unused_variables)]");
if !self.frozen_needs_py {
write_str!(self, " #[allow(unused_variables)]");
}

write_str!(self, " pub fn __repr__(&self, py: Python) -> String {");
write_str!(self, " format!(");

Expand Down Expand Up @@ -496,6 +499,47 @@ impl StructBindGenerator {
write_str!(self, " }");
}

fn generate_long_match_args(&mut self) {
write_str!(self, " #[classattr]");
write_str!(
self,
" fn __match_args__(py: Python) -> Bound<pyo3::types::PyTuple> {"
);
write_str!(self, " pyo3::types::PyTuple::new_bound(py, [");

for variable_info in &self.types {
write_fmt!(self, " \"{}\",", variable_info.name);
}

write_str!(self, " ])");
write_str!(self, " }");
}

fn generate_match_args(&mut self) {
if self.types.is_empty() {
return;
}

if self.types.len() > 12 {
self.generate_long_match_args();
return;
}

let sig_parts: Vec<_> = repeat("&'static str").take(self.types.len()).collect();
let sig = sig_parts.join(", ");

write_str!(self, " #[classattr]");
write_fmt!(self, " fn __match_args__() -> ({sig},) {{",);
write_str!(self, " (");

for variable_info in &self.types {
write_fmt!(self, " \"{}\",", variable_info.name);
}

write_str!(self, " )");
write_str!(self, " }");
}

fn generate_pack_method(&mut self) {
write_str!(
self,
Expand Down Expand Up @@ -939,6 +983,9 @@ impl Generator for StructBindGenerator {
self.generate_repr_method();
write_str!(self, "");

self.generate_match_args();
write_str!(self, "");

self.generate_pack_method();
write_str!(self, "");

Expand Down
32 changes: 26 additions & 6 deletions pytest.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,19 @@ def random_script_config():
dgs = DesiredGameState(
game_info_state=DesiredGameInfoState(game_speed=2, end_match=Bool())
)
dgs.game_info_state.world_gravity_z = Float(-650)
dgs.game_info_state.end_match.val = True

match dgs.game_info_state:
case DesiredGameInfoState():
dgs.game_info_state.world_gravity_z = Float(-650)
case _:
assert False

match dgs.game_info_state.end_match:
case Bool(val):
dgs.game_info_state.end_match.val = not val
case _:
assert False

dgs.console_commands = [ConsoleCommand("dump_items")]
dgs.ball_state = DesiredBallState()

Expand Down Expand Up @@ -88,7 +99,14 @@ def random_script_config():
print(comm.content.decode("utf-8"))
print()

print(hash(AirState.Dodging))
air_state = AirState.Dodging
print(hash(air_state))

match air_state:
case AirState.Dodging:
pass
case _:
assert False

try:
AirState(8)
Expand Down Expand Up @@ -117,12 +135,14 @@ def random_script_config():
renderPolyLine = RenderMessage(
PolyLine3D(
[Vector3() for _ in range(2048)],
)
Color(255),
),
)

match renderPolyLine.variety.item:
case PolyLine3D():
assert len(renderPolyLine.variety.item.points) == 2048
case PolyLine3D(points, clr):
assert len(points) == 2048
assert clr.a == 255
case _:
assert False

Expand Down

0 comments on commit db80588

Please sign in to comment.