Skip to content

Commit

Permalink
implement cont bind
Browse files Browse the repository at this point in the history
  • Loading branch information
ahuoguo committed Nov 27, 2024
1 parent 2c8121c commit f347473
Show file tree
Hide file tree
Showing 9 changed files with 157 additions and 41 deletions.
12 changes: 12 additions & 0 deletions benchmarks/wasm/wasmfx/cont_bind4.bin.wast
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
(module definition binary
"\00\61\73\6d\01\00\00\00\01\8e\80\80\80\00\04\60"
"\00\01\7f\5d\00\60\01\7f\01\7f\5d\02\03\83\80\80"
"\80\00\02\02\02\07\88\80\80\80\00\01\04\6d\61\69"
"\6e\00\01\09\85\80\80\80\00\01\03\00\01\00\0a\a0"
"\80\80\80\00\02\87\80\80\80\00\00\20\00\41\2c\6a"
"\0b\8e\80\80\80\00\00\20\00\d2\00\e0\03\e1\03\01"
"\e3\01\00\0b"
)
(module instance)
(assert_return (invoke "main" (i32.const 0x16)) (i32.const 0x42))
(assert_return (invoke "main" (i32.const 0xffff_fe44)) (i32.const 0xffff_fe70))
21 changes: 21 additions & 0 deletions benchmarks/wasm/wasmfx/cont_bind4.wast
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
(module
(type $f1 (func (result i32)))
(type $c1 (cont $f1))

(type $f2 (func (param i32) (result i32)))
(type $c2 (cont $f2))

(func $add44 (param i32) (result i32) (i32.add (local.get 0) (i32.const 44)))
(elem declare func $add44)

(func (export "main") (param i32) (result i32)
(resume $c1
(cont.bind $c2 $c1
(local.get 0)
(cont.new $c2 (ref.func $add44))))
)
)

(assert_return (invoke "main" (i32.const 22)) (i32.const 66))
(assert_return (invoke "main" (i32.const -444)) (i32.const -400))

31 changes: 31 additions & 0 deletions benchmarks/wasm/wasmfx/cont_bind5-strip.wast
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
(module
(type (;0;) (func (param i32) (result i32)))
(type (;1;) (cont 0))
(type (;2;) (func (param i32 i32) (result i32)))
(type (;3;) (cont 2))
(type (;4;) (func (param i32)))
(type (;5;) (func))
(import "spectest" "print_i32" (func (;0;) (type 4)))
(export "main" (func 3))
(start 3)
(elem (;0;) declare func 1)
(func (;1;) (type 2) (param i32 i32) (result i32)
local.get 0
local.get 1
i32.sub
)
(func (;2;) (type 2) (param i32 i32) (result i32)
local.get 1
local.get 0
ref.func 1
cont.new 3
cont.bind 3 1
resume 1
)
(func (;3;) (type 5)
i32.const 22
i32.const 44
call 2
call 0
)
)
18 changes: 18 additions & 0 deletions benchmarks/wasm/wasmfx/cont_bind5.bin.wast
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
(module definition binary
"\00\61\73\6d\01\00\00\00\01\90\80\80\80\00\04\60"
"\01\7f\01\7f\5d\00\60\02\7f\7f\01\7f\5d\02\03\83"
"\80\80\80\00\02\02\02\07\88\80\80\80\00\01\04\6d"
"\61\69\6e\00\01\09\85\80\80\80\00\01\03\00\01\00"
"\0a\a2\80\80\80\00\02\87\80\80\80\00\00\20\00\20"
"\01\6b\0b\90\80\80\80\00\00\20\01\20\00\d2\00\e0"
"\03\e1\03\01\e3\01\00\0b"
)
(module instance)
(assert_return
(invoke "main" (i32.const 0x16) (i32.const 0x2c))
(i32.const 0xffff_ffea)
)
(assert_return
(invoke "main" (i32.const 0xffff_fe44) (i32.const 0x6f))
(i32.const 0xffff_fdd5)
)
25 changes: 25 additions & 0 deletions benchmarks/wasm/wasmfx/cont_bind5.wast
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
(module
;; (type $f1 (func (result i32)))
;; (type $c1 (cont $f1))

(type $f2 (func (param i32) (result i32)))
(type $c2 (cont $f2))

(type $f3 (func (param i32 i32) (result i32)))
(type $c3 (cont $f3))

(func $sub (param i32 i32) (result i32) (i32.sub (local.get 0) (local.get 1)))
(elem declare func $sub)

(func (export "main") (param i32 i32) (result i32)
(resume $c2
(local.get 1)
(cont.bind $c3 $c2
(local.get 0)
(cont.new $c3 (ref.func $sub))))
)
)

