-
Notifications
You must be signed in to change notification settings - Fork 264
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add confusion matrix * Fix docstring * Fix doctest * Fix * Update confusion_matrix.py * Add whitespace * Return to fix * Quick test for windows * Make tests not depend on each other * Update confusion_matrix.py * Ellipsis * Fix whitespace
- Loading branch information
1 parent
dfbbf15
commit 8dfe057
Showing
9 changed files
with
202 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
--- | ||
title: Confusion Matrix | ||
emoji: 🤗 | ||
colorFrom: blue | ||
colorTo: red | ||
sdk: gradio | ||
sdk_version: 3.19.1 | ||
app_file: app.py | ||
pinned: false | ||
tags: | ||
- evaluate | ||
- metric | ||
description: >- | ||
The confusion matrix evaluates classification accuracy. | ||
Each row in a confusion matrix represents a true class and each column represents the instances in a predicted class. | ||
--- | ||
|
||
# Metric Card for Confusion Matrix | ||
|
||
|
||
## Metric Description | ||
|
||
The confusion matrix evaluates classification accuracy. Each row in a confusion matrix represents a true class and each column represents the instances in a predicted class. Let's look at an example: | ||
|
||
| | setosa | versicolor | virginica | | ||
| ---------- | ------ | ---------- | --------- | | ||
| setosa | 13 | 0 | 0 | | ||
| versicolor | 0 | 10 | 6 | | ||
| virginica | 0 | 0 | 9 | | ||
|
||
What information does this confusion matrix provide? | ||
|
||
* All setosa instances were properly predicted as such (true positives). | ||
* The model always correctly classifies the setosa class (there are no false positives). | ||
* 10 versicolor instances were properly classified, but 6 instances were misclassified as virginica. | ||
* All virginica insances were properly classified as such. | ||
|
||
|
||
## How to Use | ||
|
||
At minimum, this metric requires predictions and references as inputs. | ||
|
||
```python | ||
>>> confusion_metric = evaluate.load("confusion_matrix") | ||
>>> results = confusion_metric.compute(references=[0, 1, 1, 2, 0, 2, 2], predictions=[0, 2, 1, 1, 0, 2, 0]) | ||
>>> print(results) | ||
{'confusion_matrix': [[2, 0, 0], [0, 1, 1], [1, 1, 1]]} | ||
``` | ||
|
||
|
||
### Inputs | ||
- **predictions** (`list` of `int`): Predicted labels. | ||
- **references** (`list` of `int`): Ground truth labels. | ||
- **labels** (`list` of `int`): List of labels to index the matrix. This may be used to reorder or select a subset of labels. | ||
- **sample_weight** (`list` of `float`): Sample weights. | ||
- **normalize** (`str`): Normalizes confusion matrix over the true (rows), predicted (columns) conditions or all the population. | ||
|
||
|
||
### Output Values | ||
- **confusion_matrix**(`list` of `list` of `str`): Confusion matrix. Minimum possible value is 0. Maximum possible value is 1.0, or the number of examples input, if `normalize` is set to `True`. | ||
|
||
Output Example(s): | ||
```python | ||
{'confusion_matrix': [[2, 0, 0], [0, 1, 1], [1, 1, 1]]} | ||
``` | ||
|
||
This metric outputs a dictionary, containing the confusion matrix. | ||
|
||
|
||
### Examples | ||
|
||
Example 1 - A simple example | ||
|
||
```python | ||
>>> confusion_metric = evaluate.load("confusion_matrix") | ||
>>> results = confusion_metric.compute(references=[0, 1, 1, 2, 0, 2, 2], predictions=[0, 2, 1, 1, 0, 2, 0]) | ||
>>> print(results) | ||
{'confusion_matrix': [[2, 0, 0], [0, 1, 1], [1, 1, 1]]} | ||
``` | ||
|
||
## Citation(s) | ||
```bibtex | ||
@article{scikit-learn, | ||
title={Scikit-learn: Machine Learning in {P}ython}, | ||
author={Pedregosa, F. and Varoquaux, G. and Gramfort, A. and Michel, V. | ||
and Thirion, B. and Grisel, O. and Blondel, M. and Prettenhofer, P. | ||
and Weiss, R. and Dubourg, V. and Vanderplas, J. and Passos, A. and | ||
Cournapeau, D. and Brucher, M. and Perrot, M. and Duchesnay, E.}, | ||
journal={Journal of Machine Learning Research}, | ||
volume={12}, | ||
pages={2825--2830}, | ||
year={2011} | ||
} | ||
``` | ||
|
||
|
||
## Further References | ||
|
||
* https://scikit-learn.org/stable/modules/generated/sklearn.metrics.confusion_matrix.html | ||
* https://en.wikipedia.org/wiki/Confusion_matrix |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
import evaluate | ||
from evaluate.utils import launch_gradio_widget | ||
|
||
|
||
module = evaluate.load("confusion_matrix") | ||
launch_gradio_widget(module) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Confusion Matrix.""" | ||
|
||
import datasets | ||
from sklearn.metrics import confusion_matrix | ||
|
||
import evaluate | ||
|
||
|
||
_DESCRIPTION = """ | ||
The confusion matrix evaluates classification accuracy. Each row in a confusion matrix represents a true class and each column represents the instances in a predicted class | ||
""" | ||
|
||
_KWARGS_DESCRIPTION = """ | ||
Args: | ||
predictions (`list` of `int`): Predicted labels. | ||
references (`list` of `int`): Ground truth labels. | ||
labels (`list` of `int`): List of labels to index the matrix. This may be used to reorder or select a subset of labels. | ||
sample_weight (`list` of `float`): Sample weights. | ||
normalize (`str`): Normalizes confusion matrix over the true (rows), predicted (columns) conditions or all the population. | ||
Returns: | ||
confusion_matrix (`list` of `list` of `int`): Confusion matrix whose i-th row and j-th column entry indicates the number of samples with true label being i-th class and predicted label being j-th class. | ||
Examples: | ||
Example 1-A simple example | ||
>>> confusion_matrix_metric = evaluate.load("confusion_matrix") | ||
>>> results = confusion_matrix_metric.compute(references=[0, 1, 2, 0, 1, 2], predictions=[0, 1, 1, 2, 1, 0]) | ||
>>> print(results) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE | ||
{'confusion_matrix': array([[1, 0, 1], [0, 2, 0], [1, 1, 0]][...])} | ||
""" | ||
|
||
|
||
_CITATION = """ | ||
@article{scikit-learn, | ||
title={Scikit-learn: Machine Learning in {P}ython}, | ||
author={Pedregosa, F. and Varoquaux, G. and Gramfort, A. and Michel, V. | ||
and Thirion, B. and Grisel, O. and Blondel, M. and Prettenhofer, P. | ||
and Weiss, R. and Dubourg, V. and Vanderplas, J. and Passos, A. and | ||
Cournapeau, D. and Brucher, M. and Perrot, M. and Duchesnay, E.}, | ||
journal={Journal of Machine Learning Research}, | ||
volume={12}, | ||
pages={2825--2830}, | ||
year={2011} | ||
} | ||
""" | ||
|
||
|
||
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) | ||
class ConfusionMatrix(evaluate.Metric): | ||
def _info(self): | ||
return evaluate.MetricInfo( | ||
description=_DESCRIPTION, | ||
citation=_CITATION, | ||
inputs_description=_KWARGS_DESCRIPTION, | ||
features=datasets.Features( | ||
{ | ||
"predictions": datasets.Sequence(datasets.Value("int32")), | ||
"references": datasets.Sequence(datasets.Value("int32")), | ||
} | ||
if self.config_name == "multilabel" | ||
else { | ||
"predictions": datasets.Value("int32"), | ||
"references": datasets.Value("int32"), | ||
} | ||
), | ||
reference_urls=["https://scikit-learn.org/stable/modules/generated/sklearn.metrics.confusion_matrix.html"], | ||
) | ||
|
||
def _compute(self, predictions, references, labels=None, sample_weight=None, normalize=None): | ||
return { | ||
"confusion_matrix": confusion_matrix( | ||
references, predictions, labels=labels, sample_weight=sample_weight, normalize=normalize | ||
) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
git+https://github.com/huggingface/evaluate@{COMMIT_PLACEHOLDER} | ||
scikit-learn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters