Skip to content

Commit

Permalink
add model grid to pr-k plots
Browse files Browse the repository at this point in the history
  • Loading branch information
kasunamare committed Dec 21, 2023
1 parent 0f4a7c3 commit 7b5fb5f
Showing 1 changed file with 18 additions and 8 deletions.
26 changes: 18 additions & 8 deletions src/triage/component/postmodeling/model_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -770,7 +770,7 @@ def plot_precision_threshold_curve(self, ax, matrix_uuid=None):
return ax

# TODO: Facilitate plotting pr-k with absolute thresholds
def plot_precision_recall_curve(self, ax=None, matrix_uuid=None, list_size_upper_bound_pct=1, pct_step_size=0.01, subset_hash=None):
def plot_precision_recall_curve(self, ax=None, matrix_uuid=None, list_size_upper_bound_pct=1, pct_step_size=0.01, subset_hash=None, title_string=None):
"""
Plots precision-recall curves at each train_end_time for all model groups
Expand Down Expand Up @@ -833,7 +833,7 @@ def plot_precision_recall_curve(self, ax=None, matrix_uuid=None, list_size_upper

else:
logging.debug(f'Found saved predictions for model id {self.model_id}.(group: {self.model_group_id})')
k_values = np.arange(0 + pct_step_size, list_size_upper_bound_pct + pct_step_size, pct_step_size)
k_values = np.arange(0, list_size_upper_bound_pct + pct_step_size, pct_step_size)

precisions = list()
recalls = list()
Expand All @@ -845,21 +845,31 @@ def plot_precision_recall_curve(self, ax=None, matrix_uuid=None, list_size_upper

pred_pos = pred_df.iloc[:num_above_thresh]

precision = pred_pos.label_value.sum() / num_above_thresh
recall = pred_pos.label_value.sum() / num_positives

precision = pred_pos.label_value.sum() / num_above_thresh if num_above_thresh > 0 else 0
recall = pred_pos.label_value.sum() / num_positives if num_above_thresh > 0 else 0

precisions.append(precision)
recalls.append(recall)
sns.lineplot(x=k_values * 100, y=precisions, ax=ax, label='Precision@k')
sns.lineplot(x=k_values * 100, y=recalls, ax=ax, label='Recall@k')

with sns.axes_style("darkgrid"):
sns.lineplot(x=k_values * 100, y=precisions, ax=ax, label='Precision@k')
sns.lineplot(x=k_values * 100, y=recalls, ax=ax, label='Recall@k')


ax.set_xlabel('Population percentage (k %)')
ax.set_ylabel('Metric Value')
ax.set_xlim(0, 100)
ax.set_ylim(0, 1)
# ax.set_title(f'Model {self.model_id}, group: {self.model_group_id}')
# ax.set_title('Precision-Recall Curve')
ax.set_title(f'Model: {self.model_id}, Group: {self.model_group_id}')


ax.legend(frameon=False)

if title_string is None:
ax.set_title(f'Model: {self.model_id}, Group: {self.model_group_id}')
else:
ax.set_title(title_string)
return ax

def plot_feature_importance(self, ax, n_top_features=20):
Expand Down

0 comments on commit 7b5fb5f

Please sign in to comment.