Skip to content

Commit

Permalink
Fix #245: return error for NaN in naive bayes
Browse files Browse the repository at this point in the history
  • Loading branch information
Mec-iS committed Nov 21, 2022
1 parent 83dcf9a commit 8881439
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 4 deletions.
1 change: 0 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
clippy::upper_case_acronyms
)]
#![warn(missing_docs)]
#![warn(rustdoc::missing_doc_code_examples)]

//! # smartcore
//!
Expand Down
21 changes: 21 additions & 0 deletions src/naive_bayes/gaussian.rs
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,27 @@ mod tests {
);
}

#[test]
fn run_gaussian_naive_bayes_with_few_samples() {
let x = DenseMatrix::<f64>::from_2d_array(&[
&[-1., -1.],
&[-2., -1.],
&[-3., -2.],
&[1., 1.],
]);
let y: Vec<u32> = vec![1, 1, 1, 2];

let gnb = GaussianNB::fit(&x, &y, Default::default());

match gnb.unwrap().predict(&x) {
Ok(_) => assert!(false, "test should return Failed"),
Err(err) => {
assert!(err.to_string() == "Can't find solution: log_likelihood for distribution of one of the rows is NaN");
assert!(true)
},
}
}

#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
Expand Down
18 changes: 15 additions & 3 deletions src/naive_bayes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
//!
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use crate::error::Failed;
use crate::error::{Failed, FailedError};
use crate::linalg::basic::arrays::{Array1, Array2, ArrayView1};
use crate::numbers::basenum::Number;
#[cfg(feature = "serde")]
Expand Down Expand Up @@ -93,24 +93,36 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D: NBDistribution<TX,
pub fn predict(&self, x: &X) -> Result<Y, Failed> {
let y_classes = self.distribution.classes();
let (rows, _) = x.shape();
let mut log_likehood_is_nan = false;
let predictions = (0..rows)
.map(|row_index| {
let row = x.get_row(row_index);
let (prediction, _probability) = y_classes
.iter()
.enumerate()
.map(|(class_index, class)| {
let mut log_likelihood = self.distribution.log_likelihood(class_index, &row);
if log_likelihood.is_nan() {
log_likelihood = 0f64;
log_likehood_is_nan = true;
}
(
class,
self.distribution.log_likelihood(class_index, &row)
log_likelihood
+ self.distribution.prior(class_index).ln(),
)
})
})
.max_by(|(_, p1), (_, p2)| p1.partial_cmp(p2).unwrap())
.unwrap();
*prediction
})
.collect::<Vec<TY>>();
if log_likehood_is_nan {
return Err(Failed::because(
FailedError::SolutionFailed,
"log_likelihood for distribution of one of the rows is NaN",
));
}
let y_hat = Y::from_vec_slice(&predictions);
Ok(y_hat)
}
Expand Down

0 comments on commit 8881439

Please sign in to comment.