Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SemanticSegmentationTask: add class-wise metrics #2130

Open
wants to merge 35 commits into
base: main
Choose a base branch
from
Open
Changes from 13 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
23fa1fb
Add average metrics
robmarkcole Jun 19, 2024
b7d8305
Add average metrics
robmarkcole Jun 19, 2024
b1526fa
refactor: Rename metrics in SemanticSegmentationTask
robmarkcole Jun 19, 2024
341e272
Ruff format
robmarkcole Jun 19, 2024
024feda
Use ignore_index
robmarkcole Jun 20, 2024
04cac59
pass on_epoch
robmarkcole Jun 20, 2024
56f20fc
on_epoch to train too
robmarkcole Jun 20, 2024
3d2b309
Disable on_step for train metrics
robmarkcole Jun 20, 2024
9af1493
Merge branch 'main' into update-metrics
robmarkcole Jun 20, 2024
192c496
ruff format
robmarkcole Jun 20, 2024
73b710f
Merge branch 'main' into update-metrics
robmarkcole Jun 21, 2024
8ce8c30
Merge branch 'main' into update-metrics
robmarkcole Jun 23, 2024
e4ed9fd
Merge branch 'main' into update-metrics
robmarkcole Jul 2, 2024
d9c2688
Merge branch 'main' into update-metrics
robmarkcole Jul 8, 2024
400fae3
Merge branch 'main' into update-metrics
robmarkcole Jul 21, 2024
f4c793e
Merge branch 'main' into update-metrics
robmarkcole Aug 1, 2024
3b629ea
Merge branch 'main' into update-metrics
robmarkcole Aug 5, 2024
e6abadd
Merge branch 'main' into update-metrics
robmarkcole Aug 6, 2024
da887fe
Bump min torchmetrics
robmarkcole Aug 6, 2024
5138ccb
Merge branch 'update-metrics' of https://github.com/robmarkcole/torch…
robmarkcole Aug 6, 2024
479c7e3
Raise torchmetrics min
robmarkcole Aug 6, 2024
50b7d29
remo on_epoch etc
robmarkcole Aug 7, 2024
1cd436f
Remove on_epoch
robmarkcole Aug 7, 2024
9e985e2
try torchmetrics==1.1.0
robmarkcole Aug 7, 2024
c773322
try torchmetrics==1.1.1
robmarkcole Aug 7, 2024
9d8c8e4
Merge branch 'main' into update-metrics
robmarkcole Aug 7, 2024
e2640f5
Use loop to generate metrics
robmarkcole Aug 8, 2024
19187a9
Update
robmarkcole Aug 8, 2024
a3f7ffe
Fix jaccard
robmarkcole Aug 8, 2024
9a66442
fix dependencies delta
robmarkcole Aug 8, 2024
8381cb7
fix pyproject
robmarkcole Aug 8, 2024
b5050ad
Merge branch 'main' into update-metrics
robmarkcole Sep 5, 2024
07e7c4d
Merge branch 'main' into update-metrics
robmarkcole Dec 31, 2024
59ba3c8
Specify on_epoch
robmarkcole Jan 2, 2025
1184647
Ruff format
robmarkcole Jan 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 124 additions & 17 deletions torchgeo/trainers/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,14 @@
from matplotlib.figure import Figure
from torch import Tensor
from torchmetrics import MetricCollection
from torchmetrics.classification import MulticlassAccuracy, MulticlassJaccardIndex
from torchmetrics.classification import (
Accuracy,
FBetaScore,
JaccardIndex,
Precision,
Recall,
)
from torchmetrics.wrappers import ClasswiseWrapper
from torchvision.models._api import WeightsEnum

