diff --git a/clinicadl/dictionary/__init__.py b/clinicadl/dictionary/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/clinicadl/dictionary/suffixes.py b/clinicadl/dictionary/suffixes.py new file mode 100644 index 000000000..77e7da05a --- /dev/null +++ b/clinicadl/dictionary/suffixes.py @@ -0,0 +1,7 @@ +TSV = ".tsv" +JSON = ".json" +LOG = ".log" +PT = ".pt" +NII = ".nii" +GZ = ".gz" +NII_GZ = NII + GZ diff --git a/clinicadl/dictionary/utils.py b/clinicadl/dictionary/utils.py new file mode 100644 index 000000000..54caa46d3 --- /dev/null +++ b/clinicadl/dictionary/utils.py @@ -0,0 +1 @@ +SEP = "\t" diff --git a/clinicadl/dictionary/words.py b/clinicadl/dictionary/words.py new file mode 100644 index 000000000..9c287c677 --- /dev/null +++ b/clinicadl/dictionary/words.py @@ -0,0 +1,32 @@ +AGE = "age" +AUGMENTATION = "augmentation" +BASELINE = "baseline" +BEST = "best" +CONFIG = "config" +COUNT = "count" +CUDA = "cuda" +DESCRIPTION = "description" +FOLD = "fold" +ID = "id" +IMAGE = "image" +KFOLD = "k" + FOLD +LABEL = "label" +MEAN = "mean" +OBJECT = "object" +PARTICIPANT = "participant" +PARTICIPANT_ID = PARTICIPANT + "_" + ID +PROPORTION = "proportion" +SAMPLE = "sample" +SEX = "sex" +SESSION = "session" +SESSION_ID = SESSION + "_" + ID +SINGLE = "single" +SPLIT = "split" +STATISTIC = "statistic" +STD = "std" +TEST = "test" +TMP = "tmp" +TRAIN = "train" +TRANSFORMATION = "transformation" +VALIDATION = "validation" +VALUE = "value" diff --git a/clinicadl/transforms/extraction/base.py b/clinicadl/transforms/extraction/base.py index bc15a0a92..34e1ba16f 100644 --- a/clinicadl/transforms/extraction/base.py +++ b/clinicadl/transforms/extraction/base.py @@ -9,6 +9,7 @@ import torchio as tio from pydantic import computed_field +from clinicadl.dictionary.words import IMAGE, LABEL, SAMPLE from clinicadl.utils.config import ClinicaDLConfig from clinicadl.utils.enum import ExtractionMethod @@ -244,7 +245,7 @@ def extract_tio_sample( IndexError If 'sample_index' is greater or equal to the number of samples in the image. """ - if not hasattr(tio_image, "image") or not isinstance( + if not hasattr(tio_image, IMAGE) or not isinstance( tio_image.image, tio.ScalarImage ): raise AttributeError( @@ -268,7 +269,7 @@ def extract_tio_sample( ) tio_sample.sample = tio_sample.image - delattr(tio_sample, "image") + delattr(tio_sample, IMAGE) return tio_sample @@ -278,14 +279,14 @@ def _check_tio_sample(tio_sample: tio.Subject): Checks that a TorchIO Subject is a valid sample, i.e. a sample with a TorchIO ScalarImage named 'sample', a label named 'label' and a description named 'description'. """ - if not hasattr(tio_sample, "sample") or not isinstance( + if not hasattr(tio_sample, SAMPLE) or not isinstance( tio_sample.sample, tio.ScalarImage ): raise AttributeError( "'tio_sample' must contain ScalarImage named 'image'. Got only the following images: " f"{tio_sample.get_images_names()}" ) - if not hasattr(tio_sample, "label"): + if not hasattr(tio_sample, LABEL): raise AttributeError( "'tio_sample' must contain an attribute named 'label'." ) diff --git a/clinicadl/transforms/extraction/image.py b/clinicadl/transforms/extraction/image.py index 87b301490..1aa08096c 100644 --- a/clinicadl/transforms/extraction/image.py +++ b/clinicadl/transforms/extraction/image.py @@ -6,14 +6,13 @@ import torchio as tio from pydantic import PositiveInt, computed_field +from clinicadl.dictionary.suffixes import PT from clinicadl.utils.enum import ExtractionMethod from .base import Extraction, Sample logger = getLogger("clinicadl.extraction.image") -PT = ".pt" - class ImageSample(Sample): """ diff --git a/clinicadl/transforms/extraction/patch.py b/clinicadl/transforms/extraction/patch.py index 9473f6a7d..9b7e70e0c 100644 --- a/clinicadl/transforms/extraction/patch.py +++ b/clinicadl/transforms/extraction/patch.py @@ -6,14 +6,13 @@ import torchio as tio from pydantic import NonNegativeInt, PositiveInt, computed_field, field_validator +from clinicadl.dictionary.suffixes import PT from clinicadl.utils.enum import ExtractionMethod from .base import Extraction, Sample logger = getLogger("clinicadl.extraction.patch") -PT = ".pt" - class PatchSample(Sample): """ diff --git a/clinicadl/transforms/extraction/slice.py b/clinicadl/transforms/extraction/slice.py index 59bd2c314..8396d82cc 100644 --- a/clinicadl/transforms/extraction/slice.py +++ b/clinicadl/transforms/extraction/slice.py @@ -14,6 +14,7 @@ ) from typing_extensions import Self +from clinicadl.dictionary.suffixes import PT from clinicadl.utils.enum import ( ExtractionMethod, SliceDirection, @@ -24,8 +25,6 @@ logger = getLogger("clinicadl.extraction.slice") -PT = ".pt" - class SliceSample(Sample): """ diff --git a/clinicadl/transforms/transforms.py b/clinicadl/transforms/transforms.py index 4055f4218..b8c3565d4 100644 --- a/clinicadl/transforms/transforms.py +++ b/clinicadl/transforms/transforms.py @@ -4,6 +4,7 @@ import torchvision.transforms as torch_transforms from pydantic import model_validator +from clinicadl.dictionary.words import AUGMENTATION, IMAGE, OBJECT, TRANSFORMATION from clinicadl.transforms.extraction import Extraction, Image from clinicadl.transforms.factory import ( MinMaxNormalization, @@ -115,8 +116,8 @@ def __str__(self) -> str: def _to_str( list_: list[Callable] = [], - object_: str = "object", - transfo_: str = "transformation", + object_: str = OBJECT, + transfo_: str = TRANSFORMATION, ): str_ = "" if list_: @@ -128,13 +129,13 @@ def _to_str( return str_ - transform_str += _to_str(self.image_transforms, object_="image") - transform_str += _to_str(self.object_transforms, object_="object") + transform_str += _to_str(self.image_transforms, object_=IMAGE) + transform_str += _to_str(self.object_transforms, object_=OBJECT) transform_str += _to_str( - self.image_augmentation, object_="image", transfo_="augmentation" + self.image_augmentation, object_=IMAGE, transfo_=AUGMENTATION ) transform_str += _to_str( - self.object_augmentation, object_="object", transfo_="augmentation" + self.object_augmentation, object_=OBJECT, transfo_=AUGMENTATION ) return transform_str diff --git a/clinicadl/transforms/utils.py b/clinicadl/transforms/utils.py index 884edb620..f40a32e8a 100644 --- a/clinicadl/transforms/utils.py +++ b/clinicadl/transforms/utils.py @@ -3,6 +3,8 @@ import torch import torchio as tio +from clinicadl.dictionary.words import LABEL + def get_tio_image( image: torch.Tensor, @@ -32,9 +34,9 @@ def get_tio_image( tio_image = tio.Subject(image=tio.ScalarImage(tensor=image)) if isinstance(label, torch.Tensor): - tio_image.add_image(tio.LabelMap(tensor=label), "label") + tio_image.add_image(tio.LabelMap(tensor=label), LABEL) else: - setattr(tio_image, "label", label) + setattr(tio_image, LABEL, label) for name, mask in masks.items(): tio_image.add_image(tio.LabelMap(tensor=mask), name)