From 464dddf5069f7cbde8730ebda58db9da63fdf3f3 Mon Sep 17 00:00:00 2001 From: Thibault de Varax <154365476+thibaultdvx@users.noreply.github.com> Date: Mon, 16 Dec 2024 18:11:49 +0100 Subject: [PATCH] Continuing refactoring of the extraction objects (#686) * remove ROI and improve extraction objects --- clinicadl/dataset/datasets/caps_dataset.py | 2 +- clinicadl/dataset/utils.py | 4 +- clinicadl/transforms/extraction/__init__.py | 3 +- clinicadl/transforms/extraction/base.py | 224 +++++++++-- clinicadl/transforms/extraction/image.py | 184 +++++++-- clinicadl/transforms/extraction/patch.py | 243 +++++++++--- clinicadl/transforms/extraction/roi.py | 357 ------------------ clinicadl/transforms/extraction/slice.py | 331 ++++++++++++---- clinicadl/transforms/transforms.py | 6 +- clinicadl/transforms/utils.py | 42 +++ clinicadl/utils/config.py | 5 +- clinicadl/utils/enum.py | 8 +- .../caps_example/subjects_sessions_list.tsv | 8 + tests/unittests/transforms/__init__.py | 0 .../transforms/extraction/__init__.py | 0 .../transforms/extraction/test_image.py | 164 ++++++++ .../transforms/extraction/test_patch.py | 192 ++++++++++ .../transforms/extraction/test_slice.py | 199 ++++++++++ tests/unittests/transforms/test_extraction.py | 1 - tests/unittests/transforms/test_utils.py | 26 ++ 20 files changed, 1440 insertions(+), 559 deletions(-) delete mode 100644 clinicadl/transforms/extraction/roi.py create mode 100644 clinicadl/transforms/utils.py create mode 100644 tests/unittests/transforms/__init__.py create mode 100644 tests/unittests/transforms/extraction/__init__.py create mode 100644 tests/unittests/transforms/extraction/test_image.py create mode 100644 tests/unittests/transforms/extraction/test_patch.py create mode 100644 tests/unittests/transforms/extraction/test_slice.py delete mode 100644 tests/unittests/transforms/test_extraction.py create mode 100644 tests/unittests/transforms/test_utils.py diff --git a/clinicadl/dataset/datasets/caps_dataset.py b/clinicadl/dataset/datasets/caps_dataset.py index fe0f6c5a2..95535d73f 100644 --- a/clinicadl/dataset/datasets/caps_dataset.py +++ b/clinicadl/dataset/datasets/caps_dataset.py @@ -104,7 +104,7 @@ def elem_per_image(self): Number of elements per image. """ if not hasattr(self, "_elem_per_image"): - self._elem_per_image = self.extraction.num_elem_per_image( + self._elem_per_image = self.extraction.num_samples_per_image( image=self._get_full_image()[0] ) return self._elem_per_image diff --git a/clinicadl/dataset/utils.py b/clinicadl/dataset/utils.py index f44df1a9d..6316b00bd 100644 --- a/clinicadl/dataset/utils.py +++ b/clinicadl/dataset/utils.py @@ -128,7 +128,7 @@ def make_case_insensitive_pattern(c: str) -> str: def get_extraction( extract_method: Union[str, ExtractionMethod], -) -> type[extraction.BaseExtraction]: +) -> type[extraction.Extraction]: """ Retrieves the extraction method based on the specified extraction method. @@ -136,7 +136,7 @@ def get_extraction( extract_method (Union[str, ExtractionMethod]): The extraction method as either a string or an `ExtractionMethod` enum. Returns: - type[extraction.BaseExtraction]: The corresponding extraction class (e.g., `ROI`, `Slice`, etc.). + type[extraction.Extraction]: The corresponding extraction class (e.g., `ROI`, `Slice`, etc.). Raises: ValueError: If the provided `extract_method` is not supported or is invalid. diff --git a/clinicadl/transforms/extraction/__init__.py b/clinicadl/transforms/extraction/__init__.py index 692969f40..caadb419a 100644 --- a/clinicadl/transforms/extraction/__init__.py +++ b/clinicadl/transforms/extraction/__init__.py @@ -1,5 +1,4 @@ -from .base import BaseExtraction +from .base import Extraction from .image import Image from .patch import Patch -from .roi import ROI from .slice import Slice diff --git a/clinicadl/transforms/extraction/base.py b/clinicadl/transforms/extraction/base.py index 057b427bc..bc15a0a92 100644 --- a/clinicadl/transforms/extraction/base.py +++ b/clinicadl/transforms/extraction/base.py @@ -1,40 +1,52 @@ -from abc import abstractmethod +from abc import ABC, abstractmethod +from copy import deepcopy from logging import getLogger from pathlib import Path -from typing import List, Tuple +from typing import Any, List, Optional, Tuple, Union import nibabel as nib import torch -from pydantic import PositiveInt +import torchio as tio +from pydantic import computed_field from clinicadl.utils.config import ClinicaDLConfig from clinicadl.utils.enum import ExtractionMethod logger = getLogger("clinicadl.base_extraction") -NII_GZ = ".nii.gz" -PT = ".pt" +class Sample(ClinicaDLConfig, ABC): + """Abstract class for outputs of CapsDataset.""" -class BaseExtraction(ClinicaDLConfig): + sample: torch.Tensor + participant_id: str + session_id: str + image_path: str + label: Optional[Union[float, int, torch.Tensor]] + + @computed_field + @property + @abstractmethod + def extraction(self) -> ExtractionMethod: + """The extraction method.""" + + +class Extraction(ClinicaDLConfig, ABC): """ Abstract base class for image extraction procedures. This class defines the common structure and methods for extracting data from neuroimaging files (such as NIfTI) into a tensor representation for further processing. - - Parameters - ---------- - extract_method : ExtractionMethod - The method to be used for the extraction process (ROI, Image, Patch, Slice). - use_uncropped_image : bool, optional - A flag to specify whether to use the uncropped image, by default True. """ - extract_method: ExtractionMethod - use_uncropped_image: bool = True + @computed_field + @property + @abstractmethod + def extract_method(self) -> ExtractionMethod: + """The method to be used for the extraction process (Image, Patch, Slice).""" - def extract_image(self, input_img: Path) -> torch.Tensor: + @staticmethod + def load_image(input_img: Path) -> torch.Tensor: """ Loads a NIfTI image and converts it to a float32 tensor. @@ -55,95 +67,229 @@ def extract_image(self, input_img: Path) -> torch.Tensor: nib.loadsave.ImageFileError If the image file cannot be read as a NIfTI file. """ + if not Path(input_img).exists(): + raise FileNotFoundError(f"The path '{input_img}' does not match any file.") + try: image_array = nib.load(input_img).get_fdata(dtype="float32") # type: ignore except Exception as e: - raise FileNotFoundError(f"Failed to load the image: {input_img}") from e + raise Exception( + f"Unable to read the image in {input_img}. Consider using a nifti file format " + "('.nii' or '.nii.gz')." + ) from e + return torch.from_numpy(image_array).unsqueeze(0).float() @abstractmethod - def extract_tensor( + def extract_sample( self, image_tensor: torch.Tensor, - index: int, + sample_index: int, ) -> torch.Tensor: """ - Abstract method for extracting specific data from a given image tensor. + Abstract method for extracting a sample from a given image. Parameters ---------- image_tensor : torch.Tensor - The image tensor to extract data from. - index : int - Index indicating the element to extract. + The image tensor to extract a sample from. + sample_index : int + Index indicating the sample to extract. Returns ------- torch.Tensor - A tensor containing the extracted data. + A tensor containing the extracted sample. + + Raises + ------ + IndexError + If 'sample_index' is greater or equal to the number of samples in the image. Notes ----- This method needs to be implemented in the subclasses. """ - pass + # TODO : remove? @abstractmethod - def extract_path(self, image_path, index): + def sample_path(self, image_path: Path, sample_index: int) -> Path: """ - Abstract method for defining the path where extracted elements will be saved. + Abstract method for defining the path where extracted sample will be saved. Parameters ---------- image_path : Path Path to the original image. - index : int - Index of the element being extracted. + sample_index : int + Index of the sample being extracted. Returns ------- Path - Path where the extracted data will be saved. + Path where the extracted sample will be saved. Notes ----- This method needs to be implemented in the subclasses. """ - pass + # TODO : remove? @abstractmethod def extract(self, nii_path: Path) -> List[Tuple[Path, torch.Tensor]]: """ - Abstract method for performing the extraction based on the configured method. + Abstract method to extract all the samples. Parameters ---------- nii_path : Path Path to the NIfTI file to process. + Returns + ------- + List[Tuple[Path, torch.Tensor]] + A list of tuples, where each tuple contains an extracted sample, + and the path where to store it. + Notes ----- This method needs to be implemented in the subclasses. """ - pass @abstractmethod - def num_elem_per_image(self, image: torch.Tensor) -> PositiveInt: + def num_samples_per_image(self, image: torch.Tensor) -> int: """ - Abstract method to return the number of extracted elements per image. + Abstract method to return the number of extracted samples per image. Parameters ---------- image : torch.Tensor - The image tensor from which the number of elements will be determined. + The image tensor from which the number of samples will be determined. Returns ------- - PositiveInt - The number of extracted elements from the image. + int + The number of samples in the image. Notes ----- This method needs to be implemented in the subclasses. """ - pass + + @abstractmethod + def _get_sample_description( + self, image_tensor: torch.Tensor, sample_index: int + ) -> Any: + """A description of the sample, e.g. slice position or patch index.""" + + @abstractmethod + def format_output( + self, + tio_sample: tio.Subject, + participant_id: str, + session_id: str, + image_path: Union[str, Path], + ) -> Sample: + """ + Puts all the output information in a Sample object. + + Parameters + ---------- + tio_sample : tio.Subject + a TorchIO Subject corresponding to the sample, with at least a ScalarImage named 'sample', + an attribute named 'label' and an attribute named 'description'. + participant_id : str + the subject concerned. + session_id : str + the session concerned. + image_path : Union[str, Path] + the path of the base image, from which the sample was extracted. + + Returns + ------- + Sample + a Sample object with all the relevant information on the sample. + + Raises + ------ + AttributeError + if `tio_sample` doesn't have a TorchIO ScalarImage named 'sample', and attributes + 'label' and 'description'. + """ + + def extract_tio_sample( + self, tio_image: tio.Subject, sample_index: int + ) -> tio.Subject: + """ + Extracts a sample from a TorchIO Subject. + + Parameters + ---------- + tio_image : tio.Subject + The image as a TorchIO Subject. Can contain masks associated + to the image as well. + sample_index : int + Index indicating the sample to extract. + + Returns + ------- + tio.Subject + A new TorchIO Subject with the extracted sample, accessible via the attribute + 'sample', and the potential masks, extracted in the same way as the sample. + + Raises + ------ + AttributeError + If 'tio_image' doesn't have a TorchIO ScalarImage named 'image'. + IndexError + If 'sample_index' is greater or equal to the number of samples in the image. + """ + if not hasattr(tio_image, "image") or not isinstance( + tio_image.image, tio.ScalarImage + ): + raise AttributeError( + "'tio_image' must contain ScalarImage named 'image'. Got only the following images: " + f"{tio_image.get_images_names()}" + ) + + tio_sample = deepcopy(tio_image) + + image: tio.Image + for name, image in tio_image.get_images_dict(intensity_only=False).items(): + sample = self.extract_sample(image.tensor, sample_index) + + if isinstance(image, tio.ScalarImage): + setattr(tio_sample, name, tio.ScalarImage(tensor=sample)) + elif isinstance(image, tio.LabelMap): + setattr(tio_sample, name, tio.LabelMap(tensor=sample)) + + tio_sample.description = self._get_sample_description( + image.tensor, sample_index + ) + + tio_sample.sample = tio_sample.image + delattr(tio_sample, "image") + + return tio_sample + + @staticmethod + def _check_tio_sample(tio_sample: tio.Subject): + """ + Checks that a TorchIO Subject is a valid sample, i.e. a sample with a TorchIO ScalarImage + named 'sample', a label named 'label' and a description named 'description'. + """ + if not hasattr(tio_sample, "sample") or not isinstance( + tio_sample.sample, tio.ScalarImage + ): + raise AttributeError( + "'tio_sample' must contain ScalarImage named 'image'. Got only the following images: " + f"{tio_sample.get_images_names()}" + ) + if not hasattr(tio_sample, "label"): + raise AttributeError( + "'tio_sample' must contain an attribute named 'label'." + ) + if not hasattr(tio_sample, "description"): + raise AttributeError( + "'tio_sample' must contain an attribute named 'description'." + ) diff --git a/clinicadl/transforms/extraction/image.py b/clinicadl/transforms/extraction/image.py index e5ecd3496..87b301490 100644 --- a/clinicadl/transforms/extraction/image.py +++ b/clinicadl/transforms/extraction/image.py @@ -1,38 +1,63 @@ from logging import getLogger from pathlib import Path -from typing import Tuple +from typing import List, Tuple, Union import torch -from pydantic import PositiveInt +import torchio as tio +from pydantic import PositiveInt, computed_field -from clinicadl.transforms.extraction.base import BaseExtraction from clinicadl.utils.enum import ExtractionMethod +from .base import Extraction, Sample + logger = getLogger("clinicadl.extraction.image") -NII_GZ = ".nii.gz" PT = ".pt" -class Image(BaseExtraction): +class ImageSample(Sample): """ - Configuration class for full image extraction as a single tensor. - - This class implements the extraction process for a full image, where the entire - image is loaded and returned as a single tensor. It handles extraction using - the `ExtractionMethod.IMAGE` and saves the output as a tensor file. + Output of a CapsDataset when image extraction is performed (i.e. no extraction). Attributes ---------- - extract_method : ExtractionMethod - The method used for the extraction. For this class, it's set to IMAGE. + sample : torch.Tensor + the image as 4D PyTorch tensor (with one channel dimension). + participant_id : str + the subject concerned. + session_id : str + the session concerned. + image_path : str + the path to the image. + label : Optional[Union[float, int, torch.Tensor]] + the potential label associated to the image. """ - extract_method: ExtractionMethod = ExtractionMethod.IMAGE + @computed_field + @property + def extraction(self) -> ExtractionMethod: + """The extraction method.""" + return ExtractionMethod.IMAGE + + +class Image(Extraction): + """ + Transform class for full image extraction as a single tensor. - def extract(self, nii_path: Path) -> list[Tuple[Path, torch.Tensor]]: + This class implements the extraction process to get the full image, where the entire + image is loaded and returned as a single tensor. + """ + + @computed_field + @property + def extract_method(self) -> ExtractionMethod: + """The method to be used for the extraction process (Image, Patch, Slice).""" + return ExtractionMethod.IMAGE + + def extract(self, nii_path: Path) -> List[Tuple[Path, torch.Tensor]]: """ - Extracts the full image as a single tensor file and saves it. + Extracts the full image as a single tensor file and returns the path + where to save it. Parameters ---------- @@ -41,21 +66,21 @@ def extract(self, nii_path: Path) -> list[Tuple[Path, torch.Tensor]]: Returns ------- - list of Tuple[Path, torch.Tensor] - A list containing a tuple with the output file path and the extracted image tensor. + List[Tuple[Path, torch.Tensor]] + A list containing a single tuple with the output file path and the extracted image tensor. Notes ----- - The image is loaded, converted into a tensor, and saved with the same name as the original image but with a `.pt` extension. + The image is loaded and returned into a tensor along with the input path with the `.pt` extension. """ - image_tensor = self.extract_image(nii_path) - output_file = nii_path.with_suffix("").with_suffix(PT), image_tensor.clone() - return [output_file] + image_tensor = self.load_image(nii_path) + + return [(self.sample_path(nii_path), self.extract_sample(image_tensor))] - def extract_tensor( + def extract_sample( self, image_tensor: torch.Tensor, - index: int, + sample_index: int = 0, ) -> torch.Tensor: """ Returns the entire image tensor as no further extraction is needed. @@ -64,7 +89,7 @@ def extract_tensor( ---------- image_tensor : torch.Tensor The image tensor to extract data from. - index : int + sample_index : int (optional, default=0) The index to identify the extracted data (though this is not used in this method). Returns @@ -72,13 +97,24 @@ def extract_tensor( torch.Tensor The same image tensor as no further extraction is applied. + Raises + ------ + IndexError + If 'sample_index' is not 0. + Notes ----- This method is a placeholder in this class as the full image is returned without modification. """ - return image_tensor + if sample_index != 0: + raise IndexError( + f"'sample_index' {sample_index} is out of range as there is only " + "1 sample in the image." + ) + + return image_tensor.clone() - def extract_path(self, image_path, index): + def sample_path(self, image_path: Path, sample_index: int = 0) -> Path: # pylint:disable=unused-argument """ Returns the input image path as the path to save the extracted data. @@ -96,11 +132,11 @@ def extract_path(self, image_path, index): Notes ----- - This method does not alter the path, returning the same path as the input. + This method only changes the extension of the path. """ - return image_path + return image_path.with_suffix("").with_suffix(PT) - def num_elem_per_image(self, image: torch.Tensor) -> PositiveInt: + def num_samples_per_image(self, image: torch.Tensor) -> PositiveInt: """ Returns the number of elements per image. Since the entire image is extracted, this method always returns 1. @@ -119,3 +155,93 @@ def num_elem_per_image(self, image: torch.Tensor) -> PositiveInt: This method is specific to the full image extraction, where only one element (the image) is returned. """ return 1 + + def _get_sample_description( + self, image_tensor: torch.Tensor, sample_index: int + ) -> None: + """No need for description in the case of image extraction.""" + return None + + def format_output( + self, + tio_sample: tio.Subject, + participant_id: str, + session_id: str, + image_path: Union[str, Path], + ) -> ImageSample: + """ + Puts all the output information in an ImageSample object. + + Parameters + ---------- + tio_sample : tio.Subject + a TorchIO Subject corresponding to the image, with at least a ScalarImage named 'sample', + an attribute named 'label' and an attribute named 'description'. + participant_id : str + the subject concerned. + session_id : str + the session concerned. + image_path : Union[str, Path] + the path of the image. + + Returns + ------- + ImageSample + an ImageSample object with all the relevant information on the image. + + Raises + ------ + AttributeError + if `tio_sample` doesn't have a TorchIO ScalarImage named 'sample', and attributes + 'label' and 'description'. + """ + self._check_tio_sample(tio_sample) + + sample = tio_sample.sample.tensor + if isinstance(tio_sample.label, tio.Image): + label = tio_sample.label.tensor + else: + label = tio_sample.label + + return ImageSample( + sample=sample, + participant_id=participant_id, + session_id=session_id, + image_path=str(image_path), + label=label, + ) + + def extract_tio_sample( + self, tio_image: tio.Subject, sample_index: int = 0 + ) -> tio.Subject: + """ + Converts a TorchIO Subject representing the image to a TorchIO Subject + representing the sample (which is here the image). + + Parameters + ---------- + tio_image : tio.Subject + The image as a TorchIO Subject. Can contain masks associated + to the image as well. + sample_index : int (optional, default=0) + For consistency with other extraction methods. Always 0 here. + + Returns + ------- + tio.Subject + A new TorchIO Subject representing the sample (the full image here), accessible via the + attribute 'sample', and the potential masks, extracted in the same way as the sample. + + Raises + ------ + AttributeError + If 'tio_image' doesn't have a TorchIO ScalarImage named 'image'. + IndexError + If 'sample_index' is not 0. + + Notes + ----- + This method is trivial here as no extraction is performed. The TorchIO Subject is just converted + to another format. + """ + return super().extract_tio_sample(tio_image, sample_index) diff --git a/clinicadl/transforms/extraction/patch.py b/clinicadl/transforms/extraction/patch.py index cc2b74568..9473f6a7d 100644 --- a/clinicadl/transforms/extraction/patch.py +++ b/clinicadl/transforms/extraction/patch.py @@ -1,48 +1,102 @@ from logging import getLogger from pathlib import Path -from typing import List, Tuple +from typing import List, Tuple, Union import torch -from pydantic import PositiveInt +import torchio as tio +from pydantic import NonNegativeInt, PositiveInt, computed_field, field_validator -from clinicadl.transforms.extraction.base import BaseExtraction from clinicadl.utils.enum import ExtractionMethod +from .base import Extraction, Sample + logger = getLogger("clinicadl.extraction.patch") -NII_GZ = ".nii.gz" PT = ".pt" -class Patch(BaseExtraction): +class PatchSample(Sample): """ - Configuration class for patch extraction from an image with defined patch size and stride. - - This class extracts patches from an image tensor. The image is divided into smaller patches - using a sliding window approach, where the patch size and stride size are configurable. + Output of a CapsDataset when patch extraction is performed. Attributes ---------- - patch_size : int - The size of each patch (default is 50). - stride_size : int - The stride or step size used to move the sliding window (default is 50). - extract_method : ExtractionMethod - The extraction method used for this class, set to PATCH. + sample : torch.Tensor + the patch as 4D PyTorch tensor (with one channel dimension). + participant_id : str + the subject concerned. + session_id : str + the session concerned. + image_path : str + the path to the image from which the patch has been extracted. + label : Optional[Union[float, int, torch.Tensor]] + the potential label associated to the image. + patch_index : NonNegativeInt + the index of the patch among all patches extracted from the image. + patch_size : Tuple[PositiveInt, PositiveInt, PositiveInt] + the size of the patch. + patch_stride : Tuple[PositiveInt, PositiveInt, PositiveInt] + the stride used for patch extraction. + """ + + patch_index: NonNegativeInt + patch_size: Tuple[PositiveInt, PositiveInt, PositiveInt] + patch_stride: Tuple[PositiveInt, PositiveInt, PositiveInt] + + @computed_field + @property + def extraction(self) -> ExtractionMethod: + """The extraction method.""" + return ExtractionMethod.PATCH + + +class Patch(Extraction): + """ + Transform class to extract patches from an image. + + This class enables patches extraction from an image tensor. The image is divided into smaller patches + using a sliding window approach, where the patch size and the stride are configurable. + + Parameters + ---------- + patch_size : Union[PositiveInt, Tuple[PositiveInt, PositiveInt, PositiveInt]] (optional, default=50) + The size of each patch. If a single value is passed, the same patch size will be used for the three + spatial dimensions. + stride : Union[PositiveInt, Tuple[PositiveInt, PositiveInt, PositiveInt]] (optional, default=50) + The stride or step size used to move the sliding window. If a single value is passed, the same patch + stride will be used for the three spatial dimensions. """ - patch_size: int = 50 - stride_size: int = 50 - extract_method: ExtractionMethod = ExtractionMethod.PATCH + patch_size: Union[PositiveInt, Tuple[PositiveInt, PositiveInt, PositiveInt]] = 50 + stride: Union[PositiveInt, Tuple[PositiveInt, PositiveInt, PositiveInt]] = 50 + + @computed_field + @property + def extract_method(self) -> ExtractionMethod: + """The method to be used for the extraction process (Image, Patch, Slice).""" + return ExtractionMethod.PATCH - def num_elem_per_image(self, image: torch.Tensor) -> PositiveInt: + @field_validator("patch_size", "stride", mode="after") + @classmethod + def ensure_tuples( + cls, v: Union[PositiveInt, Tuple[PositiveInt, PositiveInt, PositiveInt]] + ) -> Tuple[PositiveInt, PositiveInt, PositiveInt]: """ - Returns the total number of patches generated from the image. + Ensures that 'patch_size' and 'stride' is always a tuple. + """ + if isinstance(v, int): + return (v, v, v) + else: + return v + + def num_samples_per_image(self, image: torch.Tensor) -> int: + """ + Returns the total number of patches extracted from an image. Parameters ---------- image : torch.Tensor - The input image tensor from which patches will be created. + The input image tensor (4D), where the first dimension represents the channel dimension. Returns ------- @@ -51,13 +105,13 @@ def num_elem_per_image(self, image: torch.Tensor) -> PositiveInt: Notes ----- - The number of patches is determined by the image size, patch size, and stride size. + The number of patches is determined by the image size, the patch size, and the stride. """ - return self.create_patches(image).shape[0] + return self.get_patches(image).shape[0] def extract(self, nii_path: Path) -> List[Tuple[Path, torch.Tensor]]: """ - Extracts patches from a NIfTI image tensor. + Extracts all the patches from a NIfTI image tensor. Parameters ---------- @@ -67,8 +121,8 @@ def extract(self, nii_path: Path) -> List[Tuple[Path, torch.Tensor]]: Returns ------- List[Tuple[Path, torch.Tensor]] - A list of tuples where each tuple contains the path to save the patch - and the corresponding patch tensor. + A list of tuples, where each tuple contains an extracted patch, + and the path where to store it. Notes ----- @@ -76,49 +130,61 @@ def extract(self, nii_path: Path) -> List[Tuple[Path, torch.Tensor]]: Each patch tensor is returned along with its associated file path. """ - image_tensor = self.extract_image(nii_path) - patches_tensor = self.create_patches(image_tensor) + image_tensor = self.load_image(nii_path) + patches_tensor = self.get_patches(image_tensor) patch_list = [ - (self.extract_path(nii_path, i), patches_tensor[i].unsqueeze(0)) - for i in range(patches_tensor.size(0)) + (self.sample_path(nii_path, idx), patches_tensor[idx]) + for idx in range(patches_tensor.size(0)) ] return patch_list - def extract_tensor( - self, image_tensor: torch.Tensor, patch_index: int + def extract_sample( + self, image_tensor: torch.Tensor, sample_index: int ) -> torch.Tensor: """ - Extracts a single patch from the image tensor. + Extracts a single patch from an image. Parameters ---------- image_tensor : torch.Tensor - The input image tensor from which a patch will be extracted. + The input image tensor from which a patch will be extracted. Must be a 4D tensor + with a channel dimension and 3 spatial dimensions. patch_index : int The index of the patch to extract from the image tensor. Returns ------- torch.Tensor - The extracted patch as a tensor, with a batch dimension added. + The extracted patch as a 4D tensor (with a channel dimension). + + Raises + ------ + IndexError + If 'sample_index' is greater or equal to the number of patches in the image. Notes ----- This method allows for the extraction of individual patches based on the provided index. """ - patches_tensor = self.create_patches(image_tensor) - return patches_tensor[patch_index, ...].unsqueeze_(0).clone() + patches_tensor = self.get_patches(image_tensor) + try: + return patches_tensor[sample_index].unsqueeze(0).clone() + except IndexError as exc: + raise IndexError( + f"'sample_index' {sample_index} is out of range as there are only " + f"{len(patches_tensor)} patches in the image." + ) from exc - def extract_path(self, img_path: Path, patch_index: int) -> Path: + def sample_path(self, image_path: Path, sample_index: int) -> Path: """ - Constructs the save path for a given patch. + Constructs the path to save a given patch. Parameters ---------- - img_path : Path - The original image path used to derive the patch's save location. - patch_index : int - The index of the patch used to generate a unique filename. + image_path : Path + The original image path, used to derive the path for saving the patch. + sample_index : int + The index of the patch being saved. Returns ------- @@ -131,26 +197,34 @@ def extract_path(self, img_path: Path, patch_index: int) -> Path: The filename is generated using the original image name, appending patch size, stride, and the patch index to ensure each patch is saved with a unique name. """ - prefix_suffix = img_path.name.rsplit("_", 1) - return Path( - f"{prefix_suffix[0]}_patchsize-{self.patch_size}_stride-{self.stride_size}_patch-{patch_index}{prefix_suffix[1].replace(NII_GZ, PT)}" + parent = image_path.parent + prefix_suffix = image_path.name.rsplit("_", 1) + patch_size_str = "x".join([str(s) for s in self.patch_size]) + stride_str = "x".join([str(s) for s in self.stride]) + return ( + ( + parent + / f"{prefix_suffix[0]}_patchsize-{patch_size_str}_stride-{stride_str}_patch-{sample_index}_{prefix_suffix[1]}" + ) + .with_suffix("") + .with_suffix(PT) ) - def create_patches(self, image_tensor: torch.Tensor) -> torch.Tensor: + def get_patches(self, image_tensor: torch.Tensor) -> torch.Tensor: """ - Creates a tensor of patches from the image using `unfold`. + Creates a tensor of patches from the image using the PyTorch method `unfold`. Parameters ---------- image_tensor : torch.Tensor - The input image tensor from which patches will be extracted. + The input image tensor (4D), where the first dimension represents the channel dimension. Returns ------- torch.Tensor A tensor containing all the patches extracted from the image. The tensor shape - will be (num_patches, patch_size, patch_size, patch_size), where `num_patches` is - determined by the image size, patch size, and stride. + will be `(num_patches, patch_size[0], patch_size[1], patch_size[2])`, where `num_patches` is + determined by the image size, the patch size, and the stride. Notes ----- @@ -158,11 +232,70 @@ def create_patches(self, image_tensor: torch.Tensor) -> torch.Tensor: The patches are then reshaped into a 4D tensor where each patch is a separate element. """ patches_tensor = ( - image_tensor.unfold(1, self.patch_size, self.stride_size) - .unfold(2, self.patch_size, self.stride_size) - .unfold(3, self.patch_size, self.stride_size) + image_tensor.unfold(1, self.patch_size[0], self.stride[0]) + .unfold(2, self.patch_size[1], self.stride[1]) + .unfold(3, self.patch_size[2], self.stride[2]) .contiguous() ) + return patches_tensor.view( - -1, self.patch_size, self.patch_size, self.patch_size + -1, self.patch_size[0], self.patch_size[1], self.patch_size[2] + ) + + def _get_sample_description( + self, image_tensor: torch.Tensor, sample_index: int + ) -> int: + """The sample description for patch extraction is the index of the patch.""" + return sample_index + + def format_output( + self, + tio_sample: tio.Subject, + participant_id: str, + session_id: str, + image_path: Union[str, Path], + ) -> PatchSample: + """ + Puts all the output information in an PatchSample object. + + Parameters + ---------- + tio_sample : tio.Subject + a TorchIO Subject corresponding to the patch, with at least a ScalarImage named 'sample', + an attribute named 'label' and an attribute named 'description'. + participant_id : str + the subject concerned. + session_id : str + the session concerned. + image_path : Union[str, Path] + the path of the image from which the patch is extracted. + + Returns + ------- + PatchSample + a PatchSample object with all the relevant information on the patch. + + Raises + ------ + AttributeError + if `tio_sample` doesn't have a TorchIO ScalarImage named 'sample', and attributes + 'label' and 'description'. + """ + self._check_tio_sample(tio_sample) + + sample = tio_sample.sample.tensor + if isinstance(tio_sample.label, tio.Image): + label = tio_sample.label.tensor + else: + label = tio_sample.label + + return PatchSample( + sample=sample, + participant_id=participant_id, + session_id=session_id, + image_path=str(image_path), + label=label, + patch_index=tio_sample.description, + patch_size=self.patch_size, + patch_stride=self.stride, ) diff --git a/clinicadl/transforms/extraction/roi.py b/clinicadl/transforms/extraction/roi.py deleted file mode 100644 index f539373b3..000000000 --- a/clinicadl/transforms/extraction/roi.py +++ /dev/null @@ -1,357 +0,0 @@ -from logging import getLogger -from pathlib import Path -from typing import List, Tuple - -import nibabel as nib -import numpy as np -import torch -from pydantic import field_validator, model_validator -from typing_extensions import Self - -from clinicadl.transforms.extraction.base import BaseExtraction -from clinicadl.utils.enum import ExtractionMethod -from clinicadl.utils.exceptions import ClinicaDLArgumentError - -logger = getLogger("clinicadl.extraction.roi") - -NII_GZ = ".nii.gz" -PT = ".pt" - - -class ROI(BaseExtraction): - """ - Configuration class for extracting regions of interest (ROIs) from images using masks. - """ - - roi_list: List[str] - roi_mask_location: Path - roi_crop_input: bool = False - roi_crop_output: bool = True - roi_template: str = "MNI152NLin2009cSym" - roi_mask_pattern: str = "res-1x1x1" - - roi_custom_template: str = "" - roi_custom_mask_pattern: str = "" - extract_method: ExtractionMethod = ExtractionMethod.ROI - - @field_validator("roi_mask_pattern", mode="before") - def validate_roi_mask_pattern(cls, v: str) -> str: - """Validates the ROI mask pattern to ensure it starts and ends with an underscore.""" - if not v: - raise ClinicaDLArgumentError("A mask pattern must be defined.") - if not v.startswith("_"): - v = "_" + v - if not v.endswith("_"): - v = v + "_" - return v - - @field_validator("roi_list", mode="before") - def validate_roi_list(cls, v: List[str]) -> List[str]: - """Validates that the ROI list is not empty.""" - if not v: - raise NotImplementedError( - "Default regions are not available anymore in ClinicaDL. " - "Please define appropriate masks and give a roi_list." - ) - if len(v) == 0: - raise ClinicaDLArgumentError("A list of regions of interest must be given.") - - return v - - @field_validator("roi_mask_location", mode="before") - def validate_roi_mask_location(cls, v: Path) -> Path: - """Validates that the given path to the mask location is a valid directory""" - if not v: - raise ClinicaDLArgumentError("A path to the mask location must be given.") - if isinstance(v, str): - v = Path(v) - if not v.is_dir(): - raise ClinicaDLArgumentError( - f"The path '{v}' is not a directory, please give another directory with masks location" - ) - - return v - - @model_validator(mode="after") - def check_mask_list(self) -> Self: - """ - Checks that all the masks in the `roi_list` are valid and binary. - - Validates that the mask files contain binary values (0 and 1) and are present - in the correct directory. If the mask is missing or contains invalid values, - it raises an error. - - Returns - ------- - Self - The validated instance of the ROI extraction configuration. - - Raises - ------ - FileNotFoundError - If the ROI mask file is not found. - ValueError - If the mask is not binary (i.e., contains values other than 0 and 1). - """ - if self.roi_mask_location.resolve().parts[-1] != f"tpl-{self.roi_template}": - self.roi_mask_location = ( - self.roi_mask_location / f"tpl-{self.roi_template}" - ) # caps_directory / "masks" = mask_location - - for roi in self.roi_list: - roi_path, desc = self.find_mask_path(roi) - if roi_path is None: - raise FileNotFoundError( - f"The ROI '{roi}' does not correspond to a mask in the CAPS directory. {desc}" - ) - roi_mask = nib.loadsave.load(roi_path).get_fdata() # type: ignore # do we need to check here ? - mask_values = set(np.unique(roi_mask)) - if mask_values != {0, 1}: - raise ValueError( - "The ROI masks used should be binary (composed of 0 and 1 only)." - ) - return self - - def num_elem_per_image(self, image: torch.Tensor) -> int: - """ - Returns the number of ROIs to extract for the given image. - - Parameters - ---------- - image : torch.Tensor - The input image tensor. - - Returns - ------- - int - The number of regions of interest (ROIs) defined in `roi_list`. - """ - return len(self.roi_list) - - # def check_preprocessing(self, preprocessing: Preprocessing): - # if preprocessing == Preprocessing.CUSTOM: - # if not self.roi_template: - # raise ClinicaDLArgumentError( - # "A custom template must be defined when the modality is set to custom." - # ) - # # self.roi_template = self.roi_custom_template - # # self.roi_mask_pattern = self.roi_custom_mask_pattern - # else: - # if preprocessing == Preprocessing.T1_LINEAR: - # self.roi_template = Template.T1_LINEAR - # self.roi_mask_pattern = Pattern.T1_LINEAR - # elif preprocessing == Preprocessing.PET_LINEAR: - # self.roi_template = Template.PET_LINEAR - # self.roi_mask_pattern = Pattern.PET_LINEAR - # elif preprocessing == Preprocessing.FLAIR_LINEAR: - # self.roi_template = Template.FLAIR_LINEAR - # self.roi_mask_pattern = Pattern.FLAIR_LINEAR - - def find_mask_path(self, roi: str) -> Tuple[Path, str]: - """ - Finds the mask corresponding to the given ROI. - - Parameters - ---------- - roi : str - The name of the region of interest (ROI). - - Returns - ------- - Tuple[Path, str] - The path to the mask for the ROI and a description of the pattern used. - - Raises - ------ - FileNotFoundError - If no mask matching the pattern is found. - """ - candidates_pattern = f"*{self.roi_mask_pattern}*roi-{roi}_mask.nii*" - - desc = f"The mask should follow the pattern {candidates_pattern}. " - - candidates = [e for e in self.roi_mask_location.glob(candidates_pattern)] - - if self.roi_crop_input is None: - # pass - candidates2 = candidates - elif self.roi_crop_input: - candidates2 = [mask for mask in candidates if "_desc-Crop_" in mask.name] - desc += "and contain '_desc-Crop_' string." - else: - candidates2 = [ - mask for mask in candidates if "_desc-Crop_" not in mask.name - ] - desc += "and not contain '_desc-Crop_' string." - - if len(candidates2) == 0: - raise FileNotFoundError( - f"Could not find any masks corresponding to the pattern asked and containing the adequate {self.roi_crop_input} description " - ) - # return None, desc - else: - return min(candidates2), desc - - def compute_output_pattern(self, mask_path: Path): - """ - Computes the output filename pattern for the cropped ROI. - - Parameters - ---------- - mask_path : Path - The path to the mask file. - - Returns - ------- - str - The computed output filename pattern. - """ - - mask_filename = mask_path.name - template_id = mask_filename.split("_")[0].split("-")[1] - mask_descriptors = mask_filename.split("_")[1:-2:] - roi_id = mask_filename.split("_")[-2].split("-")[1] - if "desc-Crop" not in mask_descriptors and not self.roi_crop_output: - mask_descriptors = ["desc-CropRoi"] + mask_descriptors - elif "desc-Crop" in mask_descriptors: - mask_descriptors = [ - descriptor - for descriptor in mask_descriptors - if descriptor != "desc-Crop" - ] - if self.roi_crop_output: - mask_descriptors = ["desc-CropRoi"] + mask_descriptors - else: - mask_descriptors = ["desc-CropImage"] + mask_descriptors - - mask_pattern = "_".join(mask_descriptors) - - if mask_pattern == "": - output_pattern = f"space-{template_id}_roi-{roi_id}" - else: - output_pattern = f"space-{template_id}_{mask_pattern}_roi-{roi_id}" - - return output_pattern - - def extract(self, nii_path: Path) -> List[Tuple[str, torch.Tensor]]: - """ - Extracts the defined regions of interest (ROIs) from the given NIfTI image. - - Parameters - ---------- - nii_path : Path - The path to the NIfTI image file. - - Returns - ------- - List[Tuple[str, torch.Tensor]] - A list of tuples, where each tuple contains the output path and the extracted ROI tensor. - """ - image_tensor = self.extract_image(nii_path) - roi_list = [] - for roi_name in self.roi_list: - mask_path, _ = self.find_mask_path(roi_name) - mask_np = nib.loadsave.load(mask_path).get_fdata() # type: ignore - roi_list.append( - ( - self.extract_tensor(image_tensor, mask_np), - self.extract_path(nii_path, mask_path), - ) - ) - return roi_list - - def extract_tensor(self, image_tensor: torch.Tensor, roi_idx: int) -> torch.Tensor: - """ - Extracts the tensor for a single ROI. - - Parameters - ---------- - image_tensor : torch.Tensor - The input image tensor. - roi_idx : int - The index of the region of interest (ROI). - - Returns - ------- - torch.Tensor - The extracted ROI tensor. - - Raises - ------ - ValueError - If the ROI mask is not a valid 3D or 4D tensor. - """ - _, mask_arrays = self._get_mask_paths_and_tensors() - mask_np = mask_arrays[roi_idx] - - if len(mask_np.shape) == 3: - mask_np = np.expand_dims(mask_np, axis=0) - elif len(mask_np.shape) == 4: - assert mask_np.shape[0] == 1 - else: - raise ValueError( - "ROI masks must be 3D or 4D tensors. " - f"The dimension of your ROI mask is {len(mask_np.shape)}." - ) - - roi_tensor = image_tensor * mask_np - if self.roi_crop_output: - roi_tensor = roi_tensor[ - np.ix_( - mask_np.any((1, 2, 3)), - mask_np.any((0, 2, 3)), - mask_np.any((0, 1, 3)), - mask_np.any((0, 1, 2)), - ) - ] - return roi_tensor.float().clone() - - def extract_path(self, img_path: Path, mask_path: Path) -> str: - """ - Computes the output path for the extracted ROI. - - Parameters - ---------- - img_path : Path - The path to the input image file. - mask_path : Path - The path to the mask file for the ROI. - - Returns - ------- - str - The computed output path. - """ - - input_img_filename = img_path.name - - sub_ses_prefix = "_".join(input_img_filename.split("_")[0:3:]) - if not sub_ses_prefix.endswith("_T1w"): - sub_ses_prefix = "_".join(input_img_filename.split("_")[0:2:]) - input_suffix = input_img_filename.split("_")[-1].split(".")[0] - - output_pattern = self.compute_output_pattern(mask_path) - - return f"{sub_ses_prefix}_{output_pattern}_{input_suffix}{PT}" - - def _get_mask_paths_and_tensors(self) -> Tuple[List[str], List]: - """ - Loads the masks necessary for extracting regions of interest (ROIs). - - Returns - ------- - Tuple[List[str], List] - A tuple containing a list of mask paths and a list of corresponding mask arrays (NIfTI data). - """ - - mask_paths, mask_arrays = list(), list() - for roi in self.roi_list: - logger.info(f"Find mask for roi {roi}.") - mask_path, desc = self.find_mask_path(roi) - if mask_path is None: - raise FileNotFoundError(desc) - mask_nii = nib.loadsave.load(mask_path) - mask_paths.append(Path(mask_path)) - mask_arrays.append(mask_nii.get_fdata()) # type: ignore - - return mask_paths, mask_arrays diff --git a/clinicadl/transforms/extraction/slice.py b/clinicadl/transforms/extraction/slice.py index 21dfb7f89..59bd2c314 100644 --- a/clinicadl/transforms/extraction/slice.py +++ b/clinicadl/transforms/extraction/slice.py @@ -1,82 +1,151 @@ from logging import getLogger from pathlib import Path -from typing import List, Tuple, Union +from typing import List, Optional, Tuple, Union +import numpy as np import torch -from pydantic import field_validator +import torchio as tio +from pydantic import ( + NonNegativeInt, + PositiveInt, + computed_field, + field_validator, + model_validator, +) +from typing_extensions import Self -from clinicadl.transforms.extraction.base import BaseExtraction from clinicadl.utils.enum import ( ExtractionMethod, SliceDirection, SliceMode, ) +from .base import Extraction, Sample + logger = getLogger("clinicadl.extraction.slice") -NII_GZ = ".nii.gz" PT = ".pt" -class Slice(BaseExtraction): +class SliceSample(Sample): + """ + Output of a CapsDataset when slice extraction is performed. + + Attributes + ---------- + sample : torch.Tensor + the 2D slice as 3D PyTorch tensor (with one channel dimension). + participant_id : str + the subject concerned. + session_id : str + the session concerned. + image_path : str + the path to the image from which the slice has been extracted. + slice_position : int + position of the slice in the original image. + slice_direction : SliceDirection + the slicing direction. Can be 0 (sagittal direction), 1 (coronal) or 2 (axial). + """ + + slice_position: NonNegativeInt + slice_direction: SliceDirection + + @computed_field + @property + def extraction(self) -> ExtractionMethod: + """The extraction method.""" + return ExtractionMethod.SLICE + + +class Slice(Extraction): """ - Configuration class for slice extraction from an image in specified directions. + Transform class to extract slices from an image in a specified direction. - This class allows users to define extraction configurations for obtaining slices from a 3D image tensor. + This class allows users to define extraction configurations for obtaining slices from a 4D image tensor. The extracted slices can be processed in different directions (e.g., sagittal, coronal, axial) and can be adjusted for RGB mode. + Parameters + ---------- + slices : Optional[List[NonNegativeInt]] (optional, default=None) + the slices to select. If None, slices will be selected with `discarded_slices`` + and/or `borders`. If all these three parameters are None, all slices will be + kept. + discarded_slices : Optional[List[NonNegativeInt]] (optional, default=None) + indices of the slices to discard. Cannot be used with `slices`. + borders : Optional[Union[PositiveInt, Tuple[PositiveInt, PositiveInt]]] (optional, default=None) + the number of border slices, that will be filtered out. If an integer `a` is passed, the first + `a` slices and the last `a` slices will be filtered out. If a tuple `(a, b)` is passed, the first + `a` slices and the last `b` slices will be filtered out. + slice_direction : SliceDirection (optional, default=SliceDirection.SAGITTAL) + the slicing direction. Can be 0 (sagittal direction), 1 (coronal) or 2 (axial). """ + slices: Optional[List[NonNegativeInt]] = None + discarded_slices: Optional[List[NonNegativeInt]] = None + borders: Optional[Union[PositiveInt, Tuple[PositiveInt, PositiveInt]]] = None slice_direction: SliceDirection = SliceDirection.SAGITTAL - slice_mode: SliceMode = SliceMode.RGB - discarded_slices: Tuple[int, int] = (0, 0) - extract_method: ExtractionMethod = ExtractionMethod.SLICE - @field_validator("discarded_slices", mode="before") - def validate_discarded_slice(cls, v: Union[int, Tuple]) -> Tuple[int, int]: - """ - Validates the discarded_slices attribute, ensuring it is either a single integer or a tuple of two integers. + @computed_field + @property + def extract_method(self) -> ExtractionMethod: + """The method to be used for the extraction process (Image, Patch, Slice).""" + return ExtractionMethod.SLICE - Raises - ------ - IndexError - If the value for discarded_slices is neither an integer nor a tuple with one or two elements. + @field_validator("borders", mode="after") + @classmethod + def validate_borders( + cls, v: Union[PositiveInt, Tuple[PositiveInt, PositiveInt]] + ) -> Tuple[PositiveInt, PositiveInt]: + """ + Ensures that 'borders' is always a tuple. """ if isinstance(v, int): return (v, v) - elif len(v) == 1: - return (v[0], v[0]) - elif len(v) == 2: - return v else: - raise IndexError( - f"Maximum two number of discarded slices can be defined. " - f"You gave discarded slices = {v}." + return v + + @model_validator(mode="after") + def validate_slices(self) -> Self: + """ + Checks consistency between 'slices', 'discarded_slices' and 'borders'. + """ + if (self.slices is not None) and (self.discarded_slices is not None): + raise ValueError( + "'slices' and 'discarded_slices' can't be passed simultaneously. Specify the wanted slices " + "in 'slices'." + ) + elif (self.slices is not None) and (self.borders is not None): + raise ValueError( + "'slices' and 'borders' can't be passed simultaneously. Specify the wanted slices " + "in 'slices'." ) + return self - def num_elem_per_image(self, image: torch.Tensor) -> int: + def num_samples_per_image(self, image: torch.Tensor) -> int: """ - Returns the number of slices that can be extracted from the input image tensor, - accounting for the discarded slices at the start and end. + Returns the number of slices that can be extracted from the input image tensor. + + If 'slices', 'discarded_slices' and 'borders' have not been passed, there is no + slice filtering, so the function will simply output the number of slices in the + image. Parameters ---------- image : torch.Tensor - The input image tensor (4D), where the first dimension represents the batch size - and the second dimension represents the slices in the specified direction. + The input image tensor (4D), where the first dimension represents the channel dimension. Returns ------- int - The number of slices available after applying the discarded slices. + The number of slices remaining after slice filtering. """ - direction = int(self.slice_direction) - return image.size(direction + 1) - sum(self.discarded_slices) + return self._get_slice_selection(image).sum() + # TODO : remove? def extract(self, nii_path: Path) -> List[Tuple[Path, torch.Tensor]]: """ - Extracts slices from the image at the specified direction, accounting for the discarded slices. + Extracts all the selected slices from the image in the specified direction. Parameters ---------- @@ -86,57 +155,60 @@ def extract(self, nii_path: Path) -> List[Tuple[Path, torch.Tensor]]: Returns ------- List[Tuple[Path, torch.Tensor]] - A list of tuples, where each tuple contains the file path for saving the slice - and the extracted slice tensor. + A list of tuples, where each tuple contains an extracted slice, + and the path where to store it. """ - image_tensor = self.extract_image(nii_path) - start, end = self.discarded_slices + image_tensor = self.load_image(nii_path) slices = [] - for i in range( - start, image_tensor.size(int(self.slice_direction.value) + 1) - end - ): - slice_tensor = self.extract_tensor(image_tensor, i) - slices.append((self.extract_path(nii_path, i), slice_tensor)) + for i in range(self.num_samples_per_image(image_tensor)): + slice_tensor = self.extract_sample(image_tensor, i).squeeze( + self.slice_direction + 1 + ) + slices.append((self.sample_path(nii_path, i), slice_tensor)) + return slices - def extract_tensor( - self, image_tensor: torch.Tensor, slice_index: int + def extract_sample( + self, image_tensor: torch.Tensor, sample_index: int ) -> torch.Tensor: """ - Extracts a single slice from the image tensor at the specified index. + Extracts a single slice from an image. Parameters ---------- image_tensor : torch.Tensor - The input image tensor, which is a 4D tensor with dimensions (batch_size, slices, height, width). - slice_index : int + The input image tensor, which is a 4D tensor with a channel dimension and 3 spatial + dimensions. + sample_index : int The index of the slice to extract in the specified direction. Returns ------- torch.Tensor - A tensor representing the extracted slice, with dimensions (3, height, width) if in RGB mode, - or (1, height, width) otherwise. + The extracted slice as a tensor. The tensor is still 4D (with the dimension in + the slice direction equal to 1). + + Raises + ------ + IndexError + If 'sample_index' is greater or equal to the number of slices in the image. """ - idx_tuple = tuple( - [slice(None)] * (int(self.slice_direction) + 1) - + [slice_index + self.discarded_slices[0]] - + [slice(None)] * (2 - int(self.slice_direction)) - ) - slice_tensor = image_tensor[idx_tuple] # shape is 1 * W * L - if self.slice_mode == SliceMode.RGB: - slice_tensor = torch.cat([slice_tensor] * 3) # shape is 3 * W * L + slice_position = self._get_slice_position(image_tensor, sample_index) + slice_tensor = self._get_slice(image_tensor, slice_position) + return slice_tensor.clone() - def extract_path(self, img_path: Path, slice_index: int) -> Path: + # TODO : remove? + def sample_path(self, image_path: Path, sample_index: int) -> Path: """ - Constructs the file path for saving a given slice, based on the input image path and slice index. + Constructs the file path for saving a given slice, based on the input image path and + the slice index. Parameters ---------- img_path : Path The path to the input image file. - slice_index : int + sample_index : int The index of the slice being saved. Returns @@ -144,10 +216,139 @@ def extract_path(self, img_path: Path, slice_index: int) -> Path: Path The constructed file path for the slice. """ - prefix_suffix = img_path.name.rsplit("_", 1) + parent = image_path.parent + prefix_suffix = image_path.name.rsplit("_", 1) slice_dict = {0: "sag", 1: "cor", 2: "axi"} - return Path( - f"{prefix_suffix[0]}_axis-{slice_dict[int(self.slice_direction.value)]}" - f"_channel-{self.slice_mode.value}_slice-{slice_index}{prefix_suffix[1].replace(NII_GZ, PT)}" + return ( + ( + parent / f"{prefix_suffix[0]}_axis-{slice_dict[self.slice_direction]}" + f"_slice-{sample_index}_{prefix_suffix[1]}" + ) + .with_suffix("") + .with_suffix(PT) + ) + + def _get_slice_selection(self, image: torch.Tensor) -> np.ndarray[bool]: + """ + Returns the slices of an image that can be extracted, depending on 'slices', + 'discarded_slices' and 'borders'. + """ + n_slices = image.size(self.slice_direction + 1) + selection = np.ones(n_slices).astype(bool) + + if self.slices: + selection = ~selection + try: + selection[self.slices] = True + except IndexError as exc: + raise IndexError( + "Invalid slices in 'slices': " + f"slices in the image are indexed from 0 to {n_slices-1}, but got " + f"slices={self.slices}." + ) from exc + else: + if self.discarded_slices: + try: + selection[self.discarded_slices] = False + except IndexError as exc: + raise IndexError( + "Invalid slices in 'discarded_slices': " + f"slices in the image are indexed from 0 to {n_slices-1}, but got " + f"discarded_slices={self.discarded_slices}." + ) from exc + + if self.borders: + selection[: self.borders[0]] = False + selection[n_slices - self.borders[1] :] = False + + return selection + + def _get_slice_position(self, image: torch.Tensor, slice_index: int) -> int: + """ + Returns the position in the image of 'slice_index'. They may differ as + 'slice_index' is the index among the selected slices. + """ + selection = self._get_slice_selection(image) + slice_positions = np.arange(len(selection))[selection] + + try: + return slice_positions[slice_index] + except IndexError as exc: + raise IndexError( + f"'sample_index' {slice_index} is out of range as there are only " + f"{len(slice_positions)} slices in the image." + ) from exc + + def _get_slice(self, image: torch.Tensor, slice_position: int) -> torch.Tensor: + """ + Gets the wanted slice, according to the slicing direction. + """ + if self.slice_direction == 0: + slice_tensor = image[:, slice_position, :, :] + elif self.slice_direction == 1: + slice_tensor = image[:, :, slice_position, :] + elif self.slice_direction == 2: + slice_tensor = image[:, :, :, slice_position] + + return slice_tensor.unsqueeze(self.slice_direction + 1) # pylint: disable=possibly-used-before-assignment + + def _get_sample_description( + self, image_tensor: torch.Tensor, sample_index: int + ) -> int: + """ + The sample description for slice extraction is the position of the slice + in the original image. + """ + return self._get_slice_position(image_tensor, sample_index) + + def format_output( + self, + tio_sample: tio.Subject, + participant_id: str, + session_id: str, + image_path: Union[str, Path], + ) -> SliceSample: + """ + Puts all the output information in an SliceSample object. + + Parameters + ---------- + tio_sample : tio.Subject + a TorchIO Subject corresponding to the slice, with at least a ScalarImage named 'sample', + an attribute named 'label' and an attribute named 'description'. + participant_id : str + the subject concerned. + session_id : str + the session concerned. + image_path : Union[str, Path] + the path of the image from which the slice is extracted. + + Returns + ------- + SliceSample + a SliceSample object with all the relevant information on the slice. + + Raises + ------ + AttributeError + if `tio_sample` doesn't have a TorchIO ScalarImage named 'sample', and attributes + 'label' and 'description'. + """ + self._check_tio_sample(tio_sample) + + sample = tio_sample.sample.tensor.squeeze(self.slice_direction + 1) + if isinstance(tio_sample.label, tio.Image): + label = tio_sample.label.tensor.squeeze(self.slice_direction + 1) + else: + label = tio_sample.label + + return SliceSample( + sample=sample, + participant_id=participant_id, + session_id=session_id, + image_path=str(image_path), + label=label, + slice_position=tio_sample.description, + slice_direction=self.slice_direction, ) diff --git a/clinicadl/transforms/transforms.py b/clinicadl/transforms/transforms.py index 86ff86d1e..4055f4218 100644 --- a/clinicadl/transforms/transforms.py +++ b/clinicadl/transforms/transforms.py @@ -4,7 +4,7 @@ import torchvision.transforms as torch_transforms from pydantic import model_validator -from clinicadl.transforms.extraction import BaseExtraction, Image +from clinicadl.transforms.extraction import Extraction, Image from clinicadl.transforms.factory import ( MinMaxNormalization, NanRemoval, @@ -28,7 +28,7 @@ class Transforms(ClinicaDLConfig): Attributes ---------- - extraction : BaseExtraction + extraction : Extraction The extraction method used for preprocessing the data. image_augmentation : list[Callable] A list of augmentation functions for images. @@ -55,7 +55,7 @@ class Transforms(ClinicaDLConfig): Returns a tuple of composed transformations for images, objects, and augmentations. """ - extraction: BaseExtraction = Image() + extraction: Extraction = Image() image_augmentation: list[Callable] = [] object_augmentation: list[Callable] = [] image_transforms: list[Callable] = [] diff --git a/clinicadl/transforms/utils.py b/clinicadl/transforms/utils.py new file mode 100644 index 000000000..884edb620 --- /dev/null +++ b/clinicadl/transforms/utils.py @@ -0,0 +1,42 @@ +from typing import Optional, Union + +import torch +import torchio as tio + + +def get_tio_image( + image: torch.Tensor, + label: Optional[Union[float, int, torch.Tensor]], + **masks: torch.Tensor, +) -> tio.Subject: + """ + Creates a TorchIO Subject from the image, the label and possibly + masks related to the image. + + Parameters + ---------- + image : torch.Tensor + the image, as a Pytorch tensor. + label : Optional[Union[float, int, torch.Tensor]] + the label related to the image. Can be None if no label. + **masks : torch.Tensor + any mask related to the image and useful to compute transforms. + + Returns + ------- + tio.Subject + the TorchIO subject with the image and the label, accessible via + the attributes 'image' and 'label', as well as the masks, accessible + via their names. + """ + tio_image = tio.Subject(image=tio.ScalarImage(tensor=image)) + + if isinstance(label, torch.Tensor): + tio_image.add_image(tio.LabelMap(tensor=label), "label") + else: + setattr(tio_image, "label", label) + + for name, mask in masks.items(): + tio_image.add_image(tio.LabelMap(tensor=mask), name) + + return tio_image diff --git a/clinicadl/utils/config.py b/clinicadl/utils/config.py index 0526441df..70f7d0e86 100644 --- a/clinicadl/utils/config.py +++ b/clinicadl/utils/config.py @@ -5,5 +5,8 @@ class ClinicaDLConfig(BaseModel): """Base configuration class.""" model_config = ConfigDict( - validate_assignment=True, use_enum_values=True, validate_default=True + validate_assignment=True, + use_enum_values=True, + validate_default=True, + arbitrary_types_allowed=True, ) diff --git a/clinicadl/utils/enum.py b/clinicadl/utils/enum.py index 4e5c7721c..f407b6b98 100644 --- a/clinicadl/utils/enum.py +++ b/clinicadl/utils/enum.py @@ -121,12 +121,12 @@ class ExtractionMethod(str, Enum): ROI = "roi" -class SliceDirection(str, Enum): +class SliceDirection(int, Enum): """Possible directions for a slice.""" - SAGITTAL = "0" - CORONAL = "1" - AXIAL = "2" + SAGITTAL = 0 + CORONAL = 1 + AXIAL = 2 class SliceMode(str, Enum): diff --git a/tests/unittests/ressources/caps_example/subjects_sessions_list.tsv b/tests/unittests/ressources/caps_example/subjects_sessions_list.tsv index e505e36c3..cd3c1d0b2 100644 --- a/tests/unittests/ressources/caps_example/subjects_sessions_list.tsv +++ b/tests/unittests/ressources/caps_example/subjects_sessions_list.tsv @@ -3,3 +3,11 @@ sub-000 ses-M000 sub-000 ses-M006 sub-001 ses-M000 sub-001 ses-M018 +sub-OAS30010 ses-M000 +sub-OAS30011 ses-M000 +sub-OAS30011 ses-M054 +sub-OAS30012 ses-M006 +sub-OAS30012 ses-M018 +sub-OAS30013 ses-M006 +sub-OAS30014 ses-M006 +sub-OAS30014 ses-M036 diff --git a/tests/unittests/transforms/__init__.py b/tests/unittests/transforms/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unittests/transforms/extraction/__init__.py b/tests/unittests/transforms/extraction/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unittests/transforms/extraction/test_image.py b/tests/unittests/transforms/extraction/test_image.py new file mode 100644 index 000000000..463568c99 --- /dev/null +++ b/tests/unittests/transforms/extraction/test_image.py @@ -0,0 +1,164 @@ +import shutil +from pathlib import Path + +import nibabel as nib +import numpy as np +import pytest +import torch +import torchio as tio + +from clinicadl.transforms.extraction import Image + + +def test_extract_method(): + image = Image() + assert image.extract_method == "image" + + +def test_num_samples_per_image(): + image = Image() + assert image.num_samples_per_image(torch.randn(1, 3, 4, 5)) == 1 + + +def test_sample_path(): + image = Image() + assert image.sample_path( + Path("sub-001/ses-M000/sub-001_ses-M000_T1w.nii.gz"), 0 + ) == Path("sub-001/ses-M000/sub-001_ses-M000_T1w.pt") + assert image.sample_path( + Path("sub-001/ses-M001/sub-001_ses-M001_FLAIR.nii"), 0 + ) == Path("sub-001/ses-M001/sub-001_ses-M001_FLAIR.pt") + + +def test_extract_sample(): + image = Image() + image_tensor = torch.randn(1, 3, 4, 5) + assert (image.extract_sample(image_tensor, sample_index=0) == image_tensor).all() + + with pytest.raises(IndexError): + image.extract_sample(image_tensor, sample_index=1) + + +def test_extract(): + tmp_dir = Path(__file__).parents[2] / "ressources" / "tmp" + tmp_dir.mkdir(parents=True, exist_ok=True) + image_tensor = torch.randn(1, 3, 4, 5) + image_nifti = nib.Nifti1Image(image_tensor.squeeze(0).numpy(), np.eye(4)) + nib.save(image_nifti, tmp_dir / "sub-001_ses-M000_T1w.nii.gz") + nib.save(image_nifti, tmp_dir / "sub-001_ses-M001_FLAIR.nii") + + image = Image() + output = image.extract(tmp_dir / "sub-001_ses-M000_T1w.nii.gz") + assert len(output) == 1 + assert output[0][0] == tmp_dir / "sub-001_ses-M000_T1w.pt" + assert (output[0][1] == image_tensor).all() + + output = image.extract(tmp_dir / "sub-001_ses-M001_FLAIR.nii") + assert output[0][0] == tmp_dir / "sub-001_ses-M001_FLAIR.pt" + assert (output[0][1] == image_tensor).all() + + shutil.rmtree(tmp_dir) + + +def test_extract_tio_sample(): + image = Image() + image_tensor = torch.randn(1, 3, 4, 5) + mask_1 = torch.ones(1, 3, 4, 5) + mask_2 = torch.zeros(1, 3, 4, 5) + label = torch.ones(1, 3, 4, 5) + + tio_image = tio.Subject( + image=tio.ScalarImage(tensor=image_tensor), + label=tio.LabelMap(tensor=label), + mask_1=tio.LabelMap(tensor=mask_1), + mask_2=tio.LabelMap(tensor=mask_2), + ) + tio_sample = image.extract_tio_sample(tio_image) + assert isinstance(tio_sample.sample, tio.ScalarImage) + assert (tio_sample.sample.tensor == image_tensor).all() + assert isinstance(tio_sample.label, tio.LabelMap) + assert (tio_sample.label.tensor == label).all() + assert isinstance(tio_sample.mask_1, tio.LabelMap) + assert (tio_sample.mask_1.tensor == mask_1).all() + assert isinstance(tio_sample.mask_2, tio.LabelMap) + assert (tio_sample.mask_2.tensor == mask_2).all() + with pytest.raises(AttributeError): + tio_sample.image + + tio_image = tio.Subject(image=tio.ScalarImage(tensor=image_tensor), label=1) + tio_sample = image.extract_tio_sample(tio_image) + assert tio_sample.label == 1 + + with pytest.raises(IndexError): + image.extract_tio_sample(tio_image, sample_index=1) + with pytest.raises(AttributeError): + image.extract_tio_sample( + tio.Subject(label=tio.LabelMap(tensor=label)), sample_index=1 + ) + + +def test_format_output(): + image = Image() + image_tensor = torch.randn(1, 3, 4, 5) + mask_1 = torch.ones(1, 3, 4, 5) + label = torch.ones(1, 3, 4, 5) + + tio_sample = tio.Subject( + sample=tio.ScalarImage(tensor=image_tensor), + label=tio.LabelMap(tensor=label), + mask_1=tio.LabelMap(tensor=mask_1), + description=None, + ) + output = image.format_output( + tio_sample, + participant_id="sub-001", + session_id="ses-M001", + image_path=Path("sub-001_ses-M001_T1w.nii.gz"), + ) + assert (output.sample == image_tensor).all() + assert (output.label == label).all() + assert output.session_id == "ses-M001" + assert output.participant_id == "sub-001" + assert output.extraction == "image" + assert output.image_path == "sub-001_ses-M001_T1w.nii.gz" + + tio_sample = tio.Subject( + sample=tio.ScalarImage(tensor=image_tensor), + label=0.5, + description=None, + ) + output = image.format_output( + tio_sample, + participant_id="sub-001", + session_id="ses-M001", + image_path=Path("sub-001_ses-M001_T1w.nii.gz"), + ) + assert output.label == 0.5 + + +@pytest.mark.parametrize( + "tio_sample", + [ + tio.Subject( + label=tio.LabelMap(tensor=torch.ones(1, 3, 4, 5)), + description=None, + ), + tio.Subject( + sample=tio.ScalarImage(tensor=torch.randn(1, 3, 4, 5)), + description=None, + ), + tio.Subject( + label=0.5, + sample=tio.ScalarImage(tensor=torch.randn(1, 3, 4, 5)), + ), + ], +) +def test_format_output_errors(tio_sample): + image = Image() + with pytest.raises(AttributeError): + image.format_output( + tio_sample, + participant_id="sub-001", + session_id="ses-M001", + image_path=Path("sub-001_ses-M001_T1w.nii.gz"), + ) diff --git a/tests/unittests/transforms/extraction/test_patch.py b/tests/unittests/transforms/extraction/test_patch.py new file mode 100644 index 000000000..b80e6ab6c --- /dev/null +++ b/tests/unittests/transforms/extraction/test_patch.py @@ -0,0 +1,192 @@ +import shutil +from pathlib import Path + +import nibabel as nib +import numpy as np +import pytest +import torch +import torchio as tio +from pydantic import ValidationError + +from clinicadl.transforms.extraction import Patch + + +def test_args(): + with pytest.raises(ValidationError): + Patch(patch_size=0) + with pytest.raises(ValidationError): + Patch(stride=0) + + +def test_extract_method(): + patch = Patch() + assert patch.extract_method == "patch" + + +def test_num_samples_per_image(): + img = torch.randn(1, 5, 7, 3) + + patch = Patch(patch_size=3, stride=1) + assert patch.num_samples_per_image(img) == 3 * 5 * 1 + + patch = Patch(patch_size=(2, 3, 2), stride=(1, 2, 1)) + assert patch.num_samples_per_image(img) == 4 * 3 * 2 + + patch = Patch(patch_size=(2, 3, 2), stride=3) + assert patch.num_samples_per_image(img) == 2 * 2 * 1 + + +def test_sample_path(): + patch = Patch(patch_size=(2, 3, 2), stride=3) + assert patch.sample_path( + Path("sub-001/ses-M000/sub-001_ses-M000_T1w.nii.gz"), 3 + ) == Path( + "sub-001/ses-M000/sub-001_ses-M000_patchsize-2x3x2_stride-3x3x3_patch-3_T1w.pt" + ) + + patch = Patch(patch_size=(2, 3, 2), stride=(1, 2, 1)) + assert patch.sample_path( + Path("sub-001/ses-M001/sub-001_ses-M001_FLAIR.nii"), 7 + ) == Path( + "sub-001/ses-M001/sub-001_ses-M001_patchsize-2x3x2_stride-1x2x1_patch-7_FLAIR.pt" + ) + + +def test_extract_sample(): + image_tensor = torch.randn(1, 5, 7, 3) + + patch = Patch(patch_size=2, stride=1) + assert ( + patch.extract_sample(image_tensor, sample_index=1) + == image_tensor[:, :2, :2, 1:3] # .view starts with the last dimension + ).all() + + patch = Patch(patch_size=(2, 3, 2), stride=(1, 2, 1)) + assert ( + patch.extract_sample(image_tensor, sample_index=5) + == image_tensor[:, :2, 4:7, 1:3] + ).all() + + patch = Patch(patch_size=(2, 3, 2), stride=(1, 2, 1)) + assert ( + patch.extract_sample(image_tensor, sample_index=7) + == image_tensor[:, 1:3, :3, 1:3] + ).all() + + patch = Patch(patch_size=(2, 3, 2), stride=(1, 2, 1)) + with pytest.raises(IndexError): + patch.extract_sample(image_tensor, sample_index=24) + + +def test_extract(): + tmp_dir = Path(__file__).parents[2] / "ressources" / "tmp" + tmp_dir.mkdir(parents=True, exist_ok=True) + image_tensor = torch.randn(1, 5, 7, 4) + image_nifti = nib.Nifti1Image(image_tensor.squeeze(0).numpy(), np.eye(4)) + nib.save(image_nifti, tmp_dir / "sub-001_ses-M000_T1w.nii.gz") + + image = Patch(patch_size=2, stride=2) + output = image.extract(tmp_dir / "sub-001_ses-M000_T1w.nii.gz") + assert len(output) == 12 + assert ( + output[0][0] + == tmp_dir / "sub-001_ses-M000_patchsize-2x2x2_stride-2x2x2_patch-0_T1w.pt" + ) + assert (output[0][1] == image_tensor[:, :2, :2, :2]).all() + assert ( + output[3][0] + == tmp_dir / "sub-001_ses-M000_patchsize-2x2x2_stride-2x2x2_patch-3_T1w.pt" + ) + assert (output[3][1] == image_tensor[:, :2, 2:4, 2:4]).all() + + shutil.rmtree(tmp_dir) + + +def test_extract_tio_sample(): + patch = Patch(patch_size=(2, 3, 2), stride=(1, 2, 1)) + image_tensor = torch.randn(1, 5, 7, 3) + mask_1 = torch.ones(1, 5, 7, 3) + label = torch.ones(1, 5, 7, 3) + + tio_image = tio.Subject( + image=tio.ScalarImage(tensor=image_tensor), + label=tio.LabelMap(tensor=label), + mask_1=tio.LabelMap(tensor=mask_1), + ) + tio_sample = patch.extract_tio_sample(tio_image, sample_index=5) + assert isinstance(tio_sample.sample, tio.ScalarImage) + assert (tio_sample.sample.tensor == image_tensor[:, :2, 4:7, 1:3]).all() + assert isinstance(tio_sample.label, tio.LabelMap) + assert (tio_sample.label.tensor == label[:, :2, 4:7, 1:3]).all() + assert isinstance(tio_sample.mask_1, tio.LabelMap) + assert (tio_sample.mask_1.tensor == mask_1[:, :2, 4:7, 1:3]).all() + assert tio_sample.description == 5 + with pytest.raises(AttributeError): + tio_sample.image + + tio_image = tio.Subject(image=tio.ScalarImage(tensor=image_tensor), label=1) + tio_sample = patch.extract_tio_sample(tio_image, sample_index=5) + assert tio_sample.label == 1 + + with pytest.raises(IndexError): + patch.extract_tio_sample(tio_image, sample_index=25) + with pytest.raises(AttributeError): + patch.extract_tio_sample( + tio.Subject(label=tio.LabelMap(tensor=label)), sample_index=1 + ) + + +def test_format_output(): + patch = Patch(patch_size=(3, 4, 3), stride=2) + image_tensor = torch.randn(1, 3, 4, 5) + mask_1 = torch.ones(1, 3, 4, 5) + label = torch.ones(1, 3, 4, 5) + + tio_sample = tio.Subject( + sample=tio.ScalarImage(tensor=image_tensor), + label=tio.LabelMap(tensor=label), + mask_1=tio.LabelMap(tensor=mask_1), + description=1, + ) + output = patch.format_output( + tio_sample, + participant_id="sub-001", + session_id="ses-M001", + image_path=Path("sub-001_ses-M001_T1w.nii.gz"), + ) + assert (output.sample == image_tensor).all() + assert (output.label == label).all() + assert output.session_id == "ses-M001" + assert output.participant_id == "sub-001" + assert output.extraction == "patch" + assert output.image_path == "sub-001_ses-M001_T1w.nii.gz" + assert output.patch_index == 1 + assert output.patch_size == (3, 4, 3) + assert output.patch_stride == (2, 2, 2) + + tio_sample = tio.Subject( + sample=tio.ScalarImage(tensor=image_tensor), + label=0.5, + description=1, + ) + output = patch.format_output( + tio_sample, + participant_id="sub-001", + session_id="ses-M001", + image_path=Path("sub-001_ses-M001_T1w.nii.gz"), + ) + assert output.label == 0.5 + + # check that checks on sample are performed + tio_sample = tio.Subject( + sample=tio.ScalarImage(tensor=image_tensor), + label=tio.LabelMap(tensor=label), + mask_1=tio.LabelMap(tensor=mask_1), + ) + with pytest.raises(AttributeError): + patch.format_output( + tio_sample, + participant_id="sub-001", + session_id="ses-M001", + image_path=Path("sub-001_ses-M001_T1w.nii.gz"), + ) diff --git a/tests/unittests/transforms/extraction/test_slice.py b/tests/unittests/transforms/extraction/test_slice.py new file mode 100644 index 000000000..858c9ef3d --- /dev/null +++ b/tests/unittests/transforms/extraction/test_slice.py @@ -0,0 +1,199 @@ +import shutil +from pathlib import Path + +import nibabel as nib +import numpy as np +import pytest +import torch +import torchio as tio +from pydantic import ValidationError + +from clinicadl.transforms.extraction import Slice + + +def test_args(): + with pytest.raises(ValidationError): + Slice(slices=[0], slice_direction=3) + with pytest.raises(ValidationError): + Slice(slices=[0], discarded_slices=[1]) + with pytest.raises(ValidationError): + Slice(slices=[0], borders=1) + + +def test_extract_method(): + slice = Slice(slices=[0, 1, 2]) + assert slice.extract_method == "slice" + + +def test_num_samples_per_image(): + img = torch.randn(1, 5, 7, 3) + + slice = Slice() + assert slice.num_samples_per_image(img) == 5 + + slice = Slice(slices=[1, 2]) + assert slice.num_samples_per_image(img) == 2 + + slice = Slice(borders=2, slice_direction=1) + assert slice.num_samples_per_image(img) == 3 + + slice = Slice(discarded_slices=[1, 2], slice_direction=2) + assert slice.num_samples_per_image(img) == 1 + + slice = Slice(discarded_slices=[1], borders=2, slice_direction=0) + assert slice.num_samples_per_image(img) == 1 + + slice = Slice(discarded_slices=[2], borders=2, slice_direction=0) + assert slice.num_samples_per_image(img) == 0 + + slice = Slice(discarded_slices=[3], slice_direction=2) + with pytest.raises(IndexError): + slice.num_samples_per_image(img) + + slice = Slice(slices=[3], slice_direction=2) + with pytest.raises(IndexError): + slice.num_samples_per_image(img) + + +def test_sample_path(): + slice = Slice(slices=[1, 2, 3]) + assert slice.sample_path( + Path("sub-001/ses-M000/sub-001_ses-M000_T1w.nii.gz"), 1 + ) == Path("sub-001/ses-M000/sub-001_ses-M000_axis-sag_slice-1_T1w.pt") + + slice = Slice(slices=[1, 2, 3], slice_direction=1) + assert slice.sample_path( + Path("sub-001/ses-M001/sub-001_ses-M001_FLAIR.nii"), 2 + ) == Path("sub-001/ses-M001/sub-001_ses-M001_axis-cor_slice-2_FLAIR.pt") + + +def test_extract_sample(): + image_tensor = torch.randn(1, 5, 3, 7) + + slice = Slice(slices=[2, 3]) + assert ( + slice.extract_sample(image_tensor, sample_index=1) == image_tensor[:, 3:4] + ).all() + + slice = Slice(discarded_slices=[0], slice_direction=1) + assert ( + slice.extract_sample(image_tensor, sample_index=0) == image_tensor[:, :, 1:2] + ).all() + + slice = Slice(discarded_slices=[4], borders=1, slice_direction=2) + assert ( + slice.extract_sample(image_tensor, sample_index=3) == image_tensor[:, :, :, 5:6] + ).all() + + slice = Slice(discarded_slices=[4], borders=1, slice_direction=2) + with pytest.raises(IndexError): + slice.extract_sample(image_tensor, sample_index=4) + + +def test_extract(): + tmp_dir = Path(__file__).parents[2] / "ressources" / "tmp" + tmp_dir.mkdir(parents=True, exist_ok=True) + image_tensor = torch.randn(1, 3, 4, 7) + image_nifti = nib.Nifti1Image(image_tensor.squeeze(0).numpy(), np.eye(4)) + nib.save(image_nifti, tmp_dir / "sub-001_ses-M000_T1w.nii.gz") + + slice = Slice(discarded_slices=[1, 4], borders=1, slice_direction=2) + output = slice.extract(tmp_dir / "sub-001_ses-M000_T1w.nii.gz") + assert len(output) == 3 + assert output[0][0] == tmp_dir / "sub-001_ses-M000_axis-axi_slice-0_T1w.pt" + assert (output[0][1] == image_tensor[:, :, :, 2]).all() + assert output[1][0] == tmp_dir / "sub-001_ses-M000_axis-axi_slice-1_T1w.pt" + assert (output[1][1] == image_tensor[:, :, :, 3]).all() + assert output[2][0] == tmp_dir / "sub-001_ses-M000_axis-axi_slice-2_T1w.pt" + assert (output[2][1] == image_tensor[:, :, :, 5]).all() + + shutil.rmtree(tmp_dir) + + +def test_extract_tio_sample(): + slice = Slice(slices=[2, 3]) + image_tensor = torch.randn(1, 5, 7, 3) + mask_1 = torch.ones(1, 5, 7, 3) + label = torch.ones(1, 5, 7, 3) + + tio_image = tio.Subject( + image=tio.ScalarImage(tensor=image_tensor), + label=tio.LabelMap(tensor=label), + mask_1=tio.LabelMap(tensor=mask_1), + ) + tio_sample = slice.extract_tio_sample(tio_image, sample_index=1) + assert isinstance(tio_sample.sample, tio.ScalarImage) + assert (tio_sample.sample.tensor == image_tensor[:, 3:4]).all() + assert isinstance(tio_sample.label, tio.LabelMap) + assert (tio_sample.label.tensor == label[:, 3:4]).all() + assert isinstance(tio_sample.mask_1, tio.LabelMap) + assert (tio_sample.mask_1.tensor == mask_1[:, 3:4]).all() + assert tio_sample.description == 3 + with pytest.raises(AttributeError): + tio_sample.image + + tio_image = tio.Subject(image=tio.ScalarImage(tensor=image_tensor), label=1) + tio_sample = slice.extract_tio_sample(tio_image, sample_index=1) + assert tio_sample.label == 1 + + with pytest.raises(IndexError): + slice.extract_tio_sample(tio_image, sample_index=42) + with pytest.raises(AttributeError): + slice.extract_tio_sample( + tio.Subject(label=tio.LabelMap(tensor=label)), sample_index=1 + ) + + +def test_format_output(): + slice = Slice(slice_direction=2) + image_tensor = torch.randn(1, 3, 4, 1) + mask_1 = torch.ones(1, 3, 4, 1) + label = torch.ones(1, 3, 4, 1) + + tio_sample = tio.Subject( + sample=tio.ScalarImage(tensor=image_tensor), + label=tio.LabelMap(tensor=label), + mask_1=tio.LabelMap(tensor=mask_1), + description=1, + ) + output = slice.format_output( + tio_sample, + participant_id="sub-001", + session_id="ses-M001", + image_path=Path("sub-001_ses-M001_T1w.nii.gz"), + ) + assert (output.sample == image_tensor.squeeze(3)).all() + assert (output.label == label.squeeze(3)).all() + assert output.session_id == "ses-M001" + assert output.participant_id == "sub-001" + assert output.extraction == "slice" + assert output.image_path == "sub-001_ses-M001_T1w.nii.gz" + assert output.slice_direction == 2 + assert output.slice_position == 1 + + tio_sample = tio.Subject( + sample=tio.ScalarImage(tensor=image_tensor), + label=0.5, + description=1, + ) + output = slice.format_output( + tio_sample, + participant_id="sub-001", + session_id="ses-M001", + image_path=Path("sub-001_ses-M001_T1w.nii.gz"), + ) + assert output.label == 0.5 + + # check that checks on sample are performed + tio_sample = tio.Subject( + sample=tio.ScalarImage(tensor=image_tensor), + label=tio.LabelMap(tensor=label), + mask_1=tio.LabelMap(tensor=mask_1), + ) + with pytest.raises(AttributeError): + slice.format_output( + tio_sample, + participant_id="sub-001", + session_id="ses-M001", + image_path=Path("sub-001_ses-M001_T1w.nii.gz"), + ) diff --git a/tests/unittests/transforms/test_extraction.py b/tests/unittests/transforms/test_extraction.py deleted file mode 100644 index ed73a04e7..000000000 --- a/tests/unittests/transforms/test_extraction.py +++ /dev/null @@ -1 +0,0 @@ -# TODO later when we are sure of the preprocessing architecture diff --git a/tests/unittests/transforms/test_utils.py b/tests/unittests/transforms/test_utils.py new file mode 100644 index 000000000..8a91b624a --- /dev/null +++ b/tests/unittests/transforms/test_utils.py @@ -0,0 +1,26 @@ +import torch +import torchio as tio + +from clinicadl.transforms.utils import get_tio_image + + +def test_get_tio_image(): + image_tensor = torch.randn(1, 3, 4, 5) + mask_1 = torch.ones(1, 3, 4, 5) + mask_2 = torch.zeros(1, 3, 4, 5) + label = torch.ones(1, 3, 4, 5) + + tio_image = get_tio_image(image_tensor, label, mask_1=mask_1, mask_2=mask_2) + assert isinstance(tio_image.image, tio.ScalarImage) + assert (tio_image.image.tensor == image_tensor).all() + assert isinstance(tio_image.label, tio.LabelMap) + assert (tio_image.label.tensor == label).all() + assert isinstance(tio_image.mask_1, tio.LabelMap) + assert (tio_image.mask_1.tensor == mask_1).all() + assert isinstance(tio_image.mask_2, tio.LabelMap) + assert (tio_image.mask_2.tensor == mask_2).all() + + tio_image = get_tio_image(image_tensor, label=None) + assert tio_image.label is None + tio_image = get_tio_image(image_tensor, label=1) + assert tio_image.label == 1