Skip to content

Commit

Permalink
sidevm: Merge the two ocall(_fast_return) implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
kvinwang committed Oct 12, 2022
1 parent 92bdb40 commit 52d45ca
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 114 deletions.
19 changes: 2 additions & 17 deletions crates/sidevm/env/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,6 @@ pub trait TestOcall {

#[ocall(id = 105, encode_output)]
fn sub_fi(a: i32, b: i32) -> Result<i32>;

#[ocall(id = 106)]
fn copy(dst: &mut [u8], src: &[u8]) -> Result<()>;
}

struct TestHost;
Expand All @@ -46,10 +43,6 @@ impl TestOcall for TestHost {
fn sub_fi(&mut self, a: i32, b: i32) -> Result<i32> {
Ok(a.wrapping_sub(b))
}
fn copy(&mut self, dst: &mut [u8], src: &[u8]) -> Result<()> {
dst.copy_from_slice(src);
Ok(())
}
}

impl OcallEnv for TestHost {
Expand Down Expand Up @@ -97,7 +90,7 @@ extern "C" fn sidevm_ocall(
p2: IntPtr,
p3: IntPtr,
) -> IntRet {
let result = dispatch_call(&mut TestHost, func_id, p0, p1, p2, p3);
let result = dispatch_ocall(false, &mut TestHost, &TestHost, func_id, p0, p1, p2, p3);
println!("sidevm_ocall {} result={:?}", func_id, result);
result.encode_ret()
}
Expand All @@ -111,7 +104,7 @@ extern "C" fn sidevm_ocall_fast_return(
p2: IntPtr,
p3: IntPtr,
) -> IntRet {
let result = dispatch_call_fast_return(&mut TestHost, func_id, p0, p1, p2, p3);
let result = dispatch_ocall(true, &mut TestHost, &TestHost, func_id, p0, p1, p2, p3);
println!("sidevm_ocall_fast_return {} result={:?}", func_id, result);
result.encode_ret()
}
Expand All @@ -137,14 +130,6 @@ fn test_fi_fo() {
assert_eq!(ocall::sub_fi(1, 4).unwrap(), -3);
}

#[test]
fn test_fi_fo_buf() {
let mut a = [0u8; 4];
let b = [1u8, 2, 3, 4];
ocall::copy(&mut a[..], &b[..]).unwrap();
assert_eq!(a, b);
}

#[test]
fn test_fi_fo_overflow() {
let a = u32::MAX;
Expand Down
41 changes: 17 additions & 24 deletions crates/sidevm/host-runtime/src/env.rs
Original file line number Diff line number Diff line change
Expand Up @@ -629,46 +629,39 @@ async fn tcp_connect(host: &str, port: u16) -> std::io::Result<tokio::net::TcpSt
}

fn sidevm_ocall_fast_return(
mut func_env: FunctionEnvMut<Env>,
func_env: FunctionEnvMut<Env>,
task_id: i32,
func_id: i32,
p0: IntPtr,
p1: IntPtr,
p2: IntPtr,
p3: IntPtr,
) -> Result<IntRet, OcallAborted> {
let inner = func_env.data().inner.clone();
let mut guard = inner.lock().unwrap();
let env = &mut *guard;

env.current_task = task_id;
let result = set_task_env(env.awake_tasks.clone(), task_id, || {
let memory = env.memory.unwrap_ref().clone();
let vm = MemoryView(memory.view(&func_env));
let mut state = env.make_mut(&mut func_env);
env::dispatch_call_fast_return(&mut state, &vm, func_id, p0, p1, p2, p3)
});

if env.ocall_trace_enabled {
let func_name = env::ocall_id2name(func_id);
let vm_id = ShortId(&env.id);
log::trace!(
target: "sidevm",
"[{vm_id}][tid={task_id:<3}](F) {func_name}({p0}, {p1}, {p2}, {p3}) = {result:?}"
);
}
convert(result)
do_ocall(func_env, task_id, func_id, p0, p1, p2, p3, true)
}

// Support all ocalls. Put the result into a temporary vec and wait for next fetch_result ocall to fetch the result.
fn sidevm_ocall(
func_env: FunctionEnvMut<Env>,
task_id: i32,
func_id: i32,
p0: IntPtr,
p1: IntPtr,
p2: IntPtr,
p3: IntPtr,
) -> Result<IntRet, OcallAborted> {
do_ocall(func_env, task_id, func_id, p0, p1, p2, p3, false)
}

fn do_ocall(
mut func_env: FunctionEnvMut<Env>,
task_id: i32,
func_id: i32,
p0: IntPtr,
p1: IntPtr,
p2: IntPtr,
p3: IntPtr,
fast_return: bool,
) -> Result<IntRet, OcallAborted> {
let inner = func_env.data().inner.clone();
let mut guard = inner.lock().unwrap();
Expand All @@ -679,15 +672,15 @@ fn sidevm_ocall(
let memory = env.memory.unwrap_ref().clone();
let vm = MemoryView(memory.view(&func_env));
let mut state = env.make_mut(&mut func_env);
env::dispatch_call(&mut state, &vm, func_id, p0, p1, p2, p3)
env::dispatch_ocall(fast_return, &mut state, &vm, func_id, p0, p1, p2, p3)
});

if env.ocall_trace_enabled {
let func_name = env::ocall_id2name(func_id);
let vm_id = ShortId(&env.id);
log::trace!(
target: "sidevm",
"[{vm_id}][tid={task_id:<3}](S) {func_name}({p0}, {p1}, {p2}, {p3}) = {result:?}"
"[{vm_id}][tid={task_id:<3}] {func_name}({p0}, {p1}, {p2}, {p3}) = {result:?}"
);
}
convert(result)
Expand Down
35 changes: 14 additions & 21 deletions crates/sidevm/macro/src/macro_ocall.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,36 +212,29 @@ fn gen_dispatcher(methods: &[OcallMethod], trait_name: &Ident) -> Result<TokenSt
};

