Skip to content

Commit

Permalink
first draft fow new extraction objects
Browse files Browse the repository at this point in the history
  • Loading branch information
thibaultdvx committed Dec 9, 2024
1 parent c28d4d5 commit 559040e
Show file tree
Hide file tree
Showing 8 changed files with 478 additions and 470 deletions.
2 changes: 1 addition & 1 deletion clinicadl/transforms/extraction/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .base import BaseExtraction
from .base import Extraction
from .image import Image
from .patch import Patch
from .roi import ROI
Expand Down
73 changes: 35 additions & 38 deletions clinicadl/transforms/extraction/base.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,33 @@
from abc import abstractmethod
from abc import ABC, abstractmethod
from logging import getLogger
from pathlib import Path
from typing import List, Tuple

import nibabel as nib
import torch
from pydantic import PositiveInt
from pydantic import computed_field

from clinicadl.utils.config import ClinicaDLConfig
from clinicadl.utils.enum import ExtractionMethod

logger = getLogger("clinicadl.base_extraction")

NII_GZ = ".nii.gz"
PT = ".pt"


class BaseExtraction(ClinicaDLConfig):
class Extraction(ClinicaDLConfig, ABC):
"""
Abstract base class for image extraction procedures.
This class defines the common structure and methods for extracting data from
neuroimaging files (such as NIfTI) into a tensor representation for further processing.
Parameters
----------
extract_method : ExtractionMethod
The method to be used for the extraction process (ROI, Image, Patch, Slice).
use_uncropped_image : bool, optional
A flag to specify whether to use the uncropped image, by default True.
"""

extract_method: ExtractionMethod
use_uncropped_image: bool = True
@computed_field
@property
@abstractmethod
def extract_method(self) -> ExtractionMethod:
"""The method to be used for the extraction process (ROI, Image, Patch, Slice)."""

