Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sl dart ssda #485

Merged
merged 123 commits into from
Oct 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
123 commits
Select commit Hold shift + click to select a range
a1337cb
merge dev dans master (#432)
camillebrianceau May 24, 2023
bd8e9d5
Add proposed SSDA method for MICCAI
sophieloiz Sep 28, 2023
0985c27
Add ssda network option
sophieloiz Sep 28, 2023
381e018
Add ssda option
sophieloiz Sep 29, 2023
d2ae2de
Add all ssda options
sophieloiz Sep 29, 2023
826f9b4
Fix issue with Path lib
sophieloiz Sep 29, 2023
1df24aa
Fix issue with Path lib
sophieloiz Sep 29, 2023
7e2a0b1
Fix issue with param name
sophieloiz Sep 29, 2023
4cc46e7
Fix issue with param name
sophieloiz Sep 29, 2023
7bd8c35
Fix issue dict_target
sophieloiz Sep 29, 2023
d4594b1
Fix issue dict_target
sophieloiz Sep 29, 2023
1eef433
Fix issue dict_target
sophieloiz Sep 29, 2023
48e7add
Fix issue dict_target
sophieloiz Sep 29, 2023
dfcee3f
Fix issue dict_target
sophieloiz Sep 29, 2023
6b6f1eb
Fix issue dict_target
sophieloiz Sep 29, 2023
5372243
Fix issue dict_target
sophieloiz Sep 29, 2023
4bed782
Fix issue dict_target
sophieloiz Sep 29, 2023
8953c43
Fix issue dict_target
sophieloiz Sep 29, 2023
004e46a
Fix sampler
sophieloiz Sep 29, 2023
41a0869
Fix optimizer
sophieloiz Sep 29, 2023
5e70d04
Fix optimizer
sophieloiz Sep 29, 2023
4dc476e
Update subnetworks
sophieloiz Sep 29, 2023
7cf339f
Update subnetworks
sophieloiz Sep 29, 2023
69f2968
Update subnetworks
sophieloiz Sep 29, 2023
864ead9
Update subnetworks
sophieloiz Sep 29, 2023
f3be271
Update subnetworks
sophieloiz Sep 29, 2023
e4abe31
Update subnetworks
sophieloiz Sep 29, 2023
34db0c4
Update subnetworks
sophieloiz Sep 29, 2023
13e5823
Update subnetworks
sophieloiz Sep 29, 2023
f6c4422
Update subnetworks
sophieloiz Sep 29, 2023
45162ce
Update subnetworks
sophieloiz Sep 29, 2023
86266a9
Update subnetworks
sophieloiz Sep 29, 2023
8661c42
Update subnetworks
sophieloiz Sep 29, 2023
2770dbc
Update subnetworks
sophieloiz Sep 29, 2023
7494265
Update subnetworks
sophieloiz Sep 29, 2023
2e0ddd0
Update subnetworks
sophieloiz Sep 29, 2023
2e07669
Update subnetworks
sophieloiz Sep 29, 2023
66e34c0
Update subnetworks
sophieloiz Sep 29, 2023
590ef92
Update subnetworks
sophieloiz Sep 29, 2023
930f469
Update subnetworks
sophieloiz Sep 29, 2023
32f6936
Update maps_manager
sophieloiz Sep 29, 2023
a2b0a6c
Update maps_manager
sophieloiz Sep 29, 2023
4f22136
Update maps_manager
sophieloiz Sep 29, 2023
33cb0db
Update log writter
sophieloiz Oct 2, 2023
6a8d780
Update log writter
sophieloiz Oct 2, 2023
d89a519
Update test loader
sophieloiz Oct 2, 2023
a831552
Update test loader
sophieloiz Oct 2, 2023
12d04b7
Update test loader
sophieloiz Oct 2, 2023
25650dd
Update test loader
sophieloiz Oct 2, 2023
fb1af2e
Update test loader
sophieloiz Oct 2, 2023
20fc841
merge
camillebrianceau Oct 2, 2023
8b72472
black
camillebrianceau Oct 3, 2023
bf0a427
Test initialization
sophieloiz Oct 3, 2023
34292e5
tests
camillebrianceau Oct 3, 2023
a8215b4
tests
camillebrianceau Oct 3, 2023
12586e4
Weights the loss
sophieloiz Oct 4, 2023
a295235
Debug test_da function
sophieloiz Oct 4, 2023
39beb23
Debug test_da function
sophieloiz Oct 4, 2023
2926e15
Debug test_da function
sophieloiz Oct 4, 2023
605fbd4
Debug test_da function
sophieloiz Oct 4, 2023
a54422b
Debug test_da function
sophieloiz Oct 4, 2023
46b09a9
Debug test_da function
sophieloiz Oct 4, 2023
e6b76d9
Debug test_da function
sophieloiz Oct 4, 2023
27cb127
Debug test_da function
sophieloiz Oct 4, 2023
4a018b0
Debug test_da function
sophieloiz Oct 4, 2023
0be8c94
Debug test_da function
sophieloiz Oct 4, 2023
aa96175
Debug test_da function
sophieloiz Oct 4, 2023
dbd05b1
Debug test_da function
sophieloiz Oct 4, 2023
1581092
Debug test_da function
sophieloiz Oct 4, 2023
cd28952
Debug test_da function
sophieloiz Oct 4, 2023
601a914
Debug test_da function
sophieloiz Oct 4, 2023
34ac3b1
Debug test_da function
sophieloiz Oct 4, 2023
3515f45
Debug test_da function
sophieloiz Oct 4, 2023
0881f8d
Debug test_da function
sophieloiz Oct 4, 2023
02e33b6
Debug test_da function
sophieloiz Oct 4, 2023
7e54ca7
Debug test_da function
sophieloiz Oct 4, 2023
24ca5fb
Debug test_da function
sophieloiz Oct 4, 2023
c32ce63
Debug test_da function
sophieloiz Oct 4, 2023
0368ad2
Debug test_da function
sophieloiz Oct 4, 2023
69985fd
Debug test_da function
sophieloiz Oct 4, 2023
10811bf
Debug test_da function
sophieloiz Oct 4, 2023
e05da1c
Debug test_da function
sophieloiz Oct 4, 2023
317c888
Debug test_da function
sophieloiz Oct 4, 2023
1169c4f
Debug test_da function
sophieloiz Oct 4, 2023
ae3a62f
Debug test_da function
sophieloiz Oct 4, 2023
9540aac
Debug test_da function
sophieloiz Oct 4, 2023
bd98d67
Debug test_da function
sophieloiz Oct 4, 2023
cc70f3f
Debug test_da function
sophieloiz Oct 4, 2023
de7f9a1
Debug test_da function
sophieloiz Oct 4, 2023
3ce0555
Debug test_da function
sophieloiz Oct 4, 2023
80f8c56
Debug test_da function
sophieloiz Oct 4, 2023
1a8d2ac
Debug test_da function
sophieloiz Oct 4, 2023
412e3cf
Merge branch 'sl_dart_ssda' into sl_dart_eds
sophieloiz Oct 4, 2023
2b1d23f
Last changes before merge
sophieloiz Oct 4, 2023
14ffe4b
Merge pull request #5 from sophieloiz/sl_dart_eds
sophieloiz Oct 4, 2023
9f05346
Black and isort
sophieloiz Oct 4, 2023
de06f73
Fix conficts
sophieloiz Oct 4, 2023
9198ec8
pre-commit
sophieloiz Oct 4, 2023
8a0c09c
pre-commit
sophieloiz Oct 4, 2023
5a2edc4
Merge pull request #6 from sophieloiz/sl_dart_eds
sophieloiz Oct 4, 2023
56651ab
Fix test issues
sophieloiz Oct 4, 2023
a24e3aa
Merge pull request #7 from sophieloiz/sl_dart_eds
sophieloiz Oct 4, 2023
5d6568e
Fix test issues
sophieloiz Oct 4, 2023
40571bb
Merge pull request #8 from sophieloiz/sl_dart_eds
sophieloiz Oct 4, 2023
1dd76d0
Fix test issues
sophieloiz Oct 4, 2023
1a3c880
Merge pull request #9 from sophieloiz/sl_dart_eds
sophieloiz Oct 4, 2023
ae57cca
Hard code ssda_network
sophieloiz Oct 5, 2023
bd8cd77
Hard code ssda_network
sophieloiz Oct 5, 2023
72c0f45
Merge pull request #10 from sophieloiz/sl_dart_eds
sophieloiz Oct 5, 2023
d652fb2
Clean code
sophieloiz Oct 5, 2023
3bc930d
Clean code
sophieloiz Oct 5, 2023
485cd4c
Fix parameters
sophieloiz Oct 5, 2023
a07c04d
Merge pull request #11 from sophieloiz/sl_dart_eds
sophieloiz Oct 5, 2023
9d967c6
Fix parameters
sophieloiz Oct 5, 2023
71a368c
Merge pull request #12 from sophieloiz/sl_dart_eds
sophieloiz Oct 5, 2023
7485013
Debug task utils with ssda parameters
sophieloiz Oct 5, 2023
c9abd6d
Merge pull request #13 from sophieloiz/sl_dart_eds
sophieloiz Oct 5, 2023
d7640e1
Update data files
sophieloiz Oct 5, 2023
0556cbc
Merge pull request #14 from sophieloiz/sl_dart_eds
sophieloiz Oct 5, 2023
f72240e
Update data files
sophieloiz Oct 5, 2023
ea9c57d
Update data files
sophieloiz Oct 5, 2023
82ae8e6
Add ssda for reconstruction and regression
sophieloiz Oct 5, 2023
a4f0edc
Merge pull request #15 from sophieloiz/sl_dart_eds
sophieloiz Oct 5, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions clinicadl/generate/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,8 @@ def create_random_image(subject_id):
write_missing_mods(output_dir, output_df)
logger.info(f"Random dataset was generated at {output_dir}")

logger.info(f"Random dataset was generated at {output_dir}")


def generate_trivial_dataset(
caps_directory: Path,
Expand Down Expand Up @@ -355,6 +357,8 @@ def create_trivial_image(subject_id, output_df):
write_missing_mods(output_dir, output_df)
logger.info(f"Trivial dataset was generated at {output_dir}")

logger.info(f"Trivial dataset was generated at {output_dir}")


def generate_shepplogan_dataset(
output_dir: Path,
Expand Down
83 changes: 83 additions & 0 deletions clinicadl/mlflow_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import logging
import os
import sys
import warnings
from urllib.parse import urlparse

import mlflow
import mlflow.sklearn
import numpy as np
import pandas as pd
from sklearn.linear_model import ElasticNet
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from sklearn.model_selection import train_test_split

logging.basicConfig(level=logging.WARN)
logger = logging.getLogger(__name__)


def eval_metrics(actual, pred):
rmse = np.sqrt(mean_squared_error(actual, pred))
mae = mean_absolute_error(actual, pred)
r2 = r2_score(actual, pred)
return rmse, mae, r2


if __name__ == "__main__":
warnings.filterwarnings("ignore")
np.random.seed(40)

# Read the wine-quality csv file from the URL
csv_url = "https://raw.githubusercontent.com/mlflow/mlflow/master/tests/data/winequality-red.csv"
try:
data = pd.read_csv(csv_url, sep=";")
except Exception as e:
logger.exception(
"Unable to download training & test CSV, check your internet connection. Error: %s",
e,
)

# Split the data into training and test sets. (0.75, 0.25) split.
train, test = train_test_split(data)

# The predicted column is "quality" which is a scalar from [3, 9]
train_x = train.drop(["quality"], axis=1)
test_x = test.drop(["quality"], axis=1)
train_y = train[["quality"]]
test_y = test[["quality"]]

alpha = float(sys.argv[1]) if len(sys.argv) > 1 else 0.5
l1_ratio = float(sys.argv[2]) if len(sys.argv) > 2 else 0.5

with mlflow.start_run():
lr = ElasticNet(alpha=alpha, l1_ratio=l1_ratio, random_state=42)
lr.fit(train_x, train_y)

predicted_qualities = lr.predict(test_x)

(rmse, mae, r2) = eval_metrics(test_y, predicted_qualities)

print("Elasticnet model (alpha={:f}, l1_ratio={:f}):".format(alpha, l1_ratio))
print(" RMSE: %s" % rmse)
print(" MAE: %s" % mae)
print(" R2: %s" % r2)

mlflow.log_param("alpha", alpha)
mlflow.log_param("l1_ratio", l1_ratio)
mlflow.log_metric("rmse", rmse)
mlflow.log_metric("r2", r2)
mlflow.log_metric("mae", mae)

tracking_url_type_store = urlparse(mlflow.get_tracking_uri()).scheme

# Model registry does not work with file store
if tracking_url_type_store != "file":
# Register the model
# There are other ways to use the Model Registry, which depends on the use case,
# please refer to the doc for more information:
# https://mlflow.org/docs/latest/model-registry.html#api-workflow
mlflow.sklearn.log_model(
lr, "model", registered_model_name="ElasticnetWineModel"
)
else:
mlflow.sklearn.log_model(lr, "model")
1 change: 1 addition & 0 deletions clinicadl/random_search/random_search_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def random_sampling(rs_options: Dict[str, Any]) -> Dict[str, Any]:
"mode": "fixed",
"multi_cohort": "fixed",
"multi_network": "choice",
"ssda_netork": "fixed",
"n_fcblocks": "randint",
"n_splits": "fixed",
"n_proc": "fixed",
Expand Down
7 changes: 6 additions & 1 deletion clinicadl/resources/config/train_config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
[Model]
architecture = "default" # ex : Conv5_FC3
multi_network = false
ssda_network = false

[Architecture]
# CNN
Expand Down Expand Up @@ -66,6 +67,10 @@ data_augmentation = false
sampler = "random"
size_reduction=false
size_reduction_factor=2
caps_target = ""
tsv_target_lab = ""
tsv_target_unlab = ""
preprocessing_dict_target = ""

[Cross_validation]
n_splits = 0
Expand All @@ -82,4 +87,4 @@ accumulation_steps = 1
profiler = false

[Informations]
emissions_calculator = false
emissions_calculator = false
5 changes: 5 additions & 0 deletions clinicadl/train/tasks/classification_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,18 @@
# Model
@train_option.architecture
@train_option.multi_network
@train_option.ssda_network
# Data
@train_option.multi_cohort
@train_option.diagnoses
@train_option.baseline
@train_option.normalize
@train_option.data_augmentation
@train_option.sampler
@train_option.caps_target
@train_option.tsv_target_lab
@train_option.tsv_target_unlab
@train_option.preprocessing_dict_target
# Cross validation
@train_option.n_splits
@train_option.split
Expand Down
5 changes: 5 additions & 0 deletions clinicadl/train/tasks/reconstruction_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,18 @@
# Model
@train_option.architecture
@train_option.multi_network
@train_option.ssda_network
# Data
@train_option.multi_cohort
@train_option.diagnoses
@train_option.baseline
@train_option.normalize
@train_option.data_augmentation
@train_option.sampler
@train_option.caps_target
@train_option.tsv_target_lab
@train_option.tsv_target_unlab
@train_option.preprocessing_dict_target
# Cross validation
@train_option.n_splits
@train_option.split
Expand Down
5 changes: 5 additions & 0 deletions clinicadl/train/tasks/regression_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,18 @@
# Model
@train_option.architecture
@train_option.multi_network
@train_option.ssda_network
# Data
@train_option.multi_cohort
@train_option.diagnoses
@train_option.baseline
@train_option.normalize
@train_option.data_augmentation
@train_option.sampler
@train_option.caps_target
@train_option.tsv_target_lab
@train_option.tsv_target_unlab
@train_option.preprocessing_dict_target
# Cross validation
@train_option.n_splits
@train_option.split
Expand Down
33 changes: 33 additions & 0 deletions clinicadl/train/tasks/task_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def task_launcher(network_task: str, task_options_list: List[str], **kwargs):
"learning_rate",
"multi_cohort",
"multi_network",
"ssda_network",
"n_proc",
"n_splits",
"nb_unfrozen_layer",
Expand All @@ -65,6 +66,10 @@ def task_launcher(network_task: str, task_options_list: List[str], **kwargs):
"sampler",
"seed",
"split",
"caps_target",
"tsv_target_lab",
"tsv_target_unlab",
"preprocessing_dict_target",
]
all_options_list = standard_options_list + task_options_list

Expand All @@ -79,6 +84,13 @@ def task_launcher(network_task: str, task_options_list: List[str], **kwargs):
/ "tensor_extraction"
/ kwargs["preprocessing_json"]
)

if train_dict["ssda_network"]:
preprocessing_json_target = (
Path(kwargs["caps_target"])
/ "tensor_extraction"
/ kwargs["preprocessing_dict_target"]
)
else:
caps_dict = CapsDataset.create_caps_dict(
train_dict["caps_directory"], train_dict["multi_cohort"]
Expand All @@ -98,12 +110,33 @@ def task_launcher(network_task: str, task_options_list: List[str], **kwargs):
f"Preprocessing JSON {kwargs['preprocessing_json']} was not found for any CAPS "
f"in {caps_dict}."
)
# To CHECK AND CHANGE
if train_dict["ssda_network"]:
caps_target = Path(kwargs["caps_target"])
preprocessing_json_target = (
caps_target / "tensor_extraction" / kwargs["preprocessing_dict_target"]
)

if preprocessing_json_target.is_file():
logger.info(
f"Preprocessing JSON {preprocessing_json_target} found in CAPS {caps_target}."
)
json_found = True
if not json_found:
raise ValueError(
f"Preprocessing JSON {kwargs['preprocessing_json_target']} was not found for any CAPS "
f"in {caps_target}."
)

# Mode and preprocessing
preprocessing_dict = read_preprocessing(preprocessing_json)
train_dict["preprocessing_dict"] = preprocessing_dict
train_dict["mode"] = preprocessing_dict["mode"]

if train_dict["ssda_network"]:
preprocessing_dict_target = read_preprocessing(preprocessing_json_target)
train_dict["preprocessing_dict_target"] = preprocessing_dict_target

# Add default values if missing
if (
preprocessing_dict["mode"] == "roi"
Expand Down
2 changes: 2 additions & 0 deletions clinicadl/tsvtools/split/split.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,8 @@ def split_diagnoses(

if categorical_split_variable is None:
categorical_split_variable = "diagnosis"
else:
categorical_split_variable.append("diagnosis")

# Read files
diagnosis_df_path = data_tsv.name
Expand Down
39 changes: 32 additions & 7 deletions clinicadl/utils/caps_dataset/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,11 @@ def __init__(
raise AttributeError("Child class of CapsDataset, must set mode attribute.")

self.df = data_df

mandatory_col = {"participant_id", "session_id", "cohort"}
mandatory_col = {
"participant_id",
"session_id",
"cohort",
}
if self.label_presence and self.label is not None:
mandatory_col.add(self.label)

Expand Down Expand Up @@ -108,6 +111,18 @@ def label_fn(self, target: Union[str, float, int]) -> Union[float, int]:
else:
return self.label_code[str(target)]

def domain_fn(self, target: Union[str, float, int]) -> Union[float, int]:
"""
Returns the label value usable in criterion.

Args:
target: value of the target.
Returns:
label: value of the label usable in criterion.
"""
domain_code = {"t1": 0, "flair": 1}
return domain_code[str(target)]

def __len__(self) -> int:
return len(self.df) * self.elem_per_image

Expand Down Expand Up @@ -209,7 +224,12 @@ def _get_meta_data(self, idx: int) -> Tuple[str, str, str, int, int]:
else:
label = -1

return participant, session, cohort, elem_idx, label
if "domain" in self.df.columns:
domain = self.df.loc[image_idx, "domain"]
domain = self.domain_fn(domain)
else:
domain = "" # TO MODIFY
return participant, session, cohort, elem_idx, label, domain

def _get_full_image(self) -> torch.Tensor:
"""
Expand Down Expand Up @@ -323,7 +343,7 @@ def elem_index(self):
return None

def __getitem__(self, idx):
participant, session, cohort, _, label = self._get_meta_data(idx)
participant, session, cohort, _, label, domain = self._get_meta_data(idx)

image_path = self._get_image_path(participant, session, cohort)
image = torch.load(image_path)
Expand All @@ -341,6 +361,7 @@ def __getitem__(self, idx):
"session_id": session,
"image_id": 0,
"image_path": image_path.as_posix(),
"domain": domain,
}

return sample
Expand Down Expand Up @@ -400,7 +421,9 @@ def elem_index(self):
return self.patch_index

def __getitem__(self, idx):
participant, session, cohort, patch_idx, label = self._get_meta_data(idx)
participant, session, cohort, patch_idx, label, domain = self._get_meta_data(
idx
)
image_path = self._get_image_path(participant, session, cohort)

if self.prepare_dl:
Expand Down Expand Up @@ -507,7 +530,7 @@ def elem_index(self):
return self.roi_index

def __getitem__(self, idx):
participant, session, cohort, roi_idx, label = self._get_meta_data(idx)
participant, session, cohort, roi_idx, label, domain = self._get_meta_data(idx)
image_path = self._get_image_path(participant, session, cohort)

if self.roi_list is None:
Expand Down Expand Up @@ -672,7 +695,9 @@ def elem_index(self):
return self.slice_index

def __getitem__(self, idx):
participant, session, cohort, slice_idx, label = self._get_meta_data(idx)
participant, session, cohort, slice_idx, label, domain = self._get_meta_data(
idx
)
slice_idx = slice_idx + self.discarded_slices[0]
image_path = self._get_image_path(participant, session, cohort)

Expand Down
7 changes: 7 additions & 0 deletions clinicadl/utils/cli_param/option.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,13 @@
multiple=True,
default=None,
)
ssda_network = click.option(
"--ssda_network",
type=bool,
default=False,
show_default=True,
help="ssda training.",
)
# GENERATE
participant_list = click.option(
"--participants_tsv",
Expand Down
Loading