Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deduplicate arbitrator RustBytes types #2830

Merged
merged 5 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 62 additions & 17 deletions arbitrator/prover/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ use once_cell::sync::OnceCell;
use static_assertions::const_assert_eq;
use std::{
ffi::CStr,
marker::PhantomData,
num::NonZeroUsize,
os::raw::{c_char, c_int},
path::Path,
Expand All @@ -59,11 +60,67 @@ pub struct CByteArray {
}

#[repr(C)]
#[derive(Clone, Copy)]
pub struct RustByteArray {
pub struct RustSlice<'a> {
pub ptr: *const u8,
pub len: usize,
pub phantom: PhantomData<&'a [u8]>,
}

impl<'a> RustSlice<'a> {
pub fn new(slice: &'a [u8]) -> Self {
if slice.is_empty() {
return Self {
ptr: ptr::null(),
len: 0,
phantom: PhantomData,
};
}
Self {
ptr: slice.as_ptr(),
len: slice.len(),
phantom: PhantomData,
}
}
}

#[repr(C)]
pub struct RustBytes {
pub ptr: *mut u8,
pub len: usize,
pub capacity: usize,
pub cap: usize,
}

impl RustBytes {
pub unsafe fn into_vec(self) -> Vec<u8> {
Vec::from_raw_parts(self.ptr, self.len, self.cap)
}

pub unsafe fn write(&mut self, mut vec: Vec<u8>) {
if vec.capacity() == 0 {
*self = RustBytes {
ptr: ptr::null_mut(),
len: 0,
cap: 0,
};
return;
}
self.ptr = vec.as_mut_ptr();
self.len = vec.len();
self.cap = vec.capacity();
std::mem::forget(vec);
}
}

/// Frees the vector. Does nothing when the vector is null.
///
/// # Safety
///
/// Must only be called once per vec.
#[no_mangle]
pub unsafe extern "C" fn free_rust_bytes(vec: RustBytes) {
if !vec.ptr.is_null() {
drop(vec.into_vec())
}
}

#[no_mangle]
Expand Down Expand Up @@ -410,18 +467,6 @@ pub unsafe extern "C" fn arbitrator_module_root(mach: *mut Machine) -> Bytes32 {

#[no_mangle]
#[cfg(feature = "native")]
pub unsafe extern "C" fn arbitrator_gen_proof(mach: *mut Machine) -> RustByteArray {
let mut proof = (*mach).serialize_proof();
let ret = RustByteArray {
ptr: proof.as_mut_ptr(),
len: proof.len(),
capacity: proof.capacity(),
};
std::mem::forget(proof);
ret
}

#[no_mangle]
pub unsafe extern "C" fn arbitrator_free_proof(proof: RustByteArray) {
drop(Vec::from_raw_parts(proof.ptr, proof.len, proof.capacity))
pub unsafe extern "C" fn arbitrator_gen_proof(mach: *mut Machine, out: *mut RustBytes) {
(*out).write((*mach).serialize_proof());
}
3 changes: 2 additions & 1 deletion arbitrator/stylus/src/evm_api.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
// Copyright 2022-2024, Offchain Labs, Inc.
// For license information, see https://github.com/OffchainLabs/nitro/blob/master/LICENSE

use crate::{GoSliceData, RustSlice};
use crate::GoSliceData;
use arbutil::evm::{
api::{EvmApiMethod, Gas, EVM_API_METHOD_REQ_OFFSET},
req::RequestHandler,
};
use prover::RustSlice;

#[repr(C)]
pub struct NativeRequestHandler {
Expand Down
90 changes: 22 additions & 68 deletions arbitrator/stylus/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@ use cache::{deserialize_module, CacheMetrics, InitCache};
use evm_api::NativeRequestHandler;
use eyre::ErrReport;
use native::NativeInstance;
use prover::programs::{prelude::*, StylusData};
use prover::{
programs::{prelude::*, StylusData},
RustBytes,
};
use run::RunProgram;
use std::{marker::PhantomData, mem, ptr};
use std::ptr;
use target_cache::{target_cache_get, target_cache_set};

pub use brotli;
Expand Down Expand Up @@ -76,52 +79,15 @@ impl DataReader for GoSliceData {
}
}

#[repr(C)]
pub struct RustSlice<'a> {
ptr: *const u8,
len: usize,
phantom: PhantomData<&'a [u8]>,
}

impl<'a> RustSlice<'a> {
fn new(slice: &'a [u8]) -> Self {
Self {
ptr: slice.as_ptr(),
len: slice.len(),
phantom: PhantomData,
}
}
}

#[repr(C)]
pub struct RustBytes {
ptr: *mut u8,
len: usize,
cap: usize,
unsafe fn write_err(output: &mut RustBytes, err: ErrReport) -> UserOutcomeKind {
output.write(err.debug_bytes());
UserOutcomeKind::Failure
}

impl RustBytes {
unsafe fn into_vec(self) -> Vec<u8> {
Vec::from_raw_parts(self.ptr, self.len, self.cap)
}

unsafe fn write(&mut self, mut vec: Vec<u8>) {
self.ptr = vec.as_mut_ptr();
self.len = vec.len();
self.cap = vec.capacity();
mem::forget(vec);
}

unsafe fn write_err(&mut self, err: ErrReport) -> UserOutcomeKind {
self.write(err.debug_bytes());
UserOutcomeKind::Failure
}

unsafe fn write_outcome(&mut self, outcome: UserOutcome) -> UserOutcomeKind {
let (status, outs) = outcome.into_data();
self.write(outs);
status
}
unsafe fn write_outcome(output: &mut RustBytes, outcome: UserOutcome) -> UserOutcomeKind {
let (status, outs) = outcome.into_data();
output.write(outs);
status
}

/// "activates" a user wasm.
Expand Down Expand Up @@ -164,7 +130,7 @@ pub unsafe extern "C" fn stylus_activate(
gas,
) {
Ok(val) => val,
Err(err) => return output.write_err(err),
Err(err) => return write_err(output, err),
};

*module_hash = module.hash();
Expand Down Expand Up @@ -194,16 +160,16 @@ pub unsafe extern "C" fn stylus_compile(
let output = &mut *output;
let name = match String::from_utf8(name.slice().to_vec()) {
Ok(val) => val,
Err(err) => return output.write_err(err.into()),
Err(err) => return write_err(output, err.into()),
};
let target = match target_cache_get(&name) {
Ok(val) => val,
Err(err) => return output.write_err(err),
Err(err) => return write_err(output, err),
};

let asm = match native::compile(wasm, version, debug, target) {
Ok(val) => val,
Err(err) => return output.write_err(err),
Err(err) => return write_err(output, err),
};

output.write(asm);
Expand All @@ -218,7 +184,7 @@ pub unsafe extern "C" fn wat_to_wasm(wat: GoSliceData, output: *mut RustBytes) -
let output = &mut *output;
let wasm = match wasmer::wat2wasm(wat.slice()) {
Ok(val) => val,
Err(err) => return output.write_err(err.into()),
Err(err) => return write_err(output, err.into()),
};
output.write(wasm.into_owned());
UserOutcomeKind::Success
Expand All @@ -241,16 +207,16 @@ pub unsafe extern "C" fn stylus_target_set(
let output = &mut *output;
let name = match String::from_utf8(name.slice().to_vec()) {
Ok(val) => val,
Err(err) => return output.write_err(err.into()),
Err(err) => return write_err(output, err.into()),
};

let desc_str = match String::from_utf8(description.slice().to_vec()) {
Ok(val) => val,
Err(err) => return output.write_err(err.into()),
Err(err) => return write_err(output, err.into()),
};

if let Err(err) = target_cache_set(name, desc_str, native) {
return output.write_err(err);
return write_err(output, err);
};

UserOutcomeKind::Success
Expand Down Expand Up @@ -298,8 +264,8 @@ pub unsafe extern "C" fn stylus_call(
};

let status = match instance.run_main(&calldata, config, ink) {
Err(e) | Ok(UserOutcome::Failure(e)) => output.write_err(e.wrap_err("call failed")),
Ok(outcome) => output.write_outcome(outcome),
Err(e) | Ok(UserOutcome::Failure(e)) => write_err(output, e.wrap_err("call failed")),
Ok(outcome) => write_outcome(output, outcome),
};
let ink_left = match status {
UserOutcomeKind::OutOfStack => Ink(0), // take all gas when out of stack
Expand Down Expand Up @@ -352,18 +318,6 @@ pub extern "C" fn stylus_reorg_vm(_block: u64, arbos_tag: u32) {
InitCache::clear_long_term(arbos_tag);
}

/// Frees the vector. Does nothing when the vector is null.
///
/// # Safety
///
/// Must only be called once per vec.
#[no_mangle]
pub unsafe extern "C" fn stylus_drop_vec(vec: RustBytes) {
if !vec.ptr.is_null() {
mem::drop(vec.into_vec())
}
}

/// Gets cache metrics.
///
/// # Safety
Expand Down
8 changes: 7 additions & 1 deletion arbos/programs/native.go
Original file line number Diff line number Diff line change
Expand Up @@ -450,10 +450,16 @@ func addressToBytes20(addr common.Address) bytes20 {
}

func (slice *rustSlice) read() []byte {
if slice.len == 0 {
return nil
}
return arbutil.PointerToSlice((*byte)(slice.ptr), int(slice.len))
}

func (vec *rustBytes) read() []byte {
if vec.len == 0 {
return nil
}
return arbutil.PointerToSlice((*byte)(vec.ptr), int(vec.len))
}

Expand All @@ -464,7 +470,7 @@ func (vec *rustBytes) intoBytes() []byte {
}

func (vec *rustBytes) drop() {
C.stylus_drop_vec(*vec)
C.free_rust_bytes(*vec)
}

func goSlice(slice []byte) C.GoSliceData {
Expand Down
10 changes: 7 additions & 3 deletions validator/server_arb/machine.go
Original file line number Diff line number Diff line change
Expand Up @@ -304,9 +304,13 @@ func (m *ArbitratorMachine) ProveNextStep() []byte {
m.mutex.Lock()
defer m.mutex.Unlock()

rustProof := C.arbitrator_gen_proof(m.ptr)
proofBytes := C.GoBytes(unsafe.Pointer(rustProof.ptr), C.int(rustProof.len))
C.arbitrator_free_proof(rustProof)
output := &C.RustBytes{}
C.arbitrator_gen_proof(m.ptr, output)
defer C.free_rust_bytes(*output)
if output.len == 0 {
return nil
}
proofBytes := C.GoBytes(unsafe.Pointer(output.ptr), C.int(output.len))

return proofBytes
}
Expand Down
Loading