Skip to content

Commit

Permalink
feat: add prophet-wasmstan component and JS package stubs (#126)
Browse files Browse the repository at this point in the history
  • Loading branch information
sd2k authored Oct 16, 2024
1 parent b64c8fc commit e395566
Show file tree
Hide file tree
Showing 28 changed files with 8,883 additions and 0 deletions.
29 changes: 29 additions & 0 deletions .github/workflows/wasmstan.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
name: prophet-wasmstan

on:
push:
branches: [ "main" ]
pull_request:
branches: [ "main" ]

env:
CARGO_TERM_COLOR: always

jobs:
build:

runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v4
- uses: dtolnay/rust-toolchain@stable
with:
targets: wasm32-unknown-unknown,wasm32-wasip1
- uses: taiki-e/install-action@v2
with:
tool: cargo-binstall,just,wasmtime
- name: Install deps
run: just components/install-deps
- uses: actions/setup-node@v4
- name: Run node test
run: just components/test
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "components/cpp/prophet-wasmstan/stan"]
path = components/cpp/prophet-wasmstan/stan
url = https://github.com/stan-dev/stan
1 change: 1 addition & 0 deletions components/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
/tools
3 changes: 3 additions & 0 deletions components/cpp/prophet-wasmstan/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
*-core.wasm
*-component.wasm
*_component_type.o
49 changes: 49 additions & 0 deletions components/cpp/prophet-wasmstan/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# prophet-wasmstan - a WASM Component for the Prophet Stan model

`prophet-wasmstan` is a WASM component exposing the core model fitting
and sampling functionality of the [Prophet](https://github.com/facebook/prophet)
time series forecasting model. Specifically, this component uses the
generated Stan code (a C++ file) and the Stan library to expose
the `optimize` and `sample` functions of Stan using the Prophet model,
allowing it to be called from a WASM module.

## Building

To build the component you'll need to have several tools from the
WASM Component toolchain installed. The easiest way to do this is
using the `justfile` from the `components` directory of the repository,
which has an `install-dependencies` target that will install all
the necessary tools.

```bash
just install-dependencies
```

Once the dependencies are installed, you can build the component
with the `build-lib-component` target:

```bash
just build
```

This will generate a `prophet-wasmstan-component.wasm` file in the `prophet-wasmstan`
directory. This file can be used as a WASM component.

## Using the component

The interface exposed by the component is defined in the `prophet-wasmstan.wit`
file.

See the [Component Model docs](https://component-model.bytecodealliance.org/language-support.html)
for instructions on how to use the component in a WASM component in other
languages.

### Javascript

Run the following command to generate Javascript bindings for the component:

```bash
just transpile
```

You should now have Javascript bindings in the `js/prophet-wasmstan` directory.
1,343 changes: 1,343 additions & 0 deletions components/cpp/prophet-wasmstan/model/model.hpp

Large diffs are not rendered by default.

143 changes: 143 additions & 0 deletions components/cpp/prophet-wasmstan/model/model.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
// Copyright (c) Facebook, Inc. and its affiliates.

// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.

functions {
matrix get_changepoint_matrix(vector t, vector t_change, int T, int S) {
// Assumes t and t_change are sorted.
matrix[T, S] A;
row_vector[S] a_row;
int cp_idx;

// Start with an empty matrix.
A = rep_matrix(0, T, S);
a_row = rep_row_vector(0, S);
cp_idx = 1;

// Fill in each row of A.
for (i in 1:T) {
while ((cp_idx <= S) && (t[i] >= t_change[cp_idx])) {
a_row[cp_idx] = 1;
cp_idx = cp_idx + 1;
}
A[i] = a_row;
}
return A;
}

// Logistic trend functions

vector logistic_gamma(real k, real m, vector delta, vector t_change, int S) {
vector[S] gamma; // adjusted offsets, for piecewise continuity
vector[S + 1] k_s; // actual rate in each segment
real m_pr;

// Compute the rate in each segment
k_s = append_row(k, k + cumulative_sum(delta));

// Piecewise offsets
m_pr = m; // The offset in the previous segment
for (i in 1:S) {
gamma[i] = (t_change[i] - m_pr) * (1 - k_s[i] / k_s[i + 1]);
m_pr = m_pr + gamma[i]; // update for the next segment
}
return gamma;
}

vector logistic_trend(
real k,
real m,
vector delta,
vector t,
vector cap,
matrix A,
vector t_change,
int S
) {
vector[S] gamma;

gamma = logistic_gamma(k, m, delta, t_change, S);
return cap .* inv_logit((k + A * delta) .* (t - (m + A * gamma)));
}

// Linear trend function

vector linear_trend(
real k,
real m,
vector delta,
vector t,
matrix A,
vector t_change
) {
return (k + A * delta) .* t + (m + A * (-t_change .* delta));
}

// Flat trend function

vector flat_trend(
real m,
int T
) {
return rep_vector(m, T);
}
}

data {
int T; // Number of time periods
int<lower=1> K; // Number of regressors
vector[T] t; // Time
vector[T] cap; // Capacities for logistic trend
vector[T] y; // Time series
int S; // Number of changepoints
vector[S] t_change; // Times of trend changepoints
matrix[T,K] X; // Regressors
vector[K] sigmas; // Scale on seasonality prior
real<lower=0> tau; // Scale on changepoints prior
int trend_indicator; // 0 for linear, 1 for logistic, 2 for flat
vector[K] s_a; // Indicator of additive features
vector[K] s_m; // Indicator of multiplicative features
}

transformed data {
matrix[T, S] A = get_changepoint_matrix(t, t_change, T, S);
matrix[T, K] X_sa = X .* rep_matrix(s_a', T);
matrix[T, K] X_sm = X .* rep_matrix(s_m', T);
}

parameters {
real k; // Base trend growth rate
real m; // Trend offset
vector[S] delta; // Trend rate adjustments
real<lower=0> sigma_obs; // Observation noise
vector[K] beta; // Regressor coefficients
}

transformed parameters {
vector[T] trend;
if (trend_indicator == 0) {
trend = linear_trend(k, m, delta, t, A, t_change);
} else if (trend_indicator == 1) {
trend = logistic_trend(k, m, delta, t, cap, A, t_change, S);
} else if (trend_indicator == 2) {
trend = flat_trend(m, T);
}
}

model {
//priors
k ~ normal(0, 5);
m ~ normal(0, 5);
delta ~ double_exponential(0, tau);
sigma_obs ~ normal(0, 0.5);
beta ~ normal(0, sigmas);

// Likelihood
y ~ normal_id_glm(
X_sa,
trend .* (1 + X_sm * beta),
beta,
sigma_obs
);
}
Loading

0 comments on commit e395566

Please sign in to comment.