Skip to content

Commit

Permalink
feat: add Plugin.Typed to define typed plugins (#14)
Browse files Browse the repository at this point in the history
  • Loading branch information
zshipko authored Jan 4, 2024
1 parent e6e67a7 commit 1e398ac
Show file tree
Hide file tree
Showing 5 changed files with 166 additions and 5 deletions.
20 changes: 20 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,26 @@ convert it directly to `Yojson.Safe.t`:

See [Extism.Type.S](https://extism.github.io/ocaml-sdk/extism/Extism/Type/module-type-S/index.html) to define your own input/output types.

### Typed Plugins

Plug-ins can also use pre-defined functions using `Plugin.Typed`:

```ocaml
module Example = struct
include Plugin.Typed.Init ()
let count_vowels = exn @@ fn "count_vowels" Type.string Type.json
end
```

This can then be initialized using an existing `Plugin.t`:

```ocaml
let example = Example.of_plugin_exn plugin in
let res = Example.count_vowels example "this is a test" in
print_endline (Yojson.Safe.to_string res)
```

### Plug-in State

Plug-ins may be stateful or stateless. Plug-ins can maintain state b/w calls by the use of variables. Our count vowels plug-in remembers the total number of vowels it's ever counted in the "total" key in the result. You can see this by making subsequent calls to the export:
Expand Down
4 changes: 2 additions & 2 deletions examples/dune
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
(executables
(names runner kv)
(names runner kv typed)
(libraries extism))

(alias
(name examples)
(deps runner.exe kv.exe))
(deps runner.exe kv.exe typed.exe))

(alias
(name runtest)
Expand Down
18 changes: 18 additions & 0 deletions examples/typed.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
open Extism

let url =
"https://github.com/extism/plugins/releases/latest/download/count_vowels.wasm"

module Typed_example = struct
include Plugin.Typed.Init ()

let count_vowels = exn @@ fn "count_vowels" Type.string Type.json
end

let () =
let wasm = Manifest.Wasm.url url in
let manifest = Manifest.create [ wasm ] in
let plugin = Plugin.of_manifest_exn manifest in
let plugin = Typed_example.of_plugin_exn plugin in
let res = Typed_example.count_vowels plugin "this is a test" in
print_endline (Yojson.Safe.to_string res)
57 changes: 56 additions & 1 deletion src/extism.mli
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,28 @@
print_endline res
]}
Using the typed plugin interface you can pre-define plug-in functions:
{@ocaml[
open Extism
module Example = struct
include Plugin.Typed.Init ()
let count_vowels = exn @@ fn "count_vowels" Type.string Type.string
end
let () =
let plugin =
Example.of_plugin_exn
@@ Plugin.of_manifest_exn
@@ Manifest.create [ Manifest.Wasm.file "test/code.wasm" ]
in
let res =
Example.count_vowels plugin "input data"
in
print_endline res
]}
{1 API} *)

module Manifest = Extism_manifest
Expand Down Expand Up @@ -69,7 +91,7 @@ module Val_type : sig

val of_int : int -> t
(** Convert from [int] to {!t},
@raise Invalid_argument if the integer isn't valid *)
raises Invalid_argument if the integer isn't valid *)

val to_int : t -> int
(** Convert from {!t} to [int] *)
Expand Down Expand Up @@ -421,6 +443,39 @@ module Plugin : sig

val id : t -> Uuidm.t
(** Get the plugin UUID *)

(** Typed plugins allow for plugin functions to be check at initialization and called with static types *)
module Typed : sig

(** Defines the interface for typed plugins *)
module type S = sig
type plugin := t

type t
(** Opaque typed plugin type *)

val of_plugin : plugin -> (t, Error.t) result
(** Load an instance of a typed plugin, returning a result *)

val of_plugin_exn : plugin -> t
(** Load an instance of a typed plugin, raising an exception if an error occurs *)

val fn :
string ->
(module Type.S with type t = 'a) ->
(module Type.S with type t = 'b) ->
t ->
'a ->
('b, Error.t) result
(** Pre-declare a function that returns a result type *)

val exn : (t -> 'a -> ('b, Error.t) result) -> (t -> 'a -> 'b)
(** Convert a pre-declared function to raise an exception instead of a result type *)
end

module Init () : S
(** Initialize a new typed plugin module *)
end
end

val set_log_file :
Expand Down
72 changes: 70 additions & 2 deletions src/plugin.ml
Original file line number Diff line number Diff line change
Expand Up @@ -204,5 +204,73 @@ let id { pointer; _ } =
let s = Ctypes.string_from_ptr id ~length:16 in
Uuidm.unsafe_of_bytes s

let reset { pointer; _ } =
Bindings.extism_plugin_reset pointer
let reset { pointer; _ } = Bindings.extism_plugin_reset pointer

type plugin = t

module Typed = struct
module Functions = Set.Make (String)

module type S = sig
type t

val of_plugin : plugin -> (t, Error.t) result
val of_plugin_exn : plugin -> t

val fn :
string ->
(module Type.S with type t = 'a) ->
(module Type.S with type t = 'b) ->
t ->
'a ->
('b, Error.t) result

val exn : (t -> 'a -> ('b, Error.t) result) -> t -> 'a -> 'b
end

type typed = { mutable functions : Functions.t; mutable sealed : bool }

module Init () : S = struct
type nonrec t = t

let state = { functions = Functions.empty; sealed = false }

let fn name params results =
if state.sealed then invalid_arg "Typed function has already been sealed";
state.functions <- Functions.add name state.functions;
let f = call params results ~name in
fun plugin params -> f plugin params

let exn f (t : t) x = Error.unwrap (f t x)
let finish () = state.sealed <- true

let of_plugin_exn plugin =
finish ();
Functions.iter
(fun name ->
if not (function_exists plugin name) then
Error.throw (`Msg ("invalid plugin function: " ^ name)))
state.functions;
plugin

let of_plugin plugin =
finish ();
match of_plugin_exn plugin with
| exception Error.Error e -> Error e
| x -> Ok x
end
end

let%test "typed" =
let module Test = struct
include Typed.Init ()

let count_vowels = exn @@ fn "count_vowels" Type.string Type.json
end in
let manifest = Manifest.(create [ Wasm.file "test/code.wasm" ]) in
let plugin = of_manifest manifest |> Error.unwrap in
let t = Test.of_plugin_exn plugin in
let res = Test.count_vowels t "aaa" in
let n = Yojson.Safe.Util.member "count" res |> Yojson.Safe.Util.to_number in
Printf.printf "count = %f\n" n;
n = 3.0

0 comments on commit 1e398ac

Please sign in to comment.