Skip to content

Commit

Permalink
Specify on_epoch
Browse files Browse the repository at this point in the history
  • Loading branch information
robmarkcole committed Jan 2, 2025
1 parent 07e7c4d commit 59ba3c8
Showing 1 changed file with 27 additions and 3 deletions.
30 changes: 27 additions & 3 deletions torchgeo/trainers/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,11 +282,19 @@ def training_step(
batch_size = x.shape[0]
y_hat = self(x)
loss: Tensor = self.criterion(y_hat, y)
self.log('train_loss', loss, batch_size=batch_size)
self.log(
'train_loss',
loss,
batch_size=batch_size,
on_step=True,
on_epoch=True,
)
self.train_metrics(y_hat, y)
self.log_dict(
{f'{k}': v for k, v in self.train_metrics.compute().items()},
batch_size=batch_size,
on_step=True,
on_epoch=True,
)
return loss

Expand All @@ -305,11 +313,19 @@ def validation_step(
batch_size = x.shape[0]
y_hat = self(x)
loss = self.criterion(y_hat, y)
self.log('val_loss', loss, batch_size=batch_size)
self.log(
'val_loss',
loss,
batch_size=batch_size,
on_step=False,
on_epoch=True,
)
self.val_metrics(y_hat, y)
self.log_dict(
{f'{k}': v for k, v in self.val_metrics.compute().items()},
batch_size=batch_size,
on_step=False,
on_epoch=True,
)

if (
Expand Down Expand Up @@ -352,11 +368,19 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None
batch_size = x.shape[0]
y_hat = self(x)
loss = self.criterion(y_hat, y)
self.log('test_loss', loss, batch_size=batch_size)
self.log(
'test_loss',
loss,
batch_size=batch_size,
on_step=False,
on_epoch=True,
)
self.test_metrics(y_hat, y)
self.log_dict(
{f'{k}': v for k, v in self.test_metrics.compute().items()},
batch_size=batch_size,
on_step=False,
on_epoch=True,
)

def predict_step(
Expand Down

0 comments on commit 59ba3c8

Please sign in to comment.