Skip to content

Commit

Permalink
final commit i hope
Browse files Browse the repository at this point in the history
  • Loading branch information
camillebrianceau committed Oct 17, 2024
1 parent 6412243 commit 06ec3c2
Show file tree
Hide file tree
Showing 11 changed files with 18 additions and 32 deletions.
1 change: 0 additions & 1 deletion clinicadl/commandline/pipelines/generate/trivial/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ def create_trivial_image(subject_id: int) -> pd.DataFrame:
if caps_config.data.mask_path is None:
caps_config.data.mask_path = get_mask_path()
path_to_mask = caps_config.data.mask_path / f"mask-{label + 1}.nii"
print(path_to_mask)
if path_to_mask.is_file():
atlas_to_mask = nib.loadsave.load(path_to_mask).get_fdata()
else:
Expand Down
4 changes: 3 additions & 1 deletion clinicadl/commandline/pipelines/train/from_json/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ def cli(**kwargs):
logger.info(f"Reading JSON file at path {kwargs['config_file']}...")

trainer = Trainer.from_json(
config_file=kwargs["config_file"], maps_path=kwargs["output_maps_directory"]
config_file=kwargs["config_file"],
maps_path=kwargs["output_maps_directory"],
split=kwargs["split"],
)
trainer.train(split_list=kwargs["split"], overwrite=True)
1 change: 0 additions & 1 deletion clinicadl/monai_networks/config/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ def _divide(
numerator: Union[int, Tuple[int, ...]],
denominator: Union[int, Tuple[int, ...]],
) -> bool:
print(self.dim)
"""Checks if numerator is divisible by denominator."""
if isinstance(numerator, int):
numerator = (numerator,) * self.dim
Expand Down
2 changes: 0 additions & 2 deletions clinicadl/nn/networks/cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,6 @@ def __init__(self, convolution_layers: nn.Module, fc_layers: nn.Module) -> None:

def forward(self, x):
inter = self.convolutions(x)
print(self.convolutions)
print(inter.shape)
return self.fc(inter)