from ..datasets import RGBBandsMissingError, unbind_samples
Expand All @@ -31,6 +38,7 @@ def __init__(
weights: WeightsEnum | str | bool | None = None,
in_channels: int = 3,
num_classes: int = 1000,
labels: list[str] | None = None,
num_filters: int = 3,
loss: str = 'ce',
class_weights: Tensor | None = None,
Expand All @@ -55,6 +63,7 @@ def __init__(
are not supported yet.
in_channels: Number of input channels to model.
num_classes: Number of prediction classes (including the background).
labels: List of class labels.
num_filters: Number of filters. Only applicable when model='fcn'.
loss: Name of the loss function, currently supports
'ce', 'jaccard' or 'focal' loss.
Expand Down Expand Up @@ -190,27 +199,112 @@ def configure_metrics(self) -> None:
.. note::
* 'Micro' averaging suits overall performance evaluation but may not reflect
minority class accuracy.
* 'Macro' averaging, not used here, gives equal weight to each class, useful
* 'Macro' averaging gives equal weight to each class, useful
for balanced performance assessment across imbalanced classes.
"""
num_classes: int = self.hparams['num_classes']
ignore_index: int | None = self.hparams['ignore_index']
metrics = MetricCollection(
[
MulticlassAccuracy(
labels: list[str] | None = self.hparams['labels']

self.train_metrics = MetricCollection(
robmarkcole marked this conversation as resolved.
Show resolved Hide resolved
{
'OverallAccuracy': Accuracy(
task='multiclass',
num_classes=num_classes,
average='micro',
multidim_average='global',
ignore_index=ignore_index,
),
'OverallF1Score': FBetaScore(
task='multiclass',
num_classes=num_classes,
beta=1.0,
average='micro',
multidim_average='global',
ignore_index=ignore_index,
),
'OverallJaccardIndex': JaccardIndex(
task='multiclass',
num_classes=num_classes,
ignore_index=ignore_index,
average='micro',
),
MulticlassJaccardIndex(
num_classes=num_classes, ignore_index=ignore_index, average='micro'
'AverageAccuracy': Accuracy(
task='multiclass',
num_classes=num_classes,
average='macro',
multidim_average='global',
ignore_index=ignore_index,
),
]
'AverageF1Score': FBetaScore(
task='multiclass',
num_classes=num_classes,
beta=1.0,
average='macro',
multidim_average='global',
ignore_index=ignore_index,
),
'AverageJaccardIndex': JaccardIndex(
task='multiclass',
num_classes=num_classes,
ignore_index=ignore_index,
average='macro',
),
'Accuracy': ClasswiseWrapper(
Accuracy(
task='multiclass',
num_classes=num_classes,
average='none',
multidim_average='global',
ignore_index=ignore_index,
),
labels=labels,
),
'Precision': ClasswiseWrapper(
Precision(
task='multiclass',
num_classes=num_classes,
average='none',
multidim_average='global',
ignore_index=ignore_index,
),
labels=labels,
),
'Recall': ClasswiseWrapper(
Recall(
task='multiclass',
num_classes=num_classes,
average='none',
multidim_average='global',
ignore_index=ignore_index,
),
labels=labels,
),
'F1Score': ClasswiseWrapper(
FBetaScore(
task='multiclass',
num_classes=num_classes,
beta=1.0,
average='none',
multidim_average='global',
ignore_index=ignore_index,
),
labels=labels,
),
'JaccardIndex': ClasswiseWrapper(
JaccardIndex(
task='multiclass',
num_classes=num_classes,
average='none',
ignore_index=ignore_index,
),
labels=labels,
),
},
prefix='train_',
)
self.train_metrics = metrics.clone(prefix='train_')
self.val_metrics = metrics.clone(prefix='val_')
self.test_metrics = metrics.clone(prefix='test_')
self.val_metrics = self.train_metrics.clone(prefix='val_')
self.test_metrics = self.train_metrics.clone(prefix='test_')

def training_step(
self, batch: Any, batch_idx: int, dataloader_idx: int = 0
Expand All @@ -230,9 +324,14 @@ def training_step(
batch_size = x.shape[0]
y_hat = self(x)
loss: Tensor = self.criterion(y_hat, y)
self.log('train_loss', loss, batch_size=batch_size)
self.log('train_loss', loss, batch_size=batch_size, on_epoch=True, on_step=True)
self.train_metrics(y_hat, y)
self.log_dict(self.train_metrics, batch_size=batch_size)
self.log_dict(
robmarkcole marked this conversation as resolved.
Show resolved Hide resolved
{f'{k}': v for k, v in self.train_metrics.compute().items()},
robmarkcole marked this conversation as resolved.
Show resolved Hide resolved
robmarkcole marked this conversation as resolved.
Show resolved Hide resolved
batch_size=batch_size,
on_epoch=True,
on_step=False,
)
return loss

def validation_step(
Expand All @@ -250,9 +349,13 @@ def validation_step(
batch_size = x.shape[0]
y_hat = self(x)
loss = self.criterion(y_hat, y)
self.log('val_loss', loss, batch_size=batch_size)
self.log('val_loss', loss, batch_size=batch_size, on_epoch=True)
robmarkcole marked this conversation as resolved.
Show resolved Hide resolved
self.val_metrics(y_hat, y)
self.log_dict(self.val_metrics, batch_size=batch_size)
self.log_dict(
{f'{k}': v for k, v in self.val_metrics.compute().items()},
robmarkcole marked this conversation as resolved.
Show resolved Hide resolved
batch_size=batch_size,
on_epoch=True,
)

if (
batch_idx < 10
Expand Down Expand Up @@ -294,9 +397,13 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None
batch_size = x.shape[0]
y_hat = self(x)
loss = self.criterion(y_hat, y)
self.log('test_loss', loss, batch_size=batch_size)
self.log('test_loss', loss, batch_size=batch_size, on_epoch=True)
self.test_metrics(y_hat, y)
self.log_dict(self.test_metrics, batch_size=batch_size)
self.log_dict(
{f'{k}': v for k, v in self.test_metrics.compute().items()},
batch_size=batch_size,
on_epoch=True,
)

def predict_step(
self, batch: Any, batch_idx: int, dataloader_idx: int = 0
Expand Down
Loading