From 23fa1fb5da3ff1de9b0ddea8b500be3597ebcfde Mon Sep 17 00:00:00 2001 From: Robin Cole Date: Wed, 19 Jun 2024 09:13:30 +0000 Subject: [PATCH 01/22] Add average metrics --- torchgeo/trainers/segmentation.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py index afd71521002..33bd683d92d 100644 --- a/torchgeo/trainers/segmentation.py +++ b/torchgeo/trainers/segmentation.py @@ -190,23 +190,32 @@ def configure_metrics(self) -> None: .. note:: * 'Micro' averaging suits overall performance evaluation but may not reflect minority class accuracy. - * 'Macro' averaging, not used here, gives equal weight to each class, useful + * 'Macro' averaging gives equal weight to each class, useful for balanced performance assessment across imbalanced classes. """ num_classes: int = self.hparams['num_classes'] ignore_index: int | None = self.hparams['ignore_index'] metrics = MetricCollection( - [ - MulticlassAccuracy( + { + 'OverallAccuracy': MulticlassAccuracy( num_classes=num_classes, ignore_index=ignore_index, multidim_average='global', average='micro', ), - MulticlassJaccardIndex( + 'AverageAccuracy': MulticlassAccuracy( + num_classes=num_classes, + ignore_index=ignore_index, + multidim_average='global', + average='macro', + ), + 'OverallJaccardIndex': MulticlassJaccardIndex( num_classes=num_classes, ignore_index=ignore_index, average='micro' ), - ] + 'AverageJaccardIndex': MulticlassJaccardIndex( + num_classes=num_classes, ignore_index=ignore_index, average='macro' + ), + } ) self.train_metrics = metrics.clone(prefix='train_') self.val_metrics = metrics.clone(prefix='val_') From b7d83057770c5e0eb46a475da28d57cce36a23df Mon Sep 17 00:00:00 2001 From: Robin Cole Date: Wed, 19 Jun 2024 09:52:17 +0000 Subject: [PATCH 02/22] Add average metrics --- torchgeo/trainers/segmentation.py | 117 +++++++++++++++++++++++++----- 1 file changed, 100 insertions(+), 17 deletions(-) diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py index 33bd683d92d..738df612f49 100644 --- a/torchgeo/trainers/segmentation.py +++ b/torchgeo/trainers/segmentation.py @@ -4,7 +4,7 @@ """Trainers for semantic segmentation.""" import os -from typing import Any +from typing import Any, Optional, List import matplotlib.pyplot as plt import segmentation_models_pytorch as smp @@ -12,7 +12,14 @@ from matplotlib.figure import Figure from torch import Tensor from torchmetrics import MetricCollection -from torchmetrics.classification import MulticlassAccuracy, MulticlassJaccardIndex +from torchmetrics.classification import ( + Accuracy, + FBetaScore, + JaccardIndex, + Precision, + Recall, +) +from torchmetrics.wrappers import ClasswiseWrapper from torchvision.models._api import WeightsEnum from ..datasets import RGBBandsMissingError, unbind_samples @@ -31,6 +38,7 @@ def __init__( weights: WeightsEnum | str | bool | None = None, in_channels: int = 3, num_classes: int = 1000, + labels: Optional[List[str]] = None, num_filters: int = 3, loss: str = 'ce', class_weights: Tensor | None = None, @@ -55,6 +63,7 @@ def __init__( are not supported yet. in_channels: Number of input channels to model. num_classes: Number of prediction classes (including the background). + labels: List of class labels. num_filters: Number of filters. Only applicable when model='fcn'. loss: Name of the loss function, currently supports 'ce', 'jaccard' or 'focal' loss. @@ -195,31 +204,96 @@ def configure_metrics(self) -> None: """ num_classes: int = self.hparams['num_classes'] ignore_index: int | None = self.hparams['ignore_index'] - metrics = MetricCollection( + labels: Optional[List[str]] = self.hparams['labels'] + + self.train_metrics = MetricCollection( { - 'OverallAccuracy': MulticlassAccuracy( + 'OverallAccuracy': Accuracy( + task='multiclass', num_classes=num_classes, - ignore_index=ignore_index, + average='micro', multidim_average='global', + ), + 'OverallF1Score': FBetaScore( + task='multiclass', + num_classes=num_classes, + beta=1.0, average='micro', + multidim_average='global', ), - 'AverageAccuracy': MulticlassAccuracy( + 'OverallIoU': JaccardIndex( + task='multiclass', num_classes=num_classes, ignore_index=ignore_index, + average='micro', + ), + 'AverageAccuracy': Accuracy( + task='multiclass', + num_classes=num_classes, + average='macro', multidim_average='global', + ), + 'AverageF1Score': FBetaScore( + task='multiclass', + num_classes=num_classes, + beta=1.0, average='macro', + multidim_average='global', + ), + 'AverageIoU': JaccardIndex( + task='multiclass', + num_classes=num_classes, + ignore_index=ignore_index, + average='macro', + ), + 'Accuracy': ClasswiseWrapper( + Accuracy( + task='multiclass', + num_classes=num_classes, + average='none', + multidim_average='global', + ), + labels=labels, ), - 'OverallJaccardIndex': MulticlassJaccardIndex( - num_classes=num_classes, ignore_index=ignore_index, average='micro' + 'Precision': ClasswiseWrapper( + Precision( + task='multiclass', + num_classes=num_classes, + average='none', + multidim_average='global', + ), + labels=labels, ), - 'AverageJaccardIndex': MulticlassJaccardIndex( - num_classes=num_classes, ignore_index=ignore_index, average='macro' + 'Recall': ClasswiseWrapper( + Recall( + task='multiclass', + num_classes=num_classes, + average='none', + multidim_average='global', + ), + labels=labels, ), - } + 'F1Score': ClasswiseWrapper( + FBetaScore( + task='multiclass', + num_classes=num_classes, + beta=1.0, + average='none', + multidim_average='global', + ), + labels=labels, + ), + 'IoU': ClasswiseWrapper( + JaccardIndex( + task='multiclass', num_classes=num_classes, average='none' + ), + labels=labels, + ), + }, + prefix='train_', ) - self.train_metrics = metrics.clone(prefix='train_') - self.val_metrics = metrics.clone(prefix='val_') - self.test_metrics = metrics.clone(prefix='test_') + self.val_metrics = self.train_metrics.clone(prefix='val_') + self.test_metrics = self.train_metrics.clone(prefix='test_') def training_step( self, batch: Any, batch_idx: int, dataloader_idx: int = 0 @@ -241,7 +315,10 @@ def training_step( loss: Tensor = self.criterion(y_hat, y) self.log('train_loss', loss, batch_size=batch_size) self.train_metrics(y_hat, y) - self.log_dict(self.train_metrics, batch_size=batch_size) + self.log_dict( + {f'{k}': v for k, v in self.train_metrics.compute().items()}, + batch_size=batch_size, + ) return loss def validation_step( @@ -261,7 +338,10 @@ def validation_step( loss = self.criterion(y_hat, y) self.log('val_loss', loss, batch_size=batch_size) self.val_metrics(y_hat, y) - self.log_dict(self.val_metrics, batch_size=batch_size) + self.log_dict( + {f'{k}': v for k, v in self.val_metrics.compute().items()}, + batch_size=batch_size, + ) if ( batch_idx < 10 @@ -305,7 +385,10 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None loss = self.criterion(y_hat, y) self.log('test_loss', loss, batch_size=batch_size) self.test_metrics(y_hat, y) - self.log_dict(self.test_metrics, batch_size=batch_size) + self.log_dict( + {f'{k}': v for k, v in self.test_metrics.compute().items()}, + batch_size=batch_size, + ) def predict_step( self, batch: Any, batch_idx: int, dataloader_idx: int = 0 From b1526fa9bbea7274b17dd018b1f0296bbf4fa9b9 Mon Sep 17 00:00:00 2001 From: Robin Cole Date: Wed, 19 Jun 2024 10:05:15 +0000 Subject: [PATCH 03/22] refactor: Rename metrics in SemanticSegmentationTask --- torchgeo/trainers/segmentation.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py index 738df612f49..600c71e126b 100644 --- a/torchgeo/trainers/segmentation.py +++ b/torchgeo/trainers/segmentation.py @@ -221,7 +221,7 @@ def configure_metrics(self) -> None: average='micro', multidim_average='global', ), - 'OverallIoU': JaccardIndex( + 'OverallJaccardIndex': JaccardIndex( task='multiclass', num_classes=num_classes, ignore_index=ignore_index, @@ -240,7 +240,7 @@ def configure_metrics(self) -> None: average='macro', multidim_average='global', ), - 'AverageIoU': JaccardIndex( + 'AverageJaccardIndex': JaccardIndex( task='multiclass', num_classes=num_classes, ignore_index=ignore_index, @@ -283,7 +283,7 @@ def configure_metrics(self) -> None: ), labels=labels, ), - 'IoU': ClasswiseWrapper( + 'JaccardIndex': ClasswiseWrapper( JaccardIndex( task='multiclass', num_classes=num_classes, average='none' ), From 341e2724f967b3c2439842afba1df2db7fa6cfef Mon Sep 17 00:00:00 2001 From: Robin Cole Date: Wed, 19 Jun 2024 10:27:05 +0000 Subject: [PATCH 04/22] Ruff format --- torchgeo/trainers/segmentation.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py index 600c71e126b..cb932a5a2cb 100644 --- a/torchgeo/trainers/segmentation.py +++ b/torchgeo/trainers/segmentation.py @@ -4,7 +4,7 @@ """Trainers for semantic segmentation.""" import os -from typing import Any, Optional, List +from typing import Any import matplotlib.pyplot as plt import segmentation_models_pytorch as smp @@ -38,7 +38,7 @@ def __init__( weights: WeightsEnum | str | bool | None = None, in_channels: int = 3, num_classes: int = 1000, - labels: Optional[List[str]] = None, + labels: list[str] | None = None, num_filters: int = 3, loss: str = 'ce', class_weights: Tensor | None = None, @@ -204,7 +204,7 @@ def configure_metrics(self) -> None: """ num_classes: int = self.hparams['num_classes'] ignore_index: int | None = self.hparams['ignore_index'] - labels: Optional[List[str]] = self.hparams['labels'] + labels: list[str] | None = self.hparams['labels'] self.train_metrics = MetricCollection( { From 024feda9527ed11ee0204848d3413d773eef7fdf Mon Sep 17 00:00:00 2001 From: Robin Cole Date: Thu, 20 Jun 2024 09:00:59 +0000 Subject: [PATCH 05/22] Use ignore_index --- torchgeo/trainers/segmentation.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py index cb932a5a2cb..94bcb1bf18c 100644 --- a/torchgeo/trainers/segmentation.py +++ b/torchgeo/trainers/segmentation.py @@ -213,6 +213,7 @@ def configure_metrics(self) -> None: num_classes=num_classes, average='micro', multidim_average='global', + ignore_index=ignore_index, ), 'OverallF1Score': FBetaScore( task='multiclass', @@ -220,6 +221,7 @@ def configure_metrics(self) -> None: beta=1.0, average='micro', multidim_average='global', + ignore_index=ignore_index, ), 'OverallJaccardIndex': JaccardIndex( task='multiclass', @@ -232,6 +234,7 @@ def configure_metrics(self) -> None: num_classes=num_classes, average='macro', multidim_average='global', + ignore_index=ignore_index, ), 'AverageF1Score': FBetaScore( task='multiclass', @@ -239,6 +242,7 @@ def configure_metrics(self) -> None: beta=1.0, average='macro', multidim_average='global', + ignore_index=ignore_index, ), 'AverageJaccardIndex': JaccardIndex( task='multiclass', @@ -252,6 +256,7 @@ def configure_metrics(self) -> None: num_classes=num_classes, average='none', multidim_average='global', + ignore_index=ignore_index, ), labels=labels, ), @@ -261,6 +266,7 @@ def configure_metrics(self) -> None: num_classes=num_classes, average='none', multidim_average='global', + ignore_index=ignore_index, ), labels=labels, ), @@ -270,6 +276,7 @@ def configure_metrics(self) -> None: num_classes=num_classes, average='none', multidim_average='global', + ignore_index=ignore_index, ), labels=labels, ), @@ -280,12 +287,16 @@ def configure_metrics(self) -> None: beta=1.0, average='none', multidim_average='global', + ignore_index=ignore_index, ), labels=labels, ), 'JaccardIndex': ClasswiseWrapper( JaccardIndex( - task='multiclass', num_classes=num_classes, average='none' + task='multiclass', + num_classes=num_classes, + average='none', + ignore_index=ignore_index, ), labels=labels, ), From 04cac5925c9993d9bbfcdcad2f2dd4469321ae85 Mon Sep 17 00:00:00 2001 From: Robin Cole Date: Thu, 20 Jun 2024 09:10:36 +0000 Subject: [PATCH 06/22] pass on_epoch --- torchgeo/trainers/segmentation.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py index 94bcb1bf18c..29adcefaab1 100644 --- a/torchgeo/trainers/segmentation.py +++ b/torchgeo/trainers/segmentation.py @@ -293,8 +293,8 @@ def configure_metrics(self) -> None: ), 'JaccardIndex': ClasswiseWrapper( JaccardIndex( - task='multiclass', - num_classes=num_classes, + task='multiclass', + num_classes=num_classes, average='none', ignore_index=ignore_index, ), @@ -347,11 +347,12 @@ def validation_step( batch_size = x.shape[0] y_hat = self(x) loss = self.criterion(y_hat, y) - self.log('val_loss', loss, batch_size=batch_size) + self.log('val_loss', loss, batch_size=batch_size, on_epoch=True) self.val_metrics(y_hat, y) self.log_dict( {f'{k}': v for k, v in self.val_metrics.compute().items()}, batch_size=batch_size, + on_epoch=True, ) if ( @@ -394,11 +395,12 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None batch_size = x.shape[0] y_hat = self(x) loss = self.criterion(y_hat, y) - self.log('test_loss', loss, batch_size=batch_size) + self.log('test_loss', loss, batch_size=batch_size, on_epoch=True) self.test_metrics(y_hat, y) self.log_dict( {f'{k}': v for k, v in self.test_metrics.compute().items()}, batch_size=batch_size, + on_epoch=True, ) def predict_step( From 56f20fc3ab88f20a0ca34206d744b6f1a226604d Mon Sep 17 00:00:00 2001 From: Robin Cole Date: Thu, 20 Jun 2024 09:12:19 +0000 Subject: [PATCH 07/22] on_epoch to train too --- torchgeo/trainers/segmentation.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py index 29adcefaab1..125fb4b64c6 100644 --- a/torchgeo/trainers/segmentation.py +++ b/torchgeo/trainers/segmentation.py @@ -324,11 +324,12 @@ def training_step( batch_size = x.shape[0] y_hat = self(x) loss: Tensor = self.criterion(y_hat, y) - self.log('train_loss', loss, batch_size=batch_size) + self.log('train_loss', loss, batch_size=batch_size, on_epoch=True) self.train_metrics(y_hat, y) self.log_dict( {f'{k}': v for k, v in self.train_metrics.compute().items()}, batch_size=batch_size, + on_epoch=True, ) return loss From 3d2b309805d0e0bd4360b60c22527b431fcb2fdf Mon Sep 17 00:00:00 2001 From: Robin Cole Date: Thu, 20 Jun 2024 13:40:57 +0000 Subject: [PATCH 08/22] Disable on_step for train metrics --- torchgeo/trainers/segmentation.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py index 125fb4b64c6..ee45d503bb0 100644 --- a/torchgeo/trainers/segmentation.py +++ b/torchgeo/trainers/segmentation.py @@ -324,12 +324,19 @@ def training_step( batch_size = x.shape[0] y_hat = self(x) loss: Tensor = self.criterion(y_hat, y) - self.log('train_loss', loss, batch_size=batch_size, on_epoch=True) + self.log( + 'train_loss', + loss, + batch_size=batch_size, + on_epoch=True, + on_step=True, + ) self.train_metrics(y_hat, y) self.log_dict( {f'{k}': v for k, v in self.train_metrics.compute().items()}, batch_size=batch_size, on_epoch=True, + on_step=False, ) return loss From 192c4967fec6eccf5c4ba58ebad0f5485e778cd2 Mon Sep 17 00:00:00 2001 From: Robin Cole Date: Thu, 20 Jun 2024 13:44:32 +0000 Subject: [PATCH 09/22] ruff format --- torchgeo/trainers/segmentation.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py index ee45d503bb0..fd4ffa8c294 100644 --- a/torchgeo/trainers/segmentation.py +++ b/torchgeo/trainers/segmentation.py @@ -324,13 +324,7 @@ def training_step( batch_size = x.shape[0] y_hat = self(x) loss: Tensor = self.criterion(y_hat, y) - self.log( - 'train_loss', - loss, - batch_size=batch_size, - on_epoch=True, - on_step=True, - ) + self.log('train_loss', loss, batch_size=batch_size, on_epoch=True, on_step=True) self.train_metrics(y_hat, y) self.log_dict( {f'{k}': v for k, v in self.train_metrics.compute().items()}, From da887fe796ae757e277d7ad93712a68c900c2cf1 Mon Sep 17 00:00:00 2001 From: Robin Cole Date: Tue, 6 Aug 2024 13:56:17 +0000 Subject: [PATCH 10/22] Bump min torchmetrics --- requirements/min-reqs.old | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/min-reqs.old b/requirements/min-reqs.old index 9475f2289f2..1a267bb7438 100644 --- a/requirements/min-reqs.old +++ b/requirements/min-reqs.old @@ -18,7 +18,7 @@ segmentation-models-pytorch==0.2.0 shapely==1.8.0 timm==0.4.12 torch==1.13.0 -torchmetrics==0.10.0 +torchmetrics==1.1.0 torchvision==0.14.0 # datasets From 479c7e3517ce0da3999a142cc39081dbef3a3b7f Mon Sep 17 00:00:00 2001 From: Robin Cole Date: Tue, 6 Aug 2024 14:15:28 +0000 Subject: [PATCH 11/22] Raise torchmetrics min --- requirements/min-reqs.old | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/min-reqs.old b/requirements/min-reqs.old index 8647adf4d3b..df38927e86b 100644 --- a/requirements/min-reqs.old +++ b/requirements/min-reqs.old @@ -18,7 +18,7 @@ segmentation-models-pytorch==0.2.0 shapely==1.8.0 timm==0.4.12 torch==1.13.0 -torchmetrics==1.1.0 +torchmetrics==1.2.0 torchvision==0.14.0 # datasets From 50b7d29c8d05c44080497abda6417992e8665f71 Mon Sep 17 00:00:00 2001 From: Robin Cole Date: Wed, 7 Aug 2024 13:35:01 +0000 Subject: [PATCH 12/22] remo on_epoch etc --- .pre-commit-config.yaml | 2 +- pyproject.toml | 4 ++-- torchgeo/trainers/segmentation.py | 6 +----- 3 files changed, 4 insertions(+), 8 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b6cbc81a5e7..a236db00c76 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -31,7 +31,7 @@ repos: - pyvista>=0.34.2 - scikit-image>=0.22.0 - torch>=2.3 - - torchmetrics>=0.10 + - torchmetrics>=1.2.0 - torchvision>=0.18 exclude: (build|data|dist|logo|logs|output)/ - repo: https://github.com/pre-commit/mirrors-prettier diff --git a/pyproject.toml b/pyproject.toml index fa90653782f..b6fddfd210e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,8 +72,8 @@ dependencies = [ "timm>=0.4.12", # torch 1.13+ required by torchvision "torch>=1.13", - # torchmetrics 0.10+ required for binary/multiclass/multilabel classification metrics - "torchmetrics>=0.10", + # torchmetrics 1.2.0+ required for MetricCollection + "torchmetrics>=1.2.0", # torchvision 0.14+ required for torchvision.models.swin_v2_b "torchvision>=0.14", ] diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py index fd4ffa8c294..60db2d0d8e8 100644 --- a/torchgeo/trainers/segmentation.py +++ b/torchgeo/trainers/segmentation.py @@ -329,8 +329,6 @@ def training_step( self.log_dict( {f'{k}': v for k, v in self.train_metrics.compute().items()}, batch_size=batch_size, - on_epoch=True, - on_step=False, ) return loss @@ -349,12 +347,11 @@ def validation_step( batch_size = x.shape[0] y_hat = self(x) loss = self.criterion(y_hat, y) - self.log('val_loss', loss, batch_size=batch_size, on_epoch=True) + self.log('val_loss', loss, batch_size=batch_size) self.val_metrics(y_hat, y) self.log_dict( {f'{k}': v for k, v in self.val_metrics.compute().items()}, batch_size=batch_size, - on_epoch=True, ) if ( @@ -402,7 +399,6 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None self.log_dict( {f'{k}': v for k, v in self.test_metrics.compute().items()}, batch_size=batch_size, - on_epoch=True, ) def predict_step( From 1cd436f36cfe58617d3d1d99c30ccb014355071d Mon Sep 17 00:00:00 2001 From: Robin Cole Date: Wed, 7 Aug 2024 13:39:01 +0000 Subject: [PATCH 13/22] Remove on_epoch --- torchgeo/trainers/segmentation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py index 60db2d0d8e8..1522c96cdc4 100644 --- a/torchgeo/trainers/segmentation.py +++ b/torchgeo/trainers/segmentation.py @@ -324,7 +324,7 @@ def training_step( batch_size = x.shape[0] y_hat = self(x) loss: Tensor = self.criterion(y_hat, y) - self.log('train_loss', loss, batch_size=batch_size, on_epoch=True, on_step=True) + self.log('train_loss', loss, batch_size=batch_size) self.train_metrics(y_hat, y) self.log_dict( {f'{k}': v for k, v in self.train_metrics.compute().items()}, @@ -394,7 +394,7 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None batch_size = x.shape[0] y_hat = self(x) loss = self.criterion(y_hat, y) - self.log('test_loss', loss, batch_size=batch_size, on_epoch=True) + self.log('test_loss', loss, batch_size=batch_size) self.test_metrics(y_hat, y) self.log_dict( {f'{k}': v for k, v in self.test_metrics.compute().items()}, From 9e985e2b648d947faf6134d471ae45375bfb62ee Mon Sep 17 00:00:00 2001 From: Robin Cole Date: Wed, 7 Aug 2024 14:27:33 +0000 Subject: [PATCH 14/22] try torchmetrics==1.1.0 --- .pre-commit-config.yaml | 2 +- pyproject.toml | 4 ++-- requirements/min-reqs.old | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a236db00c76..836f13f38b1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -31,7 +31,7 @@ repos: - pyvista>=0.34.2 - scikit-image>=0.22.0 - torch>=2.3 - - torchmetrics>=1.2.0 + - torchmetrics>=1.1.0 - torchvision>=0.18 exclude: (build|data|dist|logo|logs|output)/ - repo: https://github.com/pre-commit/mirrors-prettier diff --git a/pyproject.toml b/pyproject.toml index b6fddfd210e..4196de9b10f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,8 +72,8 @@ dependencies = [ "timm>=0.4.12", # torch 1.13+ required by torchvision "torch>=1.13", - # torchmetrics 1.2.0+ required for MetricCollection - "torchmetrics>=1.2.0", + # torchmetrics 1.1.0+ required + "torchmetrics>=1.1.0", # torchvision 0.14+ required for torchvision.models.swin_v2_b "torchvision>=0.14", ] diff --git a/requirements/min-reqs.old b/requirements/min-reqs.old index df38927e86b..8647adf4d3b 100644 --- a/requirements/min-reqs.old +++ b/requirements/min-reqs.old @@ -18,7 +18,7 @@ segmentation-models-pytorch==0.2.0 shapely==1.8.0 timm==0.4.12 torch==1.13.0 -torchmetrics==1.2.0 +torchmetrics==1.1.0 torchvision==0.14.0 # datasets From c773322af83e780dc1038c7ffaf7fd3c6e84f3d5 Mon Sep 17 00:00:00 2001 From: Robin Cole Date: Wed, 7 Aug 2024 14:50:34 +0000 Subject: [PATCH 15/22] try torchmetrics==1.1.1 --- .pre-commit-config.yaml | 2 +- pyproject.toml | 4 ++-- requirements/min-reqs.old | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 836f13f38b1..92ab9d4dd08 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -31,7 +31,7 @@ repos: - pyvista>=0.34.2 - scikit-image>=0.22.0 - torch>=2.3 - - torchmetrics>=1.1.0 + - torchmetrics>=1.1.1 - torchvision>=0.18 exclude: (build|data|dist|logo|logs|output)/ - repo: https://github.com/pre-commit/mirrors-prettier diff --git a/pyproject.toml b/pyproject.toml index 4196de9b10f..c4afa7f637e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,8 +72,8 @@ dependencies = [ "timm>=0.4.12", # torch 1.13+ required by torchvision "torch>=1.13", - # torchmetrics 1.1.0+ required - "torchmetrics>=1.1.0", + # torchmetrics 1.1.1+ required for average argument to MeanAveragePrecision + "torchmetrics>=1.1.1", # torchvision 0.14+ required for torchvision.models.swin_v2_b "torchvision>=0.14", ] diff --git a/requirements/min-reqs.old b/requirements/min-reqs.old index 8647adf4d3b..a4034fc8585 100644 --- a/requirements/min-reqs.old +++ b/requirements/min-reqs.old @@ -18,7 +18,7 @@ segmentation-models-pytorch==0.2.0 shapely==1.8.0 timm==0.4.12 torch==1.13.0 -torchmetrics==1.1.0 +torchmetrics==1.1.1 torchvision==0.14.0 # datasets From e2640f506be10b7097b9d431936b6dafac9a6b6a Mon Sep 17 00:00:00 2001 From: Robin Cole Date: Thu, 8 Aug 2024 08:06:05 +0000 Subject: [PATCH 16/22] Use loop to generate metrics --- torchgeo/trainers/segmentation.py | 123 ++++++++++-------------------- 1 file changed, 41 insertions(+), 82 deletions(-) diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py index 1522c96cdc4..f9f6e91636e 100644 --- a/torchgeo/trainers/segmentation.py +++ b/torchgeo/trainers/segmentation.py @@ -206,72 +206,40 @@ def configure_metrics(self) -> None: ignore_index: int | None = self.hparams['ignore_index'] labels: list[str] | None = self.hparams['labels'] - self.train_metrics = MetricCollection( - { - 'OverallAccuracy': Accuracy( - task='multiclass', - num_classes=num_classes, - average='micro', - multidim_average='global', - ignore_index=ignore_index, - ), - 'OverallF1Score': FBetaScore( - task='multiclass', - num_classes=num_classes, - beta=1.0, - average='micro', - multidim_average='global', - ignore_index=ignore_index, - ), - 'OverallJaccardIndex': JaccardIndex( - task='multiclass', - num_classes=num_classes, - ignore_index=ignore_index, - average='micro', - ), - 'AverageAccuracy': Accuracy( - task='multiclass', - num_classes=num_classes, - average='macro', - multidim_average='global', - ignore_index=ignore_index, - ), - 'AverageF1Score': FBetaScore( - task='multiclass', - num_classes=num_classes, - beta=1.0, - average='macro', - multidim_average='global', - ignore_index=ignore_index, - ), - 'AverageJaccardIndex': JaccardIndex( - task='multiclass', - num_classes=num_classes, - ignore_index=ignore_index, - average='macro', - ), - 'Accuracy': ClasswiseWrapper( - Accuracy( - task='multiclass', - num_classes=num_classes, - average='none', - multidim_average='global', - ignore_index=ignore_index, - ), - labels=labels, - ), - 'Precision': ClasswiseWrapper( - Precision( - task='multiclass', - num_classes=num_classes, - average='none', - multidim_average='global', - ignore_index=ignore_index, - ), - labels=labels, - ), - 'Recall': ClasswiseWrapper( - Recall( + metric_classes = { + 'Accuracy': Accuracy, + 'F1Score': FBetaScore, + 'JaccardIndex': JaccardIndex, + 'Precision': Precision, + 'Recall': Recall, + } + + metrics_dict = {} + + # Loop through the types of averaging + for average in ['micro', 'macro']: + for metric_name, metric_class in metric_classes.items(): + name = ( + f'Overall{metric_name}' + if average == 'micro' + else f'Average{metric_name}' + ) + params = { + 'task': 'multiclass', + 'num_classes': num_classes, + 'average': average, + 'multidim_average': 'global', + 'ignore_index': ignore_index, + } + if metric_name == 'F1Score': + params['beta'] = 1.0 + metrics_dict[name] = metric_class(**params) + + # Loop through the classwise metrics + for metric_name, metric_class in metric_classes.items(): + if metric_name != 'F1Score': + metrics_dict[metric_name] = ClasswiseWrapper( + metric_class( task='multiclass', num_classes=num_classes, average='none', @@ -279,9 +247,10 @@ def configure_metrics(self) -> None: ignore_index=ignore_index, ), labels=labels, - ), - 'F1Score': ClasswiseWrapper( - FBetaScore( + ) + else: + metrics_dict[metric_name] = ClasswiseWrapper( + metric_class( task='multiclass', num_classes=num_classes, beta=1.0, @@ -290,19 +259,9 @@ def configure_metrics(self) -> None: ignore_index=ignore_index, ), labels=labels, - ), - 'JaccardIndex': ClasswiseWrapper( - JaccardIndex( - task='multiclass', - num_classes=num_classes, - average='none', - ignore_index=ignore_index, - ), - labels=labels, - ), - }, - prefix='train_', - ) + ) + + self.train_metrics = MetricCollection(metrics_dict, prefix='train_') self.val_metrics = self.train_metrics.clone(prefix='val_') self.test_metrics = self.train_metrics.clone(prefix='test_') From 19187a91052e76ba56d422d1cc39ed6bf445f94b Mon Sep 17 00:00:00 2001 From: Robin Cole Date: Thu, 8 Aug 2024 08:24:19 +0000 Subject: [PATCH 17/22] Update --- .github/dependabot.yml | 4 ++++ pyproject.toml | 4 ++-- requirements/required.txt | 4 ++-- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/.github/dependabot.yml b/.github/dependabot.yml index eb0571076dc..e9796f63b42 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -20,6 +20,10 @@ updates: - "torch" - "torchvision" ignore: + # lightning 2.3+ contains known bugs related to YAML parsing + # https://github.com/Lightning-AI/pytorch-lightning/issues/19977 + - dependency-name: "lightning" + version: ">=2.3" # setuptools releases new versions almost daily - dependency-name: "setuptools" update-types: ["version-update:semver-patch"] diff --git a/pyproject.toml b/pyproject.toml index adcbf96de3d..c4afa7f637e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,9 +47,9 @@ dependencies = [ # https://github.com/microsoft/torchgeo/issues/1824 "lightly>=1.4.5,!=1.4.26", # lightning 2+ required for LightningCLI args + sys.argv support - # lightning 2.3 contains known bugs related to YAML parsing + # lightning 2.3+ contains known bugs related to YAML parsing # https://github.com/Lightning-AI/pytorch-lightning/issues/19977 - "lightning[pytorch-extra]>=2,!=2.3.*", + "lightning[pytorch-extra]>=2,<2.3", # matplotlib 3.5+ required for Python 3.10 wheels "matplotlib>=3.5", # numpy 1.21.2+ required by Python 3.10 wheels diff --git a/requirements/required.txt b/requirements/required.txt index 23fd3650d17..46e79defbf5 100644 --- a/requirements/required.txt +++ b/requirements/required.txt @@ -5,8 +5,8 @@ setuptools==72.1.0 einops==0.8.0 fiona==1.9.6 kornia==0.7.3 -lightly==1.5.11 -lightning[pytorch-extra]==2.4.0 +lightly==1.5.10 +lightning[pytorch-extra]==2.2.5 matplotlib==3.9.0 numpy==1.26.4 pandas==2.2.2 From a3f7ffe6bf531d213cf412e9d1511454b2c53dcd Mon Sep 17 00:00:00 2001 From: Robin Cole Date: Thu, 8 Aug 2024 08:47:16 +0000 Subject: [PATCH 18/22] Fix jaccard --- torchgeo/trainers/segmentation.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py index f9f6e91636e..81cf75a9095 100644 --- a/torchgeo/trainers/segmentation.py +++ b/torchgeo/trainers/segmentation.py @@ -228,16 +228,17 @@ def configure_metrics(self) -> None: 'task': 'multiclass', 'num_classes': num_classes, 'average': average, - 'multidim_average': 'global', 'ignore_index': ignore_index, } + if metric_name in ['Accuracy', 'F1Score', 'Precision', 'Recall']: + params['multidim_average'] = 'global' if metric_name == 'F1Score': params['beta'] = 1.0 metrics_dict[name] = metric_class(**params) # Loop through the classwise metrics for metric_name, metric_class in metric_classes.items(): - if metric_name != 'F1Score': + if metric_name != 'JaccardIndex': metrics_dict[metric_name] = ClasswiseWrapper( metric_class( task='multiclass', @@ -253,9 +254,7 @@ def configure_metrics(self) -> None: metric_class( task='multiclass', num_classes=num_classes, - beta=1.0, average='none', - multidim_average='global', ignore_index=ignore_index, ), labels=labels, From 9a664426aa426b56d66ba533e33a21836d09fba2 Mon Sep 17 00:00:00 2001 From: Robin Cole Date: Thu, 8 Aug 2024 09:10:13 +0000 Subject: [PATCH 19/22] fix dependencies delta --- .github/dependabot.yml | 4 ---- requirements/required.txt | 4 ++-- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/.github/dependabot.yml b/.github/dependabot.yml index e9796f63b42..eb0571076dc 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -20,10 +20,6 @@ updates: - "torch" - "torchvision" ignore: - # lightning 2.3+ contains known bugs related to YAML parsing - # https://github.com/Lightning-AI/pytorch-lightning/issues/19977 - - dependency-name: "lightning" - version: ">=2.3" # setuptools releases new versions almost daily - dependency-name: "setuptools" update-types: ["version-update:semver-patch"] diff --git a/requirements/required.txt b/requirements/required.txt index 46e79defbf5..23fd3650d17 100644 --- a/requirements/required.txt +++ b/requirements/required.txt @@ -5,8 +5,8 @@ setuptools==72.1.0 einops==0.8.0 fiona==1.9.6 kornia==0.7.3 -lightly==1.5.10 -lightning[pytorch-extra]==2.2.5 +lightly==1.5.11 +lightning[pytorch-extra]==2.4.0 matplotlib==3.9.0 numpy==1.26.4 pandas==2.2.2 From 8381cb7e0f06147e704cc21e460881d638b1d19b Mon Sep 17 00:00:00 2001 From: Robin Cole Date: Thu, 8 Aug 2024 09:11:49 +0000 Subject: [PATCH 20/22] fix pyproject --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c4afa7f637e..adcbf96de3d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,9 +47,9 @@ dependencies = [ # https://github.com/microsoft/torchgeo/issues/1824 "lightly>=1.4.5,!=1.4.26", # lightning 2+ required for LightningCLI args + sys.argv support - # lightning 2.3+ contains known bugs related to YAML parsing + # lightning 2.3 contains known bugs related to YAML parsing # https://github.com/Lightning-AI/pytorch-lightning/issues/19977 - "lightning[pytorch-extra]>=2,<2.3", + "lightning[pytorch-extra]>=2,!=2.3.*", # matplotlib 3.5+ required for Python 3.10 wheels "matplotlib>=3.5", # numpy 1.21.2+ required by Python 3.10 wheels From 59ba3c8df1e29d9c5c9b278a05b0467f129b5352 Mon Sep 17 00:00:00 2001 From: Robin Cole Date: Thu, 2 Jan 2025 09:37:47 +0000 Subject: [PATCH 21/22] Specify on_epoch --- torchgeo/trainers/segmentation.py | 30 +++++++++++++++++++++++++++--- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py index 3a58226ed92..997bfc09b6c 100644 --- a/torchgeo/trainers/segmentation.py +++ b/torchgeo/trainers/segmentation.py @@ -282,11 +282,19 @@ def training_step( batch_size = x.shape[0] y_hat = self(x) loss: Tensor = self.criterion(y_hat, y) - self.log('train_loss', loss, batch_size=batch_size) + self.log( + 'train_loss', + loss, + batch_size=batch_size, + on_step=True, + on_epoch=True, + ) self.train_metrics(y_hat, y) self.log_dict( {f'{k}': v for k, v in self.train_metrics.compute().items()}, batch_size=batch_size, + on_step=True, + on_epoch=True, ) return loss @@ -305,11 +313,19 @@ def validation_step( batch_size = x.shape[0] y_hat = self(x) loss = self.criterion(y_hat, y) - self.log('val_loss', loss, batch_size=batch_size) + self.log( + 'val_loss', + loss, + batch_size=batch_size, + on_step=False, + on_epoch=True, + ) self.val_metrics(y_hat, y) self.log_dict( {f'{k}': v for k, v in self.val_metrics.compute().items()}, batch_size=batch_size, + on_step=False, + on_epoch=True, ) if ( @@ -352,11 +368,19 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None batch_size = x.shape[0] y_hat = self(x) loss = self.criterion(y_hat, y) - self.log('test_loss', loss, batch_size=batch_size) + self.log( + 'test_loss', + loss, + batch_size=batch_size, + on_step=False, + on_epoch=True, + ) self.test_metrics(y_hat, y) self.log_dict( {f'{k}': v for k, v in self.test_metrics.compute().items()}, batch_size=batch_size, + on_step=False, + on_epoch=True, ) def predict_step( From 1184647bead2dbdd46bac1adef3a8594a163efff Mon Sep 17 00:00:00 2001 From: Robin Cole Date: Thu, 2 Jan 2025 09:57:58 +0000 Subject: [PATCH 22/22] Ruff format --- torchgeo/trainers/segmentation.py | 24 +++--------------------- 1 file changed, 3 insertions(+), 21 deletions(-) diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py index 997bfc09b6c..c56a06d1092 100644 --- a/torchgeo/trainers/segmentation.py +++ b/torchgeo/trainers/segmentation.py @@ -282,13 +282,7 @@ def training_step( batch_size = x.shape[0] y_hat = self(x) loss: Tensor = self.criterion(y_hat, y) - self.log( - 'train_loss', - loss, - batch_size=batch_size, - on_step=True, - on_epoch=True, - ) + self.log('train_loss', loss, batch_size=batch_size, on_step=True, on_epoch=True) self.train_metrics(y_hat, y) self.log_dict( {f'{k}': v for k, v in self.train_metrics.compute().items()}, @@ -313,13 +307,7 @@ def validation_step( batch_size = x.shape[0] y_hat = self(x) loss = self.criterion(y_hat, y) - self.log( - 'val_loss', - loss, - batch_size=batch_size, - on_step=False, - on_epoch=True, - ) + self.log('val_loss', loss, batch_size=batch_size, on_step=False, on_epoch=True) self.val_metrics(y_hat, y) self.log_dict( {f'{k}': v for k, v in self.val_metrics.compute().items()}, @@ -368,13 +356,7 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None batch_size = x.shape[0] y_hat = self(x) loss = self.criterion(y_hat, y) - self.log( - 'test_loss', - loss, - batch_size=batch_size, - on_step=False, - on_epoch=True, - ) + self.log('test_loss', loss, batch_size=batch_size, on_step=False, on_epoch=True) self.test_metrics(y_hat, y) self.log_dict( {f'{k}': v for k, v in self.test_metrics.compute().items()},