-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add wasmtime based optimizer for dependency-free Rust impl
This commit adds a new feature-flagged implementation of `Optimizer`, which uses the `wasmtime` runtime to run the WASM-compiled Stan model.
- Loading branch information
Showing
15 changed files
with
6,593 additions
and
19 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
*-core.wasm | ||
*-component.wasm | ||
*_component_type.o | ||
prophet-wasmstan.wasm |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
/prophet_stan_model | ||
prophet-wasmstan.wasm |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
File renamed without changes.
2,117 changes: 2,117 additions & 0 deletions
2,117
crates/augurs-prophet/benches/prophet-wasmstan-linear.rs
Large diffs are not rendered by default.
Oops, something went wrong.
2,118 changes: 2,118 additions & 0 deletions
2,118
crates/augurs-prophet/benches/prophet-wasmstan-logistic.rs
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../../components/cpp/prophet-wasmstan/wit/prophet-wasmstan.wit |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,222 @@ | ||
//! Use `wasmtime` to run the Prophet model inside a WebAssembly component. | ||
use std::fmt; | ||
|
||
use wasmtime::{ | ||
component::{Component, Linker}, | ||
Engine, Store, | ||
}; | ||
use wasmtime_wasi::{ResourceTable, WasiCtx, WasiView}; | ||
|
||
use crate::{ | ||
optimizer::{self, Data, InitialParams, OptimizeOpts, OptimizedParams}, | ||
Optimizer, | ||
}; | ||
|
||
/// An error that can occur when using the `WasmStanOptimizer`. | ||
#[derive(Debug, thiserror::Error)] | ||
pub enum Error { | ||
/// An error occurred while compiling the WebAssembly component. | ||
#[error("Error compiling component: {0}")] | ||
Compilation(wasmtime::Error), | ||
/// An error occurred while instantiating the WebAssembly component. | ||
#[error("Error instantiating component: {0}")] | ||
Instantiation(wasmtime::Error), | ||
/// An error occurred in wasmtime while running the WebAssembly component. | ||
#[error("Error running component: {0}")] | ||
Runtime(wasmtime::Error), | ||
/// An error occurred in Stan while running the optimization. | ||
#[error("Error running optimization: {0}")] | ||
Optimize(String), | ||
/// An invalid parameter value was received from the optimizer. | ||
#[error("Invalid value ({value}) for parameter {param} received from optimizer")] | ||
InvalidParam { | ||
/// The parameter name. | ||
param: String, | ||
/// The value received from the optimizer. | ||
value: f64, | ||
}, | ||
} | ||
|
||
#[allow(missing_docs)] | ||
mod gen { | ||
use wasmtime::component::bindgen; | ||
|
||
bindgen!({ | ||
world: "prophet-wasmstan", | ||
path: "prophet-wasmstan.wit", | ||
// Uncomment this to include the pregenerated file in the `target` directory | ||
// somewhere (search for `prophet-wasmstan0.rs`). | ||
// include_generated_code_from_file: true, | ||
}); | ||
} | ||
|
||
use gen::*; | ||
|
||
struct WasiState { | ||
ctx: WasiCtx, | ||
table: ResourceTable, | ||
} | ||
|
||
impl Default for WasiState { | ||
fn default() -> Self { | ||
Self { | ||
ctx: WasiCtx::builder().build(), | ||
table: Default::default(), | ||
} | ||
} | ||
} | ||
|
||
impl WasiView for WasiState { | ||
fn ctx(&mut self) -> &mut WasiCtx { | ||
&mut self.ctx | ||
} | ||
fn table(&mut self) -> &mut ResourceTable { | ||
&mut self.table | ||
} | ||
} | ||
|
||
/// An `Optimizer` which runs the Prophet model inside a WebAssembly | ||
/// component. | ||
#[derive(Clone)] | ||
pub struct WasmstanOptimizer { | ||
engine: Engine, | ||
instance_pre: ProphetWasmstanPre<WasiState>, | ||
} | ||
|
||
impl fmt::Debug for WasmstanOptimizer { | ||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { | ||
f.debug_struct("WasmStanOptimizer").finish() | ||
} | ||
} | ||
|
||
impl WasmstanOptimizer { | ||
/// Create a new `WasmStanOptimizer`. | ||
pub fn new() -> Result<Self, Error> { | ||
// Create an engine in which to compile and run everything. | ||
let engine = Engine::default(); | ||
|
||
// Create a component from the compiled and embedded WASM binary. | ||
let component = Component::from_binary( | ||
&engine, | ||
include_bytes!(concat!( | ||
std::env!("CARGO_MANIFEST_DIR"), | ||
"/prophet-wasmstan.wasm" | ||
)), | ||
) | ||
.map_err(Error::Compilation)?; | ||
|
||
// Create a linker, which will add WASI imports to the component. | ||
let mut linker = Linker::new(&engine); | ||
wasmtime_wasi::add_to_linker_sync(&mut linker).map_err(Error::Instantiation)?; | ||
|
||
// Create a pre-instantiated component. | ||
// This does as much work as possible here, so that `optimize` can | ||
// be called multiple times with the minimum amount of overhead. | ||
let instance_pre = linker | ||
.instantiate_pre(&component) | ||
.and_then(ProphetWasmstanPre::new) | ||
.map_err(Error::Instantiation)?; | ||
Ok(Self { | ||
engine, | ||
instance_pre, | ||
}) | ||
} | ||
|
||
/// Optimize the model using the given parameters. | ||
fn wasm_optimize( | ||
&self, | ||
init: &augurs::prophet_wasmstan::types::Inits, | ||
data: &String, | ||
opts: augurs::prophet_wasmstan::types::OptimizeOpts, | ||
) -> Result<OptimizedParams, Error> { | ||
let mut store = Store::new(&self.engine, WasiState::default()); | ||
let instance = self | ||
.instance_pre | ||
.instantiate(&mut store) | ||
.map_err(Error::Instantiation)?; | ||
instance | ||
.augurs_prophet_wasmstan_optimizer() | ||
.call_optimize(&mut store, init, data, opts) | ||
.map_err(Error::Runtime)? | ||
.map_err(Error::Optimize) | ||
.and_then(|op| op.params.try_into()) | ||
} | ||
} | ||
|
||
impl TryFrom<augurs::prophet_wasmstan::types::OptimizedParams> for OptimizedParams { | ||
type Error = Error; | ||
fn try_from( | ||
value: augurs::prophet_wasmstan::types::OptimizedParams, | ||
) -> Result<Self, Self::Error> { | ||
Ok(Self { | ||
k: value.k, | ||
m: value.m, | ||
sigma_obs: value | ||
.sigma_obs | ||
.try_into() | ||
.map_err(|_| Error::InvalidParam { | ||
param: "sigma_obs".to_string(), | ||
value: value.sigma_obs, | ||
})?, | ||
delta: value.delta, | ||
beta: value.beta, | ||
trend: value.trend, | ||
}) | ||
} | ||
} | ||
|
||
impl Optimizer for WasmstanOptimizer { | ||
fn optimize( | ||
&self, | ||
init: &InitialParams, | ||
data: &Data, | ||
opts: &OptimizeOpts, | ||
) -> Result<OptimizedParams, optimizer::Error> { | ||
let data = serde_json::to_string(&data).map_err(optimizer::Error::custom)?; | ||
self.wasm_optimize(&init.into(), &data, opts.into()) | ||
.map_err(optimizer::Error::custom) | ||
} | ||
} | ||
|
||
impl From<&InitialParams> for augurs::prophet_wasmstan::types::Inits { | ||
fn from(init: &InitialParams) -> Self { | ||
Self { | ||
k: init.k, | ||
m: init.m, | ||
sigma_obs: init.sigma_obs.into(), | ||
delta: init.delta.clone(), | ||
beta: init.beta.clone(), | ||
} | ||
} | ||
} | ||
|
||
impl From<&OptimizeOpts> for augurs::prophet_wasmstan::types::OptimizeOpts { | ||
fn from(opts: &OptimizeOpts) -> Self { | ||
Self { | ||
algorithm: opts.algorithm.map(Into::into), | ||
seed: opts.seed, | ||
chain: opts.chain, | ||
init_alpha: opts.init_alpha, | ||
tol_obj: opts.tol_obj, | ||
tol_rel_obj: opts.tol_rel_obj, | ||
tol_grad: opts.tol_grad, | ||
tol_rel_grad: opts.tol_rel_grad, | ||
tol_param: opts.tol_param, | ||
history_size: opts.history_size, | ||
iter: opts.iter, | ||
jacobian: opts.jacobian, | ||
refresh: opts.refresh, | ||
} | ||
} | ||
} | ||
|
||
impl From<optimizer::Algorithm> for augurs::prophet_wasmstan::types::Algorithm { | ||
fn from(algo: optimizer::Algorithm) -> Self { | ||
match algo { | ||
optimizer::Algorithm::Bfgs => Self::Bfgs, | ||
optimizer::Algorithm::Lbfgs => Self::Lbfgs, | ||
optimizer::Algorithm::Newton => Self::Newton, | ||
} | ||
} | ||
} |
Oops, something went wrong.