diff --git a/src/triage/component/postmodeling/model_analyzer.py b/src/triage/component/postmodeling/model_analyzer.py index 361e05d29..1fbb2bd3b 100644 --- a/src/triage/component/postmodeling/model_analyzer.py +++ b/src/triage/component/postmodeling/model_analyzer.py @@ -770,7 +770,7 @@ def plot_precision_threshold_curve(self, ax, matrix_uuid=None): return ax # TODO: Facilitate plotting pr-k with absolute thresholds - def plot_precision_recall_curve(self, ax=None, matrix_uuid=None, list_size_upper_bound_pct=1, pct_step_size=0.01, subset_hash=None): + def plot_precision_recall_curve(self, ax=None, matrix_uuid=None, list_size_upper_bound_pct=1, pct_step_size=0.01, subset_hash=None, title_string=None): """ Plots precision-recall curves at each train_end_time for all model groups @@ -833,7 +833,7 @@ def plot_precision_recall_curve(self, ax=None, matrix_uuid=None, list_size_upper else: logging.debug(f'Found saved predictions for model id {self.model_id}.(group: {self.model_group_id})') - k_values = np.arange(0 + pct_step_size, list_size_upper_bound_pct + pct_step_size, pct_step_size) + k_values = np.arange(0, list_size_upper_bound_pct + pct_step_size, pct_step_size) precisions = list() recalls = list() @@ -845,21 +845,31 @@ def plot_precision_recall_curve(self, ax=None, matrix_uuid=None, list_size_upper pred_pos = pred_df.iloc[:num_above_thresh] - precision = pred_pos.label_value.sum() / num_above_thresh - recall = pred_pos.label_value.sum() / num_positives + + precision = pred_pos.label_value.sum() / num_above_thresh if num_above_thresh > 0 else 0 + recall = pred_pos.label_value.sum() / num_positives if num_above_thresh > 0 else 0 precisions.append(precision) recalls.append(recall) - sns.lineplot(x=k_values * 100, y=precisions, ax=ax, label='Precision@k') - sns.lineplot(x=k_values * 100, y=recalls, ax=ax, label='Recall@k') + + with sns.axes_style("darkgrid"): + sns.lineplot(x=k_values * 100, y=precisions, ax=ax, label='Precision@k') + sns.lineplot(x=k_values * 100, y=recalls, ax=ax, label='Recall@k') ax.set_xlabel('Population percentage (k %)') ax.set_ylabel('Metric Value') + ax.set_xlim(0, 100) + ax.set_ylim(0, 1) # ax.set_title(f'Model {self.model_id}, group: {self.model_group_id}') # ax.set_title('Precision-Recall Curve') - ax.set_title(f'Model: {self.model_id}, Group: {self.model_group_id}') - + + ax.legend(frameon=False) + + if title_string is None: + ax.set_title(f'Model: {self.model_id}, Group: {self.model_group_id}') + else: + ax.set_title(title_string) return ax def plot_feature_importance(self, ax, n_top_features=20):