From f498f9629e99c9bbd3db37b3e962a31fb067c109 Mon Sep 17 00:00:00 2001 From: Lorenzo Date: Mon, 20 Mar 2023 23:45:44 +0900 Subject: [PATCH] Implement realnum::rand (#251) Co-authored-by: Luis Moreno Co-authored-by: Lorenzo * Implement rand. Use the new derive [#default] * Use custom range * Use range seed * Bump version * Add array length checks for --- Cargo.toml | 2 +- src/algorithm/neighbour/mod.rs | 9 ++----- src/cluster/dbscan.rs | 2 +- src/ensemble/random_forest_classifier.rs | 30 +++++++++++++++++++++++- src/ensemble/random_forest_regressor.rs | 30 ++++++++++++++++++++++++ src/error/mod.rs | 2 +- src/linear/logistic_regression.rs | 9 ++----- src/linear/ridge_regression.rs | 9 ++----- src/neighbors/mod.rs | 9 ++----- src/numbers/realnum.rs | 29 ++++++++++++++++++++--- src/tree/decision_tree_classifier.rs | 26 +++++++++++++------- src/tree/decision_tree_regressor.rs | 5 +++- 12 files changed, 118 insertions(+), 44 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 5da2fe8b..a30db160 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,7 +2,7 @@ name = "smartcore" description = "Machine Learning in Rust." homepage = "https://smartcorelib.org" -version = "0.3.0" +version = "0.3.1" authors = ["smartcore Developers"] edition = "2021" license = "Apache-2.0" diff --git a/src/algorithm/neighbour/mod.rs b/src/algorithm/neighbour/mod.rs index e150d19f..3bee93aa 100644 --- a/src/algorithm/neighbour/mod.rs +++ b/src/algorithm/neighbour/mod.rs @@ -49,20 +49,15 @@ pub mod linear_search; /// Both, KNN classifier and regressor benefits from underlying search algorithms that helps to speed up queries. /// `KNNAlgorithmName` maintains a list of supported search algorithms, see [KNN algorithms](../algorithm/neighbour/index.html) #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Default)] pub enum KNNAlgorithmName { /// Heap Search algorithm, see [`LinearSearch`](../algorithm/neighbour/linear_search/index.html) LinearSearch, /// Cover Tree Search algorithm, see [`CoverTree`](../algorithm/neighbour/cover_tree/index.html) + #[default] CoverTree, } -impl Default for KNNAlgorithmName { - fn default() -> Self { - KNNAlgorithmName::CoverTree - } -} - #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug)] pub(crate) enum KNNAlgorithm>> { diff --git a/src/cluster/dbscan.rs b/src/cluster/dbscan.rs index e9e3329b..0d84a613 100644 --- a/src/cluster/dbscan.rs +++ b/src/cluster/dbscan.rs @@ -18,7 +18,7 @@ //! //! Example: //! -//! ``` +//! ```ignore //! use smartcore::linalg::basic::matrix::DenseMatrix; //! use smartcore::linalg::basic::arrays::Array2; //! use smartcore::cluster::dbscan::*; diff --git a/src/ensemble/random_forest_classifier.rs b/src/ensemble/random_forest_classifier.rs index 8ea174b5..6448b52e 100644 --- a/src/ensemble/random_forest_classifier.rs +++ b/src/ensemble/random_forest_classifier.rs @@ -454,8 +454,12 @@ impl, Y: Array1 Result, Failed> { - let (_, num_attributes) = x.shape(); + let (x_nrows, num_attributes) = x.shape(); let y_ncols = y.shape(); + if x_nrows != y_ncols { + return Err(Failed::fit("Number of rows in X should = len(y)")); + } + let mut yi: Vec = vec![0; y_ncols]; let classes = y.unique(); @@ -678,6 +682,30 @@ mod tests { assert!(accuracy(&y, &classifier.predict(&x).unwrap()) >= 0.95); } + #[test] + fn test_random_matrix_with_wrong_rownum() { + let x_rand: DenseMatrix = DenseMatrix::::rand(21, 200); + + let y: Vec = vec![0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]; + + let fail = RandomForestClassifier::fit( + &x_rand, + &y, + RandomForestClassifierParameters { + criterion: SplitCriterion::Gini, + max_depth: Option::None, + min_samples_leaf: 1, + min_samples_split: 2, + n_trees: 100, + m: Option::None, + keep_samples: false, + seed: 87, + }, + ); + + assert!(fail.is_err()); + } + #[cfg_attr( all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test diff --git a/src/ensemble/random_forest_regressor.rs b/src/ensemble/random_forest_regressor.rs index 34b9ee10..926327e1 100644 --- a/src/ensemble/random_forest_regressor.rs +++ b/src/ensemble/random_forest_regressor.rs @@ -399,6 +399,10 @@ impl, Y: Array1 ) -> Result, Failed> { let (n_rows, num_attributes) = x.shape(); + if n_rows != y.shape() { + return Err(Failed::fit("Number of rows in X should = len(y)")); + } + let mtry = parameters .m .unwrap_or((num_attributes as f64).sqrt().floor() as usize); @@ -595,6 +599,32 @@ mod tests { assert!(mean_absolute_error(&y, &y_hat) < 1.0); } + #[test] + fn test_random_matrix_with_wrong_rownum() { + let x_rand: DenseMatrix = DenseMatrix::::rand(17, 200); + + let y = vec![ + 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, + 114.2, 115.7, 116.9, + ]; + + let fail = RandomForestRegressor::fit( + &x_rand, + &y, + RandomForestRegressorParameters { + max_depth: Option::None, + min_samples_leaf: 1, + min_samples_split: 2, + n_trees: 1000, + m: Option::None, + keep_samples: false, + seed: 87, + }, + ); + + assert!(fail.is_err()); + } + #[cfg_attr( all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test diff --git a/src/error/mod.rs b/src/error/mod.rs index 5a99e856..838df085 100644 --- a/src/error/mod.rs +++ b/src/error/mod.rs @@ -30,7 +30,7 @@ pub enum FailedError { DecompositionFailed, /// Can't solve for x SolutionFailed, - /// Erro in input + /// Error in input parameters ParametersError, } diff --git a/src/linear/logistic_regression.rs b/src/linear/logistic_regression.rs index 044a771d..4a4041bc 100644 --- a/src/linear/logistic_regression.rs +++ b/src/linear/logistic_regression.rs @@ -71,19 +71,14 @@ use crate::optimization::line_search::Backtracking; use crate::optimization::FunctionOrder; #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[derive(Debug, Clone, Eq, PartialEq)] +#[derive(Debug, Clone, Eq, PartialEq, Default)] /// Solver options for Logistic regression. Right now only LBFGS solver is supported. pub enum LogisticRegressionSolverName { /// Limited-memory Broyden–Fletcher–Goldfarb–Shanno method, see [LBFGS paper](http://users.iems.northwestern.edu/~nocedal/lbfgsb.html) + #[default] LBFGS, } -impl Default for LogisticRegressionSolverName { - fn default() -> Self { - LogisticRegressionSolverName::LBFGS - } -} - /// Logistic Regression parameters #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug, Clone)] diff --git a/src/linear/ridge_regression.rs b/src/linear/ridge_regression.rs index 2cddb005..2c354299 100644 --- a/src/linear/ridge_regression.rs +++ b/src/linear/ridge_regression.rs @@ -71,21 +71,16 @@ use crate::numbers::basenum::Number; use crate::numbers::realnum::RealNumber; #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[derive(Debug, Clone, Eq, PartialEq)] +#[derive(Debug, Clone, Eq, PartialEq, Default)] /// Approach to use for estimation of regression coefficients. Cholesky is more efficient but SVD is more stable. pub enum RidgeRegressionSolverName { /// Cholesky decomposition, see [Cholesky](../../linalg/cholesky/index.html) + #[default] Cholesky, /// SVD decomposition, see [SVD](../../linalg/svd/index.html) SVD, } -impl Default for RidgeRegressionSolverName { - fn default() -> Self { - RidgeRegressionSolverName::Cholesky - } -} - /// Ridge Regression parameters #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug, Clone)] diff --git a/src/neighbors/mod.rs b/src/neighbors/mod.rs index 40b854ab..0abe9bdc 100644 --- a/src/neighbors/mod.rs +++ b/src/neighbors/mod.rs @@ -49,20 +49,15 @@ pub type KNNAlgorithmName = crate::algorithm::neighbour::KNNAlgorithmName; /// Weight function that is used to determine estimated value. #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Default)] pub enum KNNWeightFunction { /// All k nearest points are weighted equally + #[default] Uniform, /// k nearest points are weighted by the inverse of their distance. Closer neighbors will have a greater influence than neighbors which are further away. Distance, } -impl Default for KNNWeightFunction { - fn default() -> Self { - KNNWeightFunction::Uniform - } -} - impl KNNWeightFunction { fn calc_weights(&self, distances: Vec) -> std::vec::Vec { match *self { diff --git a/src/numbers/realnum.rs b/src/numbers/realnum.rs index f4d9aec1..8ef71555 100644 --- a/src/numbers/realnum.rs +++ b/src/numbers/realnum.rs @@ -2,9 +2,13 @@ //! Most algorithms in `smartcore` rely on basic linear algebra operations like dot product, matrix decomposition and other subroutines that are defined for a set of real numbers, ℝ. //! This module defines real number and some useful functions that are used in [Linear Algebra](../../linalg/index.html) module. +use rand::rngs::SmallRng; +use rand::{Rng, SeedableRng}; + use num_traits::Float; use crate::numbers::basenum::Number; +use crate::rand_custom::get_rng_impl; /// Defines real number /// @@ -63,8 +67,12 @@ impl RealNumber for f64 { } fn rand() -> f64 { - // TODO: to be implemented, see issue smartcore#214 - 1.0 + let mut small_rng = get_rng_impl(None); + + let mut rngs: Vec = (0..3) + .map(|_| SmallRng::from_rng(&mut small_rng).unwrap()) + .collect(); + rngs[0].gen::() } fn two() -> Self { @@ -108,7 +116,12 @@ impl RealNumber for f32 { } fn rand() -> f32 { - 1.0 + let mut small_rng = get_rng_impl(None); + + let mut rngs: Vec = (0..3) + .map(|_| SmallRng::from_rng(&mut small_rng).unwrap()) + .collect(); + rngs[0].gen::() } fn two() -> Self { @@ -149,4 +162,14 @@ mod tests { fn f64_from_string() { assert_eq!(f64::from_str("1.111111111").unwrap(), 1.111111111) } + + #[test] + fn f64_rand() { + f64::rand(); + } + + #[test] + fn f32_rand() { + f32::rand(); + } } diff --git a/src/tree/decision_tree_classifier.rs b/src/tree/decision_tree_classifier.rs index 95c1d895..4f36e5b9 100644 --- a/src/tree/decision_tree_classifier.rs +++ b/src/tree/decision_tree_classifier.rs @@ -137,16 +137,17 @@ impl, Y: Array1> self.classes.as_ref() } /// Get depth of tree - fn depth(&self) -> u16 { + pub fn depth(&self) -> u16 { self.depth } } /// The function to measure the quality of a split. #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Default)] pub enum SplitCriterion { /// [Gini index](../decision_tree_classifier/index.html) + #[default] Gini, /// [Entropy](../decision_tree_classifier/index.html) Entropy, @@ -154,12 +155,6 @@ pub enum SplitCriterion { ClassificationError, } -impl Default for SplitCriterion { - fn default() -> Self { - SplitCriterion::Gini - } -} - #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug, Clone)] struct Node { @@ -543,6 +538,10 @@ impl, Y: Array1> parameters: DecisionTreeClassifierParameters, ) -> Result, Failed> { let (x_nrows, num_attributes) = x.shape(); + if x_nrows != y.shape() { + return Err(Failed::fit("Size of x should equal size of y")); + } + let samples = vec![1; x_nrows]; DecisionTreeClassifier::fit_weak_learner(x, y, samples, num_attributes, parameters) } @@ -968,6 +967,17 @@ mod tests { ); } + #[test] + fn test_random_matrix_with_wrong_rownum() { + let x_rand: DenseMatrix = DenseMatrix::::rand(21, 200); + + let y: Vec = vec![0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]; + + let fail = DecisionTreeClassifier::fit(&x_rand, &y, Default::default()); + + assert!(fail.is_err()); + } + #[cfg_attr( all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test diff --git a/src/tree/decision_tree_regressor.rs b/src/tree/decision_tree_regressor.rs index 0146cbc5..d21c7490 100644 --- a/src/tree/decision_tree_regressor.rs +++ b/src/tree/decision_tree_regressor.rs @@ -18,7 +18,6 @@ //! Example: //! //! ``` -//! use rand::thread_rng; //! use smartcore::linalg::basic::matrix::DenseMatrix; //! use smartcore::tree::decision_tree_regressor::*; //! @@ -422,6 +421,10 @@ impl, Y: Array1> parameters: DecisionTreeRegressorParameters, ) -> Result, Failed> { let (x_nrows, num_attributes) = x.shape(); + if x_nrows != y.shape() { + return Err(Failed::fit("Size of x should equal size of y")); + } + let samples = vec![1; x_nrows]; DecisionTreeRegressor::fit_weak_learner(x, y, samples, num_attributes, parameters) }