;; (assert_return (invoke "main" (i32.const 22) (i32.const 44)) (i32.const -22))
;; (assert_return (invoke "main" (i32.const -444) (i32.const 111)) (i32.const -555))

41 changes: 23 additions & 18 deletions src/main/scala/wasm/MiniWasmFX.scala
Original file line number Diff line number Diff line change
Expand Up @@ -263,10 +263,8 @@ case class EvaluatorFX(module: ModuleInstance) {
// add the continuation on the stack

val k = (s: Stack, k1: Cont[Ans], m: MCont[Ans], handler: Handler[Ans]) => {
// the 3 lines below doens't work
// k1 is the default handler
// so it should be konk ++ k1
// val kontK: Cont[Ans] = (s1, m1) => kont(s1, s2 => k1(s2, m1))
// TODO: does the following work?
// val kontK: Cont[Ans] = (s1, m1) => kont(s1, s2 => k1(s2, m1))

// Q: is it okay to forget `k1` and `mkont` here?
// Ans: No! Because the resumable continuation might be install by
Expand All @@ -290,7 +288,7 @@ case class EvaluatorFX(module: ModuleInstance) {

if (handler.length == 0) {
// the metacontinuation contains the default handler
val mk: MCont[Ans] = (s) => eval(rest, s, frame, kont, mkont, trail, h)
val mk: MCont[Ans] = (s) => eval(rest, s, frame, kont, mkont, trail, h)
val emptyK: Cont[Ans] = (s, m) => m(s)
f.k(inputs, emptyK, mk, h)
} else {
Expand All @@ -304,26 +302,33 @@ case class EvaluatorFX(module: ModuleInstance) {

// f might be handled by the default handler (namely kont), or by the
// handler specified by tags (newhandler, which has the same type as meta-continuation)
val mk: MCont[Ans] = (s) => eval(rest, s, frame, kont, mkont, trail, h)
val mk: MCont[Ans] = (s) => eval(rest, s, frame, kont, mkont, trail, h)
val emptyK: Cont[Ans] = (s, m) => m(s)
f.k(inputs, emptyK, mk, newHandler)

}

}

case ContBind(oldContTy, newConTy) =>
// val RefContV(oldContAddr) :: newStack = stack
// // use oldParamTy - newParamTy to get how many values to pop from the stack
// val oldParamTy = module.types(oldContTy).inps
// val newParamTy = module.types(newConTy).inps
// val (inputs, restStack) = newStack.splitAt(oldParamTy.size)
// // partially apply the old continuation
// val oldCont = module.funcs(oldContAddr) match {
// case RefContV(f) => f
// case _ => throw new Exception("Continuation is not a function")
// }
throw new Exception("ContBind unimplemented")
case ContBind(oldContTyId, newConTyId) =>
val (f: ContV[Ans]) :: newStack = stack
// use oldParamTy - newParamTy to get how many values to pop from the stack
val ContType(oldId) = module.types(oldContTyId)
val FuncType(_, oldParamTy, _) = module.types(oldId)
val ContType(newId) = module.types(newConTyId)
val FuncType(_, newParamTy, _) = module.types(newId)

// get oldParamTy - newParamTy (there's no type checking at all)
val inputSize = oldParamTy.size - newParamTy.size

val (inputs, restStack) = newStack.splitAt(inputSize)

// partially apply the old continuation
def kr(s: Stack, k1: Cont[Ans], mk: MCont[Ans], handler: Handler[Ans]): Ans = {
f.k(s ++ inputs, k1, mk, handler)
}

eval(rest, ContV(kr) :: restStack, frame, kont, mkont, trail, h)

case CallRef(ty) =>
val RefFuncV(f) :: newStack = stack
Expand Down
26 changes: 15 additions & 11 deletions src/main/scala/wasm/MiniWasmScript.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,22 @@ sealed class ScriptRunner {

def getInstance(instName: Option[String]): ModuleInstance = {
instName match {
case Some(name) => instanceMap(name)
case None => instances.head
case Some(name) => instanceMap(name)
case None => instances.head
}
}

def assertReturn(action: Action, expect: List[Value]): Unit = {
action match {
case Invoke(instName, name, args) =>
val module = getInstance(instName)
val func = module.exports.collectFirst({
case Export(`name`, ExportFunc(index)) =>
module.funcs(index)
case _ => throw new RuntimeException("Not Supported")
}).get
val func = module.exports
.collectFirst({
case Export(`name`, ExportFunc(index)) =>
module.funcs(index)
case _ => throw new RuntimeException("Not Supported")
})
.get
val instrs = func match {
case FuncDef(_, FuncBodyDef(ty, _, locals, body)) => body
}
Expand All @@ -36,16 +38,18 @@ sealed class ScriptRunner {
val h0: Handler = stack => throw new Exception(s"Uncaught exception: $stack")
// TODO: change this back to Evaluator if we are just testing original stuff
val actual = evaluator.eval(instrs, List(), Frame(ArrayBuffer(args: _*)), k, mk, List(k), h0)
println(s"expect = $expect")
println(s"actual = $actual")
assert(actual == expect)
}
}

def runCmd(cmd: Cmd): Unit = {
cmd match {
case CmdModule(module) => instances += ModuleInstance(module)
case CmdModule(module) => instances += ModuleInstance(module)
case AssertReturn(action, expect) => assertReturn(action, expect)
case CMdInstnace() => ()
case AssertTrap(action, message) => ???
case CMdInstnace() => ()
case AssertTrap(action, message) => ???
}
}

Expand All @@ -54,4 +58,4 @@ sealed class ScriptRunner {
runCmd(cmd)
}
}
}
}
1 change: 0 additions & 1 deletion src/main/scala/wasm/Parser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -778,7 +778,6 @@ class GSWasmVisitor extends WatParserBaseVisitor[WIR] {
val Array(ty, _) = constCtx.CONST.getText.split("\\.")
visitLiteralWithType(constCtx.literal, toNumType(ty))
}
println(s"expect = $expect")
AssertReturn(action, expect.toList)
} else {
throw new RuntimeException("Unsupported")
Expand Down
23 changes: 12 additions & 11 deletions src/test/scala/genwasym/TestFx.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@ class TestFx extends FunSuite {
type MCont = evaluator.MCont[Unit]
val haltK: Cont = (stack, m) => m(stack)
val haltMK: MCont = (stack) => {
//println(s"halt cont: $stack")
// println(s"halt cont: $stack")
expected match {
case ExpInt(e) => assert(stack(0) == I32V(e))
case ExpInt(e) => assert(stack(0) == I32V(e))
case ExpStack(e) => assert(stack == e)
case Ignore => ()
case Ignore => ()
}
}
evaluator.evalTop(haltK, haltMK, main)
Expand Down Expand Up @@ -142,11 +142,6 @@ class TestFx extends FunSuite {

/* REAL WASMFX STUFF */

// TODO: test after implemented cont_bind3
// test("simple script") {
// testWastFile("./benchmarks/wasm/wasmfx/cont_bind3.bin.wast")
// }

test("cont") {
// testFile("./benchmarks/wasm/wasmfx/callcont.wast", None, ExpInt(11))
testWastFile("./benchmarks/wasm/wasmfx/callcont.bin.wast")
Expand All @@ -167,9 +162,7 @@ class TestFx extends FunSuite {

// printing 0 not 1
test("nested suspend") {
testFile("./benchmarks/wasm/wasmfx/nested_suspend-strip.wat")

// testFileOutput("./benchmarks/wasm/wasmfx/nested_suspend-strip.wat", List(0))
testFileOutput("./benchmarks/wasm/wasmfx/nested_suspend-strip.wat", List(0))
}

// going to print 100 to 1 and then print 42
Expand All @@ -181,4 +174,12 @@ class TestFx extends FunSuite {
testFileOutput("./benchmarks/wasm/wasmfx/diff_resume-strip.wat", List(10, 11, 42))
}

test("cont_bind_4") {
testWastFile("./benchmarks/wasm/wasmfx/cont_bind4.bin.wast")
}

test("cont_bind_5") {
testWastFile("./benchmarks/wasm/wasmfx/cont_bind5.bin.wast")
}

}

0 comments on commit f347473

Please sign in to comment.