Skip to content

Commit

Permalink
Transforms in CapsDataset (#687)
Browse files Browse the repository at this point in the history
* torchio Subject in CapsDataset

---------

Co-authored-by: camillebrianceau <[email protected]>
  • Loading branch information
thibaultdvx and camillebrianceau authored Dec 19, 2024
1 parent dc3443e commit 1a7b08b
Show file tree
Hide file tree
Showing 34 changed files with 991 additions and 768 deletions.
8 changes: 2 additions & 6 deletions clinicadl/data/config/data.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
from logging import getLogger
from pathlib import Path
from typing import Any, Dict, Optional, Union
from typing import Dict, Optional, Union

import pandas as pd
from pydantic import field_validator

from clinicadl.utils.config import ClinicaDLConfig
from clinicadl.utils.exceptions import (
ClinicaDLArgumentError,
ClinicaDLTSVError,
)
from clinicadl.utils.exceptions import ClinicaDLTSVError

logger = getLogger("clinicadl.data_config")

Expand Down
14 changes: 8 additions & 6 deletions clinicadl/data/config/file_type.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from enum import Enum
from typing import Optional, Union

from pydantic import field_validator

from clinicadl.utils.config import ClinicaDLConfig
from clinicadl.utils.enum import Preprocessing
from clinicadl.utils.enum import PreprocessingMethod


class FileType(ClinicaDLConfig):
Expand All @@ -14,9 +13,10 @@ class FileType(ClinicaDLConfig):

pattern: str
description: str
needed_pipeline: Optional[Preprocessing] = None
needed_pipeline: Optional[PreprocessingMethod] = None

@field_validator("pattern", mode="before")
@classmethod
def check_pattern(cls, v):
if not v:
raise ValueError("A pattern must be specified")
Expand All @@ -30,18 +30,20 @@ def check_pattern(cls, v):
return v

@field_validator("description", mode="before")
@classmethod
def check_description(cls, v):
if not v:
raise ValueError("A description must be specified")
return v

@field_validator("needed_pipeline", mode="after")
def check_needed_pipeline(cls, v: Optional[Union[str, Preprocessing]]):
@classmethod
def check_needed_pipeline(cls, v: Optional[Union[str, PreprocessingMethod]]):
if v:
try:
v = Preprocessing(v)
v = PreprocessingMethod(v)
except ValueError:
raise ValueError(
f"Invalid pipeline: {v}. Choose from {[e.value for e in Preprocessing]}"
f"Invalid pipeline: {v}. Choose from {[e.value for e in PreprocessingMethod]}"
)
return v
Loading

0 comments on commit 1a7b08b

Please sign in to comment.