Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ClinicaDL dictionary #688

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
7 changes: 7 additions & 0 deletions clinicadl/dictionary/suffixes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
TSV = ".tsv"
JSON = ".json"
LOG = ".log"
PT = ".pt"
NII = ".nii"
GZ = ".gz"
NII_GZ = NII + GZ
1 change: 1 addition & 0 deletions clinicadl/dictionary/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
SEP = "\t"
32 changes: 32 additions & 0 deletions clinicadl/dictionary/words.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
AGE = "age"
AUGMENTATION = "augmentation"
BASELINE = "baseline"
BEST = "best"
CONFIG = "config"
COUNT = "count"
CUDA = "cuda"
DESCRIPTION = "description"
FOLD = "fold"
ID = "id"
IMAGE = "image"
KFOLD = "k" + FOLD
LABEL = "label"
MEAN = "mean"
OBJECT = "object"
PARTICIPANT = "participant"
PARTICIPANT_ID = PARTICIPANT + "_" + ID
PROPORTION = "proportion"
SAMPLE = "sample"
SEX = "sex"
SESSION = "session"
SESSION_ID = SESSION + "_" + ID
SINGLE = "single"
SPLIT = "split"
STATISTIC = "statistic"
STD = "std"
TEST = "test"
TMP = "tmp"
TRAIN = "train"
TRANSFORMATION = "transformation"
VALIDATION = "validation"
VALUE = "value"
9 changes: 5 additions & 4 deletions clinicadl/transforms/extraction/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torchio as tio
from pydantic import computed_field

from clinicadl.dictionary.words import IMAGE, LABEL, SAMPLE
from clinicadl.utils.config import ClinicaDLConfig
from clinicadl.utils.enum import ExtractionMethod

Expand Down Expand Up @@ -244,7 +245,7 @@ def extract_tio_sample(
IndexError
If 'sample_index' is greater or equal to the number of samples in the image.
"""
if not hasattr(tio_image, "image") or not isinstance(
if not hasattr(tio_image, IMAGE) or not isinstance(
tio_image.image, tio.ScalarImage
):
raise AttributeError(
Expand All @@ -268,7 +269,7 @@ def extract_tio_sample(
)

tio_sample.sample = tio_sample.image
delattr(tio_sample, "image")
delattr(tio_sample, IMAGE)

return tio_sample

Expand All @@ -278,14 +279,14 @@ def _check_tio_sample(tio_sample: tio.Subject):
Checks that a TorchIO Subject is a valid sample, i.e. a sample with a TorchIO ScalarImage
named 'sample', a label named 'label' and a description named 'description'.
"""
if not hasattr(tio_sample, "sample") or not isinstance(
if not hasattr(tio_sample, SAMPLE) or not isinstance(
tio_sample.sample, tio.ScalarImage
):
raise AttributeError(
"'tio_sample' must contain ScalarImage named 'image'. Got only the following images: "
f"{tio_sample.get_images_names()}"
)
if not hasattr(tio_sample, "label"):
if not hasattr(tio_sample, LABEL):
raise AttributeError(
"'tio_sample' must contain an attribute named 'label'."
)
Expand Down
3 changes: 1 addition & 2 deletions clinicadl/transforms/extraction/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,13 @@
import torchio as tio
from pydantic import PositiveInt, computed_field

from clinicadl.dictionary.suffixes import PT
from clinicadl.utils.enum import ExtractionMethod

from .base import Extraction, Sample

logger = getLogger("clinicadl.extraction.image")

PT = ".pt"


class ImageSample(Sample):
"""
Expand Down
3 changes: 1 addition & 2 deletions clinicadl/transforms/extraction/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,13 @@
import torchio as tio
from pydantic import NonNegativeInt, PositiveInt, computed_field, field_validator

from clinicadl.dictionary.suffixes import PT
from clinicadl.utils.enum import ExtractionMethod

from .base import Extraction, Sample

logger = getLogger("clinicadl.extraction.patch")

PT = ".pt"


class PatchSample(Sample):
"""
Expand Down
3 changes: 1 addition & 2 deletions clinicadl/transforms/extraction/slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)
from typing_extensions import Self

from clinicadl.dictionary.suffixes import PT
from clinicadl.utils.enum import (
ExtractionMethod,
SliceDirection,
Expand All @@ -24,8 +25,6 @@

logger = getLogger("clinicadl.extraction.slice")

PT = ".pt"


class SliceSample(Sample):
"""
Expand Down
13 changes: 7 additions & 6 deletions clinicadl/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torchvision.transforms as torch_transforms
from pydantic import model_validator

from clinicadl.dictionary.words import AUGMENTATION, IMAGE, OBJECT, TRANSFORMATION
from clinicadl.transforms.extraction import Extraction, Image
from clinicadl.transforms.factory import (
MinMaxNormalization,
Expand Down Expand Up @@ -115,8 +116,8 @@ def __str__(self) -> str:

def _to_str(
list_: list[Callable] = [],
object_: str = "object",
transfo_: str = "transformation",
object_: str = OBJECT,
transfo_: str = TRANSFORMATION,
):
str_ = ""
if list_:
Expand All @@ -128,13 +129,13 @@ def _to_str(

return str_

transform_str += _to_str(self.image_transforms, object_="image")
transform_str += _to_str(self.object_transforms, object_="object")
transform_str += _to_str(self.image_transforms, object_=IMAGE)
transform_str += _to_str(self.object_transforms, object_=OBJECT)
transform_str += _to_str(
self.image_augmentation, object_="image", transfo_="augmentation"
self.image_augmentation, object_=IMAGE, transfo_=AUGMENTATION
)
transform_str += _to_str(
self.object_augmentation, object_="object", transfo_="augmentation"
self.object_augmentation, object_=OBJECT, transfo_=AUGMENTATION
)

return transform_str
Expand Down
6 changes: 4 additions & 2 deletions clinicadl/transforms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import torch
import torchio as tio

from clinicadl.dictionary.words import LABEL


def get_tio_image(
image: torch.Tensor,
Expand Down Expand Up @@ -32,9 +34,9 @@ def get_tio_image(
tio_image = tio.Subject(image=tio.ScalarImage(tensor=image))

if isinstance(label, torch.Tensor):
tio_image.add_image(tio.LabelMap(tensor=label), "label")
tio_image.add_image(tio.LabelMap(tensor=label), LABEL)
else:
setattr(tio_image, "label", label)
setattr(tio_image, LABEL, label)

for name, mask in masks.items():
tio_image.add_image(tio.LabelMap(tensor=mask), name)
Expand Down
Loading