From 52d45cace354bc7dd0cd0c920fe00b4c62ca169c Mon Sep 17 00:00:00 2001 From: Kevin Wang Date: Wed, 12 Oct 2022 01:45:31 +0000 Subject: [PATCH] sidevm: Merge the two ocall(_fast_return) implementation --- crates/sidevm/env/src/tests.rs | 19 +--- crates/sidevm/host-runtime/src/env.rs | 41 +++---- crates/sidevm/macro/src/macro_ocall.rs | 35 +++--- .../snapshots/sidevm_macro__tests__ocall.snap | 102 +++++++++--------- 4 files changed, 83 insertions(+), 114 deletions(-) diff --git a/crates/sidevm/env/src/tests.rs b/crates/sidevm/env/src/tests.rs index 68e5187928..52ab84225f 100644 --- a/crates/sidevm/env/src/tests.rs +++ b/crates/sidevm/env/src/tests.rs @@ -21,9 +21,6 @@ pub trait TestOcall { #[ocall(id = 105, encode_output)] fn sub_fi(a: i32, b: i32) -> Result; - - #[ocall(id = 106)] - fn copy(dst: &mut [u8], src: &[u8]) -> Result<()>; } struct TestHost; @@ -46,10 +43,6 @@ impl TestOcall for TestHost { fn sub_fi(&mut self, a: i32, b: i32) -> Result { Ok(a.wrapping_sub(b)) } - fn copy(&mut self, dst: &mut [u8], src: &[u8]) -> Result<()> { - dst.copy_from_slice(src); - Ok(()) - } } impl OcallEnv for TestHost { @@ -97,7 +90,7 @@ extern "C" fn sidevm_ocall( p2: IntPtr, p3: IntPtr, ) -> IntRet { - let result = dispatch_call(&mut TestHost, func_id, p0, p1, p2, p3); + let result = dispatch_ocall(false, &mut TestHost, &TestHost, func_id, p0, p1, p2, p3); println!("sidevm_ocall {} result={:?}", func_id, result); result.encode_ret() } @@ -111,7 +104,7 @@ extern "C" fn sidevm_ocall_fast_return( p2: IntPtr, p3: IntPtr, ) -> IntRet { - let result = dispatch_call_fast_return(&mut TestHost, func_id, p0, p1, p2, p3); + let result = dispatch_ocall(true, &mut TestHost, &TestHost, func_id, p0, p1, p2, p3); println!("sidevm_ocall_fast_return {} result={:?}", func_id, result); result.encode_ret() } @@ -137,14 +130,6 @@ fn test_fi_fo() { assert_eq!(ocall::sub_fi(1, 4).unwrap(), -3); } -#[test] -fn test_fi_fo_buf() { - let mut a = [0u8; 4]; - let b = [1u8, 2, 3, 4]; - ocall::copy(&mut a[..], &b[..]).unwrap(); - assert_eq!(a, b); -} - #[test] fn test_fi_fo_overflow() { let a = u32::MAX; diff --git a/crates/sidevm/host-runtime/src/env.rs b/crates/sidevm/host-runtime/src/env.rs index 5552dad52d..8db8089802 100644 --- a/crates/sidevm/host-runtime/src/env.rs +++ b/crates/sidevm/host-runtime/src/env.rs @@ -629,7 +629,7 @@ async fn tcp_connect(host: &str, port: u16) -> std::io::Result, + func_env: FunctionEnvMut, task_id: i32, func_id: i32, p0: IntPtr, @@ -637,31 +637,23 @@ fn sidevm_ocall_fast_return( p2: IntPtr, p3: IntPtr, ) -> Result { - let inner = func_env.data().inner.clone(); - let mut guard = inner.lock().unwrap(); - let env = &mut *guard; - - env.current_task = task_id; - let result = set_task_env(env.awake_tasks.clone(), task_id, || { - let memory = env.memory.unwrap_ref().clone(); - let vm = MemoryView(memory.view(&func_env)); - let mut state = env.make_mut(&mut func_env); - env::dispatch_call_fast_return(&mut state, &vm, func_id, p0, p1, p2, p3) - }); - - if env.ocall_trace_enabled { - let func_name = env::ocall_id2name(func_id); - let vm_id = ShortId(&env.id); - log::trace!( - target: "sidevm", - "[{vm_id}][tid={task_id:<3}](F) {func_name}({p0}, {p1}, {p2}, {p3}) = {result:?}" - ); - } - convert(result) + do_ocall(func_env, task_id, func_id, p0, p1, p2, p3, true) } // Support all ocalls. Put the result into a temporary vec and wait for next fetch_result ocall to fetch the result. fn sidevm_ocall( + func_env: FunctionEnvMut, + task_id: i32, + func_id: i32, + p0: IntPtr, + p1: IntPtr, + p2: IntPtr, + p3: IntPtr, +) -> Result { + do_ocall(func_env, task_id, func_id, p0, p1, p2, p3, false) +} + +fn do_ocall( mut func_env: FunctionEnvMut, task_id: i32, func_id: i32, @@ -669,6 +661,7 @@ fn sidevm_ocall( p1: IntPtr, p2: IntPtr, p3: IntPtr, + fast_return: bool, ) -> Result { let inner = func_env.data().inner.clone(); let mut guard = inner.lock().unwrap(); @@ -679,7 +672,7 @@ fn sidevm_ocall( let memory = env.memory.unwrap_ref().clone(); let vm = MemoryView(memory.view(&func_env)); let mut state = env.make_mut(&mut func_env); - env::dispatch_call(&mut state, &vm, func_id, p0, p1, p2, p3) + env::dispatch_ocall(fast_return, &mut state, &vm, func_id, p0, p1, p2, p3) }); if env.ocall_trace_enabled { @@ -687,7 +680,7 @@ fn sidevm_ocall( let vm_id = ShortId(&env.id); log::trace!( target: "sidevm", - "[{vm_id}][tid={task_id:<3}](S) {func_name}({p0}, {p1}, {p2}, {p3}) = {result:?}" + "[{vm_id}][tid={task_id:<3}] {func_name}({p0}, {p1}, {p2}, {p3}) = {result:?}" ); } convert(result) diff --git a/crates/sidevm/macro/src/macro_ocall.rs b/crates/sidevm/macro/src/macro_ocall.rs index 52dfe2bfed..daba445b68 100644 --- a/crates/sidevm/macro/src/macro_ocall.rs +++ b/crates/sidevm/macro/src/macro_ocall.rs @@ -212,36 +212,29 @@ fn gen_dispatcher(methods: &[OcallMethod], trait_name: &Ident) -> Result( + pub fn dispatch_ocall( + fast_return: bool, env: &mut Env, vm: &Vm, id: i32, p0: IntPtr, p1: IntPtr, p2: IntPtr, - p3: IntPtr + p3: IntPtr, ) -> Result { - match id { - 0 => #call_get_return, - #(#fast_calls)* - _ => Err(OcallError::UnknownCallNumber), + if fast_return { + match id { + 0 => #call_get_return, + #(#fast_calls)* + _ => Err(OcallError::UnknownCallNumber), + } + } else { + Ok(match id { + #(#slow_calls)* + _ => return Err(OcallError::UnknownCallNumber), + }) } } - - pub fn dispatch_call( - env: &mut Env, - vm: &Vm, - id: i32, - p0: IntPtr, - p1: IntPtr, - p2: IntPtr, - p3: IntPtr - ) -> Result { - Ok(match id { - #(#slow_calls)* - _ => return Err(OcallError::UnknownCallNumber), - }) - } }) } diff --git a/crates/sidevm/macro/src/snapshots/sidevm_macro__tests__ocall.snap b/crates/sidevm/macro/src/snapshots/sidevm_macro__tests__ocall.snap index 0ca6515c8f..67bfadee3d 100644 --- a/crates/sidevm/macro/src/snapshots/sidevm_macro__tests__ocall.snap +++ b/crates/sidevm/macro/src/snapshots/sidevm_macro__tests__ocall.snap @@ -72,69 +72,67 @@ pub mod ocall_guest { } } } -pub fn dispatch_call_fast_return( +pub fn dispatch_ocall( + fast_return: bool, env: &mut Env, + vm: &Vm, id: i32, p0: IntPtr, p1: IntPtr, p2: IntPtr, p3: IntPtr, ) -> Result { - match id { - 0 => { - let buffer = env.take_return().ok_or(OcallError::NotFound)?; - let len = p1 as usize; - if buffer.len() != len { - return Err(OcallError::InvalidParameter); + if fast_return { + match id { + 0 => { + let buffer = env.take_return().ok_or(OcallError::NotFound)?; + let len = p1 as usize; + if buffer.len() != len { + return Err(OcallError::InvalidParameter); + } + vm.copy_to_vm(&buffer, p0)?; + Ok(len as i32) } - env.copy_to_vm(&buffer, p0)?; - Ok(len as i32) - } - 104 => { - let (a, b) = { - let mut buf = env.slice_from_vm(p0, p1)?; - Decode::decode(&mut buf).or(Err(OcallError::InvalidParameter))? - }; - env.call_fo(a, b).map(|x| x.to_i32()) - } - 102 => { - let stack = StackedArgs::load(&[p0, p1, p2, p3]).ok_or(OcallError::InvalidParameter)?; - let (b, stack) = stack.pop_arg(env)?; - let (a, stack) = stack.pop_arg(env)?; - let _: StackedArgs<()> = stack; - env.poll_fi_fo(a, b).map(|x| x.to_i32()) + 104 => { + let (a, b) = { + let mut buf = vm.slice_from_vm(p0, p1)?; + Decode::decode(&mut buf).or(Err(OcallError::InvalidParameter))? + }; + env.call_fo(a, b).map(|x| x.to_i32()) + } + 102 => { + let stack = + StackedArgs::load(&[p0, p1, p2, p3]).ok_or(OcallError::InvalidParameter)?; + let (b, stack) = stack.pop_arg(vm)?; + let (a, stack) = stack.pop_arg(vm)?; + let _: StackedArgs<()> = stack; + env.poll_fi_fo(a, b).map(|x| x.to_i32()) + } + _ => Err(OcallError::UnknownCallNumber), } - _ => Err(OcallError::UnknownCallNumber), + } else { + Ok(match id { + 101 => { + let (a, b) = { + let mut buf = vm.slice_from_vm(p0, p1)?; + Decode::decode(&mut buf).or(Err(OcallError::InvalidParameter))? + }; + let ret = env.call_slow(a, b); + env.put_return(ret?.encode()) as _ + } + 103 => { + let stack = + StackedArgs::load(&[p0, p1, p2, p3]).ok_or(OcallError::InvalidParameter)?; + let (b, stack) = stack.pop_arg(vm)?; + let (a, stack) = stack.pop_arg(vm)?; + let _: StackedArgs<()> = stack; + let ret = env.call_fi(a, b); + env.put_return(ret?.encode()) as _ + } + _ => return Err(OcallError::UnknownCallNumber), + }) } } -pub fn dispatch_call( - env: &mut Env, - id: i32, - p0: IntPtr, - p1: IntPtr, - p2: IntPtr, - p3: IntPtr, -) -> Result { - Ok(match id { - 101 => { - let (a, b) = { - let mut buf = env.slice_from_vm(p0, p1)?; - Decode::decode(&mut buf).or(Err(OcallError::InvalidParameter))? - }; - let ret = env.call_slow(a, b); - env.put_return(ret?.encode()) as _ - } - 103 => { - let stack = StackedArgs::load(&[p0, p1, p2, p3]).ok_or(OcallError::InvalidParameter)?; - let (b, stack) = stack.pop_arg(env)?; - let (a, stack) = stack.pop_arg(env)?; - let _: StackedArgs<()> = stack; - let ret = env.call_fi(a, b); - env.put_return(ret?.encode()) as _ - } - _ => return Err(OcallError::UnknownCallNumber), - }) -} pub fn ocall_id2name(id: i32) -> &'static str { match id { 0 => "get_return",