Skip to content

Commit

Permalink
feat: add wasmtime based optimizer for dependency-free Rust impl
Browse files Browse the repository at this point in the history
This commit adds a new feature-flagged implementation of `Optimizer`,
which uses the `wasmtime` runtime to run the WASM-compiled Stan model.
  • Loading branch information
sd2k committed Nov 7, 2024
1 parent 2a0a7bb commit 3880c2a
Show file tree
Hide file tree
Showing 15 changed files with 6,593 additions and 19 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ jobs:

test:
name: Tests
# We can only run tests if the triggering workflow succeeded, because
# we need the prophet-wasmstan.wasm file to be present.
if: ${{ github.event.workflow_run.conclusion == 'success' }}
runs-on: ubuntu-latest
steps:
- uses: actions/download-artifact@v4
Expand Down
1 change: 1 addition & 0 deletions components/cpp/prophet-wasmstan/.gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
*-core.wasm
*-component.wasm
*_component_type.o
prophet-wasmstan.wasm
1 change: 1 addition & 0 deletions crates/augurs-prophet/.gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
/prophet_stan_model
prophet-wasmstan.wasm
30 changes: 28 additions & 2 deletions crates/augurs-prophet/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ tempfile = { version = "3.13.0", optional = true }
thiserror.workspace = true
tracing.workspace = true
ureq = { version = "2.10.1", optional = true }
wasmtime = { version = "26", features = [
"runtime",
"component-model",
], optional = true }
wasmtime-wasi = { version = "26", optional = true }
zip = { version = "2.2.0", optional = true }

[dev-dependencies]
Expand All @@ -45,25 +50,46 @@ download = ["dep:ureq", "dep:zip"]
# end users, or you may end up with a broken build where the
# Prophet model isn't available to be compiled into the binary.
internal-ignore-cmdstan-failure = []
# Ignore wasmstan compilation in the build script.
# This should only be used for developing the library, not by
# end users, or you may end up with a broken build where the
# Prophet model isn't available to be compiled into the binary.
internal-ignore-wasmstan-failure = []
serde = ["dep:serde"]
compile-wasmstan = ["wasmstan"]
wasmstan = ["dep:serde_json", "dep:wasmtime", "dep:wasmtime-wasi", "serde"]

[lib]
bench = false

[[bench]]
name = "prophet-linear"
name = "prophet-cmdstan-linear"
harness = false
required-features = ["cmdstan", "compile-cmdstan"]

[[bench]]
name = "prophet-logistic"
name = "prophet-cmdstan-logistic"
harness = false
required-features = ["cmdstan", "compile-cmdstan"]

[[bench]]
name = "prophet-wasmstan-linear"
harness = false
required-features = ["wasmstan"]

[[bench]]
name = "prophet-wasmstan-logistic"
harness = false
required-features = ["wasmstan"]

[[bin]]
name = "download-stan-model"
path = "src/bin/main.rs"
required-features = ["download"]

[[test]]
name = "wasmstan"
required-features = ["wasmstan"]

[lints]
workspace = true
2,117 changes: 2,117 additions & 0 deletions crates/augurs-prophet/benches/prophet-wasmstan-linear.rs

Large diffs are not rendered by default.

2,118 changes: 2,118 additions & 0 deletions crates/augurs-prophet/benches/prophet-wasmstan-logistic.rs

Large diffs are not rendered by default.

