From 0afaf50119b4a9b174b8e09bf4373f011dfd530a Mon Sep 17 00:00:00 2001 From: Lee Bousfield Date: Wed, 11 Dec 2024 20:15:25 -0700 Subject: [PATCH] Deduplicate arbitrator RustBytes types --- arbitrator/prover/src/lib.rs | 79 ++++++++++++++++++++++------ arbitrator/stylus/src/evm_api.rs | 3 +- arbitrator/stylus/src/lib.rs | 90 ++++++++------------------------ arbos/programs/native.go | 2 +- validator/server_arb/machine.go | 7 +-- 5 files changed, 91 insertions(+), 90 deletions(-) diff --git a/arbitrator/prover/src/lib.rs b/arbitrator/prover/src/lib.rs index bc2bd4bc48..a147786086 100644 --- a/arbitrator/prover/src/lib.rs +++ b/arbitrator/prover/src/lib.rs @@ -36,6 +36,7 @@ use once_cell::sync::OnceCell; use static_assertions::const_assert_eq; use std::{ ffi::CStr, + marker::PhantomData, num::NonZeroUsize, os::raw::{c_char, c_int}, path::Path, @@ -59,11 +60,67 @@ pub struct CByteArray { } #[repr(C)] -#[derive(Clone, Copy)] -pub struct RustByteArray { +pub struct RustSlice<'a> { + pub ptr: *const u8, + pub len: usize, + pub phantom: PhantomData<&'a [u8]>, +} + +impl<'a> RustSlice<'a> { + pub fn new(slice: &'a [u8]) -> Self { + if slice.is_empty() { + return Self { + ptr: ptr::null(), + len: 0, + phantom: PhantomData, + }; + } + Self { + ptr: slice.as_ptr(), + len: slice.len(), + phantom: PhantomData, + } + } +} + +#[repr(C)] +pub struct RustBytes { pub ptr: *mut u8, pub len: usize, - pub capacity: usize, + pub cap: usize, +} + +impl RustBytes { + pub unsafe fn into_vec(self) -> Vec { + Vec::from_raw_parts(self.ptr, self.len, self.cap) + } + + pub unsafe fn write(&mut self, mut vec: Vec) { + if vec.capacity() == 0 { + *self = RustBytes { + ptr: ptr::null_mut(), + len: 0, + cap: 0, + }; + return; + } + self.ptr = vec.as_mut_ptr(); + self.len = vec.len(); + self.cap = vec.capacity(); + std::mem::forget(vec); + } +} + +/// Frees the vector. Does nothing when the vector is null. +/// +/// # Safety +/// +/// Must only be called once per vec. +#[no_mangle] +pub unsafe extern "C" fn free_rust_bytes(vec: RustBytes) { + if !vec.ptr.is_null() { + drop(vec.into_vec()) + } } #[no_mangle] @@ -410,18 +467,6 @@ pub unsafe extern "C" fn arbitrator_module_root(mach: *mut Machine) -> Bytes32 { #[no_mangle] #[cfg(feature = "native")] -pub unsafe extern "C" fn arbitrator_gen_proof(mach: *mut Machine) -> RustByteArray { - let mut proof = (*mach).serialize_proof(); - let ret = RustByteArray { - ptr: proof.as_mut_ptr(), - len: proof.len(), - capacity: proof.capacity(), - }; - std::mem::forget(proof); - ret -} - -#[no_mangle] -pub unsafe extern "C" fn arbitrator_free_proof(proof: RustByteArray) { - drop(Vec::from_raw_parts(proof.ptr, proof.len, proof.capacity)) +pub unsafe extern "C" fn arbitrator_gen_proof(mach: *mut Machine, out: *mut RustBytes) { + (*out).write((*mach).serialize_proof()); } diff --git a/arbitrator/stylus/src/evm_api.rs b/arbitrator/stylus/src/evm_api.rs index 0dd27e3f8c..7aa605dfe7 100644 --- a/arbitrator/stylus/src/evm_api.rs +++ b/arbitrator/stylus/src/evm_api.rs @@ -1,11 +1,12 @@ // Copyright 2022-2024, Offchain Labs, Inc. // For license information, see https://github.com/OffchainLabs/nitro/blob/master/LICENSE -use crate::{GoSliceData, RustSlice}; +use crate::GoSliceData; use arbutil::evm::{ api::{EvmApiMethod, Gas, EVM_API_METHOD_REQ_OFFSET}, req::RequestHandler, }; +use prover::RustSlice; #[repr(C)] pub struct NativeRequestHandler { diff --git a/arbitrator/stylus/src/lib.rs b/arbitrator/stylus/src/lib.rs index e7f10c2400..c73c4b2c2e 100644 --- a/arbitrator/stylus/src/lib.rs +++ b/arbitrator/stylus/src/lib.rs @@ -15,9 +15,12 @@ use cache::{deserialize_module, CacheMetrics, InitCache}; use evm_api::NativeRequestHandler; use eyre::ErrReport; use native::NativeInstance; -use prover::programs::{prelude::*, StylusData}; +use prover::{ + programs::{prelude::*, StylusData}, + RustBytes, +}; use run::RunProgram; -use std::{marker::PhantomData, mem, ptr}; +use std::ptr; use target_cache::{target_cache_get, target_cache_set}; pub use brotli; @@ -76,52 +79,15 @@ impl DataReader for GoSliceData { } } -#[repr(C)] -pub struct RustSlice<'a> { - ptr: *const u8, - len: usize, - phantom: PhantomData<&'a [u8]>, -} - -impl<'a> RustSlice<'a> { - fn new(slice: &'a [u8]) -> Self { - Self { - ptr: slice.as_ptr(), - len: slice.len(), - phantom: PhantomData, - } - } -} - -#[repr(C)] -pub struct RustBytes { - ptr: *mut u8, - len: usize, - cap: usize, +unsafe fn write_err(output: &mut RustBytes, err: ErrReport) -> UserOutcomeKind { + output.write(err.debug_bytes()); + UserOutcomeKind::Failure } -impl RustBytes { - unsafe fn into_vec(self) -> Vec { - Vec::from_raw_parts(self.ptr, self.len, self.cap) - } - - unsafe fn write(&mut self, mut vec: Vec) { - self.ptr = vec.as_mut_ptr(); - self.len = vec.len(); - self.cap = vec.capacity(); - mem::forget(vec); - } - - unsafe fn write_err(&mut self, err: ErrReport) -> UserOutcomeKind { - self.write(err.debug_bytes()); - UserOutcomeKind::Failure - } - - unsafe fn write_outcome(&mut self, outcome: UserOutcome) -> UserOutcomeKind { - let (status, outs) = outcome.into_data(); - self.write(outs); - status - } +unsafe fn write_outcome(output: &mut RustBytes, outcome: UserOutcome) -> UserOutcomeKind { + let (status, outs) = outcome.into_data(); + output.write(outs); + status } /// "activates" a user wasm. @@ -164,7 +130,7 @@ pub unsafe extern "C" fn stylus_activate( gas, ) { Ok(val) => val, - Err(err) => return output.write_err(err), + Err(err) => return write_err(output, err), }; *module_hash = module.hash(); @@ -194,16 +160,16 @@ pub unsafe extern "C" fn stylus_compile( let output = &mut *output; let name = match String::from_utf8(name.slice().to_vec()) { Ok(val) => val, - Err(err) => return output.write_err(err.into()), + Err(err) => return write_err(output, err.into()), }; let target = match target_cache_get(&name) { Ok(val) => val, - Err(err) => return output.write_err(err), + Err(err) => return write_err(output, err), }; let asm = match native::compile(wasm, version, debug, target) { Ok(val) => val, - Err(err) => return output.write_err(err), + Err(err) => return write_err(output, err), }; output.write(asm); @@ -218,7 +184,7 @@ pub unsafe extern "C" fn wat_to_wasm(wat: GoSliceData, output: *mut RustBytes) - let output = &mut *output; let wasm = match wasmer::wat2wasm(wat.slice()) { Ok(val) => val, - Err(err) => return output.write_err(err.into()), + Err(err) => return write_err(output, err.into()), }; output.write(wasm.into_owned()); UserOutcomeKind::Success @@ -241,16 +207,16 @@ pub unsafe extern "C" fn stylus_target_set( let output = &mut *output; let name = match String::from_utf8(name.slice().to_vec()) { Ok(val) => val, - Err(err) => return output.write_err(err.into()), + Err(err) => return write_err(output, err.into()), }; let desc_str = match String::from_utf8(description.slice().to_vec()) { Ok(val) => val, - Err(err) => return output.write_err(err.into()), + Err(err) => return write_err(output, err.into()), }; if let Err(err) = target_cache_set(name, desc_str, native) { - return output.write_err(err); + return write_err(output, err); }; UserOutcomeKind::Success @@ -298,8 +264,8 @@ pub unsafe extern "C" fn stylus_call( }; let status = match instance.run_main(&calldata, config, ink) { - Err(e) | Ok(UserOutcome::Failure(e)) => output.write_err(e.wrap_err("call failed")), - Ok(outcome) => output.write_outcome(outcome), + Err(e) | Ok(UserOutcome::Failure(e)) => write_err(output, e.wrap_err("call failed")), + Ok(outcome) => write_outcome(output, outcome), }; let ink_left = match status { UserOutcomeKind::OutOfStack => Ink(0), // take all gas when out of stack @@ -352,18 +318,6 @@ pub extern "C" fn stylus_reorg_vm(_block: u64, arbos_tag: u32) { InitCache::clear_long_term(arbos_tag); } -/// Frees the vector. Does nothing when the vector is null. -/// -/// # Safety -/// -/// Must only be called once per vec. -#[no_mangle] -pub unsafe extern "C" fn stylus_drop_vec(vec: RustBytes) { - if !vec.ptr.is_null() { - mem::drop(vec.into_vec()) - } -} - /// Gets cache metrics. /// /// # Safety diff --git a/arbos/programs/native.go b/arbos/programs/native.go index f162704995..73d3fe83d7 100644 --- a/arbos/programs/native.go +++ b/arbos/programs/native.go @@ -464,7 +464,7 @@ func (vec *rustBytes) intoBytes() []byte { } func (vec *rustBytes) drop() { - C.stylus_drop_vec(*vec) + C.free_rust_bytes(*vec) } func goSlice(slice []byte) C.GoSliceData { diff --git a/validator/server_arb/machine.go b/validator/server_arb/machine.go index 09a00635fb..c781234124 100644 --- a/validator/server_arb/machine.go +++ b/validator/server_arb/machine.go @@ -304,9 +304,10 @@ func (m *ArbitratorMachine) ProveNextStep() []byte { m.mutex.Lock() defer m.mutex.Unlock() - rustProof := C.arbitrator_gen_proof(m.ptr) - proofBytes := C.GoBytes(unsafe.Pointer(rustProof.ptr), C.int(rustProof.len)) - C.arbitrator_free_proof(rustProof) + output := &C.RustBytes{} + C.arbitrator_gen_proof(m.ptr, output) + proofBytes := C.GoBytes(unsafe.Pointer(output.ptr), C.int(output.len)) + C.free_rust_bytes(*output) return proofBytes }