Skip to content

Commit

Permalink
solve axes issue
Browse files Browse the repository at this point in the history
  • Loading branch information
thibaultdvx committed Dec 27, 2024
1 parent 1711484 commit 7be5872
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 29 deletions.
12 changes: 0 additions & 12 deletions clinicadl/transforms/config/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,15 +131,3 @@ def validator_masking_method(cls, v):
except ValueError:
pass
return v


class _AnatomicalAxesConfig(ClinicaDLConfig):
"""Config class for 'axes' option when it supports anatomical values."""

axes: Union[
NumericalAxis,
Tuple[NumericalAxis, ...],
AnatomicalAxis,
Tuple[AnatomicalAxis, ...],
DefaultFromLibrary,
] = DefaultFromLibrary.YES
2 changes: 1 addition & 1 deletion clinicadl/transforms/config/enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ class AnatomicalAxis(str, Enum):
"""

LEFT_RIGHT = "LR"
ANTERIOR_POSTERIOR = "AP"
POSTERIOR_ANTERIOR = "PA"
INFERIOR_SUPERIOR = "IS"


Expand Down
9 changes: 6 additions & 3 deletions clinicadl/transforms/config/intensity_augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@

from clinicadl.utils.factories import DefaultFromLibrary

from .base import TransformConfig, _AnatomicalAxesConfig
from .enum import ImplementedTransform, InterpolationMode
from .base import TransformConfig
from .enum import ImplementedTransform, InterpolationMode, NumericalAxis

__all__ = [
"RandomMotionConfig",
Expand Down Expand Up @@ -54,12 +54,15 @@ def validate_tuples(cls, v, field):
return v


class RandomGhostingConfig(TransformConfig, _AnatomicalAxesConfig):
class RandomGhostingConfig(TransformConfig):
"""Config class for RandomGhosting augmentation."""

num_ghosts: Union[
NonNegativeInt, Tuple[NonNegativeInt, NonNegativeInt], DefaultFromLibrary
] = DefaultFromLibrary.YES
axes: Union[
NumericalAxis, Tuple[NumericalAxis, ...], DefaultFromLibrary
] = DefaultFromLibrary.YES
intensity: Union[
NonNegativeFloat, Tuple[NonNegativeFloat, NonNegativeFloat], DefaultFromLibrary
] = DefaultFromLibrary.YES
Expand Down
12 changes: 10 additions & 2 deletions clinicadl/transforms/config/spatial_augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@

from clinicadl.utils.factories import DefaultFromLibrary

from .base import TransformConfig, _AnatomicalAxesConfig
from .base import TransformConfig
from .enum import (
AnatomicalAxis,
CenterMode,
ImplementedTransform,
InterpolationMode,
Expand All @@ -28,9 +29,16 @@
]


class RandomFlipConfig(TransformConfig, _AnatomicalAxesConfig):
class RandomFlipConfig(TransformConfig):
"""Config class for RandomFlip augmentation."""

axes: Union[
NumericalAxis,
Tuple[NumericalAxis, ...],
AnatomicalAxis,
Tuple[AnatomicalAxis, ...],
DefaultFromLibrary,
] = DefaultFromLibrary.YES
flip_probability: Union[float, DefaultFromLibrary] = DefaultFromLibrary.YES

@computed_field
Expand Down
17 changes: 7 additions & 10 deletions tests/unittests/transforms/config/test_intensity_augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
({"num_ghosts": (-1, 1)}, "RandomGhosting"),
({"intensity": -0.1}, "RandomGhosting"),
({"intensity": (-0.1, 0.1)}, "RandomGhosting"),
({"axes": "abc"}, "RandomGhosting"),
({"axes": "R"}, "RandomGhosting"),
({"axes": 3}, "RandomGhosting"),
({"restore": 1.1}, "RandomGhosting"),
({"num_spikes": 1.1}, "RandomSpike"),
({"num_spikes": -1}, "RandomSpike"),
Expand All @@ -35,8 +36,11 @@
GOOD_INPUTS = [
({"degrees": 0.5, "translation": 0.5, "num_transforms": 1}, "RandomMotion"),
({"degrees": (-0.5, 0.5), "translation": (-0.5, 0.5)}, "RandomMotion"),
({"num_ghosts": 0, "intensity": 0.1, "restore": 0.5}, "RandomGhosting"),
({"num_ghosts": (1, 5), "intensity": (0.1, 0.2), "restore": 0}, "RandomGhosting"),
({"num_ghosts": 0, "axes": 0, "intensity": 0.1, "restore": 0.5}, "RandomGhosting"),
(
{"num_ghosts": (1, 5), "axes": (0, 2), "intensity": (0.1, 0.2), "restore": 0},
"RandomGhosting",
),
({"num_spikes": (0, 1), "intensity": 1.0}, "RandomSpike"),
({"num_spikes": 1, "intensity": (-1.0, 1.0)}, "RandomSpike"),
({"coefficients": 0, "order": 0}, "RandomBiasField"),
Expand Down Expand Up @@ -88,10 +92,3 @@ def test_interpolation():
for transform in ["RandomMotion"]:
c = create_transform_config(transform)(image_interpolation=mode)
assert c.image_interpolation == mode


def test_axes():
axes = [0, 1, 2, (0, 1), "AP", "LR", "IS", ("AP", "LR", "IS")]
for ax in axes:
c = create_transform_config("RandomGhosting")(axes=ax)
assert c.axes == ax
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def test_interpolation():


def test_axes():
axes = [0, 1, 2, (0, 1), "AP", "LR", "IS", ("AP", "LR", "IS")]
axes = [0, 1, 2, (0, 1), "LR", "PA", "IS", ("LR", "PA", "IS")]
for ax in axes:
c = create_transform_config("RandomFlip")(axes=ax)
assert c.axes == ax

0 comments on commit 7be5872

Please sign in to comment.