Skip to content

Commit

Permalink
ClinicaDLConfig
Browse files Browse the repository at this point in the history
  • Loading branch information
thibaultdvx committed Dec 18, 2024
1 parent 9eed791 commit 286bbef
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 15 deletions.
2 changes: 1 addition & 1 deletion clinicadl/splitter/splitter/kfold.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(self, split_dir: Path):
super().__init__(split_dir=split_dir)

def _init_config(self, **args):
self.config = KFoldConfig(**args)
self.config: KFoldConfig = KFoldConfig(**args)

def _read_splits(self) -> List[SubjectsSessionsSplit]:
"""
Expand Down
1 change: 1 addition & 0 deletions clinicadl/splitter/splitter/single_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def pattern(self) -> str:
return "split"

@field_validator("p_categorical_threshold", "p_continuous_threshold", mode="before")
@classmethod
def validate_thresholds(cls, value: Union[float, int]) -> float:
if not (0 <= value <= 1):
raise ValueError(f"Threshold must be between 0 and 1, got {value}")
Expand Down
23 changes: 9 additions & 14 deletions clinicadl/splitter/splitter/splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,42 +5,40 @@

import pandas as pd
from pydantic import (
BaseModel,
ConfigDict,
computed_field,
field_validator,
)

from clinicadl.dataset.datasets.caps_dataset import CapsDataset
from clinicadl.splitter.split import Split
from clinicadl.utils.config import ClinicaDLConfig
from clinicadl.utils.exceptions import ClinicaDLTSVError
from clinicadl.utils.iotools.utils import path_encoder


class SubjectsSessionsSplit(BaseModel):
class SubjectsSessionsSplit(ClinicaDLConfig):
"""
Dataclass to store train and validation splits for subjects and sessions.
"""

train: pd.DataFrame
validation: pd.DataFrame

model_config = ConfigDict(validate_assignment=True, arbitrary_types_allowed=True)

@computed_field
@property
def train_val_df(self):
return pd.concat([self.train, self.validation], ignore_index=True)


class SplitterConfig(BaseModel):
class SplitterConfig(ClinicaDLConfig):
json_name: str
split_dir: Path
subset_name: str
stratification: Union[str, List[str], bool] = False
valid_longitudinal: bool = False

model_config = ConfigDict(validate_assignment=True, arbitrary_types_allowed=True)

@field_validator("split_dir", mode="after")
@classmethod
def validate_split_dir(cls, v):
if not isinstance(v, Path):
v = Path(v)
Expand Down Expand Up @@ -109,7 +107,7 @@ def __init__(self, split_dir: Path):

@abstractmethod
def _init_config(self, **args):
self.config = ...
self.config: SplitterConfig

def _read_json(self):
"""
Expand Down Expand Up @@ -170,10 +168,10 @@ def _read_split(self, split_path: Path) -> SubjectsSessionsSplit:
split_path / f"{self.config.subset_name}_baseline.tsv", sep="\t"
) # type: ignore

except FileNotFoundError:
except FileNotFoundError as exc:
raise FileNotFoundError(
f"One or more of the required files are missing: 'train_baseline.tsv', '{self.config.subset_name}_baseline.tsv'"
) # type: ignore
) from exc # type: ignore

return SubjectsSessionsSplit(
train=train,
Expand All @@ -195,7 +193,6 @@ def _read_splits(self) -> List[SubjectsSessionsSplit]:
None
Populates `subjects_sessions_split` and `config` attributes.
"""
pass

def check_dataset_and_tsv_consistency(self, dataset: CapsDataset):
df1 = self.subjects_sessions_split[0].train_val_df
Expand Down Expand Up @@ -233,8 +230,6 @@ def get_splits(
If the requested split indices are out of range or no splits are available.
"""

pass

def _get_split(
self,
dataset: CapsDataset,
Expand Down

0 comments on commit 286bbef

Please sign in to comment.