Ok(parse_quote! {
pub fn dispatch_call_fast_return<Env: #trait_name + OcallEnv, Vm: VmMemory>(
pub fn dispatch_ocall<Env: #trait_name + OcallEnv, Vm: VmMemory>(
fast_return: bool,
env: &mut Env,
vm: &Vm,
id: i32,
p0: IntPtr,
p1: IntPtr,
p2: IntPtr,
p3: IntPtr
p3: IntPtr,
) -> Result<i32> {
match id {
0 => #call_get_return,
#(#fast_calls)*
_ => Err(OcallError::UnknownCallNumber),
if fast_return {
match id {
0 => #call_get_return,
#(#fast_calls)*
_ => Err(OcallError::UnknownCallNumber),
}
} else {
Ok(match id {
#(#slow_calls)*
_ => return Err(OcallError::UnknownCallNumber),
})
}
}

pub fn dispatch_call<Env: #trait_name + OcallEnv, Vm: VmMemory>(
env: &mut Env,
vm: &Vm,
id: i32,
p0: IntPtr,
p1: IntPtr,
p2: IntPtr,
p3: IntPtr
) -> Result<i32> {
Ok(match id {
#(#slow_calls)*
_ => return Err(OcallError::UnknownCallNumber),
})
}
})
}

Expand Down
102 changes: 50 additions & 52 deletions crates/sidevm/macro/src/snapshots/sidevm_macro__tests__ocall.snap
Original file line number Diff line number Diff line change
Expand Up @@ -72,69 +72,67 @@ pub mod ocall_guest {
}
}
}
pub fn dispatch_call_fast_return<Env: Ocall + OcallEnv + VmMemory>(
pub fn dispatch_ocall<Env: Ocall + OcallEnv, Vm: VmMemory>(
fast_return: bool,
env: &mut Env,
vm: &Vm,
id: i32,
p0: IntPtr,
p1: IntPtr,
p2: IntPtr,
p3: IntPtr,
) -> Result<i32> {
match id {
0 => {
let buffer = env.take_return().ok_or(OcallError::NotFound)?;
let len = p1 as usize;
if buffer.len() != len {
return Err(OcallError::InvalidParameter);
if fast_return {
match id {
0 => {
let buffer = env.take_return().ok_or(OcallError::NotFound)?;
let len = p1 as usize;
if buffer.len() != len {
return Err(OcallError::InvalidParameter);
}
vm.copy_to_vm(&buffer, p0)?;
Ok(len as i32)
}
env.copy_to_vm(&buffer, p0)?;
Ok(len as i32)
}
104 => {
let (a, b) = {
let mut buf = env.slice_from_vm(p0, p1)?;
Decode::decode(&mut buf).or(Err(OcallError::InvalidParameter))?
};
env.call_fo(a, b).map(|x| x.to_i32())
}
102 => {
let stack = StackedArgs::load(&[p0, p1, p2, p3]).ok_or(OcallError::InvalidParameter)?;
let (b, stack) = stack.pop_arg(env)?;
let (a, stack) = stack.pop_arg(env)?;
let _: StackedArgs<()> = stack;
env.poll_fi_fo(a, b).map(|x| x.to_i32())
104 => {
let (a, b) = {
let mut buf = vm.slice_from_vm(p0, p1)?;
Decode::decode(&mut buf).or(Err(OcallError::InvalidParameter))?
};
env.call_fo(a, b).map(|x| x.to_i32())
}
102 => {
let stack =
StackedArgs::load(&[p0, p1, p2, p3]).ok_or(OcallError::InvalidParameter)?;
let (b, stack) = stack.pop_arg(vm)?;
let (a, stack) = stack.pop_arg(vm)?;
let _: StackedArgs<()> = stack;
env.poll_fi_fo(a, b).map(|x| x.to_i32())
}
_ => Err(OcallError::UnknownCallNumber),
}
_ => Err(OcallError::UnknownCallNumber),
} else {
Ok(match id {
101 => {
let (a, b) = {
let mut buf = vm.slice_from_vm(p0, p1)?;
Decode::decode(&mut buf).or(Err(OcallError::InvalidParameter))?
};
let ret = env.call_slow(a, b);
env.put_return(ret?.encode()) as _
}
103 => {
let stack =
StackedArgs::load(&[p0, p1, p2, p3]).ok_or(OcallError::InvalidParameter)?;
let (b, stack) = stack.pop_arg(vm)?;
let (a, stack) = stack.pop_arg(vm)?;
let _: StackedArgs<()> = stack;
let ret = env.call_fi(a, b);
env.put_return(ret?.encode()) as _
}
_ => return Err(OcallError::UnknownCallNumber),
})
}
}
pub fn dispatch_call<Env: Ocall + OcallEnv + VmMemory>(
env: &mut Env,
id: i32,
p0: IntPtr,
p1: IntPtr,
p2: IntPtr,
p3: IntPtr,
) -> Result<i32> {
Ok(match id {
101 => {
let (a, b) = {
let mut buf = env.slice_from_vm(p0, p1)?;
Decode::decode(&mut buf).or(Err(OcallError::InvalidParameter))?
};
let ret = env.call_slow(a, b);
env.put_return(ret?.encode()) as _
}
103 => {
let stack = StackedArgs::load(&[p0, p1, p2, p3]).ok_or(OcallError::InvalidParameter)?;
let (b, stack) = stack.pop_arg(env)?;
let (a, stack) = stack.pop_arg(env)?;
let _: StackedArgs<()> = stack;
let ret = env.call_fi(a, b);
env.put_return(ret?.encode()) as _
}
_ => return Err(OcallError::UnknownCallNumber),
})
}
pub fn ocall_id2name(id: i32) -> &'static str {
match id {
0 => "get_return",
Expand Down

0 comments on commit 52d45ca

Please sign in to comment.