Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
camillebrianceau committed Oct 4, 2024
1 parent e1f801f commit 4ab03ae
Show file tree
Hide file tree
Showing 7 changed files with 112 additions and 160 deletions.
81 changes: 73 additions & 8 deletions clinicadl/maps_manager/maps_manager.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import shutil
import subprocess
from datetime import datetime
from logging import getLogger
Expand Down Expand Up @@ -50,7 +51,7 @@ def __init__(
self,
maps_path: Path,
parameters: Optional[Dict[str, Any]] = None,
verbose: str = "info",
verbose: Optional[str] = "info",
):
"""
Expand Down Expand Up @@ -569,13 +570,13 @@ def _mode_to_image_tsv(
###############################
def _init_model(
self,
transfer_path: Path = None,
transfer_selection=None,
nb_unfrozen_layer=0,
split=None,
resume=False,
gpu=None,
network=None,
transfer_path: Optional[Path] = None,
transfer_selection: Optional[str] = None,
nb_unfrozen_layer: int = 0,
split: Optional[int] = None,
resume: bool = False,
gpu: bool = False,
network: Optional[int] = None,
):
"""
Instantiate the model
Expand Down Expand Up @@ -778,3 +779,67 @@ def std_amp(self) -> bool:
then calls the internal FSDP AMP mechanisms.
"""
return self.amp and not self.fully_sharded_data_parallel

def _erase_tmp(self, split: int):
"""
Erases checkpoints of the model and optimizer at the end of training.
Parameters
----------
split : int
The split on which the model has been trained.
"""
tmp_path = self.maps_path / f"split-{split}" / "tmp"
shutil.rmtree(tmp_path)

def _write_weights(
self,
state: Dict[str, Any],
metrics_dict: Optional[Dict[str, bool]],
split: int,
network: Optional[int] = None,
filename: str = "checkpoint.pth.tar",
save_all_models: bool = False,
) -> None:
"""
Update checkpoint and save the best model according to a set of
metrics.
Parameters
----------
state : Dict[str, Any]
The state of the training (model weights, epoch, etc.).
metrics_dict : Optional[Dict[str, bool]]
The output of RetainBest step. If None, only the checkpoint
is saved.
split : int
The split number.
network : int (optional, default=None)
The network number (multi-network framework).
filename : str (optional, default="checkpoint.pth.tar")
The name of the checkpoint file.
save_all_models : bool (optional, default=False)
Whether to save model weights at every epoch.
If False, only the best model will be saved.
"""
checkpoint_dir = self.maps_path / f"split-{split}" / "tmp"
checkpoint_dir.mkdir(parents=True, exist_ok=True)
checkpoint_path = checkpoint_dir / filename
torch.save(state, checkpoint_path)

if save_all_models:
all_models_dir = self.maps_path / f"split-{split}" / "all_models"
all_models_dir.mkdir(parents=True, exist_ok=True)
torch.save(state, all_models_dir / f"model_epoch_{state['epoch']}.pth.tar")

best_filename = "model.pth.tar"
if network is not None:
best_filename = f"network-{network}_model.pth.tar"

# Save model according to several metrics
if metrics_dict is not None:
for metric_name, metric_bool in metrics_dict.items():
metric_path = self.maps_path / f"split-{split}" / f"best-{metric_name}"
if metric_bool:
metric_path.mkdir(parents=True, exist_ok=True)
shutil.copyfile(checkpoint_path, metric_path / best_filename)
4 changes: 2 additions & 2 deletions clinicadl/splitter/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class SplitConfig(BaseModel):

n_splits: NonNegativeInt = 0
split: Optional[Tuple[NonNegativeInt, ...]] = None
tsv_path: Path # not needed in predict ?
tsv_path: Optional[Path] = None # not needed in interpret !

# pydantic config
model_config = ConfigDict(validate_assignment=True)
Expand All @@ -39,7 +39,7 @@ def adapt_cross_val_with_maps_manager_info(
): # maps_manager is of type MapsManager but need to be in a MapsConfig type in the future
# TEMPORARY
if not self.split:
self.split = find_splits(maps_manager.maps_path)
self.split = tuple(find_splits(maps_manager.maps_path))
logger.debug(f"List of splits {self.split}")


Expand Down
4 changes: 2 additions & 2 deletions clinicadl/splitter/split_utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from pathlib import Path
from typing import List, Optional
from typing import List


def find_splits(maps_path: Path) -> List[int]:
"""Find which splits that were trained in the MAPS."""
splits = [
int(split.name.split("-")[1])
for split in list(maps_path.iterdir())
if split.name.startswith(f"split-")
if split.name.startswith("split-")
]
return splits

Expand Down
16 changes: 3 additions & 13 deletions clinicadl/splitter/splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,7 @@

import pandas as pd

from clinicadl.caps_dataset.data_config import DataConfig
from clinicadl.splitter.config import SplitConfig, SplitterConfig
from clinicadl.splitter.validation import ValidationConfig
from clinicadl.utils.exceptions import (
ClinicaDLArgumentError,
ClinicaDLConfigurationError,
ClinicaDLTSVError,
)
from clinicadl.utils.iotools.clinica_utils import check_caps_folder
from clinicadl.splitter.config import SplitterConfig

logger = getLogger("clinicadl.split_manager")

Expand Down Expand Up @@ -147,9 +139,7 @@ def get_dataframe_from_tsv_path(tsv_path: Path) -> pd.DataFrame:
return df

@staticmethod
def load_data(
tsv_path: Path, cohort_diagnoses: Optional[List[str]] = None
) -> pd.DataFrame:
def load_data(tsv_path: Path, cohort_diagnoses: List[str]) -> pd.DataFrame:
df = Splitter.get_dataframe_from_tsv_path(tsv_path)
df = df[df.diagnosis.isin((cohort_diagnoses))]
df.reset_index(inplace=True, drop=True)
Expand All @@ -164,7 +154,7 @@ def concatenate_diagnoses(
"""Concatenated the diagnoses needed to form the train and validation sets."""

if cohort_diagnoses is None:
cohort_diagnoses = self.config.data.diagnoses
cohort_diagnoses = list(self.config.data.diagnoses)

tmp_cohort_path = (
cohort_path if cohort_path is not None else self.config.split.tsv_path
Expand Down
5 changes: 1 addition & 4 deletions clinicadl/splitter/validation.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
from logging import getLogger
from pathlib import Path
from typing import Optional, Tuple
from typing import Tuple

from pydantic import BaseModel, ConfigDict, field_validator
from pydantic.types import NonNegativeInt

from clinicadl.splitter.split_utils import find_splits

logger = getLogger("clinicadl.validation_config")


Expand Down
Loading

0 comments on commit 4ab03ae

Please sign in to comment.