Expand Down
1 change: 0 additions & 1 deletion clinicadl/nn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ def compute_output_size(
input_ = torch.randn(input_size).unsqueeze(0)
if isinstance(layer, nn.MaxUnpool3d) or isinstance(layer, nn.MaxUnpool2d):
indices = torch.zeros_like(input_, dtype=int)
print(indices)
output = layer(input_, indices)
else:
output = layer(input_)
Expand Down
12 changes: 0 additions & 12 deletions clinicadl/predictor/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,19 +53,11 @@ def __init__(self, _config: Union[PredictConfig, InterpretConfig]) -> None:

self.maps_manager = MapsManager(_config.maps_manager.maps_dir)
self._config.adapt_with_maps_manager_info(self.maps_manager)
print(
self._config.data.model_dump(
exclude=set(["preprocessing_dict", "mode", "caps_dict"])
)
)
tmp = self._config.data.model_dump(
exclude=set(["preprocessing_dict", "mode", "caps_dict"])
)
print(tmp)
tmp.update(self._config.split.model_dump())
print(tmp)
tmp.update(self._config.validation.model_dump())
print(tmp)
self.splitter = Splitter(SplitterConfig(**tmp))

def predict(
Expand Down Expand Up @@ -205,7 +197,6 @@ def _predict_single(
)
if self._config.maps_manager.save_tensor:
logger.debug("Saving tensors")
print("save_tensor")
self._compute_output_tensors(
maps_manager=self.maps_manager,
dataset=data_test,
Expand Down Expand Up @@ -526,9 +517,7 @@ def _check_data_group(
/ self._config.maps_manager.data_group
)
logger.debug(f"Group path {group_dir}")
print(f"group_dir: {group_dir}")
if group_dir.is_dir(): # Data group already exists
print("is dir")
if self._config.maps_manager.overwrite:
if self._config.maps_manager.data_group in ["train", "validation"]:
raise MAPSError("Cannot overwrite train or validation data group.")
Expand Down Expand Up @@ -1047,7 +1036,6 @@ def _test_loader(

if cluster.master:
# Replace here
print("before saving")
maps_manager._mode_level_to_tsv(
prediction_df,
metrics,
Expand Down
1 change: 0 additions & 1 deletion clinicadl/predictor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ def get_prediction(
prediction_dir = (
maps_path / f"split-{split}" / f"best-{selection_metric}" / data_group
)
print(prediction_dir)
if not prediction_dir.is_dir():
raise MAPSError(
f"No prediction corresponding to data group {data_group} was found."
Expand Down
2 changes: 0 additions & 2 deletions clinicadl/trainer/tasks_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,6 @@ def generate_label_code(
network_task = Task(network_task)
if network_task == Task.CLASSIFICATION:
unique_labels = sorted(set(df[label]))
print("unique labels", unique_labels)
return {str(key): value for value, key in enumerate(unique_labels)}

return None
Expand Down Expand Up @@ -624,7 +623,6 @@ def generate_sampler(

def calculate_weights_classification(df):
labels = df[dataset.config.data.label].unique()
print(dataset.config.data.label_code)
codes = {label_code[label] for label in labels}
count = np.zeros(len(codes))

Expand Down
21 changes: 10 additions & 11 deletions clinicadl/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __init__(
predict_config = PredictConfig(**config.get_dict())
self.validator = Predictor(predict_config)

### test
# test
splitter_config = SplitterConfig(**self.config.get_dict())
self.splitter = Splitter(splitter_config)
self._check_args()
Expand All @@ -92,7 +92,12 @@ def _init_maps_manager(self, config) -> MapsManager:
) # TODO : precise which parameters in config are useful

@classmethod
def from_json(cls, config_file: str | Path, maps_path: str | Path) -> Trainer:
def from_json(
cls,
config_file: str | Path,
maps_path: str | Path,
split: Optional[list[int]] = None,
) -> Trainer:
"""
Creates a Trainer from a json configuration file.
Expand All @@ -119,10 +124,10 @@ def from_json(cls, config_file: str | Path, maps_path: str | Path) -> Trainer:
raise FileNotFoundError(f"No file found at {str(config_file)}.")
config_dict = patch_to_read_json(read_json(config_file)) # TODO : remove patch
config_dict["maps_dir"] = maps_path
config_dict["split"] = split if split else ()
config_object = create_training_config(config_dict["network_task"])(
**config_dict
)
print(config_object.model_dump())
return cls(config_object)

@classmethod
Expand Down Expand Up @@ -173,9 +178,6 @@ def resume(self) -> None:
split_iterator = self.splitter.split_iterator()
###
absent_splits = set(split_iterator) - stopped_splits - finished_splits
print("split:", set(split_iterator))
print("stopped split:", stopped_splits)
print("finished split:", finished_splits)

logger.info(
f"Finished splits {finished_splits}\n"
Expand All @@ -194,8 +196,8 @@ def resume(self) -> None:

def _check_args(self):
self.config.reproducibility.seed = get_seed(self.config.reproducibility.seed)
# if (len(self.config.data.label_code) == 0):
# self.config.data.label_code = self.maps_manager.label_code
if len(self.config.data.label_code) == 0:
self.config.data.label_code = self.maps_manager.label_code
# TODO: deal with label_code and replace self.maps_manager.label_code
from clinicadl.trainer.tasks_utils import generate_label_code

Expand All @@ -208,9 +210,6 @@ def _check_args(self):
self.config.data.label_code = generate_label_code(
self.config.network_task, train_df, self.config.data.label
)
print(train_df)
print(self.config.network_task)
print("in check args : ", self.config.data.label_code)

def train(
self,
Expand Down
1 change: 1 addition & 0 deletions tests/test_train_from_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def test_determinism(cmdopt, tmp_path):

# Reproduce experiment (train from json)
config_json = tmp_out_dir / "maps_roi_cnn/maps.json"

flag_error = not system(
f"clinicadl train from_json {str(config_json)} {str(reproduced_maps_dir)} -s 0"
)
Expand Down
4 changes: 4 additions & 0 deletions tests/testing_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def modify_maps(
base_dir: Path,
no_gpu: bool = False,
adapt_base_dir: bool = False,
modify_split: bool = False,
) -> Dict[str, Any]:
"""
Modifies a MAPS dictionary if the user passed --no-gpu or --adapt-base-dir flags.
Expand Down Expand Up @@ -208,6 +209,9 @@ def modify_maps(
)
except KeyError: # maps with only caps directory
pass

if modify_split:
maps["split"] = (0,)
return maps


Expand Down

0 comments on commit 06ec3c2

Please sign in to comment.