diff --git a/sliceguard/sliceguard.py b/sliceguard/sliceguard.py index be2f066..92a3782 100644 --- a/sliceguard/sliceguard.py +++ b/sliceguard/sliceguard.py @@ -428,7 +428,7 @@ def report( df["sg_y_pred"] = self._generated_y_pred[selected_dataframe_rows] if hasattr(self, "_generated_y_probs") and hasattr(self, "classes"): - for class_idx, label in enumerate(self._classes): + for class_idx, label in enumerate(self.classes): df[f"sg_p_{label}"] = self._generated_y_probs[:, class_idx].tolist() spotlight_issue_list = np.array(data_issues)[data_issue_order].tolist()