Skip to content

Commit

Permalink
Update clinicadl/API_test_v2.py
Browse files Browse the repository at this point in the history
  • Loading branch information
camillebrianceau authored Oct 29, 2024
1 parent 90c7526 commit b0d54d0
Showing 1 changed file with 76 additions and 66 deletions.
142 changes: 76 additions & 66 deletions clinicadl/API_test_v2.py
Original file line number Diff line number Diff line change
@@ -1,102 +1,112 @@
# %% class
class MapsIO:
from pathlib import Path
from clinicadl.caps_dataset2.config.preprocessing import PreprocessingConfig
from clinicadl.caps_dataset2.config.extraction import ExtractionConfig
from clinicadl.caps_dataset2.data import CapsDatasetRoi, CapsDatasetPatch, CapsDatasetSlice
from clinicadl.transforms.config import TransformsConfig
import torchio
from clinicadl import tsvtools
from clinicadl.trainer.trainer import Trainer

class ExperimentManager:
pass


class CapsDataset:
class CapsReader:
pass


class Splitter:
class Transforms:
pass


class ClinicaDLModels:
class Predictor:
pass

class ClinicaDLModel:
pass

class Networks:
class KFolder:
pass

def get_loss_function():
pass

class VAE(Networks):
def get_network_from_config():
pass

def create_network_config():
pass

class Optimizer:
def get_single_split():
pass

# Create the Maps Manager / Read/write manager /
maps_path = Path("/")
manager = ExperimentManager(maps_path, overwrite = False)

class Loss:
pass
caps_directory = Path("caps_directory") # output of clinica pipelines
caps_reader = CapsReader(caps_directory, manager= manager)

preprocessing_1: PreprocessingConfig = caps_reader.get_preprocessing("t1-linear")
extraction_1: ExtractionConfig = caps_reader.extract_slice(preprocessing=preprocessing_1, arg_slice = 2)
transforms_1 = Transforms(data_augmentation = [torchio.t1, torchio.t2], image_transforms= [torchio.t1, torchio.t2], object_transforms=[torchio.t1, torchio.t2]) # not mandatory

class Metrics:
pass
preprocessing_2: PreprocessingConfig = caps_reader.get_preprocessing("pet-linear")
extraction_2: ExtractionConfig= caps_reader.extract_patch(preprocessing=preprocessing_2, arg_patch = 2)
transforms_2 = Transforms(data_augmentation = [torchio.t2], image_transforms= [torchio.t1], object_transforms=[torchio.t1, torchio.t2])

sub_ses_tsv = Path("")
split_dir = tsvtools.split_tsv(sub_ses_tsv) #-> creer un test.tsv et un train.tsv

class Trainer:
pass
dataset_t1_roi: CapsDatasetRoi = caps_reader.get_dataset( extraction = extraction_1, preprocessing = preprocessing_1, sub_ses_tsv = split_dir / "train.tsv", transforms = transforms_1)
dataset_pet_patch: CapsDatasetPatch = caps_reader.get_dataset(extraction = extraction_2, preprocessing= preprocessing_2, sub_ses_tsv = split_dir / "train.tsv", transforms = transforms_2)

dataset_multi_modality_multi_extract = concat_dataset(dataset_t1, dataset_pet) # 2 train.tsv en entrée qu'il faut concat et pareil pour les transforms à faire attention

class Validator:
pass
config_file = Path("config_file")
trainer = Trainer.from_json(config_file=config_file, manager=manager)

# CAS CROSS-VALIDATION
splitter = KFolder(n_splits = 3, caps_dataset=dataset_multi_modality_multi_extract, manager = manager)

# %% maps
maps = MapsIO("/path/to/maps") # Crée un dossier
for split in splitter.split_iterator(split_list = [0,1]):
# bien définir ce qu'il y a dans l'objet split

loss, loss_config = get_loss_function(CrossEntropyLossConfig())
network_config = create_network_config(ImplementedNetworks.CNN)(in_shape = [2,2,2], num_outputs = 1, conv_args = ConvEncoderOptions(channels = [3, 2, 2]))
network = get_network_from_config(network_config)

model = ClinicaDLModel(
network= network,
loss=loss,
optimizer= AdamConfig()
)

# %% Dataset
DataConfig = {
"caps_dir": "",
"tsv": "",
"mode": "",
}
capsdataset = CapsDataset(DataConfig, maps) # torch.dataset
trainer.train(model, split)
# le trainer va instancier un predictor/valdiator dans le train ou dans le init

# %% Model
network = VAE() # nn.module
loss = Loss()
optimizer = ClinicaDLOptim(
Adam()
) # get_optimizer({"name": "Adam", "par1": 0.5}) # torch.optim
# model = ClinicaDLModels(
# network,
# loss,
# optimizer,
# )

# %% Cross val
SplitConfig = SplitterConfig()
splitter = Splitter(SplitConfig, capsdataset)
# CAS SINGLE SPLIT
split = get_single_split(n_subject_validation = 0, caps_dataset=dataset_multi_modality_multi_extract, manager = manager)

# %% Metrics
metrics1 = Metrics("MAE") # monai.metric
metrics2 = Metrics("MSE") # monai.metric
loss, loss_config = get_loss_function(CrossEntropyLossConfig())
network_config = create_network_config(ImplementedNetworks.CNN)(in_shape = [2,2,2], num_outputs = 1, conv_args = ConvEncoderOptions(channels = [3, 2, 2]))
network = get_network_from_config(network_config)

model = ClinicaDLModel(
network= network,
loss=loss,
optimizer= AdamConfig()
)

trainer.train(model, split)
# le trainer va instancier un predictor/valdiator dans le train ou dans le init

# %% Option 1
for split in splitter.iterate():
trainer = Trainer(split, maps, (optimizer))
validator = Validator(split, [metrics1, metrics2], maps)

trainer.train(validator, model)
# TEST

preprocessing_test: PreprocessingConfig = caps_reader.get_preprocessing("pet-linear")
extraction_test: ExtractionConfig= caps_reader.extract_patch(preprocessing=preprocessing_2, arg_patch = 2)
transforms_test = Transforms(data_augmentation = [torchio.t2], image_transforms= [torchio.t1], object_transforms=[torchio.t1, torchio.t2])

# %% Option 2
val = Validator([metrics1, metrics2], maps)
trainer = Trainer(validator, maps)
for split in splitter.iterate():
trainer.train(model, split)
dataset_test: CapsDatasetROI = caps_reader.get_dataset( extraction = extraction_test, preprocessing = preprocessing_test, sub_ses_tsv = split_dir / "test.tsv", transforms = transforms_test)

# %% Option 3
trainer = Trainer(
maps, [metrics1, metrics2]
) # Initialise un maps manager + initialise un validator
for split in splitter.iterate():
model = ClinicaDLModels(
network,
loss,
optimizer,
)
trainer.train(model, split)
predictor = Predictor(manager= manager)
predictor.predict(dataset_test= dataset_test, split = 2)

0 comments on commit b0d54d0

Please sign in to comment.