def extract_image(self, input_img: Path) -> torch.Tensor:
def load_image(self, input_img: Path) -> torch.Tensor:
"""
Loads a NIfTI image and converts it to a float32 tensor.
Expand Down Expand Up @@ -62,88 +55,92 @@ def extract_image(self, input_img: Path) -> torch.Tensor:
return torch.from_numpy(image_array).unsqueeze(0).float()

@abstractmethod
def extract_tensor(
def extract_sample(
self,
image_tensor: torch.Tensor,
index: int,
sample_index: int,
) -> torch.Tensor:
"""
Abstract method for extracting specific data from a given image tensor.
Abstract method for extracting a sample from a given image.
Parameters
----------
image_tensor : torch.Tensor
The image tensor to extract data from.
index : int
Index indicating the element to extract.
The image tensor to extract a sample from.
sample_index : int
Index indicating the sample to extract.
Returns
-------
torch.Tensor
A tensor containing the extracted data.
A tensor containing the extracted sample.
Notes
-----
This method needs to be implemented in the subclasses.
"""
pass

# TODO : remove?
@abstractmethod
def extract_path(self, image_path, index):
def sample_path(self, image_path: Path, sample_index: int) -> Path:
"""
Abstract method for defining the path where extracted elements will be saved.
Abstract method for defining the path where extracted sample will be saved.
Parameters
----------
image_path : Path
Path to the original image.
index : int
Index of the element being extracted.
sample_index : int
Index of the sample being extracted.
Returns
-------
Path
Path where the extracted data will be saved.
Path where the extracted sample will be saved.
Notes
-----
This method needs to be implemented in the subclasses.
"""
pass

# TODO : remove?
@abstractmethod
def extract(self, nii_path: Path) -> List[Tuple[Path, torch.Tensor]]:
"""
Abstract method for performing the extraction based on the configured method.
Abstract method to extract all the samples.
Parameters
----------
nii_path : Path
Path to the NIfTI file to process.
Returns
-------
List[Tuple[Path, torch.Tensor]]
A list of tuples, where each tuple contains an extracted sample,
and the path where to store it.
Notes
-----
This method needs to be implemented in the subclasses.
"""
pass

@abstractmethod
def num_elem_per_image(self, image: torch.Tensor) -> PositiveInt:
def num_sample_per_image(self, image: torch.Tensor) -> int:
"""
Abstract method to return the number of extracted elements per image.
Abstract method to return the number of extracted samples per image.
Parameters
----------
image : torch.Tensor
The image tensor from which the number of elements will be determined.
The image tensor from which the number of samples will be determined.
Returns
-------
PositiveInt
The number of extracted elements from the image.
int
The number of samples in the image.
Notes
-----
This method needs to be implemented in the subclasses.
"""
pass
58 changes: 28 additions & 30 deletions clinicadl/transforms/extraction/image.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,36 @@
from logging import getLogger
from pathlib import Path
from typing import Tuple
from typing import List, Tuple

import torch
from pydantic import PositiveInt
from pydantic import PositiveInt, computed_field

from clinicadl.transforms.extraction.base import BaseExtraction
from clinicadl.transforms.extraction.base import Extraction
from clinicadl.utils.enum import ExtractionMethod

logger = getLogger("clinicadl.extraction.image")

NII_GZ = ".nii.gz"
PT = ".pt"


class Image(BaseExtraction):
class Image(Extraction):
"""
Configuration class for full image extraction as a single tensor.
Transform class for full image extraction as a single tensor.
This class implements the extraction process for a full image, where the entire
image is loaded and returned as a single tensor. It handles extraction using
the `ExtractionMethod.IMAGE` and saves the output as a tensor file.
Attributes
----------
extract_method : ExtractionMethod
The method used for the extraction. For this class, it's set to IMAGE.
This class implements the extraction process to get the full image, where the entire
image is loaded and returned as a single tensor.
"""

extract_method: ExtractionMethod = ExtractionMethod.IMAGE
@computed_field
@property
def extract_method(self) -> ExtractionMethod:
"""The method to be used for the extraction process (ROI, Image, Patch, Slice)."""
return ExtractionMethod.IMAGE

def extract(self, nii_path: Path) -> list[Tuple[Path, torch.Tensor]]:
def extract(self, nii_path: Path) -> List[Tuple[Path, torch.Tensor]]:
"""
Extracts the full image as a single tensor file and saves it.
Extracts the full image as a single tensor file and returns the path
where to save it.
Parameters
----------
Expand All @@ -41,21 +39,21 @@ def extract(self, nii_path: Path) -> list[Tuple[Path, torch.Tensor]]:
Returns
-------
list of Tuple[Path, torch.Tensor]
A list containing a tuple with the output file path and the extracted image tensor.
List[Tuple[Path, torch.Tensor]]
A list containing a single tuple with the output file path and the extracted image tensor.
Notes
-----
The image is loaded, converted into a tensor, and saved with the same name as the original image but with a `.pt` extension.
The image is loaded and returned into a tensor along with the input path with the `.pt` extension.
"""
image_tensor = self.extract_image(nii_path)
output_file = nii_path.with_suffix("").with_suffix(PT), image_tensor.clone()
return [output_file]
image_tensor = self.load_image(nii_path)

return [(self.sample_path(nii_path), self.extract_sample(image_tensor))]

def extract_tensor(
def extract_sample(
self,
image_tensor: torch.Tensor,
index: int,
sample_index: int = 0, # pylint:disable=unused-argument
) -> torch.Tensor:
"""
Returns the entire image tensor as no further extraction is needed.
Expand All @@ -64,7 +62,7 @@ def extract_tensor(
----------
image_tensor : torch.Tensor
The image tensor to extract data from.
index : int
sample_index : int
The index to identify the extracted data (though this is not used in this method).
Returns
Expand All @@ -78,7 +76,7 @@ def extract_tensor(
"""
return image_tensor

def extract_path(self, image_path, index):
def sample_path(self, image_path: Path, sample_index: int = 0) -> Path: # pylint:disable=unused-argument
"""
Returns the input image path as the path to save the extracted data.
Expand All @@ -96,11 +94,11 @@ def extract_path(self, image_path, index):
Notes
-----
This method does not alter the path, returning the same path as the input.
This method only changes the extension of the path.
"""
return image_path
return image_path.with_suffix("").with_suffix(PT)

def num_elem_per_image(self, image: torch.Tensor) -> PositiveInt:
def num_sample_per_image(self, image: torch.Tensor) -> PositiveInt:
"""
Returns the number of elements per image. Since the entire image is extracted, this method always returns 1.
Expand Down
Loading

0 comments on commit 559040e

Please sign in to comment.