Skip to content

Commit

Permalink
feat: add 'transforms' JS crate and include in augurs JS bindings (#195)
Browse files Browse the repository at this point in the history
  • Loading branch information
sd2k authored Dec 12, 2024
1 parent 5f4fae4 commit 954bd91
Show file tree
Hide file tree
Showing 8 changed files with 206 additions and 6 deletions.
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ resolver = "2"
[workspace.package]
license = "MIT OR Apache-2.0"
authors = [
"Ben Sully <[email protected]",
"Ben Sully <[email protected]>",
]
documentation = "https://docs.rs/crate/augurs"
repository = "https://github.com/grafana/augurs"
Expand Down Expand Up @@ -44,6 +44,7 @@ augurs-testing = { path = "crates/augurs-testing" }
augurs-core-js = { path = "js/augurs-core-js" }

anyhow = "1.0.89"
argmin = "0.10.0"
bytemuck = "1.18.0"
chrono = "0.4.38"
distrs = "0.2.1"
Expand Down
2 changes: 1 addition & 1 deletion crates/augurs-forecaster/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ description = "A high-level API for the augurs forecasting library."
bench = false

[dependencies]
argmin = "0.10.0"
argmin.workspace = true
augurs-core.workspace = true
itertools.workspace = true
thiserror.workspace = true
Expand Down
32 changes: 30 additions & 2 deletions crates/augurs-forecaster/src/transforms.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,21 @@ impl Transform {
}

/// Apply the transformation to the given time series.
pub(crate) fn transform<'a, T>(&'a self, input: T) -> Box<dyn Iterator<Item = f64> + 'a>
///
/// # Returns
///
/// A boxed iterator over the transformed values.
///
/// # Example
///
/// ```
/// use augurs_forecaster::transforms::Transform;
///
/// let data = vec![1.0, 2.0, 3.0];
/// let transform = Transform::log();
/// let transformed: Vec<_> = transform.transform(data.into_iter()).collect();
/// ```
pub fn transform<'a, T>(&'a self, input: T) -> Box<dyn Iterator<Item = f64> + 'a>
where
T: Iterator<Item = f64> + 'a,
{
Expand All @@ -167,7 +181,21 @@ impl Transform {
}

/// Apply the inverse transformation to the given time series.
pub(crate) fn inverse_transform<'a, T>(&'a self, input: T) -> Box<dyn Iterator<Item = f64> + 'a>
///
/// # Returns
///
/// A boxed iterator over the inverse transformed values.
///
/// # Example
///
/// ```
/// use augurs_forecaster::transforms::Transform;
///
/// let data = vec![1.0, 2.0, 3.0];
/// let transform = Transform::log();
/// let transformed: Vec<_> = transform.inverse_transform(data.into_iter()).collect();
/// ```
pub fn inverse_transform<'a, T>(&'a self, input: T) -> Box<dyn Iterator<Item = f64> + 'a>
where
T: Iterator<Item = f64> + 'a,
{
Expand Down
35 changes: 35 additions & 0 deletions js/augurs-transforms-js/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
[package]
name = "augurs-transforms-js"
version.workspace = true
authors.workspace = true
documentation.workspace = true
repository.workspace = true
license.workspace = true
edition.workspace = true
keywords.workspace = true
description = "JavaScript bindings for augurs' data transformations."
publish = false

[lib]
bench = false
crate-type = ["cdylib", "rlib"]
doc = false
doctest = false
test = false

[dependencies]
argmin = { workspace = true, features = ["wasm-bindgen"] }
augurs-core-js.workspace = true
augurs-forecaster.workspace = true
getrandom.workspace = true
serde.workspace = true
serde-wasm-bindgen.workspace = true
tsify-next.workspace = true
wasm-bindgen.workspace = true

[package.metadata.wasm-pack.profile.release]
# previously had just ['-O4']
wasm-opt = ['-O4', '--enable-bulk-memory', '--enable-threads']

[lints]
workspace = true
71 changes: 71 additions & 0 deletions js/augurs-transforms-js/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
//! JavaScript bindings for augurs transformations, such as power transforms, scaling, etc.
use serde::Deserialize;
use tsify_next::Tsify;
use wasm_bindgen::prelude::*;

use augurs_core_js::VecF64;
use augurs_forecaster::transforms::Transform;

/// A power transform.
///
/// This transform applies the power function to each item.
///
/// If all values are positive, it will use the Box-Cox transform.
/// If any values are negative or zero, it will use the Yeo-Johnson transform.
///
/// The optimal value of the `lambda` parameter is calculated from the data
/// using maximum likelihood estimation.
///
/// @experimental
#[derive(Debug)]
#[wasm_bindgen]
pub struct PowerTransform {
inner: Transform,
}

#[wasm_bindgen]
impl PowerTransform {
/// Create a new power transform for the given data.
///
/// @experimental
#[wasm_bindgen(constructor)]
pub fn new(opts: PowerTransformOptions) -> Result<PowerTransform, JsError> {
Ok(PowerTransform {
inner: Transform::power_transform(&opts.data)
.map_err(|e| JsError::new(&e.to_string()))?,
})
}

/// Transform the given data.
///
/// @experimental
#[wasm_bindgen]
pub fn transform(&self, data: VecF64) -> Result<Vec<f64>, JsError> {
Ok(self
.inner
.transform(data.convert()?.iter().copied())
.collect())
}

/// Inverse transform the given data.
///
/// @experimental
#[wasm_bindgen(js_name = "inverseTransform")]
pub fn inverse_transform(&self, data: VecF64) -> Result<Vec<f64>, JsError> {
Ok(self
.inner
.inverse_transform(data.convert()?.iter().copied())
.collect())
}
}

/// Options for the power transform.
#[derive(Debug, Default, Deserialize, Tsify)]
#[serde(rename_all = "camelCase")]
#[tsify(from_wasm_abi)]
pub struct PowerTransformOptions {
/// The data to transform. This is used to calculate the optimal value of 'lambda'.
#[tsify(type = "number[] | Float64Array")]
pub data: Vec<f64>,
}
3 changes: 2 additions & 1 deletion js/justfile
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ build: \
(build-inner "mstl") \
(build-inner "outlier") \
(build-inner "prophet") \
(build-inner "seasons")
(build-inner "seasons") \
(build-inner "transforms")
just fix-package-json

build-inner target args='':
Expand Down
3 changes: 2 additions & 1 deletion js/package.json.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
"./mstl": "./mstl.js",
"./prophet": "./prophet.js",
"./outlier": "./outlier.js",
"./seasons": "./seasons.js"
"./seasons": "./seasons.js",
"./transforms": "./transforms.js"
},
"types": "augurs.d.ts",
"sideEffects": [
Expand Down
63 changes: 63 additions & 0 deletions js/testpkg/transforms.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import { webcrypto } from 'node:crypto'
import { readFileSync } from "node:fs";

import { PowerTransform, initSync } from '@bsull/augurs/transforms';

import { describe, expect, it } from 'vitest';

// Required for Rust's `rand::thread_rng` to support NodeJS modules.
// See https://docs.rs/getrandom#nodejs-es-module-support.
// @ts-ignore
globalThis.crypto = webcrypto

initSync({ module: readFileSync('node_modules/@bsull/augurs/transforms_bg.wasm') });

describe('transforms', () => {
const y = [
0.1, 0.3, 0.8, 0.5,
0.1, 0.31, 0.79, 0.48,
0.09, 0.29, 0.81, 0.49,
0.11, 0.28, 0.78, 0.53,
0.1, 0.3, 0.8, 0.5,
0.1, 0.31, 0.79, 0.48,
0.09, 0.29, 0.81, 0.49,
0.11, 0.28, 0.78, 0.53,
];

expect.extend({
toAllBeCloseTo: (received, expected) => {
if (received.length !== expected.length) {
return {
message: () => `expected array lengths to match (got ${received.length}, wanted ${expected.length})`,
pass: false,
};
}
for (let index = 0; index < received.length; index++) {
const got = received[index];
const exp = expected[index];
if (Math.abs(got - exp) > 0.1) {
return {
message: () => `got (${got}) not close to expected (${exp}) at index ${index}`,
pass: false,
}
}
}
return { message: () => '', pass: true };
}
});


describe('power transform', () => {
it('works with arrays', () => {
const pt = new PowerTransform({ data: y });
const transformed = pt.transform(y);
expect(transformed).toBeInstanceOf(Float64Array);
expect(transformed).toHaveLength(y.length);
const inverse = pt.inverseTransform(transformed);
expect(inverse).toBeInstanceOf(Float64Array);
expect(inverse).toHaveLength(y.length);
//@ts-ignore
expect(Array.from(inverse)).toAllBeCloseTo(y);
});
})
})

0 comments on commit 954bd91

Please sign in to comment.