diff --git a/Cargo.lock b/Cargo.lock index 999a9c8388..907670d108 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -890,6 +890,7 @@ dependencies = [ "log", "nvml-wrapper", "ratatui", + "rstest", "serde", "sysinfo", "systemstat", diff --git a/burn-book/src/building-blocks/metric.md b/burn-book/src/building-blocks/metric.md index e1b2ab3321..bdaef635d0 100644 --- a/burn-book/src/building-blocks/metric.md +++ b/burn-book/src/building-blocks/metric.md @@ -7,6 +7,7 @@ throughout the training process. We currently offer a restricted range of metric | ---------------- | ------------------------------------------------------- | | Accuracy | Calculate the accuracy in percentage | | TopKAccuracy | Calculate the top-k accuracy in percentage | +| Precision | Calculate precision in percentage | | AUROC | Calculate the area under curve of ROC in percentage | | Loss | Output the loss used for the backward pass | | CPU Temperature | Fetch the temperature of CPUs | diff --git a/crates/burn-tensor/src/tensor/api/check.rs b/crates/burn-tensor/src/tensor/api/check.rs index d665af2585..0f02a0c060 100644 --- a/crates/burn-tensor/src/tensor/api/check.rs +++ b/crates/burn-tensor/src/tensor/api/check.rs @@ -478,6 +478,11 @@ impl TensorCheck { "Can't create a one hot tensor from ({index_tensor:?}) containing indexes greater or equal to the number of classes ({num_classes})", )), ); + } else if num_classes <= 1 { + check = check.register( + "One Hot", + TensorError::new("Can't create a one hot tensor with less then 2 classes"), + ) } check } diff --git a/crates/burn-train/Cargo.toml b/crates/burn-train/Cargo.toml index 04ddb66f2c..d65ef53691 100644 --- a/crates/burn-train/Cargo.toml +++ b/crates/burn-train/Cargo.toml @@ -40,6 +40,7 @@ ratatui = { workspace = true, optional = true, features = ["all-widgets", "cross derive-new = { workspace = true } serde = { workspace = true, features = ["std", "derive"] } async-channel = { workspace = true } +rstest.workspace = true [dev-dependencies] burn-ndarray = { path = "../burn-ndarray", version = "0.16.0" } diff --git a/crates/burn-train/src/learner/classification.rs b/crates/burn-train/src/learner/classification.rs index ee86a05754..381cd3a96c 100644 --- a/crates/burn-train/src/learner/classification.rs +++ b/crates/burn-train/src/learner/classification.rs @@ -1,4 +1,4 @@ -use crate::metric::{AccuracyInput, Adaptor, HammingScoreInput, LossInput}; +use crate::metric::{AccuracyInput, Adaptor, HammingScoreInput, LossInput, PrecisionInput}; use burn_core::tensor::backend::Backend; use burn_core::tensor::{Int, Tensor}; @@ -27,6 +27,23 @@ impl Adaptor> for ClassificationOutput { } } +impl Adaptor> for ClassificationOutput { + fn adapt(&self) -> PrecisionInput { + let [_, num_classes] = self.output.dims(); + if num_classes > 1 { + PrecisionInput::new( + self.output.clone(), + self.targets.clone().one_hot(num_classes).bool(), + ) + } else { + PrecisionInput::new( + self.output.clone(), + self.targets.clone().unsqueeze_dim(1).bool(), + ) + } + } +} + /// Multi-label classification output adapted for multiple metrics. #[derive(new)] pub struct MultiLabelClassificationOutput { @@ -51,3 +68,9 @@ impl Adaptor> for MultiLabelClassificationOutput { LossInput::new(self.loss.clone()) } } + +impl Adaptor> for MultiLabelClassificationOutput { + fn adapt(&self) -> PrecisionInput { + PrecisionInput::new(self.output.clone(), self.targets.clone().bool()) + } +} diff --git a/crates/burn-train/src/lib.rs b/crates/burn-train/src/lib.rs index 23413fa2ef..24337498c5 100644 --- a/crates/burn-train/src/lib.rs +++ b/crates/burn-train/src/lib.rs @@ -26,3 +26,78 @@ pub use learner::*; #[cfg(test)] pub(crate) type TestBackend = burn_ndarray::NdArray; + +#[cfg(test)] +pub(crate) mod tests { + use crate::TestBackend; + use burn_core::{prelude::Tensor, tensor::Bool}; + use std::default::Default; + + /// Probability of tp before adding errors + pub const THRESHOLD: f64 = 0.5; + + #[derive(Debug)] + pub enum ClassificationType { + Binary, + Multiclass, + Multilabel, + } + + /// Sample x Class shaped matrix for use in + /// classification metrics testing + pub fn dummy_classification_input( + classification_type: &ClassificationType, + ) -> (Tensor, Tensor) { + match classification_type { + ClassificationType::Binary => { + ( + Tensor::from_data( + [[0.3], [0.2], [0.7], [0.1], [0.55]], + //[[0], [0], [1], [0], [1]] with threshold=0.5 + &Default::default(), + ), + Tensor::from_data([[0], [1], [0], [0], [1]], &Default::default()), + ) + } + ClassificationType::Multiclass => { + ( + Tensor::from_data( + [ + [0.2, 0.8, 0.0], + [0.3, 0.6, 0.1], + [0.7, 0.25, 0.05], + [0.1, 0.15, 0.8], + [0.9, 0.03, 0.07], + ], + //[[0, 1, 0], [0, 1, 0], [1, 0, 0], [0, 0, 1], [1, 0, 0]] with top_k=1 + //[[1, 1, 0], [1, 1, 0], [1, 1, 0], [0, 1, 1], [1, 0, 1]] with top_k=2 + &Default::default(), + ), + Tensor::from_data( + [[0, 1, 0], [1, 0, 0], [0, 0, 1], [0, 0, 1], [1, 0, 0]], + &Default::default(), + ), + ) + } + ClassificationType::Multilabel => { + ( + Tensor::from_data( + [ + [0.1, 0.7, 0.6], + [0.3, 0.9, 0.05], + [0.8, 0.9, 0.4], + [0.7, 0.5, 0.9], + [1.0, 0.3, 0.2], + ], + //[[0, 1, 1], [0, 1, 0], [1, 1, 0], [1, 0, 1], [1, 0, 0]] with threshold=0.5 + &Default::default(), + ), + Tensor::from_data( + [[1, 1, 0], [1, 0, 1], [1, 1, 1], [0, 0, 1], [1, 0, 0]], + &Default::default(), + ), + ) + } + } + } +} diff --git a/crates/burn-train/src/metric/classification.rs b/crates/burn-train/src/metric/classification.rs new file mode 100644 index 0000000000..1eb51a85d0 --- /dev/null +++ b/crates/burn-train/src/metric/classification.rs @@ -0,0 +1,9 @@ +/// The reduction strategy for classification metrics. +#[derive(Copy, Clone, Default)] +pub enum ClassReduction { + /// Computes the statistics over all classes before averaging + Micro, + /// Computes the statistics independently for each class before averaging + #[default] + Macro, +} diff --git a/crates/burn-train/src/metric/confusion_stats.rs b/crates/burn-train/src/metric/confusion_stats.rs new file mode 100644 index 0000000000..cdb01b1721 --- /dev/null +++ b/crates/burn-train/src/metric/confusion_stats.rs @@ -0,0 +1,351 @@ +use super::classification::ClassReduction; +use burn_core::prelude::{Backend, Bool, Int, Tensor}; +use std::fmt::{self, Debug}; +use std::num::NonZeroUsize; + +#[derive(Clone)] +pub struct ConfusionStats { + confusion_classes: Tensor, + class_reduction: ClassReduction, +} + +impl Debug for ConfusionStats { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let to_vec = |tensor_data: Tensor| { + tensor_data + .to_data() + .to_vec::() + .expect("A vector representation of the input Tensor is expected") + }; + let ratio_of_support_vec = + |metric: Tensor| to_vec(self.clone().ratio_of_support(metric)); + f.debug_struct("ConfusionStats") + .field("tp", &ratio_of_support_vec(self.clone().true_positive())) + .field("fp", &ratio_of_support_vec(self.clone().false_positive())) + .field("tn", &ratio_of_support_vec(self.clone().true_negative())) + .field("fn", &ratio_of_support_vec(self.clone().false_negative())) + .field("support", &to_vec(self.clone().support())) + .finish() + } +} + +impl ConfusionStats { + /// Expects `predictions` to be normalized. + pub fn new( + predictions: Tensor, + targets: Tensor, + threshold: Option, + top_k: Option, + class_reduction: ClassReduction, + ) -> Self { + let prediction_mask = match (threshold, top_k) { + (Some(threshold), None) => { + predictions.greater_elem(threshold) + }, + (None, Some(top_k)) => { + let mask = predictions.zeros_like(); + let indexes = predictions.argsort_descending(1).narrow(1, 0, top_k.get()); + let values = indexes.ones_like().float(); + mask.scatter(1, indexes, values).bool() + } + _ => panic!("Either threshold (for binary or multilabel) or top_k (for multiclass) must be set."), + }; + Self { + confusion_classes: prediction_mask.int() + targets.int() * 2, + class_reduction, + } + } + + /// sum over samples + fn aggregate( + sample_class_mask: Tensor, + class_reduction: ClassReduction, + ) -> Tensor { + use ClassReduction::*; + match class_reduction { + Micro => sample_class_mask.float().sum(), + Macro => sample_class_mask.float().sum_dim(0).squeeze(0), + } + } + + pub fn true_positive(self) -> Tensor { + Self::aggregate(self.confusion_classes.equal_elem(3), self.class_reduction) + } + + pub fn true_negative(self) -> Tensor { + Self::aggregate(self.confusion_classes.equal_elem(0), self.class_reduction) + } + + pub fn false_positive(self) -> Tensor { + Self::aggregate(self.confusion_classes.equal_elem(1), self.class_reduction) + } + + pub fn false_negative(self) -> Tensor { + Self::aggregate(self.confusion_classes.equal_elem(2), self.class_reduction) + } + + pub fn positive(self) -> Tensor { + self.clone().true_positive() + self.false_negative() + } + + pub fn negative(self) -> Tensor { + self.clone().true_negative() + self.false_positive() + } + + pub fn predicted_positive(self) -> Tensor { + self.clone().true_positive() + self.false_positive() + } + + pub fn support(self) -> Tensor { + self.clone().positive() + self.negative() + } + + pub fn ratio_of_support(self, metric: Tensor) -> Tensor { + metric / self.clone().support() + } +} + +#[cfg(test)] +mod tests { + use super::{ + ClassReduction::{self, *}, + ConfusionStats, + }; + use crate::tests::{ + dummy_classification_input, + ClassificationType::{self, *}, + THRESHOLD, + }; + use burn_core::prelude::TensorData; + use rstest::rstest; + use std::num::NonZeroUsize; + + #[rstest] + #[should_panic] + #[case::both_some(Some(THRESHOLD), Some(1))] + #[should_panic] + #[case::both_none(None, None)] + fn test_exclusive_threshold_top_k( + #[case] threshold: Option, + #[case] top_k: Option, + ) { + let (predictions, targets) = dummy_classification_input(&Binary).into(); + ConfusionStats::new( + predictions, + targets, + threshold, + top_k.map(NonZeroUsize::new).flatten(), + Micro, + ); + } + + #[rstest] + #[case::binary_micro(Binary, Micro, Some(THRESHOLD), None, [1].into())] + #[case::binary_macro(Binary, Macro, Some(THRESHOLD), None, [1].into())] + #[case::multiclass_micro(Multiclass, Micro, None, Some(1), [3].into())] + #[case::multiclass_macro(Multiclass, Macro, None, Some(1), [1, 1, 1].into())] + #[case::multiclass_micro(Multiclass, Micro, None, Some(2), [4].into())] + #[case::multiclass_macro(Multiclass, Macro, None, Some(2), [2, 1, 1].into())] + #[case::multilabel_micro(Multilabel, Micro, Some(THRESHOLD), None, [5].into())] + #[case::multilabel_macro(Multilabel, Macro, Some(THRESHOLD), None, [2, 2, 1].into())] + fn test_true_positive( + #[case] classification_type: ClassificationType, + #[case] class_reduction: ClassReduction, + #[case] threshold: Option, + #[case] top_k: Option, + #[case] expected: Vec, + ) { + let (predictions, targets) = dummy_classification_input(&classification_type).into(); + ConfusionStats::new( + predictions, + targets, + threshold, + top_k.map(NonZeroUsize::new).flatten(), + class_reduction, + ) + .true_positive() + .int() + .into_data() + .assert_eq(&TensorData::from(expected.as_slice()), true); + } + + #[rstest] + #[case::binary_micro(Binary, Micro, Some(THRESHOLD), None, [2].into())] + #[case::binary_macro(Binary, Macro, Some(THRESHOLD), None, [2].into())] + #[case::multiclass_micro(Multiclass, Micro, None, Some(1), [8].into())] + #[case::multiclass_macro(Multiclass, Macro, None, Some(1), [2, 3, 3].into())] + #[case::multiclass_micro(Multiclass, Micro, None, Some(2), [4].into())] + #[case::multiclass_macro(Multiclass, Macro, None, Some(2), [1, 1, 2].into())] + #[case::multilabel_micro(Multilabel, Micro, Some(THRESHOLD), None, [3].into())] + #[case::multilabel_macro(Multilabel, Macro, Some(THRESHOLD), None, [0, 2, 1].into())] + fn test_true_negative( + #[case] classification_type: ClassificationType, + #[case] class_reduction: ClassReduction, + #[case] threshold: Option, + #[case] top_k: Option, + #[case] expected: Vec, + ) { + let (predictions, targets) = dummy_classification_input(&classification_type).into(); + ConfusionStats::new( + predictions, + targets, + threshold, + top_k.map(NonZeroUsize::new).flatten(), + class_reduction, + ) + .true_negative() + .int() + .into_data() + .assert_eq(&TensorData::from(expected.as_slice()), true); + } + + #[rstest] + #[case::binary_micro(Binary, Micro, Some(THRESHOLD), None, [1].into())] + #[case::binary_macro(Binary, Macro, Some(THRESHOLD), None, [1].into())] + #[case::multiclass_micro(Multiclass, Micro, None, Some(1), [2].into())] + #[case::multiclass_macro(Multiclass, Macro, None, Some(1), [1, 1, 0].into())] + #[case::multiclass_micro(Multiclass, Micro, None, Some(2), [6].into())] + #[case::multiclass_macro(Multiclass, Macro, None, Some(2), [2, 3, 1].into())] + #[case::multilabel_micro(Multilabel, Micro, Some(THRESHOLD), None, [3].into())] + #[case::multilabel_macro(Multilabel, Macro, Some(THRESHOLD), None, [1, 1, 1].into())] + fn test_false_positive( + #[case] classification_type: ClassificationType, + #[case] class_reduction: ClassReduction, + #[case] threshold: Option, + #[case] top_k: Option, + #[case] expected: Vec, + ) { + let (predictions, targets) = dummy_classification_input(&classification_type).into(); + ConfusionStats::new( + predictions, + targets, + threshold, + top_k.map(NonZeroUsize::new).flatten(), + class_reduction, + ) + .false_positive() + .int() + .into_data() + .assert_eq(&TensorData::from(expected.as_slice()), true); + } + + #[rstest] + #[case::binary_micro(Binary, Micro, Some(THRESHOLD), None, [1].into())] + #[case::binary_macro(Binary, Macro, Some(THRESHOLD), None, [1].into())] + #[case::multiclass_micro(Multiclass, Micro, None, Some(1), [2].into())] + #[case::multiclass_macro(Multiclass, Macro, None, Some(1), [1, 0, 1].into())] + #[case::multiclass_micro(Multiclass, Micro, None, Some(2), [1].into())] + #[case::multiclass_macro(Multiclass, Macro, None, Some(2), [0, 0, 1].into())] + #[case::multilabel_micro(Multilabel, Micro, Some(THRESHOLD), None, [4].into())] + #[case::multilabel_macro(Multilabel, Macro, Some(THRESHOLD), None, [2, 0, 2].into())] + fn test_false_negatives( + #[case] classification_type: ClassificationType, + #[case] class_reduction: ClassReduction, + #[case] threshold: Option, + #[case] top_k: Option, + #[case] expected: Vec, + ) { + let (predictions, targets) = dummy_classification_input(&classification_type).into(); + ConfusionStats::new( + predictions, + targets, + threshold, + top_k.map(NonZeroUsize::new).flatten(), + class_reduction, + ) + .false_negative() + .int() + .into_data() + .assert_eq(&TensorData::from(expected.as_slice()), true); + } + + #[rstest] + #[case::binary_micro(Binary, Micro, Some(THRESHOLD), None, [2].into())] + #[case::binary_macro(Binary, Macro, Some(THRESHOLD), None, [2].into())] + #[case::multiclass_micro(Multiclass, Micro, None, Some(1), [5].into())] + #[case::multiclass_macro(Multiclass, Macro, None, Some(1), [2, 1, 2].into())] + #[case::multiclass_micro(Multiclass, Micro, None, Some(2), [5].into())] + #[case::multiclass_macro(Multiclass, Macro, None, Some(2), [2, 1, 2].into())] + #[case::multilabel_micro(Multilabel, Micro, Some(THRESHOLD), None, [9].into())] + #[case::multilabel_macro(Multilabel, Macro, Some(THRESHOLD), None, [4, 2, 3].into())] + fn test_positive( + #[case] classification_type: ClassificationType, + #[case] class_reduction: ClassReduction, + #[case] threshold: Option, + #[case] top_k: Option, + #[case] expected: Vec, + ) { + let (predictions, targets) = dummy_classification_input(&classification_type).into(); + ConfusionStats::new( + predictions, + targets, + threshold, + top_k.map(NonZeroUsize::new).flatten(), + class_reduction, + ) + .positive() + .int() + .into_data() + .assert_eq(&TensorData::from(expected.as_slice()), true); + } + + #[rstest] + #[case::binary_micro(Binary, Micro, Some(THRESHOLD), None, [3].into())] + #[case::binary_macro(Binary, Macro, Some(THRESHOLD), None, [3].into())] + #[case::multiclass_micro(Multiclass, Micro, None, Some(1), [10].into())] + #[case::multiclass_macro(Multiclass, Macro, None, Some(1), [3, 4, 3].into())] + #[case::multiclass_micro(Multiclass, Micro, None, Some(2), [10].into())] + #[case::multiclass_macro(Multiclass, Macro, None, Some(2), [3, 4, 3].into())] + #[case::multilabel_micro(Multilabel, Micro, Some(THRESHOLD), None, [6].into())] + #[case::multilabel_macro(Multilabel, Macro, Some(THRESHOLD), None, [1, 3, 2].into())] + fn test_negative( + #[case] classification_type: ClassificationType, + #[case] class_reduction: ClassReduction, + #[case] threshold: Option, + #[case] top_k: Option, + #[case] expected: Vec, + ) { + let (predictions, targets) = dummy_classification_input(&classification_type).into(); + ConfusionStats::new( + predictions, + targets, + threshold, + top_k.map(NonZeroUsize::new).flatten(), + class_reduction, + ) + .negative() + .int() + .into_data() + .assert_eq(&TensorData::from(expected.as_slice()), true); + } + + #[rstest] + #[case::binary_micro(Binary, Micro, Some(THRESHOLD), None, [2].into())] + #[case::binary_macro(Binary, Macro, Some(THRESHOLD), None, [2].into())] + #[case::multiclass_micro(Multiclass, Micro, None, Some(1), [5].into())] + #[case::multiclass_macro(Multiclass, Macro, None, Some(1), [2, 2, 1].into())] + #[case::multiclass_micro(Multiclass, Micro, None, Some(2), [10].into())] + #[case::multiclass_macro(Multiclass, Macro, None, Some(2), [4, 4, 2].into())] + #[case::multilabel_micro(Multilabel, Micro, Some(THRESHOLD), None, [8].into())] + #[case::multilabel_macro(Multilabel, Macro, Some(THRESHOLD), None, [3, 3, 2].into())] + fn test_predicted_positive( + #[case] classification_type: ClassificationType, + #[case] class_reduction: ClassReduction, + #[case] threshold: Option, + #[case] top_k: Option, + #[case] expected: Vec, + ) { + let (predictions, targets) = dummy_classification_input(&classification_type).into(); + ConfusionStats::new( + predictions, + targets, + threshold, + top_k.map(NonZeroUsize::new).flatten(), + class_reduction, + ) + .predicted_positive() + .int() + .into_data() + .assert_eq(&TensorData::from(expected.as_slice()), true); + } +} diff --git a/crates/burn-train/src/metric/mod.rs b/crates/burn-train/src/metric/mod.rs index 2187734807..2b8d9cd801 100644 --- a/crates/burn-train/src/metric/mod.rs +++ b/crates/burn-train/src/metric/mod.rs @@ -43,3 +43,12 @@ pub use top_k_acc::*; pub(crate) mod processor; /// Module responsible to save and exposes data collected during training. pub mod store; + +pub(crate) mod classification; +#[cfg(feature = "metrics")] +pub use crate::metric::classification::ClassReduction; +mod confusion_stats; +#[cfg(feature = "metrics")] +mod precision; +#[cfg(feature = "metrics")] +pub use precision::*; diff --git a/crates/burn-train/src/metric/precision.rs b/crates/burn-train/src/metric/precision.rs new file mode 100644 index 0000000000..0b9efb3d8d --- /dev/null +++ b/crates/burn-train/src/metric/precision.rs @@ -0,0 +1,218 @@ +use super::{ + classification::ClassReduction, + confusion_stats::ConfusionStats, + state::{FormatOptions, NumericMetricState}, + Metric, MetricEntry, MetricMetadata, Numeric, +}; +use burn_core::{ + prelude::{Backend, Tensor}, + tensor::{cast::ToElement, Bool}, +}; +use core::marker::PhantomData; +use std::num::NonZeroUsize; + +/// Input for precision metric. +#[derive(new, Debug, Clone)] +pub struct PrecisionInput { + /// Sample x Class Non thresholded normalized predictions. + pub predictions: Tensor, + /// Sample x Class one-hot encoded target. + pub targets: Tensor, +} + +impl From> for (Tensor, Tensor) { + fn from(input: PrecisionInput) -> Self { + (input.predictions, input.targets) + } +} + +impl From<(Tensor, Tensor)> for PrecisionInput { + fn from(value: (Tensor, Tensor)) -> Self { + Self::new(value.0, value.1) + } +} + +enum PrecisionConfig { + Binary { threshold: f64 }, + Multiclass { top_k: NonZeroUsize }, + Multilabel { threshold: f64 }, +} + +impl Default for PrecisionConfig { + fn default() -> Self { + Self::Binary { threshold: 0.5 } + } +} + +///The Precision Metric +#[derive(Default)] +pub struct PrecisionMetric { + state: NumericMetricState, + _b: PhantomData, + class_reduction: ClassReduction, + config: PrecisionConfig, +} + +impl PrecisionMetric { + /// Precision metric for binary classification. + /// + /// # Arguments + /// + /// * `threshold` - The threshold to transform a probability into a binary prediction. + #[allow(dead_code)] + pub fn binary(threshold: f64) -> Self { + Self { + config: PrecisionConfig::Binary { threshold }, + ..Default::default() + } + } + + /// Precision metric for multiclass classification. + /// + /// # Arguments + /// + /// * `top_k` - The number of highest predictions considered to find the correct label (typically `1`). + #[allow(dead_code)] + pub fn multiclass(top_k: usize) -> Self { + Self { + config: PrecisionConfig::Multiclass { + top_k: NonZeroUsize::new(top_k).expect("top_k must be non-zero"), + }, + ..Default::default() + } + } + + /// Precision metric for multi-label classification. + /// + /// # Arguments + /// + /// * `threshold` - The threshold to transform a probability into a binary prediction. + #[allow(dead_code)] + pub fn multilabel(threshold: f64) -> Self { + Self { + config: PrecisionConfig::Multilabel { threshold }, + ..Default::default() + } + } + + /// Sets the class reduction method. + #[allow(dead_code)] + pub fn with_class_reduction(mut self, class_reduction: ClassReduction) -> Self { + self.class_reduction = class_reduction; + self + } + + fn class_average(&self, mut aggregated_metric: Tensor) -> f64 { + use ClassReduction::*; + let avg_tensor = match self.class_reduction { + Micro => aggregated_metric, + Macro => { + if aggregated_metric.contains_nan().any().into_scalar() { + let nan_mask = aggregated_metric.is_nan(); + aggregated_metric = aggregated_metric + .clone() + .select(0, nan_mask.bool_not().argwhere().squeeze(1)) + } + aggregated_metric.mean() + } + }; + avg_tensor.into_scalar().to_f64() + } +} + +impl Metric for PrecisionMetric { + const NAME: &'static str = "Precision"; + type Input = PrecisionInput; + + fn update(&mut self, input: &Self::Input, _metadata: &MetricMetadata) -> MetricEntry { + let (predictions, targets) = input.clone().into(); + let [sample_size, _] = input.predictions.dims(); + + let (threshold, top_k) = match self.config { + PrecisionConfig::Binary { threshold } | PrecisionConfig::Multilabel { threshold } => { + (Some(threshold), None) + } + PrecisionConfig::Multiclass { top_k } => (None, Some(top_k)), + }; + + let cf_stats = + ConfusionStats::new(predictions, targets, threshold, top_k, self.class_reduction); + let metric = + self.class_average(cf_stats.clone().true_positive() / cf_stats.predicted_positive()); + + self.state.update( + 100.0 * metric, + sample_size, + FormatOptions::new(Self::NAME).unit("%").precision(2), + ) + } + + fn clear(&mut self) { + self.state.reset() + } +} + +impl Numeric for PrecisionMetric { + fn value(&self) -> f64 { + self.state.value() + } +} + +#[cfg(test)] +mod tests { + use super::{ + ClassReduction::{self, *}, + Metric, MetricMetadata, Numeric, PrecisionMetric, + }; + use crate::tests::{dummy_classification_input, ClassificationType, THRESHOLD}; + use burn_core::tensor::TensorData; + use rstest::rstest; + + #[rstest] + #[case::binary_micro(Micro, THRESHOLD, 0.5)] + #[case::binary_macro(Macro, THRESHOLD, 0.5)] + fn test_binary_precision( + #[case] class_reduction: ClassReduction, + #[case] threshold: f64, + #[case] expected: f64, + ) { + let input = dummy_classification_input(&ClassificationType::Binary).into(); + let mut metric = PrecisionMetric::binary(threshold).with_class_reduction(class_reduction); + let _entry = metric.update(&input, &MetricMetadata::fake()); + TensorData::from([metric.value()]) + .assert_approx_eq(&TensorData::from([expected * 100.0]), 3) + } + + #[rstest] + #[case::multiclass_micro_k1(Micro, 1, 3.0/5.0)] + #[case::multiclass_micro_k2(Micro, 2, 4.0/10.0)] + #[case::multiclass_macro_k1(Macro, 1, (0.5 + 0.5 + 1.0)/3.0)] + #[case::multiclass_macro_k2(Macro, 2, (0.5 + 1.0/4.0 + 0.5)/3.0)] + fn test_multiclass_precision( + #[case] class_reduction: ClassReduction, + #[case] top_k: usize, + #[case] expected: f64, + ) { + let input = dummy_classification_input(&ClassificationType::Multiclass).into(); + let mut metric = PrecisionMetric::multiclass(top_k).with_class_reduction(class_reduction); + let _entry = metric.update(&input, &MetricMetadata::fake()); + TensorData::from([metric.value()]) + .assert_approx_eq(&TensorData::from([expected * 100.0]), 3) + } + + #[rstest] + #[case::multilabel_micro(Micro, THRESHOLD, 5.0/8.0)] + #[case::multilabel_macro(Macro, THRESHOLD, (2.0/3.0 + 2.0/3.0 + 0.5)/3.0)] + fn test_precision( + #[case] class_reduction: ClassReduction, + #[case] threshold: f64, + #[case] expected: f64, + ) { + let input = dummy_classification_input(&ClassificationType::Multilabel).into(); + let mut metric = + PrecisionMetric::multilabel(threshold).with_class_reduction(class_reduction); + let _entry = metric.update(&input, &MetricMetadata::fake()); + TensorData::from([metric.value()]) + .assert_approx_eq(&TensorData::from([expected * 100.0]), 3) + } +}