diff --git a/.vscode/settings.json b/.vscode/settings.json index 5fd6496..336f106 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,6 +1,6 @@ { "editor.formatOnSave": true, - "editor.formatOnSaveMode": "modifications", + "editor.formatOnSaveMode": "file", "editor.codeActionsOnSave": { "source.organizeImports": "explicit", "source.generate.finalModifiers": "explicit", diff --git a/src/main/java/org/itsallcode/luava/LowLevelLua.java b/src/main/java/org/itsallcode/luava/LowLevelLua.java index 0c6ab05..6f4bafb 100644 --- a/src/main/java/org/itsallcode/luava/LowLevelLua.java +++ b/src/main/java/org/itsallcode/luava/LowLevelLua.java @@ -2,17 +2,17 @@ import java.lang.foreign.Arena; import java.lang.foreign.MemorySegment; +import java.util.function.Supplier; import org.itsallcode.luava.ffi.Lua; import org.itsallcode.luava.ffi.lua_KFunction; -import org.itsallcode.luava.ffi.lua_KFunction.Function; class LowLevelLua implements AutoCloseable { private final Arena arena; - private final MemorySegment state; + final MemorySegment state; private final LuaStack stack; - private LowLevelLua(final Arena arena, final MemorySegment state) { + LowLevelLua(final Arena arena, final MemorySegment state) { this.arena = arena; this.state = state; this.stack = new LuaStack(state, arena); @@ -24,34 +24,57 @@ static LowLevelLua create() { return new LowLevelLua(arena, state); } + LowLevelLua forState(final MemorySegment newState) { + return new LowLevelLua(arena, newState); + } + void openLibs() { Lua.luaL_openlibs(state); } - void pcall(final int nargs, final int nresults, final int errfunc, final long ctx) { - final Function function = (final MemorySegment l, final int status, final long ctx1) -> { - return 0; - }; - pcall(nargs, nresults, errfunc, ctx, function); + void pcall(final int nargs, final int nresults) { + pcallk(nargs, nresults, 0, 0, null); } - void pcall(final int nargs, final int nresults, final int errfunc, final long ctx, + void pcall(final int nargs, final int nresults, final int errfunc) { + pcallk(nargs, nresults, errfunc, 0, null); + } + + /** + * This function behaves exactly like {@link #pcall(int, int, int, long)}, + * except that it allows the + * called function to yield. + * + * @param nargs + * @param nresults + * @param msgHandler + * @param ctx + * @param upcallFunction + */ + void pcallk(final int nargs, final int nresults, final int msgHandler, final long ctx, final lua_KFunction.Function upcallFunction) { - final MemorySegment k = lua_KFunction.allocate(upcallFunction, arena); - final int error = Lua.lua_pcallk(state, nargs, nresults, errfunc, ctx, k); - if (error != 0) { - final String message = stack.toString(-1); - stack.pop(1); - throw new FunctionCallException("lua_pcallk", error, message); - } + final MemorySegment k = upcallFunction == null ? Lua.NULL() : lua_KFunction.allocate(upcallFunction, arena); + checkStatus("lua_pcallk", () -> { + final int status = Lua.lua_pcallk(state, nargs, nresults, msgHandler, ctx, k); + if (msgHandler != 0) { + this.stack.pop(1); + } + return status; + }); } void loadString(final String chunk) { - final int error = Lua.luaL_loadstring(state, arena.allocateFrom(chunk)); - if (error != 0) { + checkStatus("luaL_loadstring", () -> Lua.luaL_loadstring(state, arena.allocateFrom(chunk))); + } + + void checkStatus(final String functionName, final Supplier nativeFunctionCall) { + final int status = nativeFunctionCall.get(); + if (status != Lua.LUA_OK()) { + System.out.println(stack.printStack()); + System.out.println("Getting error message.."); final String message = stack.toString(-1); stack.pop(1); - throw new FunctionCallException("luaL_loadstring", error, message); + throw new FunctionCallException(functionName, status, message); } } @@ -72,9 +95,9 @@ LuaTable table(final int idx) { return new LuaTable(state, stack, arena, idx); } - public LuaFunction function(final int idx) { + public LuaFunction function(final int idx, final Integer errorHandlerIdx) { assertType(idx, LuaType.FUNCTION); - return new LuaFunction(state, stack, arena, idx); + return new LuaFunction(this, arena, idx, errorHandlerIdx); } private void assertType(final int idx, final LuaType expectedType) { diff --git a/src/main/java/org/itsallcode/luava/LuaFunction.java b/src/main/java/org/itsallcode/luava/LuaFunction.java index 30df166..a0a1dcb 100644 --- a/src/main/java/org/itsallcode/luava/LuaFunction.java +++ b/src/main/java/org/itsallcode/luava/LuaFunction.java @@ -1,44 +1,32 @@ package org.itsallcode.luava; import java.lang.foreign.Arena; -import java.lang.foreign.MemorySegment; - -import org.itsallcode.luava.ffi.Lua; public class LuaFunction { - private final MemorySegment state; - private final LuaStack stack; + private final LowLevelLua lowLevelLua; private final Arena arena; private final int idx; + private final int errorHandlerIdx; - LuaFunction(final MemorySegment state, final LuaStack stack, final Arena arena, final int idx) { - this.state = state; - this.stack = stack; + LuaFunction(final LowLevelLua lowLevelLua, final Arena arena, final int idx, final int errorHandlerIdx) { + this.lowLevelLua = lowLevelLua; this.arena = arena; this.idx = idx; + this.errorHandlerIdx = errorHandlerIdx; } public void addArgInteger(final int value) { - stack.pushInteger(value); + lowLevelLua.stack().pushInteger(value); } public void call(final int nargs, final int nresults) { - call(nargs, nresults, 0, 0, Lua.NULL()); - } - - private void call(final int nargs, final int nresults, final int errfunc, final long ctx, final MemorySegment k) { - final int status = Lua.lua_pcallk(state, nargs, nresults, errfunc, ctx, k); - if (status != Lua.LUA_OK()) { - final String errorMessage = stack.toString(-1); - stack.pop(1); - throw new FunctionCallException("lua_pcallk", status, errorMessage); - } + lowLevelLua.pcall(nargs, nresults, errorHandlerIdx); } public long getIntegerResult() { - final long value = stack.toInteger(-1); - stack.pop(1); + final long value = lowLevelLua.stack().toInteger(-1); + lowLevelLua.stack().pop(1); return value; } } diff --git a/src/main/java/org/itsallcode/luava/LuaInterpreter.java b/src/main/java/org/itsallcode/luava/LuaInterpreter.java index 2d2476c..2f3701d 100644 --- a/src/main/java/org/itsallcode/luava/LuaInterpreter.java +++ b/src/main/java/org/itsallcode/luava/LuaInterpreter.java @@ -1,10 +1,13 @@ package org.itsallcode.luava; +import java.lang.foreign.MemorySegment; +import java.util.function.Function; + public class LuaInterpreter implements AutoCloseable { private final LowLevelLua lua; - public LuaInterpreter(final LowLevelLua lua) { + private LuaInterpreter(final LowLevelLua lua) { this.lua = lua; } @@ -18,6 +21,10 @@ public void close() { this.lua.close(); } + LuaStack stack() { + return lua.stack(); + } + public String getGlobalString(final String name) { lua.getGlobal(name); final String value = lua.stack().toString(-1); @@ -52,8 +59,19 @@ public LuaTable getGlobalTable(final String name) { } public LuaFunction getGlobalFunction(final String name) { + return this.getGlobalFunction(name, null); + } + + public LuaFunction getGlobalFunction(final String name, final Function messageHandler) { + int errorHandlerIdx = 0; + if (messageHandler != null) { + lua.stack().pushCFunction((final MemorySegment newState) -> { + return messageHandler.apply(new LuaInterpreter(this.lua.forState(newState))); + }); + errorHandlerIdx = lua.stack().getTop(); + } lua.getGlobal(name); - return lua.function(-1); + return lua.function(-1, errorHandlerIdx); } public void setGlobalString(final String name, final String value) { @@ -78,7 +96,6 @@ public void setGlobalBoolean(final String name, final boolean value) { public void exec(final String chunk) { lua.loadString(chunk); - lua.pcall(0, 0, 0, 0); + lua.pcall(0, 0); } - } diff --git a/src/main/java/org/itsallcode/luava/LuaStack.java b/src/main/java/org/itsallcode/luava/LuaStack.java index 04c9764..0925f93 100644 --- a/src/main/java/org/itsallcode/luava/LuaStack.java +++ b/src/main/java/org/itsallcode/luava/LuaStack.java @@ -5,6 +5,7 @@ import java.nio.charset.StandardCharsets; import org.itsallcode.luava.ffi.Lua; +import org.itsallcode.luava.ffi.lua_CFunction; class LuaStack { private final MemorySegment state; @@ -53,12 +54,28 @@ void pushString(final String value) { Lua.lua_pushlstring(state, segment, segment.byteSize() - 1); } + void pushCFunction(final lua_CFunction.Function fn) { + pushClosure(fn, 0); + } + + void pushClosure(final lua_CFunction.Function fn, final int n) { + final MemorySegment functionSegment = lua_CFunction.allocate(fn, arena); + Lua.lua_pushcclosure(state, functionSegment, n); + } + void pop(final int n) { + final int size = getTop(); + if (n > size) { + throw new IllegalStateException("Trying to pop " + n + " elements but stack has size " + size); + } + System.out.println("Popping " + n + " elements from stack with size " + size); setTop(-n - 1); } void setTop(final int n) { + System.out.println(this.printStack()); Lua.lua_settop(state, n); + System.out.println(this.printStack()); } boolean toBoolean(final int idx) { @@ -102,4 +119,32 @@ long toInteger(final int idx) { int getTop() { return Lua.lua_gettop(state); } + + String printStack() { + final StringBuilder b = new StringBuilder(); + final int top = this.getTop(); + b.append("Stack size: " + top + ": "); + for (int idx = 1; idx <= top; idx++) { + b.append(format(idx)); + if (idx < top) { + b.append(", "); + } + } + return b.toString(); + } + + private String format(final int idx) { + final LuaType type = getType(idx); + final String result = "#" + idx + " " + type; + final String value = switch (type) { + case STRING -> toString(idx); + case NUMBER -> String.valueOf(toNumber(idx)); + case BOOLEAN -> String.valueOf(toBoolean(idx)); + default -> ""; + }; + if (!value.isEmpty()) { + return result + ": " + value; + } + return result; + } } diff --git a/src/test/java/org/itsallcode/luava/LowLevelLuaTest.java b/src/test/java/org/itsallcode/luava/LowLevelLuaTest.java new file mode 100644 index 0000000..a533860 --- /dev/null +++ b/src/test/java/org/itsallcode/luava/LowLevelLuaTest.java @@ -0,0 +1,22 @@ +package org.itsallcode.luava; + +import org.junit.jupiter.api.*; + +class LowLevelLuaTest { + private LowLevelLua lua; + + @BeforeEach + void setup() { + lua = LowLevelLua.create(); + } + + @AfterEach + void stop() { + lua.close(); + } + + @Test + void functionWithErrorLeavesStackEmpty() { + + } +} diff --git a/src/test/java/org/itsallcode/luava/LuaInterpreterTest.java b/src/test/java/org/itsallcode/luava/LuaInterpreterTest.java index 25ccfa9..c05616c 100644 --- a/src/test/java/org/itsallcode/luava/LuaInterpreterTest.java +++ b/src/test/java/org/itsallcode/luava/LuaInterpreterTest.java @@ -92,22 +92,55 @@ void setGetGlobalBoolean(final boolean value) { @Test void getCallGlobalFunction() { lua.exec("function increment(x) return x+1 end"); + assertStackSize(0); final LuaFunction function = lua.getGlobalFunction("increment"); function.addArgInteger(42); function.call(1, 1); assertThat(function.getIntegerResult(), equalTo(43L)); + assertStackSize(0); } @Test void getCallGlobalFunctionFails() { lua.exec("function increment(x) error('failure') end"); + assertStackSize(0); final LuaFunction function = lua.getGlobalFunction("increment"); function.addArgInteger(42); final LuaException exception = assertThrows(LuaException.class, () -> function.call(1, 1)); + assertStackSize(0); assertThat(exception.getMessage(), equalTo( "Function 'lua_pcallk' failed with error 2: [string \"function increment(x) error('failure') end\"]:1: failure")); } + @Test + void getCallGlobalFunctionWithMessageHandler() { + System.out.println(lua.stack().printStack()); + lua.exec("function increment(x) error('failure') end"); + System.out.println(lua.stack().printStack()); + assertStackSize(0); + final LuaFunction function = lua.getGlobalFunction("increment", (final LuaInterpreter l) -> { + System.out.println(l.stack().printStack()); + final String msg = l.stack().toString(-1); + // l.stack().pop(1); + System.out.println("Updated error: " + msg); + l.stack().pushString("Updated error: " + msg); + System.out.println(l.stack().printStack()); + return 1; + }); + System.out.println("before add arg" + lua.stack().printStack()); + function.addArgInteger(42); + System.out.println("after add arg" + lua.stack().printStack()); + final FunctionCallException exception = assertThrows(FunctionCallException.class, () -> function.call(1, 1)); + System.out.println("after call" + lua.stack().printStack()); + assertThat(exception.getMessage(), equalTo( + "Function 'lua_pcallk' failed with error 2: [string \"function increment(x) error('failure') end\"]:1: failure")); + assertStackSize(0); + } + + void assertStackSize(final int expectedSize) { + assertThat("stack size", lua.stack().getTop(), equalTo(expectedSize)); + } + @Test void getTableStringValue() { lua.exec("result = { key = 'value' }"); @@ -159,7 +192,6 @@ void getTableStringFailsWrongType() { assertThat(exception.getMessage(), equalTo("Expected TABLE at -1 but got STRING")); } - void assertFails(final Executable executable, final String expectedErrorMessage) { final FunctionCallException exception = assertThrows(FunctionCallException.class, executable); assertThat(exception.getRootError(), equalTo(expectedErrorMessage));