Skip to content

Commit

Permalink
handle edge case of no evaluations
Browse files Browse the repository at this point in the history
  • Loading branch information
kasunamare committed Dec 5, 2023
1 parent faaa5de commit 6fc7b6e
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 4 deletions.
7 changes: 6 additions & 1 deletion src/triage/component/postmodeling/model_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,9 +792,13 @@ def plot_precision_recall_curve(self, ax=None, matrix_uuid=None, list_size_upper


if pred_df.empty:
logging.warning('No predictions were found. Using the evaluations to generate the plot. Zoomed PR-K not supported!')
logging.warning(f'No predictions were found for model id {self.model_id} (group: {self.model_group_id}). Using the evaluations to generate the plot. Zoomed PR-K not supported!')
eval_df = self.get_evaluations(matrix_uuid=matrix_uuid)

if eval_df.empty:
logging.error('No evaluations were found! Returning empty axes!')
return ax

eval_df['perc_points'] = [x.split('_')[0] for x in eval_df['parameter'].tolist()]
eval_df['perc_points'] = pd.to_numeric(eval_df['perc_points'])

Expand Down Expand Up @@ -823,6 +827,7 @@ def plot_precision_recall_curve(self, ax=None, matrix_uuid=None, list_size_upper
)

else:
logging.debug(f'Found saved predictions for model id {self.model_id}.(group: {self.model_group_id})')
k_values = np.arange(0 + pct_step_size, list_size_upper_bound_pct + pct_step_size, pct_step_size)

precisions = list()
Expand Down
6 changes: 3 additions & 3 deletions src/triage/component/postmodeling/report_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def _make_plot_grid(self, plot_type, subplot_width=3, subplot_len=None, sharey=F
else:
ax = axes[i, j]

plot_func(ax=ax, **kw)
ax = plot_func(ax=ax, **kw)

if j==0:
ax.set_ylabel(f'{train_end_time}')
Expand All @@ -212,8 +212,8 @@ def plot_calibration_curves(self):
"""calibration curves for all models"""
self._make_plot_grid(plot_type='plot_calibration_curve')

def plot_prk_curves(self):
self._make_plot_grid(plot_type='plot_precision_recall_curve')
def plot_prk_curves(self, **kw):
self._make_plot_grid(plot_type='plot_precision_recall_curve', **kw)

def plot_bias_threshold(self, attribute_name, attribute_values, bias_metric):
"""
Expand Down

0 comments on commit 6fc7b6e

Please sign in to comment.