From f347473ce08a1ecc544aa3d000e614cecd108263 Mon Sep 17 00:00:00 2001 From: ahuoguo Date: Wed, 27 Nov 2024 13:55:07 -0500 Subject: [PATCH] implement cont bind --- benchmarks/wasm/wasmfx/cont_bind4.bin.wast | 12 ++++++ benchmarks/wasm/wasmfx/cont_bind4.wast | 21 ++++++++++ benchmarks/wasm/wasmfx/cont_bind5-strip.wast | 31 +++++++++++++++ benchmarks/wasm/wasmfx/cont_bind5.bin.wast | 18 +++++++++ benchmarks/wasm/wasmfx/cont_bind5.wast | 25 ++++++++++++ src/main/scala/wasm/MiniWasmFX.scala | 41 +++++++++++--------- src/main/scala/wasm/MiniWasmScript.scala | 26 +++++++------ src/main/scala/wasm/Parser.scala | 1 - src/test/scala/genwasym/TestFx.scala | 23 +++++------ 9 files changed, 157 insertions(+), 41 deletions(-) create mode 100644 benchmarks/wasm/wasmfx/cont_bind4.bin.wast create mode 100644 benchmarks/wasm/wasmfx/cont_bind4.wast create mode 100644 benchmarks/wasm/wasmfx/cont_bind5-strip.wast create mode 100644 benchmarks/wasm/wasmfx/cont_bind5.bin.wast create mode 100644 benchmarks/wasm/wasmfx/cont_bind5.wast diff --git a/benchmarks/wasm/wasmfx/cont_bind4.bin.wast b/benchmarks/wasm/wasmfx/cont_bind4.bin.wast new file mode 100644 index 00000000..3c936870 --- /dev/null +++ b/benchmarks/wasm/wasmfx/cont_bind4.bin.wast @@ -0,0 +1,12 @@ +(module definition binary + "\00\61\73\6d\01\00\00\00\01\8e\80\80\80\00\04\60" + "\00\01\7f\5d\00\60\01\7f\01\7f\5d\02\03\83\80\80" + "\80\00\02\02\02\07\88\80\80\80\00\01\04\6d\61\69" + "\6e\00\01\09\85\80\80\80\00\01\03\00\01\00\0a\a0" + "\80\80\80\00\02\87\80\80\80\00\00\20\00\41\2c\6a" + "\0b\8e\80\80\80\00\00\20\00\d2\00\e0\03\e1\03\01" + "\e3\01\00\0b" +) +(module instance) +(assert_return (invoke "main" (i32.const 0x16)) (i32.const 0x42)) +(assert_return (invoke "main" (i32.const 0xffff_fe44)) (i32.const 0xffff_fe70)) diff --git a/benchmarks/wasm/wasmfx/cont_bind4.wast b/benchmarks/wasm/wasmfx/cont_bind4.wast new file mode 100644 index 00000000..214a287d --- /dev/null +++ b/benchmarks/wasm/wasmfx/cont_bind4.wast @@ -0,0 +1,21 @@ +(module + (type $f1 (func (result i32))) + (type $c1 (cont $f1)) + + (type $f2 (func (param i32) (result i32))) + (type $c2 (cont $f2)) + + (func $add44 (param i32) (result i32) (i32.add (local.get 0) (i32.const 44))) + (elem declare func $add44) + + (func (export "main") (param i32) (result i32) + (resume $c1 + (cont.bind $c2 $c1 + (local.get 0) + (cont.new $c2 (ref.func $add44)))) + ) +) + +(assert_return (invoke "main" (i32.const 22)) (i32.const 66)) +(assert_return (invoke "main" (i32.const -444)) (i32.const -400)) + diff --git a/benchmarks/wasm/wasmfx/cont_bind5-strip.wast b/benchmarks/wasm/wasmfx/cont_bind5-strip.wast new file mode 100644 index 00000000..4e56e37a --- /dev/null +++ b/benchmarks/wasm/wasmfx/cont_bind5-strip.wast @@ -0,0 +1,31 @@ +(module + (type (;0;) (func (param i32) (result i32))) + (type (;1;) (cont 0)) + (type (;2;) (func (param i32 i32) (result i32))) + (type (;3;) (cont 2)) + (type (;4;) (func (param i32))) + (type (;5;) (func)) + (import "spectest" "print_i32" (func (;0;) (type 4))) + (export "main" (func 3)) + (start 3) + (elem (;0;) declare func 1) + (func (;1;) (type 2) (param i32 i32) (result i32) + local.get 0 + local.get 1 + i32.sub + ) + (func (;2;) (type 2) (param i32 i32) (result i32) + local.get 1 + local.get 0 + ref.func 1 + cont.new 3 + cont.bind 3 1 + resume 1 + ) + (func (;3;) (type 5) + i32.const 22 + i32.const 44 + call 2 + call 0 + ) +) diff --git a/benchmarks/wasm/wasmfx/cont_bind5.bin.wast b/benchmarks/wasm/wasmfx/cont_bind5.bin.wast new file mode 100644 index 00000000..3f710cae --- /dev/null +++ b/benchmarks/wasm/wasmfx/cont_bind5.bin.wast @@ -0,0 +1,18 @@ +(module definition binary + "\00\61\73\6d\01\00\00\00\01\90\80\80\80\00\04\60" + "\01\7f\01\7f\5d\00\60\02\7f\7f\01\7f\5d\02\03\83" + "\80\80\80\00\02\02\02\07\88\80\80\80\00\01\04\6d" + "\61\69\6e\00\01\09\85\80\80\80\00\01\03\00\01\00" + "\0a\a2\80\80\80\00\02\87\80\80\80\00\00\20\00\20" + "\01\6b\0b\90\80\80\80\00\00\20\01\20\00\d2\00\e0" + "\03\e1\03\01\e3\01\00\0b" +) +(module instance) +(assert_return + (invoke "main" (i32.const 0x16) (i32.const 0x2c)) + (i32.const 0xffff_ffea) +) +(assert_return + (invoke "main" (i32.const 0xffff_fe44) (i32.const 0x6f)) + (i32.const 0xffff_fdd5) +) diff --git a/benchmarks/wasm/wasmfx/cont_bind5.wast b/benchmarks/wasm/wasmfx/cont_bind5.wast new file mode 100644 index 00000000..c0034207 --- /dev/null +++ b/benchmarks/wasm/wasmfx/cont_bind5.wast @@ -0,0 +1,25 @@ +(module +;; (type $f1 (func (result i32))) +;; (type $c1 (cont $f1)) + + (type $f2 (func (param i32) (result i32))) + (type $c2 (cont $f2)) + + (type $f3 (func (param i32 i32) (result i32))) + (type $c3 (cont $f3)) + + (func $sub (param i32 i32) (result i32) (i32.sub (local.get 0) (local.get 1))) + (elem declare func $sub) + + (func (export "main") (param i32 i32) (result i32) + (resume $c2 + (local.get 1) + (cont.bind $c3 $c2 + (local.get 0) + (cont.new $c3 (ref.func $sub)))) + ) +) + +;; (assert_return (invoke "main" (i32.const 22) (i32.const 44)) (i32.const -22)) +;; (assert_return (invoke "main" (i32.const -444) (i32.const 111)) (i32.const -555)) + diff --git a/src/main/scala/wasm/MiniWasmFX.scala b/src/main/scala/wasm/MiniWasmFX.scala index 06432077..af7e47c6 100644 --- a/src/main/scala/wasm/MiniWasmFX.scala +++ b/src/main/scala/wasm/MiniWasmFX.scala @@ -263,10 +263,8 @@ case class EvaluatorFX(module: ModuleInstance) { // add the continuation on the stack val k = (s: Stack, k1: Cont[Ans], m: MCont[Ans], handler: Handler[Ans]) => { - // the 3 lines below doens't work - // k1 is the default handler - // so it should be konk ++ k1 - // val kontK: Cont[Ans] = (s1, m1) => kont(s1, s2 => k1(s2, m1)) + // TODO: does the following work? + // val kontK: Cont[Ans] = (s1, m1) => kont(s1, s2 => k1(s2, m1)) // Q: is it okay to forget `k1` and `mkont` here? // Ans: No! Because the resumable continuation might be install by @@ -290,7 +288,7 @@ case class EvaluatorFX(module: ModuleInstance) { if (handler.length == 0) { // the metacontinuation contains the default handler - val mk: MCont[Ans] = (s) => eval(rest, s, frame, kont, mkont, trail, h) + val mk: MCont[Ans] = (s) => eval(rest, s, frame, kont, mkont, trail, h) val emptyK: Cont[Ans] = (s, m) => m(s) f.k(inputs, emptyK, mk, h) } else { @@ -304,7 +302,7 @@ case class EvaluatorFX(module: ModuleInstance) { // f might be handled by the default handler (namely kont), or by the // handler specified by tags (newhandler, which has the same type as meta-continuation) - val mk: MCont[Ans] = (s) => eval(rest, s, frame, kont, mkont, trail, h) + val mk: MCont[Ans] = (s) => eval(rest, s, frame, kont, mkont, trail, h) val emptyK: Cont[Ans] = (s, m) => m(s) f.k(inputs, emptyK, mk, newHandler) @@ -312,18 +310,25 @@ case class EvaluatorFX(module: ModuleInstance) { } - case ContBind(oldContTy, newConTy) => - // val RefContV(oldContAddr) :: newStack = stack - // // use oldParamTy - newParamTy to get how many values to pop from the stack - // val oldParamTy = module.types(oldContTy).inps - // val newParamTy = module.types(newConTy).inps - // val (inputs, restStack) = newStack.splitAt(oldParamTy.size) - // // partially apply the old continuation - // val oldCont = module.funcs(oldContAddr) match { - // case RefContV(f) => f - // case _ => throw new Exception("Continuation is not a function") - // } - throw new Exception("ContBind unimplemented") + case ContBind(oldContTyId, newConTyId) => + val (f: ContV[Ans]) :: newStack = stack + // use oldParamTy - newParamTy to get how many values to pop from the stack + val ContType(oldId) = module.types(oldContTyId) + val FuncType(_, oldParamTy, _) = module.types(oldId) + val ContType(newId) = module.types(newConTyId) + val FuncType(_, newParamTy, _) = module.types(newId) + + // get oldParamTy - newParamTy (there's no type checking at all) + val inputSize = oldParamTy.size - newParamTy.size + + val (inputs, restStack) = newStack.splitAt(inputSize) + + // partially apply the old continuation + def kr(s: Stack, k1: Cont[Ans], mk: MCont[Ans], handler: Handler[Ans]): Ans = { + f.k(s ++ inputs, k1, mk, handler) + } + + eval(rest, ContV(kr) :: restStack, frame, kont, mkont, trail, h) case CallRef(ty) => val RefFuncV(f) :: newStack = stack diff --git a/src/main/scala/wasm/MiniWasmScript.scala b/src/main/scala/wasm/MiniWasmScript.scala index 01d3a469..010ea5f4 100644 --- a/src/main/scala/wasm/MiniWasmScript.scala +++ b/src/main/scala/wasm/MiniWasmScript.scala @@ -10,8 +10,8 @@ sealed class ScriptRunner { def getInstance(instName: Option[String]): ModuleInstance = { instName match { - case Some(name) => instanceMap(name) - case None => instances.head + case Some(name) => instanceMap(name) + case None => instances.head } } @@ -19,11 +19,13 @@ sealed class ScriptRunner { action match { case Invoke(instName, name, args) => val module = getInstance(instName) - val func = module.exports.collectFirst({ - case Export(`name`, ExportFunc(index)) => - module.funcs(index) - case _ => throw new RuntimeException("Not Supported") - }).get + val func = module.exports + .collectFirst({ + case Export(`name`, ExportFunc(index)) => + module.funcs(index) + case _ => throw new RuntimeException("Not Supported") + }) + .get val instrs = func match { case FuncDef(_, FuncBodyDef(ty, _, locals, body)) => body } @@ -36,16 +38,18 @@ sealed class ScriptRunner { val h0: Handler = stack => throw new Exception(s"Uncaught exception: $stack") // TODO: change this back to Evaluator if we are just testing original stuff val actual = evaluator.eval(instrs, List(), Frame(ArrayBuffer(args: _*)), k, mk, List(k), h0) + println(s"expect = $expect") + println(s"actual = $actual") assert(actual == expect) } } def runCmd(cmd: Cmd): Unit = { cmd match { - case CmdModule(module) => instances += ModuleInstance(module) + case CmdModule(module) => instances += ModuleInstance(module) case AssertReturn(action, expect) => assertReturn(action, expect) - case CMdInstnace() => () - case AssertTrap(action, message) => ??? + case CMdInstnace() => () + case AssertTrap(action, message) => ??? } } @@ -54,4 +58,4 @@ sealed class ScriptRunner { runCmd(cmd) } } -} \ No newline at end of file +} diff --git a/src/main/scala/wasm/Parser.scala b/src/main/scala/wasm/Parser.scala index 7dc93441..5dbed47d 100644 --- a/src/main/scala/wasm/Parser.scala +++ b/src/main/scala/wasm/Parser.scala @@ -778,7 +778,6 @@ class GSWasmVisitor extends WatParserBaseVisitor[WIR] { val Array(ty, _) = constCtx.CONST.getText.split("\\.") visitLiteralWithType(constCtx.literal, toNumType(ty)) } - println(s"expect = $expect") AssertReturn(action, expect.toList) } else { throw new RuntimeException("Unsupported") diff --git a/src/test/scala/genwasym/TestFx.scala b/src/test/scala/genwasym/TestFx.scala index 0abdb22f..9e6d72e7 100644 --- a/src/test/scala/genwasym/TestFx.scala +++ b/src/test/scala/genwasym/TestFx.scala @@ -30,11 +30,11 @@ class TestFx extends FunSuite { type MCont = evaluator.MCont[Unit] val haltK: Cont = (stack, m) => m(stack) val haltMK: MCont = (stack) => { - //println(s"halt cont: $stack") + // println(s"halt cont: $stack") expected match { - case ExpInt(e) => assert(stack(0) == I32V(e)) + case ExpInt(e) => assert(stack(0) == I32V(e)) case ExpStack(e) => assert(stack == e) - case Ignore => () + case Ignore => () } } evaluator.evalTop(haltK, haltMK, main) @@ -142,11 +142,6 @@ class TestFx extends FunSuite { /* REAL WASMFX STUFF */ - // TODO: test after implemented cont_bind3 - // test("simple script") { - // testWastFile("./benchmarks/wasm/wasmfx/cont_bind3.bin.wast") - // } - test("cont") { // testFile("./benchmarks/wasm/wasmfx/callcont.wast", None, ExpInt(11)) testWastFile("./benchmarks/wasm/wasmfx/callcont.bin.wast") @@ -167,9 +162,7 @@ class TestFx extends FunSuite { // printing 0 not 1 test("nested suspend") { - testFile("./benchmarks/wasm/wasmfx/nested_suspend-strip.wat") - - // testFileOutput("./benchmarks/wasm/wasmfx/nested_suspend-strip.wat", List(0)) + testFileOutput("./benchmarks/wasm/wasmfx/nested_suspend-strip.wat", List(0)) } // going to print 100 to 1 and then print 42 @@ -181,4 +174,12 @@ class TestFx extends FunSuite { testFileOutput("./benchmarks/wasm/wasmfx/diff_resume-strip.wat", List(10, 11, 42)) } + test("cont_bind_4") { + testWastFile("./benchmarks/wasm/wasmfx/cont_bind4.bin.wast") + } + + test("cont_bind_5") { + testWastFile("./benchmarks/wasm/wasmfx/cont_bind5.bin.wast") + } + }