diff --git a/src/vm.rs b/src/vm.rs index 1b893925..2fd759c6 100644 --- a/src/vm.rs +++ b/src/vm.rs @@ -92,7 +92,7 @@ pub type BuiltinFunction = fn(&mut C, u64, u64, u64, u64, u64, &mut MemoryMapping, &mut ProgramResult); /// Represents the interface to a fixed functionality program -#[derive(PartialEq, Eq)] +#[derive(Eq)] pub struct BuiltinProgram { /// Holds the Config if this is a loader program config: Option>, @@ -100,6 +100,12 @@ pub struct BuiltinProgram { functions: FunctionRegistry>, } +impl PartialEq for BuiltinProgram { + fn eq(&self, other: &Self) -> bool { + self.config.eq(&other.config) && self.functions.eq(&other.functions) + } +} + impl BuiltinProgram { /// Constructs a loader built-in program pub fn new_loader(config: Config, functions: FunctionRegistry>) -> Self { @@ -262,7 +268,7 @@ pub trait ContextObject { } /// Simple instruction meter for testing -#[derive(Debug, Clone, Default, PartialEq, Eq)] +#[derive(Debug, Clone, Default)] pub struct TestContextObject { /// Contains the register state at every instruction in order of execution pub trace_log: Vec, @@ -543,6 +549,7 @@ impl<'a, C: ContextObject> EbpfVm<'a, C> { #[cfg(test)] mod tests { use super::*; + use crate::syscalls; #[test] fn test_program_result_is_stable() { @@ -551,4 +558,34 @@ mod tests { let err = ProgramResult::Err(Box::new(EbpfError::JitNotCompiled)); assert_eq!(unsafe { *(&err as *const _ as *const u64) }, 1); } + + #[test] + fn test_builtin_program_eq() { + let mut function_registry_a = + FunctionRegistry::>::default(); + function_registry_a + .register_function_hashed(*b"log", syscalls::bpf_syscall_string) + .unwrap(); + function_registry_a + .register_function_hashed(*b"log_64", syscalls::bpf_syscall_u64) + .unwrap(); + let mut function_registry_b = + FunctionRegistry::>::default(); + function_registry_b + .register_function_hashed(*b"log_64", syscalls::bpf_syscall_u64) + .unwrap(); + function_registry_b + .register_function_hashed(*b"log", syscalls::bpf_syscall_string) + .unwrap(); + let mut function_registry_c = + FunctionRegistry::>::default(); + function_registry_c + .register_function_hashed(*b"log_64", syscalls::bpf_syscall_u64) + .unwrap(); + let builtin_program_a = BuiltinProgram::new_loader(Config::default(), function_registry_a); + let builtin_program_b = BuiltinProgram::new_loader(Config::default(), function_registry_b); + assert_eq!(builtin_program_a, builtin_program_b); + let builtin_program_c = BuiltinProgram::new_loader(Config::default(), function_registry_c); + assert_ne!(builtin_program_a, builtin_program_c); + } }