Skip to content

Commit

Permalink
Sl dart ssda (aramis-lab#485)
Browse files Browse the repository at this point in the history
* Add proposed SSDA method for MICCAI
  • Loading branch information
sophieloiz authored and camillebrianceau committed Mar 21, 2024
1 parent 55e2128 commit 759baf2
Show file tree
Hide file tree
Showing 23 changed files with 1,217 additions and 17 deletions.
4 changes: 4 additions & 0 deletions clinicadl/generate/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,8 @@ def create_random_image(subject_id):
write_missing_mods(output_dir, output_df)
logger.info(f"Random dataset was generated at {output_dir}")

logger.info(f"Random dataset was generated at {output_dir}")


def generate_trivial_dataset(
caps_directory: Path,
Expand Down Expand Up @@ -355,6 +357,8 @@ def create_trivial_image(subject_id, output_df):
write_missing_mods(output_dir, output_df)
logger.info(f"Trivial dataset was generated at {output_dir}")

logger.info(f"Trivial dataset was generated at {output_dir}")


def generate_shepplogan_dataset(
output_dir: Path,
Expand Down
83 changes: 83 additions & 0 deletions clinicadl/mlflow_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import logging
import os
import sys
import warnings
from urllib.parse import urlparse

import mlflow
import mlflow.sklearn
import numpy as np
import pandas as pd
from sklearn.linear_model import ElasticNet
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from sklearn.model_selection import train_test_split

logging.basicConfig(level=logging.WARN)
logger = logging.getLogger(__name__)


def eval_metrics(actual, pred):
rmse = np.sqrt(mean_squared_error(actual, pred))
mae = mean_absolute_error(actual, pred)
r2 = r2_score(actual, pred)
return rmse, mae, r2


if __name__ == "__main__":
warnings.filterwarnings("ignore")
np.random.seed(40)

# Read the wine-quality csv file from the URL
csv_url = "https://raw.githubusercontent.com/mlflow/mlflow/master/tests/data/winequality-red.csv"
try:
data = pd.read_csv(csv_url, sep=";")
except Exception as e:
logger.exception(
"Unable to download training & test CSV, check your internet connection. Error: %s",
e,
)

# Split the data into training and test sets. (0.75, 0.25) split.
train, test = train_test_split(data)

# The predicted column is "quality" which is a scalar from [3, 9]
train_x = train.drop(["quality"], axis=1)
test_x = test.drop(["quality"], axis=1)
train_y = train[["quality"]]
test_y = test[["quality"]]

alpha = float(sys.argv[1]) if len(sys.argv) > 1 else 0.5
l1_ratio = float(sys.argv[2]) if len(sys.argv) > 2 else 0.5

with mlflow.start_run():
lr = ElasticNet(alpha=alpha, l1_ratio=l1_ratio, random_state=42)
lr.fit(train_x, train_y)

predicted_qualities = lr.predict(test_x)

(rmse, mae, r2) = eval_metrics(test_y, predicted_qualities)

print("Elasticnet model (alpha={:f}, l1_ratio={:f}):".format(alpha, l1_ratio))
print(" RMSE: %s" % rmse)
print(" MAE: %s" % mae)
print(" R2: %s" % r2)

mlflow.log_param("alpha", alpha)
mlflow.log_param("l1_ratio", l1_ratio)
mlflow.log_metric("rmse", rmse)
mlflow.log_metric("r2", r2)
mlflow.log_metric("mae", mae)

tracking_url_type_store = urlparse(mlflow.get_tracking_uri()).scheme

# Model registry does not work with file store
if tracking_url_type_store != "file":
# Register the model
# There are other ways to use the Model Registry, which depends on the use case,
# please refer to the doc for more information:
# https://mlflow.org/docs/latest/model-registry.html#api-workflow
mlflow.sklearn.log_model(
lr, "model", registered_model_name="ElasticnetWineModel"
)
else:
mlflow.sklearn.log_model(lr, "model")
1 change: 1 addition & 0 deletions clinicadl/random_search/random_search_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def random_sampling(rs_options: Dict[str, Any]) -> Dict[str, Any]:
"mode": "fixed",
"multi_cohort": "fixed",
"multi_network": "choice",
"ssda_netork": "fixed",
"n_fcblocks": "randint",
"n_splits": "fixed",
"n_proc": "fixed",
Expand Down
7 changes: 6 additions & 1 deletion clinicadl/resources/config/train_config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
[Model]
architecture = "default" # ex : Conv5_FC3
multi_network = false
ssda_network = false

[Architecture]
# CNN
Expand Down Expand Up @@ -67,6 +68,10 @@ data_augmentation = false
sampler = "random"
size_reduction=false
size_reduction_factor=2
caps_target = ""
tsv_target_lab = ""
tsv_target_unlab = ""
preprocessing_dict_target = ""

[Cross_validation]
n_splits = 0
Expand All @@ -83,4 +88,4 @@ accumulation_steps = 1
profiler = false

[Informations]
emissions_calculator = false
emissions_calculator = false
5 changes: 5 additions & 0 deletions clinicadl/train/tasks/classification_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
# Model
@train_option.architecture
@train_option.multi_network
@train_option.ssda_network
# Data
@train_option.multi_cohort
@train_option.diagnoses
Expand All @@ -35,6 +36,10 @@
@train_option.normalize
@train_option.data_augmentation
@train_option.sampler
@train_option.caps_target
@train_option.tsv_target_lab
@train_option.tsv_target_unlab
@train_option.preprocessing_dict_target
# Cross validation
@train_option.n_splits
@train_option.split
Expand Down
5 changes: 5 additions & 0 deletions clinicadl/train/tasks/reconstruction_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
# Model
@train_option.architecture
@train_option.multi_network
@train_option.ssda_network
# Data
@train_option.multi_cohort
@train_option.diagnoses
Expand All @@ -35,6 +36,10 @@
@train_option.normalize
@train_option.data_augmentation
@train_option.sampler
@train_option.caps_target
@train_option.tsv_target_lab
@train_option.tsv_target_unlab
@train_option.preprocessing_dict_target
# Cross validation
@train_option.n_splits
@train_option.split
Expand Down
5 changes: 5 additions & 0 deletions clinicadl/train/tasks/regression_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
# Model
@train_option.architecture
@train_option.multi_network
@train_option.ssda_network
# Data
@train_option.multi_cohort
@train_option.diagnoses
Expand All @@ -35,6 +36,10 @@
@train_option.normalize
@train_option.data_augmentation
@train_option.sampler
@train_option.caps_target
@train_option.tsv_target_lab
@train_option.tsv_target_unlab
@train_option.preprocessing_dict_target
# Cross validation
@train_option.n_splits
@train_option.split
Expand Down
33 changes: 33 additions & 0 deletions clinicadl/train/tasks/task_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def task_launcher(network_task: str, task_options_list: List[str], **kwargs):
"learning_rate",
"multi_cohort",
"multi_network",
"ssda_network",
"n_proc",
"n_splits",
"nb_unfrozen_layer",
Expand All @@ -66,6 +67,10 @@ def task_launcher(network_task: str, task_options_list: List[str], **kwargs):
"sampler",
"seed",
"split",
"caps_target",
"tsv_target_lab",
"tsv_target_unlab",
"preprocessing_dict_target",
]
all_options_list = standard_options_list + task_options_list

Expand All @@ -80,6 +85,13 @@ def task_launcher(network_task: str, task_options_list: List[str], **kwargs):
/ "tensor_extraction"
/ kwargs["preprocessing_json"]
)

if train_dict["ssda_network"]:
preprocessing_json_target = (
Path(kwargs["caps_target"])
/ "tensor_extraction"
/ kwargs["preprocessing_dict_target"]
)
else:
caps_dict = CapsDataset.create_caps_dict(
train_dict["caps_directory"], train_dict["multi_cohort"]
Expand All @@ -99,12 +111,33 @@ def task_launcher(network_task: str, task_options_list: List[str], **kwargs):
f"Preprocessing JSON {kwargs['preprocessing_json']} was not found for any CAPS "
f"in {caps_dict}."
)
# To CHECK AND CHANGE
if train_dict["ssda_network"]:
caps_target = Path(kwargs["caps_target"])
preprocessing_json_target = (
caps_target / "tensor_extraction" / kwargs["preprocessing_dict_target"]
)

if preprocessing_json_target.is_file():
logger.info(
f"Preprocessing JSON {preprocessing_json_target} found in CAPS {caps_target}."
)
json_found = True
if not json_found:
raise ValueError(
f"Preprocessing JSON {kwargs['preprocessing_json_target']} was not found for any CAPS "
f"in {caps_target}."
)

# Mode and preprocessing
preprocessing_dict = read_preprocessing(preprocessing_json)
train_dict["preprocessing_dict"] = preprocessing_dict
train_dict["mode"] = preprocessing_dict["mode"]

if train_dict["ssda_network"]:
preprocessing_dict_target = read_preprocessing(preprocessing_json_target)
train_dict["preprocessing_dict_target"] = preprocessing_dict_target

# Add default values if missing
if (
preprocessing_dict["mode"] == "roi"
Expand Down
2 changes: 2 additions & 0 deletions clinicadl/tsvtools/split/split.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,8 @@ def split_diagnoses(

if categorical_split_variable is None:
categorical_split_variable = "diagnosis"
else:
categorical_split_variable.append("diagnosis")

# Read files
diagnosis_df_path = data_tsv.name
Expand Down
39 changes: 32 additions & 7 deletions clinicadl/utils/caps_dataset/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,11 @@ def __init__(
raise AttributeError("Child class of CapsDataset, must set mode attribute.")

self.df = data_df

mandatory_col = {"participant_id", "session_id", "cohort"}
mandatory_col = {
"participant_id",
"session_id",
"cohort",
}
if self.label_presence and self.label is not None:
mandatory_col.add(self.label)

Expand Down Expand Up @@ -108,6 +111,18 @@ def label_fn(self, target: Union[str, float, int]) -> Union[float, int]:
else:
return self.label_code[str(target)]

def domain_fn(self, target: Union[str, float, int]) -> Union[float, int]:
"""
Returns the label value usable in criterion.
Args:
target: value of the target.
Returns:
label: value of the label usable in criterion.
"""
domain_code = {"t1": 0, "flair": 1}
return domain_code[str(target)]

def __len__(self) -> int:
return len(self.df) * self.elem_per_image

Expand Down Expand Up @@ -209,7 +224,12 @@ def _get_meta_data(self, idx: int) -> Tuple[str, str, str, int, int]:
else:
label = -1

return participant, session, cohort, elem_idx, label
if "domain" in self.df.columns:
domain = self.df.loc[image_idx, "domain"]
domain = self.domain_fn(domain)
else:
domain = "" # TO MODIFY
return participant, session, cohort, elem_idx, label, domain

def _get_full_image(self) -> torch.Tensor:
"""
Expand Down Expand Up @@ -323,7 +343,7 @@ def elem_index(self):
return None

def __getitem__(self, idx):
participant, session, cohort, _, label = self._get_meta_data(idx)
participant, session, cohort, _, label, domain = self._get_meta_data(idx)

image_path = self._get_image_path(participant, session, cohort)
image = torch.load(image_path)
Expand All @@ -341,6 +361,7 @@ def __getitem__(self, idx):
"session_id": session,
"image_id": 0,
"image_path": image_path.as_posix(),
"domain": domain,
}

return sample
Expand Down Expand Up @@ -400,7 +421,9 @@ def elem_index(self):
return self.patch_index

def __getitem__(self, idx):
participant, session, cohort, patch_idx, label = self._get_meta_data(idx)
participant, session, cohort, patch_idx, label, domain = self._get_meta_data(
idx
)
image_path = self._get_image_path(participant, session, cohort)

if self.prepare_dl:
Expand Down Expand Up @@ -507,7 +530,7 @@ def elem_index(self):
return self.roi_index

def __getitem__(self, idx):
participant, session, cohort, roi_idx, label = self._get_meta_data(idx)
participant, session, cohort, roi_idx, label, domain = self._get_meta_data(idx)
image_path = self._get_image_path(participant, session, cohort)

if self.roi_list is None:
Expand Down Expand Up @@ -672,7 +695,9 @@ def elem_index(self):
return self.slice_index

def __getitem__(self, idx):
participant, session, cohort, slice_idx, label = self._get_meta_data(idx)
participant, session, cohort, slice_idx, label, domain = self._get_meta_data(
idx
)
slice_idx = slice_idx + self.discarded_slices[0]
image_path = self._get_image_path(participant, session, cohort)

Expand Down
7 changes: 7 additions & 0 deletions clinicadl/utils/cli_param/option.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,13 @@
multiple=True,
default=None,
)
ssda_network = click.option(
"--ssda_network",
type=bool,
default=False,
show_default=True,
help="ssda training.",
)
# GENERATE
participant_list = click.option(
"--participants_tsv",
Expand Down
Loading

0 comments on commit 759baf2

Please sign in to comment.