Skip to content

Commit

Permalink
Update confusion_matrix.py
Browse files Browse the repository at this point in the history
formatting file
  • Loading branch information
0ssamaak0 authored Jan 11, 2024
1 parent 7f93428 commit 51aa5dd
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions metrics/confusion_matrix/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,19 +85,25 @@ def _info(self):
"references": datasets.Value("int32"),
}
),
reference_urls=["https://scikit-learn.org/stable/modules/generated/sklearn.metrics.confusion_matrix.html", "https://scikit-learn.org/stable/modules/generated/sklearn.metrics.multilabel_confusion_matrix.html"],
reference_urls=[
"https://scikit-learn.org/stable/modules/generated/sklearn.metrics.confusion_matrix.html",
"https://scikit-learn.org/stable/modules/generated/sklearn.metrics.multilabel_confusion_matrix.html",
],
)

def _compute(self, predictions, references, labels=None, sample_weight=None, normalize=None):
if self.config_name == "multilabel":
return {
"confusion_matrix": multilabel_confusion_matrix(
references, predictions, sample_weight=sample_weight, labels=labels, samplewise=normalize == "samplewise"
references,
predictions,
sample_weight=sample_weight,
labels=labels,
samplewise=normalize == "samplewise",
),
}
return {
"confusion_matrix": confusion_matrix(
references, predictions, labels=labels, sample_weight=sample_weight, normalize=normalize
),
}

0 comments on commit 51aa5dd

Please sign in to comment.