Skip to content

Commit

Permalink
Add confusion matrix (#528)
Browse files Browse the repository at this point in the history
* Add confusion matrix

* Fix docstring

* Fix doctest

* Fix

* Update confusion_matrix.py

* Add whitespace

* Return to fix

* Quick test for windows

* Make tests not depend on each other

* Update confusion_matrix.py

* Ellipsis

* Fix whitespace
  • Loading branch information
osanseviero authored Dec 27, 2023
1 parent dfbbf15 commit 8dfe057
Show file tree
Hide file tree
Showing 9 changed files with 202 additions and 4 deletions.
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ jobs:
test:
needs: check_code_quality
strategy:
fail-fast: false
matrix:
test: ['unit', 'parity']
os: [ubuntu-latest, windows-latest]
Expand Down
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ unsure, it is always a good idea to open an issue to get some feedback.
that can't be automated in one go with:

```bash
$ make fixup
$ make quality
```

This target is also optimized to only work with files modified by the PR you're working on.
Expand Down
2 changes: 1 addition & 1 deletion metrics/accuracy/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ At minimum, this metric requires predictions and references as inputs.


### Output Values
- **accuracy**(`float` or `int`): Accuracy score. Minimum possible value is 0. Maximum possible value is 1.0, or the number of examples input, if `normalize` is set to `True`.. A higher score means higher accuracy.
- **accuracy**(`float` or `int`): Accuracy score. Minimum possible value is 0. Maximum possible value is 1.0, or the number of examples input, if `normalize` is set to `True`. A higher score means higher accuracy.

Output Example(s):
```python
Expand Down
101 changes: 101 additions & 0 deletions metrics/confusion_matrix/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
---
title: Confusion Matrix
emoji: 🤗
colorFrom: blue
colorTo: red
sdk: gradio
sdk_version: 3.19.1
app_file: app.py
pinned: false
tags:
- evaluate
- metric
description: >-
The confusion matrix evaluates classification accuracy.
Each row in a confusion matrix represents a true class and each column represents the instances in a predicted class.
---

# Metric Card for Confusion Matrix


## Metric Description

The confusion matrix evaluates classification accuracy. Each row in a confusion matrix represents a true class and each column represents the instances in a predicted class. Let's look at an example:

| | setosa | versicolor | virginica |
| ---------- | ------ | ---------- | --------- |
| setosa | 13 | 0 | 0 |
| versicolor | 0 | 10 | 6 |
| virginica | 0 | 0 | 9 |

What information does this confusion matrix provide?

* All setosa instances were properly predicted as such (true positives).
* The model always correctly classifies the setosa class (there are no false positives).
* 10 versicolor instances were properly classified, but 6 instances were misclassified as virginica.
* All virginica insances were properly classified as such.


## How to Use

At minimum, this metric requires predictions and references as inputs.

```python
>>> confusion_metric = evaluate.load("confusion_matrix")
>>> results = confusion_metric.compute(references=[0, 1, 1, 2, 0, 2, 2], predictions=[0, 2, 1, 1, 0, 2, 0])
>>> print(results)
{'confusion_matrix': [[2, 0, 0], [0, 1, 1], [1, 1, 1]]}
```


### Inputs
- **predictions** (`list` of `int`): Predicted labels.
- **references** (`list` of `int`): Ground truth labels.
- **labels** (`list` of `int`): List of labels to index the matrix. This may be used to reorder or select a subset of labels.
- **sample_weight** (`list` of `float`): Sample weights.
- **normalize** (`str`): Normalizes confusion matrix over the true (rows), predicted (columns) conditions or all the population.


### Output Values
- **confusion_matrix**(`list` of `list` of `str`): Confusion matrix. Minimum possible value is 0. Maximum possible value is 1.0, or the number of examples input, if `normalize` is set to `True`.

Output Example(s):
```python
{'confusion_matrix': [[2, 0, 0], [0, 1, 1], [1, 1, 1]]}
```

This metric outputs a dictionary, containing the confusion matrix.


### Examples

Example 1 - A simple example

```python
>>> confusion_metric = evaluate.load("confusion_matrix")
>>> results = confusion_metric.compute(references=[0, 1, 1, 2, 0, 2, 2], predictions=[0, 2, 1, 1, 0, 2, 0])
>>> print(results)
{'confusion_matrix': [[2, 0, 0], [0, 1, 1], [1, 1, 1]]}
```

## Citation(s)
```bibtex
@article{scikit-learn,
title={Scikit-learn: Machine Learning in {P}ython},
author={Pedregosa, F. and Varoquaux, G. and Gramfort, A. and Michel, V.
and Thirion, B. and Grisel, O. and Blondel, M. and Prettenhofer, P.
and Weiss, R. and Dubourg, V. and Vanderplas, J. and Passos, A. and
Cournapeau, D. and Brucher, M. and Perrot, M. and Duchesnay, E.},
journal={Journal of Machine Learning Research},
volume={12},
pages={2825--2830},
year={2011}
}
```


## Further References

* https://scikit-learn.org/stable/modules/generated/sklearn.metrics.confusion_matrix.html
* https://en.wikipedia.org/wiki/Confusion_matrix
6 changes: 6 additions & 0 deletions metrics/confusion_matrix/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import evaluate
from evaluate.utils import launch_gradio_widget


module = evaluate.load("confusion_matrix")
launch_gradio_widget(module)
88 changes: 88 additions & 0 deletions metrics/confusion_matrix/confusion_matrix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Confusion Matrix."""

import datasets
from sklearn.metrics import confusion_matrix

import evaluate


_DESCRIPTION = """
The confusion matrix evaluates classification accuracy. Each row in a confusion matrix represents a true class and each column represents the instances in a predicted class
"""

_KWARGS_DESCRIPTION = """
Args:
predictions (`list` of `int`): Predicted labels.
references (`list` of `int`): Ground truth labels.
labels (`list` of `int`): List of labels to index the matrix. This may be used to reorder or select a subset of labels.
sample_weight (`list` of `float`): Sample weights.
normalize (`str`): Normalizes confusion matrix over the true (rows), predicted (columns) conditions or all the population.
Returns:
confusion_matrix (`list` of `list` of `int`): Confusion matrix whose i-th row and j-th column entry indicates the number of samples with true label being i-th class and predicted label being j-th class.
Examples:
Example 1-A simple example
>>> confusion_matrix_metric = evaluate.load("confusion_matrix")
>>> results = confusion_matrix_metric.compute(references=[0, 1, 2, 0, 1, 2], predictions=[0, 1, 1, 2, 1, 0])
>>> print(results) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
{'confusion_matrix': array([[1, 0, 1], [0, 2, 0], [1, 1, 0]][...])}
"""


_CITATION = """
@article{scikit-learn,
title={Scikit-learn: Machine Learning in {P}ython},
author={Pedregosa, F. and Varoquaux, G. and Gramfort, A. and Michel, V.
and Thirion, B. and Grisel, O. and Blondel, M. and Prettenhofer, P.
and Weiss, R. and Dubourg, V. and Vanderplas, J. and Passos, A. and
Cournapeau, D. and Brucher, M. and Perrot, M. and Duchesnay, E.},
journal={Journal of Machine Learning Research},
volume={12},
pages={2825--2830},
year={2011}
}
"""


@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class ConfusionMatrix(evaluate.Metric):
def _info(self):
return evaluate.MetricInfo(
description=_DESCRIPTION,
citation=_CITATION,
inputs_description=_KWARGS_DESCRIPTION,
features=datasets.Features(
{
"predictions": datasets.Sequence(datasets.Value("int32")),
"references": datasets.Sequence(datasets.Value("int32")),
}
if self.config_name == "multilabel"
else {
"predictions": datasets.Value("int32"),
"references": datasets.Value("int32"),
}
),
reference_urls=["https://scikit-learn.org/stable/modules/generated/sklearn.metrics.confusion_matrix.html"],
)

def _compute(self, predictions, references, labels=None, sample_weight=None, normalize=None):
return {
"confusion_matrix": confusion_matrix(
references, predictions, labels=labels, sample_weight=sample_weight, normalize=normalize
)
}
2 changes: 2 additions & 0 deletions metrics/confusion_matrix/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
git+https://github.com/huggingface/evaluate@{COMMIT_PLACEHOLDER}
scikit-learn
2 changes: 1 addition & 1 deletion src/evaluate/evaluator/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def compute(
# TODO: To clarify why `wer` and `cer` return float
# even though metric.compute contract says that it
# returns Optional[dict].
if type(metric_results) == float:
if type(metric_results) is float:
metric_results = {metric.name: metric_results}

result.update(metric_results)
Expand Down
2 changes: 1 addition & 1 deletion src/evaluate/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def _release(self):
# lists - summarize long lists similarly to NumPy
# arrays/tensors - let the frameworks control formatting
def summarize_if_long_list(obj):
if not type(obj) == list or len(obj) <= 6:
if type(obj) is not list or len(obj) <= 6:
return f"{obj}"

def format_chunk(chunk):
Expand Down

0 comments on commit 8dfe057

Please sign in to comment.