From f18f5fe36e156102924656d2cc3ccdfed7229a80 Mon Sep 17 00:00:00 2001 From: thibaultdvx Date: Thu, 19 Dec 2024 10:27:15 +0100 Subject: [PATCH] default preprocessing in capsdataset --- clinicadl/data/datasets/caps_dataset.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/clinicadl/data/datasets/caps_dataset.py b/clinicadl/data/datasets/caps_dataset.py index 3bc1723eb..c6bb7c828 100644 --- a/clinicadl/data/datasets/caps_dataset.py +++ b/clinicadl/data/datasets/caps_dataset.py @@ -11,7 +11,7 @@ from pydantic import NonNegativeInt from torch.utils.data import Dataset -from clinicadl.data.preprocessing import Preprocessing +from clinicadl.data.preprocessing import Preprocessing, PreprocessingT1 from clinicadl.data.readers.caps_reader import CapsReader from clinicadl.data.utils import ( check_df, @@ -69,7 +69,7 @@ class CapsDataset(Dataset): def __init__( self, caps_directory: Union[str, Path], - preprocessing: Preprocessing, + preprocessing: Preprocessing = PreprocessingT1(), transforms: Transforms = Transforms(), data: Optional[Union[pd.DataFrame, str, Path]] = None, label: Optional[str] = None, @@ -208,7 +208,7 @@ def describe(self): } def _get_df_from_input( - self, data: Optional[Union[pd.DataFrame, Path]] + self, data: Optional[Union[pd.DataFrame, Path, str]] ) -> pd.DataFrame: """ Generates or validates the DataFrame from the input data. @@ -248,7 +248,7 @@ def _get_df_from_input( return df - def _check_data_instance(self, data: Optional[Union[pd.DataFrame, Path]] = None): + def _check_data_instance(self, data: Optional[Union[pd.DataFrame, Path, str]]): if isinstance(data, str): data = Path(data)