-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add
prophet-wasmstan
component and JS package stubs (#126)
- Loading branch information
Showing
28 changed files
with
8,883 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
name: prophet-wasmstan | ||
|
||
on: | ||
push: | ||
branches: [ "main" ] | ||
pull_request: | ||
branches: [ "main" ] | ||
|
||
env: | ||
CARGO_TERM_COLOR: always | ||
|
||
jobs: | ||
build: | ||
|
||
runs-on: ubuntu-latest | ||
|
||
steps: | ||
- uses: actions/checkout@v4 | ||
- uses: dtolnay/rust-toolchain@stable | ||
with: | ||
targets: wasm32-unknown-unknown,wasm32-wasip1 | ||
- uses: taiki-e/install-action@v2 | ||
with: | ||
tool: cargo-binstall,just,wasmtime | ||
- name: Install deps | ||
run: just components/install-deps | ||
- uses: actions/setup-node@v4 | ||
- name: Run node test | ||
run: just components/test |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
[submodule "components/cpp/prophet-wasmstan/stan"] | ||
path = components/cpp/prophet-wasmstan/stan | ||
url = https://github.com/stan-dev/stan |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
/tools |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
*-core.wasm | ||
*-component.wasm | ||
*_component_type.o |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
# prophet-wasmstan - a WASM Component for the Prophet Stan model | ||
|
||
`prophet-wasmstan` is a WASM component exposing the core model fitting | ||
and sampling functionality of the [Prophet](https://github.com/facebook/prophet) | ||
time series forecasting model. Specifically, this component uses the | ||
generated Stan code (a C++ file) and the Stan library to expose | ||
the `optimize` and `sample` functions of Stan using the Prophet model, | ||
allowing it to be called from a WASM module. | ||
|
||
## Building | ||
|
||
To build the component you'll need to have several tools from the | ||
WASM Component toolchain installed. The easiest way to do this is | ||
using the `justfile` from the `components` directory of the repository, | ||
which has an `install-dependencies` target that will install all | ||
the necessary tools. | ||
|
||
```bash | ||
just install-dependencies | ||
``` | ||
|
||
Once the dependencies are installed, you can build the component | ||
with the `build-lib-component` target: | ||
|
||
```bash | ||
just build | ||
``` | ||
|
||
This will generate a `prophet-wasmstan-component.wasm` file in the `prophet-wasmstan` | ||
directory. This file can be used as a WASM component. | ||
|
||
## Using the component | ||
|
||
The interface exposed by the component is defined in the `prophet-wasmstan.wit` | ||
file. | ||
|
||
See the [Component Model docs](https://component-model.bytecodealliance.org/language-support.html) | ||
for instructions on how to use the component in a WASM component in other | ||
languages. | ||
|
||
### Javascript | ||
|
||
Run the following command to generate Javascript bindings for the component: | ||
|
||
```bash | ||
just transpile | ||
``` | ||
|
||
You should now have Javascript bindings in the `js/prophet-wasmstan` directory. |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
// Copyright (c) Facebook, Inc. and its affiliates. | ||
|
||
// This source code is licensed under the MIT license found in the | ||
// LICENSE file in the root directory of this source tree. | ||
|
||
functions { | ||
matrix get_changepoint_matrix(vector t, vector t_change, int T, int S) { | ||
// Assumes t and t_change are sorted. | ||
matrix[T, S] A; | ||
row_vector[S] a_row; | ||
int cp_idx; | ||
|
||
// Start with an empty matrix. | ||
A = rep_matrix(0, T, S); | ||
a_row = rep_row_vector(0, S); | ||
cp_idx = 1; | ||
|
||
// Fill in each row of A. | ||
for (i in 1:T) { | ||
while ((cp_idx <= S) && (t[i] >= t_change[cp_idx])) { | ||
a_row[cp_idx] = 1; | ||
cp_idx = cp_idx + 1; | ||
} | ||
A[i] = a_row; | ||
} | ||
return A; | ||
} | ||
|
||
// Logistic trend functions | ||
|
||
vector logistic_gamma(real k, real m, vector delta, vector t_change, int S) { | ||
vector[S] gamma; // adjusted offsets, for piecewise continuity | ||
vector[S + 1] k_s; // actual rate in each segment | ||
real m_pr; | ||
|
||
// Compute the rate in each segment | ||
k_s = append_row(k, k + cumulative_sum(delta)); | ||
|
||
// Piecewise offsets | ||
m_pr = m; // The offset in the previous segment | ||
for (i in 1:S) { | ||
gamma[i] = (t_change[i] - m_pr) * (1 - k_s[i] / k_s[i + 1]); | ||
m_pr = m_pr + gamma[i]; // update for the next segment | ||
} | ||
return gamma; | ||
} | ||
|
||
vector logistic_trend( | ||
real k, | ||
real m, | ||
vector delta, | ||
vector t, | ||
vector cap, | ||
matrix A, | ||
vector t_change, | ||
int S | ||
) { | ||
vector[S] gamma; | ||
|
||
gamma = logistic_gamma(k, m, delta, t_change, S); | ||
return cap .* inv_logit((k + A * delta) .* (t - (m + A * gamma))); | ||
} | ||
|
||
// Linear trend function | ||
|
||
vector linear_trend( | ||
real k, | ||
real m, | ||
vector delta, | ||
vector t, | ||
matrix A, | ||
vector t_change | ||
) { | ||
return (k + A * delta) .* t + (m + A * (-t_change .* delta)); | ||
} | ||
|
||
// Flat trend function | ||
|
||
vector flat_trend( | ||
real m, | ||
int T | ||
) { | ||
return rep_vector(m, T); | ||
} | ||
} | ||
|
||
data { | ||
int T; // Number of time periods | ||
int<lower=1> K; // Number of regressors | ||
vector[T] t; // Time | ||
vector[T] cap; // Capacities for logistic trend | ||
vector[T] y; // Time series | ||
int S; // Number of changepoints | ||
vector[S] t_change; // Times of trend changepoints | ||
matrix[T,K] X; // Regressors | ||
vector[K] sigmas; // Scale on seasonality prior | ||
real<lower=0> tau; // Scale on changepoints prior | ||
int trend_indicator; // 0 for linear, 1 for logistic, 2 for flat | ||
vector[K] s_a; // Indicator of additive features | ||
vector[K] s_m; // Indicator of multiplicative features | ||
} | ||
|
||
transformed data { | ||
matrix[T, S] A = get_changepoint_matrix(t, t_change, T, S); | ||
matrix[T, K] X_sa = X .* rep_matrix(s_a', T); | ||
matrix[T, K] X_sm = X .* rep_matrix(s_m', T); | ||
} | ||
|
||
parameters { | ||
real k; // Base trend growth rate | ||
real m; // Trend offset | ||
vector[S] delta; // Trend rate adjustments | ||
real<lower=0> sigma_obs; // Observation noise | ||
vector[K] beta; // Regressor coefficients | ||
} | ||
|
||
transformed parameters { | ||
vector[T] trend; | ||
if (trend_indicator == 0) { | ||
trend = linear_trend(k, m, delta, t, A, t_change); | ||
} else if (trend_indicator == 1) { | ||
trend = logistic_trend(k, m, delta, t, cap, A, t_change, S); | ||
} else if (trend_indicator == 2) { | ||
trend = flat_trend(m, T); | ||
} | ||
} | ||
|
||
model { | ||
//priors | ||
k ~ normal(0, 5); | ||
m ~ normal(0, 5); | ||
delta ~ double_exponential(0, tau); | ||
sigma_obs ~ normal(0, 0.5); | ||
beta ~ normal(0, sigmas); | ||
|
||
// Likelihood | ||
y ~ normal_id_glm( | ||
X_sa, | ||
trend .* (1 + X_sm * beta), | ||
beta, | ||
sigma_obs | ||
); | ||
} |
Oops, something went wrong.