Skip to content

Commit

Permalink
Refactor and add a couple more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sd2k committed Oct 4, 2024
1 parent fca7053 commit 02d45d7
Show file tree
Hide file tree
Showing 8 changed files with 275 additions and 58 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ augurs-testing = { path = "crates/augurs-testing" }
chrono = "0.4.38"
distrs = "0.2.1"
itertools = "0.13.0"
num-traits = "0.2.19"
roots = "0.0.8"
serde = { version = "1.0.166", features = ["derive"] }
thiserror = "1.0.40"
Expand Down
1 change: 1 addition & 0 deletions crates/augurs-prophet/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ keywords.workspace = true

[dependencies]
itertools.workspace = true
num-traits.workspace = true
thiserror.workspace = true
tracing.workspace = true

Expand Down
24 changes: 24 additions & 0 deletions crates/augurs-prophet/src/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,30 @@ impl TrainingData {
}
self
}

#[cfg(test)]
pub(crate) fn tail(mut self, n: usize) -> Self {
self.ds = self.ds.split_off(n);
self.y = self.y.split_off(n);
if let Some(cap) = self.cap.as_mut() {
*cap = cap.split_off(n);
}
if let Some(floor) = self.floor.as_mut() {
*floor = floor.split_off(n);
}
for (_, v) in self.x.iter_mut() {
*v = v.split_off(n);
}
for (_, v) in self.seasonality_conditions.iter_mut() {
*v = v.split_off(n);
}
self
}

#[cfg(test)]
pub(crate) fn len(&self) -> usize {
self.ds.len()
}
}

/// The data needed to predict with a Prophet model.
Expand Down
2 changes: 2 additions & 0 deletions crates/augurs-prophet/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ mod positive_float;
mod prophet;
#[cfg(test)]
mod testdata;
mod util;

/// A timestamp represented as seconds since the epoch.
pub type TimestampSeconds = u64;
Expand All @@ -32,3 +33,4 @@ pub use prophet::{
},
Prophet,
};
use util::FloatIterExt;
2 changes: 1 addition & 1 deletion crates/augurs-prophet/src/optimizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ pub enum Algorithm {
}

/// Arguments for optimization.
#[derive(Debug, Clone, Copy)]
#[derive(Default, Debug, Clone, Copy)]
pub struct OptimizeOpts {
/// Algorithm to use.
pub algorithm: Option<Algorithm>,
Expand Down
Loading

0 comments on commit 02d45d7

Please sign in to comment.