31 changes: 16 additions & 15 deletions crates/augurs-prophet/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
/// - The `STAN_PATH` environment variable to be set to the
/// path to the Stan installation.
#[cfg(all(feature = "cmdstan", feature = "compile-cmdstan"))]
fn compile_model() -> Result<(), Box<dyn std::error::Error>> {
fn compile_cmdstan_model() -> Result<(), Box<dyn std::error::Error>> {
use std::{fs, path::PathBuf, process::Command};
use tempfile::TempDir;

Expand Down Expand Up @@ -65,25 +65,21 @@ fn compile_model() -> Result<(), Box<dyn std::error::Error>> {
Ok(())
}

fn fallback() -> Result<(), Box<dyn std::error::Error>> {
fn create_empty_files(names: &[&str]) -> Result<(), Box<dyn std::error::Error>> {
let out_dir = std::path::PathBuf::from(std::env::var("OUT_DIR")?);
std::fs::create_dir_all(&out_dir)?;
let prophet_path = out_dir.join("prophet");
let libtbb_path = out_dir.join("libtbb.so.12");
std::fs::File::create(&prophet_path)?;
std::fs::File::create(&libtbb_path)?;
eprintln!(
"Created empty files for prophet ({}) and libtbb ({})",
prophet_path.display(),
libtbb_path.display()
);
for name in names {
let path = out_dir.join(name);
std::fs::File::create(&path)?;
eprintln!("Created empty file for {}", path.display());
}
Ok(())
}

fn main() -> Result<(), Box<dyn std::error::Error>> {
fn handle_cmdstan() -> Result<(), Box<dyn std::error::Error>> {
let _result = Ok::<(), &'static str>(());
#[cfg(all(feature = "cmdstan", feature = "compile-cmdstan"))]
let _result = compile_model();
let _result = compile_cmdstan_model();
// This is a complete hack but lets us get away with still using
// the `--all-features` flag of Cargo without everything failing
// if there isn't a Stan installation.
Expand All @@ -94,12 +90,12 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
// feature.
#[cfg(feature = "internal-ignore-cmdstan-failure")]
if _result.is_err() {
fallback()?;
create_empty_files(&["prophet", "libtbb.so.12"])?;
}
// Do the same thing in docs.rs builds.
#[cfg(not(feature = "internal-ignore-cmdstan-failure"))]
if std::env::var("DOCS_RS").is_ok() {
fallback()?;
create_empty_files(&["prophet", "libtbb.so.12"])?;
}

// If we're not in a docs.rs build and we don't have the 'ignore'
Expand All @@ -110,3 +106,8 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
}
Ok(())
}

fn main() -> Result<(), Box<dyn std::error::Error>> {
handle_cmdstan()?;
Ok(())
}
1 change: 1 addition & 0 deletions crates/augurs-prophet/prophet-wasmstan.wit
2 changes: 1 addition & 1 deletion crates/augurs-prophet/src/cmdstan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ struct OptimizeCommand<'a> {
refresh: usize,
}

impl<'a> OptimizeCommand<'a> {
impl OptimizeCommand<'_> {
fn run(&self) -> Result<optimizer::OptimizedParams, Error> {
// Set up temp dir and files.
let tempdir = tempfile::tempdir()?;
Expand Down
2 changes: 2 additions & 0 deletions crates/augurs-prophet/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ mod util;

#[cfg(feature = "cmdstan")]
pub mod cmdstan;
#[cfg(feature = "wasmstan")]
pub mod wasmstan;

/// A timestamp represented as seconds since the epoch.
pub type TimestampSeconds = i64;
Expand Down
222 changes: 222 additions & 0 deletions crates/augurs-prophet/src/wasmstan.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
//! Use `wasmtime` to run the Prophet model inside a WebAssembly component.
use std::fmt;

use wasmtime::{
component::{Component, Linker},
Engine, Store,
};
use wasmtime_wasi::{ResourceTable, WasiCtx, WasiView};

use crate::{
optimizer::{self, Data, InitialParams, OptimizeOpts, OptimizedParams},
Optimizer,
};

/// An error that can occur when using the `WasmStanOptimizer`.
#[derive(Debug, thiserror::Error)]
pub enum Error {
/// An error occurred while compiling the WebAssembly component.
#[error("Error compiling component: {0}")]
Compilation(wasmtime::Error),
/// An error occurred while instantiating the WebAssembly component.
#[error("Error instantiating component: {0}")]
Instantiation(wasmtime::Error),
/// An error occurred in wasmtime while running the WebAssembly component.
#[error("Error running component: {0}")]
Runtime(wasmtime::Error),
/// An error occurred in Stan while running the optimization.
#[error("Error running optimization: {0}")]
Optimize(String),
/// An invalid parameter value was received from the optimizer.
#[error("Invalid value ({value}) for parameter {param} received from optimizer")]
InvalidParam {
/// The parameter name.
param: String,
/// The value received from the optimizer.
value: f64,
},
}

#[allow(missing_docs)]
mod gen {
use wasmtime::component::bindgen;

bindgen!({
world: "prophet-wasmstan",
path: "prophet-wasmstan.wit",
// Uncomment this to include the pregenerated file in the `target` directory
// somewhere (search for `prophet-wasmstan0.rs`).
// include_generated_code_from_file: true,
});
}

use gen::*;

struct WasiState {
ctx: WasiCtx,
table: ResourceTable,
}

impl Default for WasiState {
fn default() -> Self {
Self {
ctx: WasiCtx::builder().build(),
table: Default::default(),
}
}
}

impl WasiView for WasiState {
fn ctx(&mut self) -> &mut WasiCtx {
&mut self.ctx
}
fn table(&mut self) -> &mut ResourceTable {
&mut self.table
}
}

/// An `Optimizer` which runs the Prophet model inside a WebAssembly
/// component.
#[derive(Clone)]
pub struct WasmstanOptimizer {
engine: Engine,
instance_pre: ProphetWasmstanPre<WasiState>,
}

impl fmt::Debug for WasmstanOptimizer {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("WasmStanOptimizer").finish()
}
}

impl WasmstanOptimizer {
/// Create a new `WasmStanOptimizer`.
pub fn new() -> Result<Self, Error> {
// Create an engine in which to compile and run everything.
let engine = Engine::default();

// Create a component from the compiled and embedded WASM binary.
let component = Component::from_binary(
&engine,
include_bytes!(concat!(
std::env!("CARGO_MANIFEST_DIR"),
"/prophet-wasmstan.wasm"
)),
)
.map_err(Error::Compilation)?;

// Create a linker, which will add WASI imports to the component.
let mut linker = Linker::new(&engine);
wasmtime_wasi::add_to_linker_sync(&mut linker).map_err(Error::Instantiation)?;

// Create a pre-instantiated component.
// This does as much work as possible here, so that `optimize` can
// be called multiple times with the minimum amount of overhead.
let instance_pre = linker
.instantiate_pre(&component)
.and_then(ProphetWasmstanPre::new)
.map_err(Error::Instantiation)?;
Ok(Self {
engine,
instance_pre,
})
}

/// Optimize the model using the given parameters.
fn wasm_optimize(
&self,
init: &augurs::prophet_wasmstan::types::Inits,
data: &String,
opts: augurs::prophet_wasmstan::types::OptimizeOpts,
) -> Result<OptimizedParams, Error> {
let mut store = Store::new(&self.engine, WasiState::default());
let instance = self
.instance_pre
.instantiate(&mut store)
.map_err(Error::Instantiation)?;
instance
.augurs_prophet_wasmstan_optimizer()
.call_optimize(&mut store, init, data, opts)
.map_err(Error::Runtime)?
.map_err(Error::Optimize)
.and_then(|op| op.params.try_into())
}
}

impl TryFrom<augurs::prophet_wasmstan::types::OptimizedParams> for OptimizedParams {
type Error = Error;
fn try_from(
value: augurs::prophet_wasmstan::types::OptimizedParams,
) -> Result<Self, Self::Error> {
Ok(Self {
k: value.k,
m: value.m,
sigma_obs: value
.sigma_obs
.try_into()
.map_err(|_| Error::InvalidParam {
param: "sigma_obs".to_string(),
value: value.sigma_obs,
})?,
delta: value.delta,
beta: value.beta,
trend: value.trend,
})
}
}

impl Optimizer for WasmstanOptimizer {
fn optimize(
&self,
init: &InitialParams,
data: &Data,
opts: &OptimizeOpts,
) -> Result<OptimizedParams, optimizer::Error> {
let data = serde_json::to_string(&data).map_err(optimizer::Error::custom)?;
self.wasm_optimize(&init.into(), &data, opts.into())
.map_err(optimizer::Error::custom)
}
}

impl From<&InitialParams> for augurs::prophet_wasmstan::types::Inits {
fn from(init: &InitialParams) -> Self {
Self {
k: init.k,
m: init.m,
sigma_obs: init.sigma_obs.into(),
delta: init.delta.clone(),
beta: init.beta.clone(),
}
}
}

impl From<&OptimizeOpts> for augurs::prophet_wasmstan::types::OptimizeOpts {
fn from(opts: &OptimizeOpts) -> Self {
Self {
algorithm: opts.algorithm.map(Into::into),
seed: opts.seed,
chain: opts.chain,
init_alpha: opts.init_alpha,
tol_obj: opts.tol_obj,
tol_rel_obj: opts.tol_rel_obj,
tol_grad: opts.tol_grad,
tol_rel_grad: opts.tol_rel_grad,
tol_param: opts.tol_param,
history_size: opts.history_size,
iter: opts.iter,
jacobian: opts.jacobian,
refresh: opts.refresh,
}
}
}

impl From<optimizer::Algorithm> for augurs::prophet_wasmstan::types::Algorithm {
fn from(algo: optimizer::Algorithm) -> Self {
match algo {
optimizer::Algorithm::Bfgs => Self::Bfgs,
optimizer::Algorithm::Lbfgs => Self::Lbfgs,
optimizer::Algorithm::Newton => Self::Newton,
}
}
}
Loading

0 comments on commit 3880c2a

Please sign in to comment.