diff --git a/README.md b/README.md index e0edd4c..7dd2efc 100644 --- a/README.md +++ b/README.md @@ -82,6 +82,26 @@ convert it directly to `Yojson.Safe.t`: See [Extism.Type.S](https://extism.github.io/ocaml-sdk/extism/Extism/Type/module-type-S/index.html) to define your own input/output types. +### Typed Plugins + +Plug-ins can also use pre-defined functions using `Plugin.Typed`: + +```ocaml +module Example = struct + include Plugin.Typed.Init () + + let count_vowels = exn @@ fn "count_vowels" Type.string Type.json +end +``` + +This can then be initialized using an existing `Plugin.t`: + +```ocaml +let example = Example.of_plugin_exn plugin in +let res = Example.count_vowels example "this is a test" in +print_endline (Yojson.Safe.to_string res) +``` + ### Plug-in State Plug-ins may be stateful or stateless. Plug-ins can maintain state b/w calls by the use of variables. Our count vowels plug-in remembers the total number of vowels it's ever counted in the "total" key in the result. You can see this by making subsequent calls to the export: diff --git a/examples/dune b/examples/dune index 77551a3..71fd889 100644 --- a/examples/dune +++ b/examples/dune @@ -1,10 +1,10 @@ (executables - (names runner kv) + (names runner kv typed) (libraries extism)) (alias (name examples) - (deps runner.exe kv.exe)) + (deps runner.exe kv.exe typed.exe)) (alias (name runtest) diff --git a/examples/typed.ml b/examples/typed.ml new file mode 100644 index 0000000..1489d23 --- /dev/null +++ b/examples/typed.ml @@ -0,0 +1,18 @@ +open Extism + +let url = + "https://github.com/extism/plugins/releases/latest/download/count_vowels.wasm" + +module Typed_example = struct + include Plugin.Typed.Init () + + let count_vowels = exn @@ fn "count_vowels" Type.string Type.json +end + +let () = + let wasm = Manifest.Wasm.url url in + let manifest = Manifest.create [ wasm ] in + let plugin = Plugin.of_manifest_exn manifest in + let plugin = Typed_example.of_plugin_exn plugin in + let res = Typed_example.count_vowels plugin "this is a test" in + print_endline (Yojson.Safe.to_string res) diff --git a/src/extism.mli b/src/extism.mli index b76766a..53699a0 100644 --- a/src/extism.mli +++ b/src/extism.mli @@ -30,6 +30,28 @@ print_endline res ]} + Using the typed plugin interface you can pre-define plug-in functions: + {@ocaml[ + open Extism + + module Example = struct + include Plugin.Typed.Init () + + let count_vowels = exn @@ fn "count_vowels" Type.string Type.string + end + + let () = + let plugin = + Example.of_plugin_exn + @@ Plugin.of_manifest_exn + @@ Manifest.create [ Manifest.Wasm.file "test/code.wasm" ] + in + let res = + Example.count_vowels plugin "input data" + in + print_endline res + ]} + {1 API} *) module Manifest = Extism_manifest @@ -69,7 +91,7 @@ module Val_type : sig val of_int : int -> t (** Convert from [int] to {!t}, - @raise Invalid_argument if the integer isn't valid *) + raises Invalid_argument if the integer isn't valid *) val to_int : t -> int (** Convert from {!t} to [int] *) @@ -421,6 +443,39 @@ module Plugin : sig val id : t -> Uuidm.t (** Get the plugin UUID *) + + (** Typed plugins allow for plugin functions to be check at initialization and called with static types *) + module Typed : sig + + (** Defines the interface for typed plugins *) + module type S = sig + type plugin := t + + type t + (** Opaque typed plugin type *) + + val of_plugin : plugin -> (t, Error.t) result + (** Load an instance of a typed plugin, returning a result *) + + val of_plugin_exn : plugin -> t + (** Load an instance of a typed plugin, raising an exception if an error occurs *) + + val fn : + string -> + (module Type.S with type t = 'a) -> + (module Type.S with type t = 'b) -> + t -> + 'a -> + ('b, Error.t) result + (** Pre-declare a function that returns a result type *) + + val exn : (t -> 'a -> ('b, Error.t) result) -> (t -> 'a -> 'b) + (** Convert a pre-declared function to raise an exception instead of a result type *) + end + + module Init () : S + (** Initialize a new typed plugin module *) + end end val set_log_file : diff --git a/src/plugin.ml b/src/plugin.ml index 69d4c6d..c6e1ed0 100644 --- a/src/plugin.ml +++ b/src/plugin.ml @@ -204,5 +204,73 @@ let id { pointer; _ } = let s = Ctypes.string_from_ptr id ~length:16 in Uuidm.unsafe_of_bytes s -let reset { pointer; _ } = - Bindings.extism_plugin_reset pointer +let reset { pointer; _ } = Bindings.extism_plugin_reset pointer + +type plugin = t + +module Typed = struct + module Functions = Set.Make (String) + + module type S = sig + type t + + val of_plugin : plugin -> (t, Error.t) result + val of_plugin_exn : plugin -> t + + val fn : + string -> + (module Type.S with type t = 'a) -> + (module Type.S with type t = 'b) -> + t -> + 'a -> + ('b, Error.t) result + + val exn : (t -> 'a -> ('b, Error.t) result) -> t -> 'a -> 'b + end + + type typed = { mutable functions : Functions.t; mutable sealed : bool } + + module Init () : S = struct + type nonrec t = t + + let state = { functions = Functions.empty; sealed = false } + + let fn name params results = + if state.sealed then invalid_arg "Typed function has already been sealed"; + state.functions <- Functions.add name state.functions; + let f = call params results ~name in + fun plugin params -> f plugin params + + let exn f (t : t) x = Error.unwrap (f t x) + let finish () = state.sealed <- true + + let of_plugin_exn plugin = + finish (); + Functions.iter + (fun name -> + if not (function_exists plugin name) then + Error.throw (`Msg ("invalid plugin function: " ^ name))) + state.functions; + plugin + + let of_plugin plugin = + finish (); + match of_plugin_exn plugin with + | exception Error.Error e -> Error e + | x -> Ok x + end +end + +let%test "typed" = + let module Test = struct + include Typed.Init () + + let count_vowels = exn @@ fn "count_vowels" Type.string Type.json + end in + let manifest = Manifest.(create [ Wasm.file "test/code.wasm" ]) in + let plugin = of_manifest manifest |> Error.unwrap in + let t = Test.of_plugin_exn plugin in + let res = Test.count_vowels t "aaa" in + let n = Yojson.Safe.Util.member "count" res |> Yojson.Safe.Util.to_number in + Printf.printf "count = %f\n" n; + n = 3.0