From 888143931278f165d3456d4be3d66471fe2a0ca4 Mon Sep 17 00:00:00 2001 From: "Lorenzo (Mec-iS)" Date: Mon, 21 Nov 2022 11:32:59 +0000 Subject: [PATCH] Fix #245: return error for NaN in naive bayes --- src/lib.rs | 1 - src/naive_bayes/gaussian.rs | 21 +++++++++++++++++++++ src/naive_bayes/mod.rs | 18 +++++++++++++++--- 3 files changed, 36 insertions(+), 4 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 03bfc03b..b8fe5b08 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,7 +6,6 @@ clippy::upper_case_acronyms )] #![warn(missing_docs)] -#![warn(rustdoc::missing_doc_code_examples)] //! # smartcore //! diff --git a/src/naive_bayes/gaussian.rs b/src/naive_bayes/gaussian.rs index a9c1d4fe..8a0a8e19 100644 --- a/src/naive_bayes/gaussian.rs +++ b/src/naive_bayes/gaussian.rs @@ -425,6 +425,27 @@ mod tests { ); } + #[test] + fn run_gaussian_naive_bayes_with_few_samples() { + let x = DenseMatrix::::from_2d_array(&[ + &[-1., -1.], + &[-2., -1.], + &[-3., -2.], + &[1., 1.], + ]); + let y: Vec = vec![1, 1, 1, 2]; + + let gnb = GaussianNB::fit(&x, &y, Default::default()); + + match gnb.unwrap().predict(&x) { + Ok(_) => assert!(false, "test should return Failed"), + Err(err) => { + assert!(err.to_string() == "Can't find solution: log_likelihood for distribution of one of the rows is NaN"); + assert!(true) + }, + } + } + #[cfg_attr( all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test diff --git a/src/naive_bayes/mod.rs b/src/naive_bayes/mod.rs index e7ab7f6d..5ce6af2f 100644 --- a/src/naive_bayes/mod.rs +++ b/src/naive_bayes/mod.rs @@ -35,7 +35,7 @@ //! //! //! -use crate::error::Failed; +use crate::error::{Failed, FailedError}; use crate::linalg::basic::arrays::{Array1, Array2, ArrayView1}; use crate::numbers::basenum::Number; #[cfg(feature = "serde")] @@ -93,6 +93,7 @@ impl, Y: Array1, D: NBDistribution Result { let y_classes = self.distribution.classes(); let (rows, _) = x.shape(); + let mut log_likehood_is_nan = false; let predictions = (0..rows) .map(|row_index| { let row = x.get_row(row_index); @@ -100,17 +101,28 @@ impl, Y: Array1, D: NBDistribution>(); + if log_likehood_is_nan { + return Err(Failed::because( + FailedError::SolutionFailed, + "log_likelihood for distribution of one of the rows is NaN", + )); + } let y_hat = Y::from_vec_slice(&predictions); Ok(y_hat) }