Skip to content

Commit

Permalink
JNA callback should check inputs and outputs for null.
Browse files Browse the repository at this point in the history
This addresses #27
once we merge extism/extism#760
and release a new version of libextism.

This PR simply check the parameters for null and the counts
to be valid, in preparation for extism/extism#760
which otherwise would cause a NullPointerException when
outputs and inputs are empty.

Signed-off-by: Edoardo Vacchi <[email protected]>
  • Loading branch information
evacchi committed Aug 30, 2024
1 parent 1448042 commit 16a9d1e
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 75 deletions.
67 changes: 43 additions & 24 deletions src/main/java/org/extism/sdk/HostFunction.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,34 +20,12 @@ public class HostFunction<T extends HostUserData> {

public final LibExtism.ExtismValType[] returns;

public final Optional<T> userData;

public HostFunction(String name, LibExtism.ExtismValType[] params, LibExtism.ExtismValType[] returns, ExtismFunction f, Optional<T> userData) {
this.freed = false;
this.name = name;
this.params = params;
this.returns = returns;
this.userData = userData;
this.callback = (Pointer currentPlugin,
LibExtism.ExtismVal inputs,
int nInputs,
LibExtism.ExtismVal outs,
int nOutputs,
Pointer data) -> {

LibExtism.ExtismVal[] outputs = (LibExtism.ExtismVal[]) outs.toArray(nOutputs);

f.invoke(
new ExtismCurrentPlugin(currentPlugin),
(LibExtism.ExtismVal[]) inputs.toArray(nInputs),
outputs,
userData
);

for (LibExtism.ExtismVal output : outputs) {
convertOutput(output, output);
}
};
this.callback = new Callback(f, userData);

this.pointer = LibExtism.INSTANCE.extism_function_new(
this.name,
Expand All @@ -61,7 +39,7 @@ public HostFunction(String name, LibExtism.ExtismValType[] params, LibExtism.Ext
);
}

void convertOutput(LibExtism.ExtismVal original, LibExtism.ExtismVal fromHostFunction) {
static void convertOutput(LibExtism.ExtismVal original, LibExtism.ExtismVal fromHostFunction) {
if (fromHostFunction.t != original.t)
throw new ExtismException(String.format("Output type mismatch, got %d but expected %d", fromHostFunction.t, original.t));

Expand Down Expand Up @@ -103,4 +81,45 @@ public void free() {
this.freed = true;
}
}

static class Callback<T> implements LibExtism.InternalExtismFunction {
private final ExtismFunction f;
private final Optional<T> userData;

public Callback(ExtismFunction f, Optional<T> userData) {
this.f = f;
this.userData = userData;
}

@Override
public void invoke(Pointer currentPlugin, LibExtism.ExtismVal ins, int nInputs, LibExtism.ExtismVal outs, int nOutputs, Pointer data) {

LibExtism.ExtismVal[] inputs;
LibExtism.ExtismVal[] outputs;

if (outs == null) {
if (nOutputs > 0) {
throw new ExtismException("Output array is null but nOutputs is greater than 0");
}
outputs = new LibExtism.ExtismVal[0];
} else {
outputs = (LibExtism.ExtismVal[]) outs.toArray(nOutputs);
}

if (ins == null) {
if (nInputs > 0) {
throw new ExtismException("Input array is null but nInputs is greater than 0");
}
inputs = new LibExtism.ExtismVal[0];
} else {
inputs = (LibExtism.ExtismVal[]) ins.toArray(nInputs);
}

f.invoke(new ExtismCurrentPlugin(currentPlugin), inputs, outputs, userData);

for (LibExtism.ExtismVal output : outputs) {
convertOutput(output, output);
}
}
}
}
25 changes: 25 additions & 0 deletions src/test/java/org/extism/sdk/HostFunctionTests.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package org.extism.sdk;

import com.sun.jna.Pointer;
import org.junit.jupiter.api.Test;

import static org.junit.jupiter.api.Assertions.assertThrows;

public class HostFunctionTests {
@Test
public void callbackShouldAcceptNullParameters() {
var callback = new HostFunction.Callback<>(
(plugin, params, returns, userData) -> {/* NOOP */}, null);
callback.invoke(Pointer.NULL, null, 0, null, 0, Pointer.NULL);
}

@Test
public void callbackShouldThrowOnNullParametersAndNonzeroCounts() {
var callback = new HostFunction.Callback<>(
(plugin, params, returns, userData) -> {/* NOOP */}, null);
assertThrows(ExtismException.class, () ->
callback.invoke(Pointer.NULL, null, 1, null, 0, Pointer.NULL));
assertThrows(ExtismException.class, () ->
callback.invoke(Pointer.NULL, null, 0, null, 1, Pointer.NULL));
}
}
102 changes: 51 additions & 51 deletions src/test/java/org/extism/sdk/PluginTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -53,57 +53,57 @@ public void shouldInvokeFunctionFromUrlWasmSource() {
assertThat(output).isEqualTo("{\"count\":4,\"total\":4,\"vowels\":\"aeiouyAEIOUY\"}");
}

// @Test
// public void shouldInvokeFunctionFromUrlWasmSourceHostFuncs() {
// var url = "https://github.com/extism/plugins/releases/latest/download/count_vowels_kvstore.wasm";
// var manifest = new Manifest(List.of(UrlWasmSource.fromUrl(url)));
//
// // Our application KV store
// // Pretend this is redis or a database :)
// var kvStore = new HashMap<String, byte[]>();
//
// ExtismFunction kvWrite = (plugin, params, returns, data) -> {
// System.out.println("Hello from Java Host Function!");
// var key = plugin.inputString(params[0]);
// var value = plugin.inputBytes(params[1]);
// System.out.println("Writing to key " + key);
// kvStore.put(key, value);
// };
//
// ExtismFunction kvRead = (plugin, params, returns, data) -> {
// System.out.println("Hello from Java Host Function!");
// var key = plugin.inputString(params[0]);
// System.out.println("Reading from key " + key);
// var value = kvStore.get(key);
// if (value == null) {
// // default to zeroed bytes
// var zero = new byte[]{0,0,0,0};
// plugin.returnBytes(returns[0], zero);
// } else {
// plugin.returnBytes(returns[0], value);
// }
// };
//
// HostFunction kvWriteHostFn = new HostFunction<>(
// "kv_write",
// new LibExtism.ExtismValType[]{LibExtism.ExtismValType.I64, LibExtism.ExtismValType.I64},
// new LibExtism.ExtismValType[0],
// kvWrite,
// Optional.empty()
// );
//
// HostFunction kvReadHostFn = new HostFunction<>(
// "kv_read",
// new LibExtism.ExtismValType[]{LibExtism.ExtismValType.I64},
// new LibExtism.ExtismValType[]{LibExtism.ExtismValType.I64},
// kvRead,
// Optional.empty()
// );
//
// HostFunction[] functions = {kvWriteHostFn, kvReadHostFn};
// var plugin = new Plugin(manifest, false, functions);
// var output = plugin.call("count_vowels", "Hello, World!");
// }
@Test
public void shouldInvokeFunctionFromUrlWasmSourceHostFuncs() {
var url = "https://github.com/extism/plugins/releases/latest/download/count_vowels_kvstore.wasm";
var manifest = new Manifest(List.of(UrlWasmSource.fromUrl(url)));

// Our application KV store
// Pretend this is redis or a database :)
var kvStore = new HashMap<String, byte[]>();

ExtismFunction kvWrite = (plugin, params, returns, data) -> {
System.out.println("Hello from Java Host Function!");
var key = plugin.inputString(params[0]);
var value = plugin.inputBytes(params[1]);
System.out.println("Writing to key " + key);
kvStore.put(key, value);
};

ExtismFunction kvRead = (plugin, params, returns, data) -> {
System.out.println("Hello from Java Host Function!");
var key = plugin.inputString(params[0]);
System.out.println("Reading from key " + key);
var value = kvStore.get(key);
if (value == null) {
// default to zeroed bytes
var zero = new byte[]{0,0,0,0};
plugin.returnBytes(returns[0], zero);
} else {
plugin.returnBytes(returns[0], value);
}
};

HostFunction kvWriteHostFn = new HostFunction<>(
"kv_write",
new LibExtism.ExtismValType[]{LibExtism.ExtismValType.I64, LibExtism.ExtismValType.I64},
new LibExtism.ExtismValType[0],
kvWrite,
Optional.empty()
);

HostFunction kvReadHostFn = new HostFunction<>(
"kv_read",
new LibExtism.ExtismValType[]{LibExtism.ExtismValType.I64},
new LibExtism.ExtismValType[]{LibExtism.ExtismValType.I64},
kvRead,
Optional.empty()
);

HostFunction[] functions = {kvWriteHostFn, kvReadHostFn};
var plugin = new Plugin(manifest, false, functions);
var output = plugin.call("count_vowels", "Hello, World!");
}

@Test
public void shouldInvokeFunctionFromByteArrayWasmSource() {
Expand Down

0 comments on commit 16a9d1e

Please sign in to comment.