diff --git a/metrics/confusion_matrix/confusion_matrix.py b/metrics/confusion_matrix/confusion_matrix.py index 195254f4..0c2ebc3b 100644 --- a/metrics/confusion_matrix/confusion_matrix.py +++ b/metrics/confusion_matrix/confusion_matrix.py @@ -85,14 +85,21 @@ def _info(self): "references": datasets.Value("int32"), } ), - reference_urls=["https://scikit-learn.org/stable/modules/generated/sklearn.metrics.confusion_matrix.html", "https://scikit-learn.org/stable/modules/generated/sklearn.metrics.multilabel_confusion_matrix.html"], + reference_urls=[ + "https://scikit-learn.org/stable/modules/generated/sklearn.metrics.confusion_matrix.html", + "https://scikit-learn.org/stable/modules/generated/sklearn.metrics.multilabel_confusion_matrix.html", + ], ) def _compute(self, predictions, references, labels=None, sample_weight=None, normalize=None): if self.config_name == "multilabel": return { "confusion_matrix": multilabel_confusion_matrix( - references, predictions, sample_weight=sample_weight, labels=labels, samplewise=normalize == "samplewise" + references, + predictions, + sample_weight=sample_weight, + labels=labels, + samplewise=normalize == "samplewise", ), } return { @@ -100,4 +107,3 @@ def _compute(self, predictions, references, labels=None, sample_weight=None, nor references, predictions, labels=labels, sample_weight=sample_weight, normalize=normalize ), } -