Skip to content

Commit

Permalink
Various updates to the scripts: niftymic without masking by default, …
Browse files Browse the repository at this point in the history
…etc.
  • Loading branch information
t-sanchez committed Oct 5, 2023
1 parent aa752d7 commit 1f9551d
Show file tree
Hide file tree
Showing 7 changed files with 165 additions and 52 deletions.
66 changes: 64 additions & 2 deletions fetal_brain_utils/cli/compute_nesvor_uq_lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ def main(argv=None):
help="Out path for the data.",
)

p.add_argument(
"--normalize_subject",
action="store_true",
help="Normalize the data by subject.",
default=False,
)
args = p.parse_args(argv)

data_path = Path(args.data_path).resolve()
Expand All @@ -58,7 +64,6 @@ def main(argv=None):
# Load a dictionary of subject-session-paths
# sub_ses_dict = iter_dir(data_path, add_run_only=True)
bids_layout = BIDSLayout(data_path, validate=False)

# Create a variety of measurements based on the
# sigma, var, and sigma_var maps.
# They can be aggregated based on their mean, median, variance or
Expand All @@ -82,6 +87,19 @@ def main(argv=None):
"sigma_mad",
"var_mad",
"sigma_var_mad",
"sigma_avg_s",
"var_avg_s",
"sigma_var_avg_s",
"n_qc_s",
"sigma_med_s",
"var_med_s",
"sigma_var_med_s",
"sigma__var_s",
"var__var_s",
"sigma_var__var_s",
"sigma_mad_s",
"var_mad_s",
"sigma_var_mad_s",
]
)

Expand All @@ -100,6 +118,25 @@ def main(argv=None):
sigma_tot = []
var_tot = []
sigma_var_tot = []
# Aggregate the result of each slice counting equally
sigma_tot_slice = []
var_tot_slice = []
sigma_var_tot_slice = []
max_sigma = 0.0
max_var = 0.0
max_sigma_var = 0.0

if args.normalize_subject:
for el in out_id:
max_sigma = max(ni.load(el).get_fdata().max(), max_sigma)
max_var = max(
ni.load(el.replace("_sigma", "_var")).get_fdata().max(), max_var
)
max_sigma_var = max(
ni.load(el.replace("_sigma", "_sigma_var")).get_fdata().max(),
max_sigma_var,
)
aggr_fc = np.mean
for el in out_id:
slice_sigma = ni.load(el).get_fdata()
slice_T2w = ni.load(el.replace("_sigma", "_T2w")).get_fdata()
Expand All @@ -109,14 +146,26 @@ def main(argv=None):
sigma = slice_sigma[loc]
var = slice_var[loc]
sigma_var = slice_sigma_var[loc]

# print(
# f"{max_sigma:.3f} ({slice_sigma.max():3f}) "
# f"{max_var:.3f} ({slice_var.max():3f}) "
# f"{max_sigma_var:.3f} ({slice_sigma_var.max():3f})"
# )
if len(sigma) == 0:
print("Empty slice ", Path(el).name)
continue
if args.normalize_subject:
sigma /= max_sigma
var /= max_var
sigma_var /= max_sigma_var
sigma_tot += sigma.tolist()
var_tot += var.tolist()
sigma_var_tot += sigma_var.tolist()

sigma_tot_slice.append(aggr_fc(sigma))
var_tot_slice.append(aggr_fc(var))
sigma_var_tot_slice.append(aggr_fc(sigma_var))

def mad(x):
return np.median(np.abs(x - np.median(x)))

Expand All @@ -138,6 +187,19 @@ def mad(x):
"sigma_mad": mad(sigma_tot),
"var_mad": mad(var_tot),
"sigma_var_mad": mad(sigma_var_tot),
"sigma_avg_s": np.mean(sigma_tot_slice),
"var_avg_s": np.mean(var_tot_slice),
"sigma_var_avg_s": np.mean(sigma_var_tot_slice),
"n_qc_s": len(sigma_tot_slice),
"sigma_med_s": np.median(sigma_tot_slice),
"var_med_s": np.median(var_tot_slice),
"sigma_var_med_s": np.median(sigma_var_tot_slice),
"sigma__var_s": np.var(sigma_tot_slice),
"var__var_s": np.var(var_tot_slice),
"sigma_var__var_s": np.var(sigma_var_tot_slice),
"sigma_mad_s": mad(sigma_tot_slice),
"var_mad_s": mad(var_tot_slice),
"sigma_var_mad_s": mad(sigma_var_tot_slice),
}
print(d)
# Add d to the dataframe:
Expand Down
13 changes: 7 additions & 6 deletions fetal_brain_utils/cli/run_mialsrtk.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@ def edit_output_json(out_folder, config, cmd, auto_dict, participant_label=None)
if participant_label:
if sub not in participant_label:
continue
if not isinstance(sub_list, list):
sub_list = [sub_list]
for sub_ses_dict in sub_list:
if "session" in sub_ses_dict.keys():
ses, run_id = sub_ses_dict["session"], sub_ses_dict["sr-id"]
ses, run_id = sub_ses_dict["session"], sub_ses_dict.get("sr-id", 1000)
ses_path = f"ses-{ses}"
output_sub_ses = out_folder / sub_path / ses_path / "anat"
out_json = output_sub_ses / f"{sub_path}_{ses_path}_rec-SR_id-{run_id}_T2w.json"
Expand Down Expand Up @@ -108,7 +110,6 @@ def merge_and_overwrite_folder(src, dest):


