Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support multilabel confusion matrix #533

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions metrics/confusion_matrix/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,12 @@ At minimum, this metric requires predictions and references as inputs.


### Output Values
- **confusion_matrix**(`list` of `list` of `str`): Confusion matrix. Minimum possible value is 0. Maximum possible value is 1.0, or the number of examples input, if `normalize` is set to `True`.
- **confusion_matrix**(`list` of `list` of `str`): Confusion matrix. In a single-label scenario, the i-th row and j-th column entry indicates the number of samples with true label being i-th class and predicted label being j-th class. In a multilabel scenario, each element in the confusion matrix represents the number of samples that have been assigned a particular combination of labels. For example, the element at the i-th row and j-th column would represent the number of samples that have been correctly assigned the i-th label and incorrectly assigned the j-th label. Minimum possible value is 0. Maximum possible value is 1.0, or the number of examples input, if `normalize` is set to `True`.

Output Example(s):
```python
{'confusion_matrix': [[2, 0, 0], [0, 1, 1], [1, 1, 1]]}
```
0ssamaak0 marked this conversation as resolved.
Show resolved Hide resolved

This metric outputs a dictionary, containing the confusion matrix.
{'confusion_matrix': [[2, 0, 0], [0, 1, 1], [1, 1, 1]]} # Single-label scenario
{'confusion_matrix': [[[1, 0], [0, 1]], [[0, 1], [1, 0]], [[1, 0], [0, 1]]]} # Multilabel scenario


### Examples
Expand All @@ -79,6 +77,16 @@ Example 1 - A simple example
{'confusion_matrix': [[2, 0, 0], [0, 1, 1], [1, 1, 1]]}
```

Example 2 - Multilabel scenario with binary labels

```python
>>> confusion_metric = evaluate.load("confusion_matrix", config_name="multilabel")
>>> results = confusion_metric.compute(references=[[0, 1], [1, 0], [0, 0], [0, 1], [1, 0], [0, 0]], predictions=[[0, 1], [1, 0], [1, 0], [0, 0], [1, 0], [0, 1]])
>>> print(results)
{'confusion_matrix': [[[3, 1], [0, 2]], [[3, 1], [1, 1]]]}
'''


## Citation(s)
```bibtex
@article{scikit-learn,
Expand All @@ -98,4 +106,4 @@ Example 1 - A simple example
## Further References

* https://scikit-learn.org/stable/modules/generated/sklearn.metrics.confusion_matrix.html
* https://en.wikipedia.org/wiki/Confusion_matrix
* https://en.wikipedia.org/wiki/Confusion_matrix
27 changes: 24 additions & 3 deletions metrics/confusion_matrix/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""Confusion Matrix."""

import datasets
from sklearn.metrics import confusion_matrix
from sklearn.metrics import confusion_matrix, multilabel_confusion_matrix

import evaluate

Expand All @@ -33,6 +33,7 @@

Returns:
confusion_matrix (`list` of `list` of `int`): Confusion matrix whose i-th row and j-th column entry indicates the number of samples with true label being i-th class and predicted label being j-th class.
In a multilabel scenario, each element in the confusion matrix represents the number of samples that have been assigned a particular combination of labels.

Examples:

Expand All @@ -41,6 +42,13 @@
>>> results = confusion_matrix_metric.compute(references=[0, 1, 2, 0, 1, 2], predictions=[0, 1, 1, 2, 1, 0])
>>> print(results) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
{'confusion_matrix': array([[1, 0, 1], [0, 2, 0], [1, 1, 0]][...])}

Example 2-Multilabel scenario
>>> you must pass (config_name="multilabel") to the load method
0ssamaak0 marked this conversation as resolved.
Show resolved Hide resolved
>>> confusion_matrix_metric = evaluate.load("confusion_matrix", config_name="multilabel")
>>> results = confusion_matrix_metric.compute(references=[[0, 1], [1, 0], [0, 0], [0, 1], [1, 0], [0, 0]], predictions=[[0, 1], [1, 0], [1, 0], [0, 0], [1, 0], [0, 1]])
>>> print(results) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
{'confusion_matrix': [[[3, 1], [0, 2]], [[3, 1], [1, 1]]]}
"""


Expand Down Expand Up @@ -77,12 +85,25 @@ def _info(self):
"references": datasets.Value("int32"),
}
),
reference_urls=["https://scikit-learn.org/stable/modules/generated/sklearn.metrics.confusion_matrix.html"],
reference_urls=[
"https://scikit-learn.org/stable/modules/generated/sklearn.metrics.confusion_matrix.html",
"https://scikit-learn.org/stable/modules/generated/sklearn.metrics.multilabel_confusion_matrix.html",
],
)

def _compute(self, predictions, references, labels=None, sample_weight=None, normalize=None):
if self.config_name == "multilabel":
return {
"confusion_matrix": multilabel_confusion_matrix(
references,
predictions,
sample_weight=sample_weight,
labels=labels,
samplewise=normalize == "samplewise",
),
}
return {
"confusion_matrix": confusion_matrix(
references, predictions, labels=labels, sample_weight=sample_weight, normalize=normalize
)
),
}
Loading