diff --git a/onnxruntime/src/session.rs b/onnxruntime/src/session.rs index df3896de..9c99837a 100644 --- a/onnxruntime/src/session.rs +++ b/onnxruntime/src/session.rs @@ -10,7 +10,7 @@ use std::os::windows::ffi::OsStrExt; #[cfg(feature = "model-fetching")] use std::env; -use ndarray::Array; +use ndarray::{Array, Dimension}; use tracing::{debug, error}; use onnxruntime_sys as sys; @@ -21,6 +21,7 @@ use crate::{ error::{status_to_result, NonMatchingDimensionsError, OrtError, Result}, g_ort, memory::MemoryInfo, + runner::{Element, RunnerBuilder}, tensor::{ ort_owned_tensor::{OrtOwnedTensor, OrtOwnedTensorExtractor}, OrtTensor, @@ -361,6 +362,13 @@ impl Drop for Session { } impl Session { + pub fn make_runner>>( + &self, + input_arrays: I, + ) -> RunnerBuilder<'_, T, D> { + RunnerBuilder::new(self, input_arrays) + } + /// Run the input data through the ONNX graph, performing inference. /// /// Note that ONNX models can have multiple inputs; a `Vec<_>` is thus