def main(argv=None):

import os
import time
from pathlib import Path
Expand All @@ -119,7 +120,7 @@ def main(argv=None):
PYMIALSRTK_PATH = (
"/home/tsanchez/Documents/mial/" "repositories/mialsuperresolutiontoolkit/pymialsrtk"
)
DOCKER_VERSION = "v2.1.0-dev"
DOCKER_VERSION = "v2.1.0"

PATH_TO_ATLAS = "/media/tsanchez/tsanchez_data/data/atlas"
p = get_default_parser("MIALSRTK")
Expand Down Expand Up @@ -166,7 +167,7 @@ def main(argv=None):
p.add_argument(
"--no_python_mount",
action="store_true",
default=True,
default=False,
help="Whether the python folder should not be mounted.",
)

Expand Down Expand Up @@ -255,8 +256,8 @@ def main(argv=None):
base_command += f" -v {param_file.parent}:/code"
base_command += (
f" -v {PATH_TO_ATLAS}:/sta"
f" sebastientourbier/mialsuperresolutiontoolkit-"
f"bidsapp:{docker_version}"
f" sebastientourbier/mialsuperresolutiontoolkit-bidsapp" # -bidsapp
f":{docker_version}"
f" /bids_dir /output_dir participant"
f" --param_file {subject_json}"
f" --openmp_nb_of_cores 3"
Expand Down
93 changes: 56 additions & 37 deletions fetal_brain_utils/cli/run_nesvor_docker.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
AUTO_MASK_PATH = "/media/tsanchez/tsanchez_data/data/out_anon/masks"

BATCH_SIZE = 8192
NESVOR_VERSION = "v0.5.0"


def get_mask_path(bids_dir, subject, ses, run):
Expand Down Expand Up @@ -58,27 +59,12 @@ def find_run_id(file_list):
return run_dict


def filter_and_complement_mask_list(stacks, sub, ses, mask_list):
"""Filter and sort a run list according to the stacks ordering"""
run_dict = find_run_id(mask_list)
auto_masks = []
for s in stacks:
if s not in run_dict.keys():
print(f"Mask for stack {s} not found.")
mask = get_mask_path(AUTO_MASK_PATH, sub, ses, s)
assert os.path.isfile(mask), f"Automated mask not found at {mask}"
print(f"Using automated mask {mask}.")
run_dict[s] = mask
auto_masks.append(s)
return [run_dict[s] for s in stacks], auto_masks


def filter_run_list(stacks, run_list):
run_dict = find_run_id(run_list)
return [run_dict[s] for s in stacks]


def crop_input(sub, ses, output_path, img_list, mask_list, fake_run=False):
def crop_input(sub, ses, output_path, img_list, mask_list, mask_input, fake_run=False):
import nibabel as ni
from fetal_brain_utils import get_cropped_stack_based_on_mask
from functools import partial
Expand Down Expand Up @@ -106,8 +92,10 @@ def crop_input(sub, ses, output_path, img_list, mask_list, fake_run=False):
imc = crop_path(im, m)
maskc = crop_path(m, m)
# Masking

imc = ni.Nifti1Image(imc.get_fdata() * maskc.get_fdata(), imc.affine)
if mask_input:
imc = ni.Nifti1Image(imc.get_fdata() * maskc.get_fdata(), imc.affine)
else:
imc = ni.Nifti1Image(imc.get_fdata(), imc.affine)

ni.save(imc, cropped_im_path)
ni.save(maskc, cropped_mask_path)
Expand All @@ -126,9 +114,10 @@ def iterate_subject(
participant_label,
target_res,
config,
nesvor_version,
mask_input,
fake_run,
):

if participant_label:
if sub not in participant_label:
return
Expand Down Expand Up @@ -161,10 +150,9 @@ def iterate_subject(
run_id = conf["sr-id"] if "sr-id" in conf else "1"
run_path = f"run-{run_id}"

mask_list, auto_masks = filter_and_complement_mask_list(stacks, sub, ses, mask_list)
mask_list = [str(f) for f in mask_list]
img_list = [str(f) for f in filter_run_list(stacks, img_list)]
conf["use_auto_mask"] = auto_masks
mask_list = [str(f) for f in filter_run_list(stacks, mask_list)]
conf["im_path"] = img_list
conf["mask_path"] = mask_list
conf["config_path"] = str(config)
Expand All @@ -175,7 +163,15 @@ def iterate_subject(
if not fake_run:
os.makedirs(output_sub_ses, exist_ok=True)

img_list, mask_list = crop_input(sub, ses, output_path_crop, img_list, mask_list, fake_run)
img_list, mask_list = crop_input(
sub,
ses,
output_path_crop,
img_list,
mask_list,
mask_input,
fake_run,
)
mount_base = Path(img_list[0]).parent
img_str = " ".join([str(Path("/data") / Path(im).name) for im in img_list])
mask_str = " ".join([str(Path("/data") / Path(m).name) for m in mask_list])
Expand All @@ -186,22 +182,28 @@ def iterate_subject(
for i, res in enumerate(target_res):
res_str = str(res).replace(".", "p")
out_base = f"{sub_path}_{ses_path}_" f"acq-haste_res-{res_str}_{run_path}_T2w"
output_file = str(out / out_base) + "_misaligned.nii.gz"
if nesvor_version == "v0.5.0":
output_file = str(out / out_base) + ".nii.gz"
else:
output_file = str(out / out_base) + "_misaligned.nii.gz"
output_json = str(output_sub_ses / out_base) + ".json"
if i == 0:
cmd = (
f"docker run --gpus '\"device=0\"' "
f"-v {mount_base}:/data "
f"-v {output_sub_ses}:/out "
f"junshenxu/nesvor:v0.1.0 nesvor reconstruct "
f"-v /media/tsanchez/tsanchez_data/data/NeSVoR/nesvor/:/usr/local/NeSVoR/nesvor/ -it "
f"junshenxu/nesvor:{nesvor_version} nesvor reconstruct "
f"--input-stacks {img_str} "
f"--stack-masks {mask_str} "
f"--n-levels-bias 1 "
f"--output-volume {output_file} "
f"--output-resolution {res} "
f"--output-model {model} "
f"--batch-size {BATCH_SIZE}"
f"--n-levels-bias 1 "
f"--batch-size {BATCH_SIZE} "
)
if nesvor_version == "v0.5.0":
cmd += "--bias-field-correction"
else:
cmd = (
f"docker run --gpus '\"device=0\"' "
Expand All @@ -228,15 +230,16 @@ def iterate_subject(
os.system(cmd)

# Transform the affine of the sr reconstruction
out_file = str(output_sub_ses / out_base) + "_misaligned.nii.gz"
out_file_reo = str(output_sub_ses / out_base) + ".nii.gz"
sr = ni.load(out_file)
affine = sr.affine[[2, 1, 0, 3]]
affine[1, :] *= -1
ni.save(
ni.Nifti1Image(sr.get_fdata()[:, :, :], affine, sr.header),
out_file_reo,
)
if nesvor_version != "v0.5.0":
out_file = str(output_sub_ses / out_base) + "_misaligned.nii.gz"
out_file_reo = str(output_sub_ses / out_base) + ".nii.gz"
sr = ni.load(out_file)
affine = sr.affine[[2, 1, 0, 3]]
affine[1, :] *= -1
ni.save(
ni.Nifti1Image(sr.get_fdata()[:, :, :], affine, sr.header),
out_file_reo,
)


def main(argv=None):
Expand All @@ -252,17 +255,31 @@ def main(argv=None):
help="Target resolutions at which the reconstruction should be done.",
)

p.add_argument(
"--version",
default=NESVOR_VERSION,
type=str,
choices=["v0.1.0", "v0.5.0"],
help="Version of NeSVoR to use.",
)

p.add_argument(
"--mask_input",
action=argparse.BooleanOptionalAction,
default=True,
help="Whether the input stacks should be masked prior to computation.",
)
args = p.parse_args(argv)

data_path = Path(args.data_path).resolve()
out_path = Path(args.out_path).resolve()
masks_folder = Path(args.masks_path).resolve()
config = args.config
config = Path(args.config).resolve()

# Load a dictionary of subject-session-paths
sub_ses_dict = iter_dir(data_path, add_run_only=True)

with open(data_path / "code" / config, "r") as f:
with open(config, "r") as f:
params = json.load(f)
# Iterate over all subjects and sessions
iterate = partial(
Expand All @@ -274,6 +291,8 @@ def main(argv=None):
participant_label=args.participant_label,
target_res=args.target_res,
config=config,
nesvor_version=args.version,
mask_input=args.mask_input,
fake_run=args.fake_run,
)
for sub, config_sub in params.items():
Expand Down
Loading

0 comments on commit 1f9551d

Please sign in to comment.