From 7be58729e3554ba6dd9851c864aa6a055a4e0ecb Mon Sep 17 00:00:00 2001 From: thibaultdvx Date: Fri, 27 Dec 2024 19:19:38 +0100 Subject: [PATCH] solve axes issue --- clinicadl/transforms/config/base.py | 12 ------------ clinicadl/transforms/config/enum.py | 2 +- .../config/intensity_augmentations.py | 9 ++++++--- .../transforms/config/spatial_augmentations.py | 12 ++++++++++-- .../config/test_intensity_augmentations.py | 17 +++++++---------- .../config/test_spatial_augmentations.py | 2 +- 6 files changed, 25 insertions(+), 29 deletions(-) diff --git a/clinicadl/transforms/config/base.py b/clinicadl/transforms/config/base.py index b3ff91d99..1347c0de5 100644 --- a/clinicadl/transforms/config/base.py +++ b/clinicadl/transforms/config/base.py @@ -131,15 +131,3 @@ def validator_masking_method(cls, v): except ValueError: pass return v - - -class _AnatomicalAxesConfig(ClinicaDLConfig): - """Config class for 'axes' option when it supports anatomical values.""" - - axes: Union[ - NumericalAxis, - Tuple[NumericalAxis, ...], - AnatomicalAxis, - Tuple[AnatomicalAxis, ...], - DefaultFromLibrary, - ] = DefaultFromLibrary.YES diff --git a/clinicadl/transforms/config/enum.py b/clinicadl/transforms/config/enum.py index 8072d261c..c832a3363 100644 --- a/clinicadl/transforms/config/enum.py +++ b/clinicadl/transforms/config/enum.py @@ -138,7 +138,7 @@ class AnatomicalAxis(str, Enum): """ LEFT_RIGHT = "LR" - ANTERIOR_POSTERIOR = "AP" + POSTERIOR_ANTERIOR = "PA" INFERIOR_SUPERIOR = "IS" diff --git a/clinicadl/transforms/config/intensity_augmentations.py b/clinicadl/transforms/config/intensity_augmentations.py index 7c367a19a..ea3faf7b1 100644 --- a/clinicadl/transforms/config/intensity_augmentations.py +++ b/clinicadl/transforms/config/intensity_augmentations.py @@ -10,8 +10,8 @@ from clinicadl.utils.factories import DefaultFromLibrary -from .base import TransformConfig, _AnatomicalAxesConfig -from .enum import ImplementedTransform, InterpolationMode +from .base import TransformConfig +from .enum import ImplementedTransform, InterpolationMode, NumericalAxis __all__ = [ "RandomMotionConfig", @@ -54,12 +54,15 @@ def validate_tuples(cls, v, field): return v -class RandomGhostingConfig(TransformConfig, _AnatomicalAxesConfig): +class RandomGhostingConfig(TransformConfig): """Config class for RandomGhosting augmentation.""" num_ghosts: Union[ NonNegativeInt, Tuple[NonNegativeInt, NonNegativeInt], DefaultFromLibrary ] = DefaultFromLibrary.YES + axes: Union[ + NumericalAxis, Tuple[NumericalAxis, ...], DefaultFromLibrary + ] = DefaultFromLibrary.YES intensity: Union[ NonNegativeFloat, Tuple[NonNegativeFloat, NonNegativeFloat], DefaultFromLibrary ] = DefaultFromLibrary.YES diff --git a/clinicadl/transforms/config/spatial_augmentations.py b/clinicadl/transforms/config/spatial_augmentations.py index 932e09baa..54261f373 100644 --- a/clinicadl/transforms/config/spatial_augmentations.py +++ b/clinicadl/transforms/config/spatial_augmentations.py @@ -10,8 +10,9 @@ from clinicadl.utils.factories import DefaultFromLibrary -from .base import TransformConfig, _AnatomicalAxesConfig +from .base import TransformConfig from .enum import ( + AnatomicalAxis, CenterMode, ImplementedTransform, InterpolationMode, @@ -28,9 +29,16 @@ ] -class RandomFlipConfig(TransformConfig, _AnatomicalAxesConfig): +class RandomFlipConfig(TransformConfig): """Config class for RandomFlip augmentation.""" + axes: Union[ + NumericalAxis, + Tuple[NumericalAxis, ...], + AnatomicalAxis, + Tuple[AnatomicalAxis, ...], + DefaultFromLibrary, + ] = DefaultFromLibrary.YES flip_probability: Union[float, DefaultFromLibrary] = DefaultFromLibrary.YES @computed_field diff --git a/tests/unittests/transforms/config/test_intensity_augmentations.py b/tests/unittests/transforms/config/test_intensity_augmentations.py index 5b73dc7f9..e3a1c21db 100644 --- a/tests/unittests/transforms/config/test_intensity_augmentations.py +++ b/tests/unittests/transforms/config/test_intensity_augmentations.py @@ -14,7 +14,8 @@ ({"num_ghosts": (-1, 1)}, "RandomGhosting"), ({"intensity": -0.1}, "RandomGhosting"), ({"intensity": (-0.1, 0.1)}, "RandomGhosting"), - ({"axes": "abc"}, "RandomGhosting"), + ({"axes": "R"}, "RandomGhosting"), + ({"axes": 3}, "RandomGhosting"), ({"restore": 1.1}, "RandomGhosting"), ({"num_spikes": 1.1}, "RandomSpike"), ({"num_spikes": -1}, "RandomSpike"), @@ -35,8 +36,11 @@ GOOD_INPUTS = [ ({"degrees": 0.5, "translation": 0.5, "num_transforms": 1}, "RandomMotion"), ({"degrees": (-0.5, 0.5), "translation": (-0.5, 0.5)}, "RandomMotion"), - ({"num_ghosts": 0, "intensity": 0.1, "restore": 0.5}, "RandomGhosting"), - ({"num_ghosts": (1, 5), "intensity": (0.1, 0.2), "restore": 0}, "RandomGhosting"), + ({"num_ghosts": 0, "axes": 0, "intensity": 0.1, "restore": 0.5}, "RandomGhosting"), + ( + {"num_ghosts": (1, 5), "axes": (0, 2), "intensity": (0.1, 0.2), "restore": 0}, + "RandomGhosting", + ), ({"num_spikes": (0, 1), "intensity": 1.0}, "RandomSpike"), ({"num_spikes": 1, "intensity": (-1.0, 1.0)}, "RandomSpike"), ({"coefficients": 0, "order": 0}, "RandomBiasField"), @@ -88,10 +92,3 @@ def test_interpolation(): for transform in ["RandomMotion"]: c = create_transform_config(transform)(image_interpolation=mode) assert c.image_interpolation == mode - - -def test_axes(): - axes = [0, 1, 2, (0, 1), "AP", "LR", "IS", ("AP", "LR", "IS")] - for ax in axes: - c = create_transform_config("RandomGhosting")(axes=ax) - assert c.axes == ax diff --git a/tests/unittests/transforms/config/test_spatial_augmentations.py b/tests/unittests/transforms/config/test_spatial_augmentations.py index d98baa4dc..afe7bc1dc 100644 --- a/tests/unittests/transforms/config/test_spatial_augmentations.py +++ b/tests/unittests/transforms/config/test_spatial_augmentations.py @@ -133,7 +133,7 @@ def test_interpolation(): def test_axes(): - axes = [0, 1, 2, (0, 1), "AP", "LR", "IS", ("AP", "LR", "IS")] + axes = [0, 1, 2, (0, 1), "LR", "PA", "IS", ("LR", "PA", "IS")] for ax in axes: c = create_transform_config("RandomFlip")(axes=ax) assert c.axes == ax