Skip to content

Commit

Permalink
Eval options and improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
dwiddows committed Apr 9, 2024
1 parent 05042fc commit ee685a4
Show file tree
Hide file tree
Showing 5 changed files with 696 additions and 64 deletions.
1 change: 1 addition & 0 deletions experiments/classification_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def nullsafe_classification_report(y_label: List[str], y_pred: List[str]):
num_pred_labels = len({y for y in y_pred if y in label_set})
y_pred = [y if y in label_set else dummy_val for y in y_pred]
report = classification_report(y_label, y_pred, output_dict=True, zero_division=0.0)

if dummy_val in report:
del report[dummy_val]
report["macro avg"]["precision"] = report["macro avg"]["precision"] * (num_pred_labels + 1) / num_pred_labels
Expand Down
Loading

0 comments on commit ee685a4

Please sign in to comment.