From e87d25f4c57ba7996f0e303e20d094b39140a826 Mon Sep 17 00:00:00 2001 From: Edoardo Vacchi Date: Fri, 30 Aug 2024 19:36:00 +0200 Subject: [PATCH] JNA callback should check inputs and outputs for null. This addresses https://github.com/extism/java-sdk/issues/27 once we merge https://github.com/extism/extism/pull/760 and release a new version of libextism. This PR simply check the parameters for null and the counts to be valid, in preparation for https://github.com/extism/extism/pull/760 which otherwise would cause a NullPointerException when outputs and inputs are empty. Signed-off-by: Edoardo Vacchi --- .../java/org/extism/sdk/HostFunction.java | 67 +++++++----- .../org/extism/sdk/HostFunctionTests.java | 25 +++++ src/test/java/org/extism/sdk/PluginTests.java | 102 +++++++++--------- 3 files changed, 119 insertions(+), 75 deletions(-) create mode 100644 src/test/java/org/extism/sdk/HostFunctionTests.java diff --git a/src/main/java/org/extism/sdk/HostFunction.java b/src/main/java/org/extism/sdk/HostFunction.java index 2c500b1..23d3d08 100644 --- a/src/main/java/org/extism/sdk/HostFunction.java +++ b/src/main/java/org/extism/sdk/HostFunction.java @@ -20,34 +20,12 @@ public class HostFunction { public final LibExtism.ExtismValType[] returns; - public final Optional userData; - public HostFunction(String name, LibExtism.ExtismValType[] params, LibExtism.ExtismValType[] returns, ExtismFunction f, Optional userData) { this.freed = false; this.name = name; this.params = params; this.returns = returns; - this.userData = userData; - this.callback = (Pointer currentPlugin, - LibExtism.ExtismVal inputs, - int nInputs, - LibExtism.ExtismVal outs, - int nOutputs, - Pointer data) -> { - - LibExtism.ExtismVal[] outputs = (LibExtism.ExtismVal[]) outs.toArray(nOutputs); - - f.invoke( - new ExtismCurrentPlugin(currentPlugin), - (LibExtism.ExtismVal[]) inputs.toArray(nInputs), - outputs, - userData - ); - - for (LibExtism.ExtismVal output : outputs) { - convertOutput(output, output); - } - }; + this.callback = new Callback(f, userData); this.pointer = LibExtism.INSTANCE.extism_function_new( this.name, @@ -61,7 +39,7 @@ public HostFunction(String name, LibExtism.ExtismValType[] params, LibExtism.Ext ); } - void convertOutput(LibExtism.ExtismVal original, LibExtism.ExtismVal fromHostFunction) { + static void convertOutput(LibExtism.ExtismVal original, LibExtism.ExtismVal fromHostFunction) { if (fromHostFunction.t != original.t) throw new ExtismException(String.format("Output type mismatch, got %d but expected %d", fromHostFunction.t, original.t)); @@ -103,4 +81,45 @@ public void free() { this.freed = true; } } + + static class Callback implements LibExtism.InternalExtismFunction { + private final ExtismFunction f; + private final Optional userData; + + public Callback(ExtismFunction f, Optional userData) { + this.f = f; + this.userData = userData; + } + + @Override + public void invoke(Pointer currentPlugin, LibExtism.ExtismVal ins, int nInputs, LibExtism.ExtismVal outs, int nOutputs, Pointer data) { + + LibExtism.ExtismVal[] inputs; + LibExtism.ExtismVal[] outputs; + + if (outs == null) { + if (nOutputs > 0) { + throw new ExtismException("Output array is null but nOutputs is greater than 0"); + } + outputs = new LibExtism.ExtismVal[0]; + } else { + outputs = (LibExtism.ExtismVal[]) outs.toArray(nOutputs); + } + + if (ins == null) { + if (nInputs > 0) { + throw new ExtismException("Input array is null but nInputs is greater than 0"); + } + inputs = new LibExtism.ExtismVal[0]; + } else { + inputs = (LibExtism.ExtismVal[]) ins.toArray(nInputs); + } + + f.invoke(new ExtismCurrentPlugin(currentPlugin), inputs, outputs, userData); + + for (LibExtism.ExtismVal output : outputs) { + convertOutput(output, output); + } + } + } } diff --git a/src/test/java/org/extism/sdk/HostFunctionTests.java b/src/test/java/org/extism/sdk/HostFunctionTests.java new file mode 100644 index 0000000..91143d8 --- /dev/null +++ b/src/test/java/org/extism/sdk/HostFunctionTests.java @@ -0,0 +1,25 @@ +package org.extism.sdk; + +import com.sun.jna.Pointer; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class HostFunctionTests { + @Test + public void callbackShouldAcceptNullParameters() { + var callback = new HostFunction.Callback<>( + (plugin, params, returns, userData) -> {/* NOOP */}, null); + callback.invoke(Pointer.NULL, null, 0, null, 0, Pointer.NULL); + } + + @Test + public void callbackShouldThrowOnNullParametersAndNonzeroCounts() { + var callback = new HostFunction.Callback<>( + (plugin, params, returns, userData) -> {/* NOOP */}, null); + assertThrows(ExtismException.class, () -> + callback.invoke(Pointer.NULL, null, 1, null, 0, Pointer.NULL)); + assertThrows(ExtismException.class, () -> + callback.invoke(Pointer.NULL, null, 0, null, 1, Pointer.NULL)); + } +} diff --git a/src/test/java/org/extism/sdk/PluginTests.java b/src/test/java/org/extism/sdk/PluginTests.java index 62da199..c11a78a 100644 --- a/src/test/java/org/extism/sdk/PluginTests.java +++ b/src/test/java/org/extism/sdk/PluginTests.java @@ -53,57 +53,57 @@ public void shouldInvokeFunctionFromUrlWasmSource() { assertThat(output).isEqualTo("{\"count\":4,\"total\":4,\"vowels\":\"aeiouyAEIOUY\"}"); } -// @Test -// public void shouldInvokeFunctionFromUrlWasmSourceHostFuncs() { -// var url = "https://github.com/extism/plugins/releases/latest/download/count_vowels_kvstore.wasm"; -// var manifest = new Manifest(List.of(UrlWasmSource.fromUrl(url))); -// -// // Our application KV store -// // Pretend this is redis or a database :) -// var kvStore = new HashMap(); -// -// ExtismFunction kvWrite = (plugin, params, returns, data) -> { -// System.out.println("Hello from Java Host Function!"); -// var key = plugin.inputString(params[0]); -// var value = plugin.inputBytes(params[1]); -// System.out.println("Writing to key " + key); -// kvStore.put(key, value); -// }; -// -// ExtismFunction kvRead = (plugin, params, returns, data) -> { -// System.out.println("Hello from Java Host Function!"); -// var key = plugin.inputString(params[0]); -// System.out.println("Reading from key " + key); -// var value = kvStore.get(key); -// if (value == null) { -// // default to zeroed bytes -// var zero = new byte[]{0,0,0,0}; -// plugin.returnBytes(returns[0], zero); -// } else { -// plugin.returnBytes(returns[0], value); -// } -// }; -// -// HostFunction kvWriteHostFn = new HostFunction<>( -// "kv_write", -// new LibExtism.ExtismValType[]{LibExtism.ExtismValType.I64, LibExtism.ExtismValType.I64}, -// new LibExtism.ExtismValType[0], -// kvWrite, -// Optional.empty() -// ); -// -// HostFunction kvReadHostFn = new HostFunction<>( -// "kv_read", -// new LibExtism.ExtismValType[]{LibExtism.ExtismValType.I64}, -// new LibExtism.ExtismValType[]{LibExtism.ExtismValType.I64}, -// kvRead, -// Optional.empty() -// ); -// -// HostFunction[] functions = {kvWriteHostFn, kvReadHostFn}; -// var plugin = new Plugin(manifest, false, functions); -// var output = plugin.call("count_vowels", "Hello, World!"); -// } + @Test + public void shouldInvokeFunctionFromUrlWasmSourceHostFuncs() { + var url = "https://github.com/extism/plugins/releases/latest/download/count_vowels_kvstore.wasm"; + var manifest = new Manifest(List.of(UrlWasmSource.fromUrl(url))); + + // Our application KV store + // Pretend this is redis or a database :) + var kvStore = new HashMap(); + + ExtismFunction kvWrite = (plugin, params, returns, data) -> { + System.out.println("Hello from Java Host Function!"); + var key = plugin.inputString(params[0]); + var value = plugin.inputBytes(params[1]); + System.out.println("Writing to key " + key); + kvStore.put(key, value); + }; + + ExtismFunction kvRead = (plugin, params, returns, data) -> { + System.out.println("Hello from Java Host Function!"); + var key = plugin.inputString(params[0]); + System.out.println("Reading from key " + key); + var value = kvStore.get(key); + if (value == null) { + // default to zeroed bytes + var zero = new byte[]{0,0,0,0}; + plugin.returnBytes(returns[0], zero); + } else { + plugin.returnBytes(returns[0], value); + } + }; + + HostFunction kvWriteHostFn = new HostFunction<>( + "kv_write", + new LibExtism.ExtismValType[]{LibExtism.ExtismValType.I64, LibExtism.ExtismValType.I64}, + new LibExtism.ExtismValType[0], + kvWrite, + Optional.empty() + ); + + HostFunction kvReadHostFn = new HostFunction<>( + "kv_read", + new LibExtism.ExtismValType[]{LibExtism.ExtismValType.I64}, + new LibExtism.ExtismValType[]{LibExtism.ExtismValType.I64}, + kvRead, + Optional.empty() + ); + + HostFunction[] functions = {kvWriteHostFn, kvReadHostFn}; + var plugin = new Plugin(manifest, false, functions); + var output = plugin.call("count_vowels", "Hello, World!"); + } @Test public void shouldInvokeFunctionFromByteArrayWasmSource() {