Skip to content

Commit

Permalink
Add precision classification metric (tracel-ai#2293)
Browse files Browse the repository at this point in the history
* Implement confusion matrix and precision, first draft

* Implement confusion matrix

* format :D

* add agg type to cm, reformat debug representation add testing.
improve dummy classification input.
reformat precision and add test with dummy data.

* formating and tiny refactor

* add ClassificationMetric trait, rename variables and types, move test module to lib.rs make precision a classification metric.

* change unwrap to expect

* update book

* remove unused code

* changes to make reusing code easier

* format :D

* change to static data tests

* remove classification metric trait, add auxiliary code for  classification input, clarify descriptions, remove dead code, rename some objects

* move classification objects to classification.rs, use rstest, remove approx lib and use tensordata asserts, move aggregate and average functions to ConfusionStats implementation

* review docstring, add top_k for multiclass tasks.

* move class averaging and metric computation to metric implementation, make dummy data more predictable and add tests for top_k > 1

* change struct and var names

* rename params, enforce nonzero for top_k param, optimize one_hot for case num_class = 1, reformat dummy data, make use of derive(new) for metric init.

* add adaptor por classification input, correct one hot function

* define default for ClassReduction, derive new for Precision metric with class_reduction as default and new setter implementation, move NonZerousize boundary to confusion_stats

* expose PrecisionMetric, change metric initialization

* check one_hot input tensor has more than 1 classes and correct it's implementation, deal with classification output with 1 class, make macro average default, expose ClassReduction type and split precision implementations by classification type

* implement adaptor for MultilabelClassificationOutput and ClassificationInput

* change with_top_k to take usize

* Add precision config for binary, multiclass and multilabel

* Fix dummy_classification_input

* make PrecisionMetric public

---------

Co-authored-by: Tiago Sanona <[email protected]>
Co-authored-by: Guillaume Lagrange <[email protected]>
  • Loading branch information
3 people authored Nov 20, 2024
1 parent 2132d47 commit 76e67bf
Show file tree
Hide file tree
Showing 10 changed files with 694 additions and 1 deletion.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions burn-book/src/building-blocks/metric.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ throughout the training process. We currently offer a restricted range of metric
| ---------------- | ------------------------------------------------------- |
| Accuracy | Calculate the accuracy in percentage |
| TopKAccuracy | Calculate the top-k accuracy in percentage |
| Precision | Calculate precision in percentage |
| AUROC | Calculate the area under curve of ROC in percentage |
| Loss | Output the loss used for the backward pass |
| CPU Temperature | Fetch the temperature of CPUs |
Expand Down
5 changes: 5 additions & 0 deletions crates/burn-tensor/src/tensor/api/check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,11 @@ impl TensorCheck {
"Can't create a one hot tensor from ({index_tensor:?}) containing indexes greater or equal to the number of classes ({num_classes})",
)),
);
} else if num_classes <= 1 {
check = check.register(
"One Hot",
TensorError::new("Can't create a one hot tensor with less then 2 classes"),
)
}
check
}
Expand Down
1 change: 1 addition & 0 deletions crates/burn-train/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ ratatui = { workspace = true, optional = true, features = ["all-widgets", "cross
derive-new = { workspace = true }
serde = { workspace = true, features = ["std", "derive"] }
async-channel = { workspace = true }
rstest.workspace = true

[dev-dependencies]
burn-ndarray = { path = "../burn-ndarray", version = "0.16.0" }
Expand Down
25 changes: 24 additions & 1 deletion crates/burn-train/src/learner/classification.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::metric::{AccuracyInput, Adaptor, HammingScoreInput, LossInput};
use crate::metric::{AccuracyInput, Adaptor, HammingScoreInput, LossInput, PrecisionInput};
use burn_core::tensor::backend::Backend;
use burn_core::tensor::{Int, Tensor};

Expand Down Expand Up @@ -27,6 +27,23 @@ impl<B: Backend> Adaptor<LossInput<B>> for ClassificationOutput<B> {
}
}

impl<B: Backend> Adaptor<PrecisionInput<B>> for ClassificationOutput<B> {
fn adapt(&self) -> PrecisionInput<B> {
let [_, num_classes] = self.output.dims();
if num_classes > 1 {
PrecisionInput::new(
self.output.clone(),
self.targets.clone().one_hot(num_classes).bool(),
)
} else {
PrecisionInput::new(
self.output.clone(),
self.targets.clone().unsqueeze_dim(1).bool(),
)
}
}
}

/// Multi-label classification output adapted for multiple metrics.
#[derive(new)]
pub struct MultiLabelClassificationOutput<B: Backend> {
Expand All @@ -51,3 +68,9 @@ impl<B: Backend> Adaptor<LossInput<B>> for MultiLabelClassificationOutput<B> {
LossInput::new(self.loss.clone())
}
}

impl<B: Backend> Adaptor<PrecisionInput<B>> for MultiLabelClassificationOutput<B> {
fn adapt(&self) -> PrecisionInput<B> {
PrecisionInput::new(self.output.clone(), self.targets.clone().bool())
}
}
75 changes: 75 additions & 0 deletions crates/burn-train/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,78 @@ pub use learner::*;

#[cfg(test)]
pub(crate) type TestBackend = burn_ndarray::NdArray<f32>;

#[cfg(test)]
pub(crate) mod tests {
use crate::TestBackend;
use burn_core::{prelude::Tensor, tensor::Bool};
use std::default::Default;

/// Probability of tp before adding errors
pub const THRESHOLD: f64 = 0.5;

#[derive(Debug)]
pub enum ClassificationType {
Binary,
Multiclass,
Multilabel,
}

/// Sample x Class shaped matrix for use in
/// classification metrics testing
pub fn dummy_classification_input(
classification_type: &ClassificationType,
) -> (Tensor<TestBackend, 2>, Tensor<TestBackend, 2, Bool>) {
match classification_type {
ClassificationType::Binary => {
(
Tensor::from_data(
[[0.3], [0.2], [0.7], [0.1], [0.55]],
//[[0], [0], [1], [0], [1]] with threshold=0.5
&Default::default(),
),
Tensor::from_data([[0], [1], [0], [0], [1]], &Default::default()),
)
}
ClassificationType::Multiclass => {
(
Tensor::from_data(
[
[0.2, 0.8, 0.0],
[0.3, 0.6, 0.1],
[0.7, 0.25, 0.05],
[0.1, 0.15, 0.8],
[0.9, 0.03, 0.07],
],
//[[0, 1, 0], [0, 1, 0], [1, 0, 0], [0, 0, 1], [1, 0, 0]] with top_k=1
//[[1, 1, 0], [1, 1, 0], [1, 1, 0], [0, 1, 1], [1, 0, 1]] with top_k=2
&Default::default(),
),
Tensor::from_data(
[[0, 1, 0], [1, 0, 0], [0, 0, 1], [0, 0, 1], [1, 0, 0]],
&Default::default(),
),
)
}
ClassificationType::Multilabel => {
(
Tensor::from_data(
[
[0.1, 0.7, 0.6],
[0.3, 0.9, 0.05],
[0.8, 0.9, 0.4],
[0.7, 0.5, 0.9],
[1.0, 0.3, 0.2],
],
//[[0, 1, 1], [0, 1, 0], [1, 1, 0], [1, 0, 1], [1, 0, 0]] with threshold=0.5
&Default::default(),
),
Tensor::from_data(
[[1, 1, 0], [1, 0, 1], [1, 1, 1], [0, 0, 1], [1, 0, 0]],
&Default::default(),
),
)
}
}
}
}
9 changes: 9 additions & 0 deletions crates/burn-train/src/metric/classification.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
/// The reduction strategy for classification metrics.
#[derive(Copy, Clone, Default)]
pub enum ClassReduction {
/// Computes the statistics over all classes before averaging
Micro,
/// Computes the statistics independently for each class before averaging
#[default]
Macro,
}
Loading

0 comments on commit 76e67bf

Please sign in to comment.