Skip to content

Commit

Permalink
add test metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
mwalmsley committed Jan 15, 2024
1 parent e3124b7 commit c0b2e41
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions zoobot/pytorch/estimators/define_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,12 @@ def setup_metrics(self, nan_strategy='error'): # may sometimes want to ignore n
self.loss_metrics = torch.nn.ModuleDict({
'train/supervised_loss': torchmetrics.MeanMetric(nan_strategy=nan_strategy),
'validation/supervised_loss': torchmetrics.MeanMetric(nan_strategy=nan_strategy),
'test/supervised_loss': torchmetrics.MeanMetric(nan_strategy=nan_strategy),
})

# TODO handle when schema doesn't exist
question_metric_dict = {}
for step_name in ['train', 'validation']: # TODO test
for step_name in ['train', 'validation', 'test']:
question_metric_dict.update({
step_name + '/question_loss/' + question.text: torchmetrics.MeanMetric(nan_strategy='ignore')
for question in self.schema.questions
Expand All @@ -77,7 +78,7 @@ def setup_metrics(self, nan_strategy='error'): # may sometimes want to ignore n

campaigns = schema_to_campaigns(self.schema)
campaign_metric_dict = {}
for step_name in ['train', 'validation']:
for step_name in ['train', 'validation', 'test']:
campaign_metric_dict.update({
step_name + '/campaign_loss/' + campaign: torchmetrics.MeanMetric(nan_strategy='ignore')
for campaign in campaigns
Expand Down

0 comments on commit c0b2e41

Please sign in to comment.