Skip to content

Commit

Permalink
Merge pull request #82 from caspark/no-bytemuck-bounds-on-input
Browse files Browse the repository at this point in the history
RFC: feat: Input type must be Default+Serde but not POD
  • Loading branch information
gschup authored Dec 14, 2024
2 parents b66d18f + a06b63b commit 58ed06e
Show file tree
Hide file tree
Showing 9 changed files with 29 additions and 59 deletions.
1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ rand = "0.8"
bitfield-rle = "0.2.1"
parking_lot = "0.12"
instant = "0.1"
bytemuck = {version = "1.9", features = ["derive"]}
getrandom = {version = "0.2", optional = true}

[target.'cfg(target_arch = "wasm32")'.dependencies]
Expand Down
3 changes: 1 addition & 2 deletions examples/ex_game/ex_game.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use std::net::SocketAddr;

use bytemuck::{Pod, Zeroable};
use ggrs::{Config, Frame, GameStateCell, GgrsRequest, InputStatus, PlayerHandle, NULL_FRAME};
use macroquad::prelude::*;
use serde::{Deserialize, Serialize};
Expand All @@ -24,7 +23,7 @@ const MAX_SPEED: f32 = 7.0;
const FRICTION: f32 = 0.98;

#[repr(C)]
#[derive(Copy, Clone, PartialEq, Pod, Zeroable)]
#[derive(Copy, Clone, PartialEq, Default, Serialize, Deserialize)]
pub struct Input {
pub inp: u8,
}
Expand Down
22 changes: 4 additions & 18 deletions src/frame_info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,36 +27,23 @@ impl<S> Default for GameState<S> {
#[derive(Debug, Copy, Clone, PartialEq)]
pub(crate) struct PlayerInput<I>
where
I: Copy
+ Clone
+ PartialEq
+ bytemuck::NoUninit
+ bytemuck::CheckedBitPattern
+ bytemuck::Zeroable,
I: Copy + Clone + PartialEq,
{
/// The frame to which this info belongs to. -1/[`NULL_FRAME`] represents an invalid frame
pub frame: Frame,
/// The input struct given by the user
pub input: I,
}

impl<
I: Copy
+ Clone
+ PartialEq
+ bytemuck::NoUninit
+ bytemuck::Zeroable
+ bytemuck::CheckedBitPattern,
> PlayerInput<I>
{
impl<I: Copy + Clone + PartialEq + Default> PlayerInput<I> {
pub(crate) fn new(frame: Frame, input: I) -> Self {
Self { frame, input }
}

pub(crate) fn blank_input(frame: Frame) -> Self {
Self {
frame,
input: I::zeroed(),
input: I::default(),
}
}

Expand All @@ -72,10 +59,9 @@ impl<
#[cfg(test)]
mod game_input_tests {
use super::*;
use bytemuck::{Pod, Zeroable};

#[repr(C)]
#[derive(Copy, Clone, PartialEq, Pod, Zeroable)]
#[derive(Copy, Clone, PartialEq, Default)]
struct TestInput {
inp: u8,
}
Expand Down
6 changes: 3 additions & 3 deletions src/input_queue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ pub(crate) struct InputQueue<T>
where
T: Config,
{
/// The head of the queue. The newest `PlayerInput` is saved here
/// The head of the queue. The newest `PlayerInput` is saved here
head: usize,
/// The tail of the queue. The oldest `PlayerInput` still valid is saved here.
tail: usize,
Expand Down Expand Up @@ -250,12 +250,12 @@ mod input_queue_tests {

use std::net::SocketAddr;

use bytemuck::{Pod, Zeroable};
use serde::{Deserialize, Serialize};

use super::*;

#[repr(C)]
#[derive(Copy, Clone, PartialEq, Pod, Zeroable)]
#[derive(Copy, Clone, PartialEq, Default, Serialize, Deserialize)]
struct TestInput {
inp: u8,
}
Expand Down
24 changes: 7 additions & 17 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ pub use error::GgrsError;
pub use network::messages::Message;
pub use network::network_stats::NetworkStats;
pub use network::udp_socket::UdpNonBlockingSocket;
use serde::{de::DeserializeOwned, Serialize};
pub use sessions::builder::SessionBuilder;
pub use sessions::p2p_session::P2PSession;
pub use sessions::p2p_spectator_session::SpectatorSession;
Expand Down Expand Up @@ -205,12 +206,9 @@ pub trait Config: 'static + Send + Sync {
/// The input type for a session. This is the only game-related data
/// transmitted over the network.
///
/// Reminder: Types implementing [Pod] may not have the same byte representation
/// on platforms with different endianness. GGRS assumes that all players are
/// running with the same endianness when encoding and decoding inputs.
///
/// [Pod]: bytemuck::Pod
type Input: Copy + Clone + PartialEq + bytemuck::Pod + bytemuck::Zeroable + Send + Sync;
/// The implementation of [Default] is used for representing "no input" for
/// a player, including when a player is disconnected.
type Input: Copy + Clone + PartialEq + Default + Serialize + DeserializeOwned + Send + Sync;

/// The save state type for the session.
type State: Clone + Send + Sync;
Expand Down Expand Up @@ -242,17 +240,9 @@ pub trait Config: 'static {
/// The input type for a session. This is the only game-related data
/// transmitted over the network.
///
/// Reminder: Types implementing [Pod] may not have the same byte representation
/// on platforms with different endianness. GGRS assumes that all players are
/// running with the same endianness when encoding and decoding inputs.
///
/// [Pod]: bytemuck::Pod
type Input: Copy
+ Clone
+ PartialEq
+ bytemuck::NoUninit
+ bytemuck::CheckedBitPattern
+ bytemuck::Zeroable;
/// The implementation of [Default] is used for representing "no input" for
/// a player, including when a player is disconnected.
type Input: Copy + Clone + PartialEq + Default + Serialize + DeserializeOwned;

/// The save state type for the session.
type State;
Expand Down
10 changes: 6 additions & 4 deletions src/network/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,9 @@ impl InputBytes {
if input.frame != NULL_FRAME {
frame = input.frame;
}
let byte_vec = bytemuck::bytes_of(&input.input);
bytes.extend_from_slice(byte_vec);

bincode::serialize_into(&mut bytes, &input.input)
.expect("input serialization failed");
}
}
Self { frame, bytes }
Expand All @@ -87,8 +88,9 @@ impl InputBytes {
for p in 0..num_players {
let start = p * size;
let end = start + size;
let input = *bytemuck::checked::try_from_bytes::<T::Input>(&self.bytes[start..end])
.expect("Expected received data to be valid.");
let player_byte_slice = &self.bytes[start..end];
let input: T::Input =
bincode::deserialize(player_byte_slice).expect("input deserialization failed");
player_inputs.push(PlayerInput::new(self.frame, input));
}
player_inputs
Expand Down
7 changes: 3 additions & 4 deletions src/sync_layer.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use bytemuck::Zeroable;
use parking_lot::{MappedMutexGuard, Mutex};
use std::ops::Deref;
use std::sync::Arc;
Expand Down Expand Up @@ -282,7 +281,7 @@ impl<T: Config> SyncLayer<T> {
let mut inputs = Vec::new();
for (i, con_stat) in connect_status.iter().enumerate() {
if con_stat.disconnected && con_stat.last_frame < self.current_frame {
inputs.push((T::Input::zeroed(), InputStatus::Disconnected));
inputs.push((T::Input::default(), InputStatus::Disconnected));
} else {
inputs.push(self.input_queues[i].input(self.current_frame));
}
Expand Down Expand Up @@ -380,11 +379,11 @@ impl<T: Config> SyncLayer<T> {
mod sync_layer_tests {

use super::*;
use bytemuck::{Pod, Zeroable};
use serde::{Deserialize, Serialize};
use std::net::SocketAddr;

#[repr(C)]
#[derive(Copy, Clone, PartialEq, Pod, Zeroable)]
#[derive(Copy, Clone, PartialEq, Default, Serialize, Deserialize)]
struct TestInput {
inp: u8,
}
Expand Down
4 changes: 2 additions & 2 deletions tests/stubs.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use rand::{prelude::ThreadRng, thread_rng, Rng};
use serde::{Deserialize, Serialize};
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use std::net::SocketAddr;
Expand All @@ -14,10 +15,9 @@ fn calculate_hash<T: Hash>(t: &T) -> u64 {
pub struct GameStub {
pub gs: StateStub,
}
use bytemuck::{Pod, Zeroable};

#[repr(C)]
#[derive(Copy, Clone, PartialEq, Pod, Zeroable)]
#[derive(Copy, Clone, PartialEq, Default, Serialize, Deserialize)]
pub struct StubInput {
pub inp: u32,
}
Expand Down
11 changes: 3 additions & 8 deletions tests/stubs_enum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::hash::{Hash, Hasher};
use std::net::SocketAddr;

use ggrs::{Config, Frame, GameStateCell, GgrsRequest, InputStatus};
use serde::{Deserialize, Serialize};

fn calculate_hash<T: Hash>(t: &T) -> u64 {
let mut s = DefaultHasher::new();
Expand All @@ -13,22 +14,16 @@ fn calculate_hash<T: Hash>(t: &T) -> u64 {
pub struct GameStubEnum {
pub gs: StateStubEnum,
}
use bytemuck::{CheckedBitPattern, NoUninit, Zeroable};

#[allow(dead_code)]
#[repr(u8)]
#[derive(Copy, Clone, PartialEq, CheckedBitPattern, NoUninit)]
#[derive(Copy, Clone, PartialEq, Default, Serialize, Deserialize)]
pub enum EnumInput {
#[default]
Val1,
Val2,
}

unsafe impl Zeroable for EnumInput {
fn zeroed() -> Self {
unsafe { core::mem::zeroed() }
}
}

pub struct StubEnumConfig;

impl Config for StubEnumConfig {
Expand Down

0 comments on commit 58ed06e

Please sign in to comment.