diff --git a/clinicadl/data/datasets/caps_dataset.py b/clinicadl/data/datasets/caps_dataset.py index 08885e09d..3bc1723eb 100644 --- a/clinicadl/data/datasets/caps_dataset.py +++ b/clinicadl/data/datasets/caps_dataset.py @@ -242,10 +242,9 @@ def _get_df_from_input( df = self._check_data_instance(data) self.df = df - if not self._check_preprocessing_config(): - raise ClinicaDLCAPSError( - f"The DataFrame does not match the preprocessing configuration: {self.preprocessing.preprocessing.value}" - ) + self.caps_reader.check_preprocessing( + self._get_participant_session_couples(), self.preprocessing + ) return df @@ -260,11 +259,11 @@ def _check_data_instance(self, data: Optional[Union[pd.DataFrame, Path]] = None) "Please ensure the file path is correct and accessible." ) df = tsv_to_df(data) - if isinstance(data, pd.DataFrame): + elif isinstance(data, pd.DataFrame): df = check_df(data) else: raise ValueError( - f"'data' must be a Pandas DataFrame, a path to a TSV file or None. Got{data}" + f"'data' must be a Pandas DataFrame, a path to a TSV file or None. Got {data}" ) return df diff --git a/clinicadl/predictor/old_predictor.py b/clinicadl/predictor/old_predictor.py index f6ddf1377..5108d7717 100644 --- a/clinicadl/predictor/old_predictor.py +++ b/clinicadl/predictor/old_predictor.py @@ -1059,7 +1059,7 @@ def _compute_output_tensors( Compute the output tensors and saves them in the MAPS. Args: - dataset (clinicadl.dataset.caps_dataset.CapsDataset): wrapper of the data set. + dataset (clinicadl.data.datasets.caps_dataset.CapsDataset): wrapper of the data set. data_group (str): name of the data group used for the task. split (int): split number. selection_metrics (list[str]): metrics used for model selection. diff --git a/tests/unittests/dataset/test_config.py b/tests/unittests/dataset/test_config.py index 57c5c681f..cf923af6a 100644 --- a/tests/unittests/dataset/test_config.py +++ b/tests/unittests/dataset/test_config.py @@ -1,6 +1,6 @@ import pytest -from clinicadl.dataset.config import FileType +from clinicadl.data.config import FileType from clinicadl.utils.enum import PreprocessingMethod diff --git a/tests/unittests/dataset/test_datasets.py b/tests/unittests/dataset/test_datasets.py index 44f023681..d8009576e 100644 --- a/tests/unittests/dataset/test_datasets.py +++ b/tests/unittests/dataset/test_datasets.py @@ -3,8 +3,8 @@ import pytest import torchio as tio -from clinicadl.dataset.datasets.caps_dataset import CapsDataset -from clinicadl.dataset.preprocessing import PreprocessingPET, PreprocessingT1 +from clinicadl.data.datasets import CapsDataset +from clinicadl.data.preprocessing import PreprocessingPET, PreprocessingT1 from clinicadl.transforms import Transforms from clinicadl.utils.exceptions import ( ClinicaDLArgumentError, diff --git a/tests/unittests/dataset/test_reader.py b/tests/unittests/dataset/test_reader.py index bfea2c12d..1dcffd923 100644 --- a/tests/unittests/dataset/test_reader.py +++ b/tests/unittests/dataset/test_reader.py @@ -2,8 +2,8 @@ import pytest -from clinicadl.dataset.preprocessing import PreprocessingT1 -from clinicadl.dataset.readers import CapsReader +from clinicadl.data.preprocessing import PreprocessingT1 +from clinicadl.data.readers import CapsReader from clinicadl.utils.enum import PreprocessingMethod from clinicadl.utils.exceptions import ( ClinicaDLArgumentError, diff --git a/tests/unittests/splitter/test_splitter.py b/tests/unittests/splitter/test_splitter.py index a1c9b0399..aefcfb792 100644 --- a/tests/unittests/splitter/test_splitter.py +++ b/tests/unittests/splitter/test_splitter.py @@ -1,30 +1,16 @@ -import json from pathlib import Path -import nibabel as nib import numpy as np -import pandas as pd import pytest from pydantic import ValidationError from clinicadl.data.datasets.caps_dataset import CapsDataset -from clinicadl.data.preprocessing import PreprocessingT1, PreprocessingT2 -from clinicadl.splitter.split import Split +from clinicadl.data.preprocessing import PreprocessingT1 from clinicadl.splitter.splitter.kfold import KFold, KFoldConfig from clinicadl.splitter.splitter.single_split import SingleSplit, SingleSplitConfig -from clinicadl.splitter.splitter.splitter import ( - Splitter, - SplitterConfig, - SubjectsSessionsSplit, -) +from clinicadl.splitter.splitter.splitter import SubjectsSessionsSplit from clinicadl.transforms import Transforms -from clinicadl.utils.enum import Preprocessing -from clinicadl.utils.exceptions import ( - ClinicaDLArgumentError, - ClinicaDLCAPSError, - ClinicaDLConfigurationError, - ClinicaDLTSVError, -) +from clinicadl.utils.exceptions import ClinicaDLTSVError caps_dir = Path(__file__).parents[1] / "ressources" / "caps_example" split_dir = caps_dir / "split_test" / "split" diff --git a/tests/unittests/transforms/test_transforms.py b/tests/unittests/transforms/test_transforms.py index f2384f31d..a776160f9 100644 --- a/tests/unittests/transforms/test_transforms.py +++ b/tests/unittests/transforms/test_transforms.py @@ -71,22 +71,22 @@ def test_get_transforms(): assert (tio_image.label.tensor == old_tio_image.label.tensor).all() assert (tio_image.mask_1.tensor == old_tio_image.mask_1.tensor).all() - tio_sample = transforms.extraction.extract_tio_sample(tio_image, 0) + tio_sample, _ = transforms.extraction.extract_tio_sample(tio_image, 0) patch_mask = np.zeros((1, 4, 4, 4)) patch_mask[:, 1:, 1:, 1:] = 1 patch_mask = torch.from_numpy(patch_mask) - assert (tio_sample.sample.tensor == tio_image.image.tensor[:, :4, :4, :4]).all() + assert (tio_sample.image.tensor == tio_image.image.tensor[:, :4, :4, :4]).all() assert (tio_sample.label.tensor == tio_image.label.tensor[:, :4, :4, :4]).all() assert (tio_sample.mask_1.tensor == patch_mask).all() tio_sample = sample_transforms(tio_sample) - assert tio_sample.sample.tensor.shape == (1, 6, 6, 6) + assert tio_sample.image.tensor.shape == (1, 6, 6, 6) assert tio_sample.label.tensor.shape == (1, 6, 6, 6) assert tio_sample.mask_1.tensor.shape == (1, 6, 6, 6) tio_sample = sample_augmentations(tio_sample) - assert (tio_sample.sample.tensor[:, :2, :2, :2] == 0.0).all() - assert (tio_sample.sample.tensor[:, 5:, 5:, 5:] == 0.0).all() + assert (tio_sample.image.tensor[:, :2, :2, :2] == 0.0).all() + assert (tio_sample.image.tensor[:, 5:, 5:, 5:] == 0.0).all() def test_str():