diff --git a/fetal_brain_utils/cli/parser.py b/fetal_brain_utils/cli/parser.py index 6432be2..8c05578 100644 --- a/fetal_brain_utils/cli/parser.py +++ b/fetal_brain_utils/cli/parser.py @@ -16,7 +16,6 @@ def get_default_parser(srr): p.add_argument( "--masks_path", default=None, - required=True, help="Path to the brain masks.", ) diff --git a/fetal_brain_utils/cli/run_svrtk.py b/fetal_brain_utils/cli/run_svrtk.py index fca4144..c115efb 100644 --- a/fetal_brain_utils/cli/run_svrtk.py +++ b/fetal_brain_utils/cli/run_svrtk.py @@ -26,9 +26,11 @@ import json import traceback from bids import BIDSLayout +import shutil # Default data path DATA_PATH = Path("/media/tsanchez/tsanchez_data/data/data") +SVRTK_AUTO_VERSION = "custom_2.20" # Default parameters for the parameter sweep # Relative paths in the docker image to various scripts. @@ -40,8 +42,10 @@ def iterate_subject( bids_layout, data_path, output_path, + target_res, masks_folder, participant_label, + use_svrtk_auto, fake_run, ): output_path = Path(output_path) @@ -52,16 +56,11 @@ def iterate_subject( print(f"Subject {sub} not found in {data_path}") return - # Prepare the output path, and locate the - # pre-computed masks - # output_path = data_path / output_path + # Prepare the output path docker_out_path = output_path / "run_files" recon_out_path = output_path / "svrtk" - masks_layout = BIDSLayout(masks_folder, validate=False) - # Pre-processing: mask (and crop) the low-resolution stacks - cropped_path_base = docker_out_path / "preprocess" + prepro_path_base = docker_out_path / "preprocess" - # Output for the parameter study. if not fake_run: os.makedirs(docker_out_path, exist_ok=True) @@ -71,6 +70,8 @@ def iterate_subject( config_sub = [config_sub] failure_list = [] + if not use_svrtk_auto: + masks_layout = BIDSLayout(masks_folder, validate=False) for conf in config_sub: try: ses = conf["session"] if "session" in conf else None @@ -80,12 +81,6 @@ def iterate_subject( extension="nii.gz", return_type="filename", ) - mask_list = masks_layout.get_runs( - subject=sub, - session=ses, - extension="nii.gz", - return_type="filename", - ) stacks = conf["stacks"] if "stacks" in conf else find_run_id(img_list) run_id = conf["sr-id"] if "sr-id" in conf else "1" @@ -93,10 +88,19 @@ def iterate_subject( run_path = f"run-{run_id}" ses_path = f"ses-{ses}" if ses is not None else "" - mask_list, auto_masks = filter_mask_list(stacks, sub, ses, mask_list) img_list = filter_run_list(stacks, img_list) conf["im_path"] = img_list - conf["mask_path"] = mask_list + if not use_svrtk_auto: + mask_list = masks_layout.get_runs( + subject=sub, + session=ses, + extension="nii.gz", + return_type="filename", + ) + else: + mask_list = None + + conf["mask_path"] = mask_list conf["config_path"] = str(config_path) if ses_path != "": sub_ses_anat = f"{sub_path}/{ses_path}/anat" @@ -107,17 +111,17 @@ def iterate_subject( # base paths input_path = data_path / sub_ses_anat - input_cropped_path = cropped_path_base / sub_ses_anat / run_path + input_prepro_path = prepro_path_base / sub_ses_anat / run_path if not fake_run: - os.makedirs(input_cropped_path, exist_ok=True) + os.makedirs(input_prepro_path, exist_ok=True) # os.makedirs(mask_cropped_path, exist_ok=True) # Get in-plane resolution to be set as target resolution. - ip_res = [] tp_res = [] # Construct the path to each data point and mask in # the filesystem of the docker image - filename_data, filename_masks = [], [] + # filename_data, filename_masks = [], [] + filename_data = [] boundary_mm = 15 crop_path = partial( get_cropped_stack_based_on_mask, @@ -125,26 +129,34 @@ def iterate_subject( boundary_j=boundary_mm, boundary_k=0, ) - for image, mask in zip(img_list, mask_list): - print(f"Processing {image} {mask}") + it_list = zip(img_list, mask_list) if not use_svrtk_auto else img_list + for o in it_list: # mask_list + image = o if use_svrtk_auto else o[0] + mask = None if use_svrtk_auto else o[1] + # Copy the image file to input_prepro_path im_file = Path(image).name - cropped_im = input_cropped_path / im_file - im, m = ni.load(image), ni.load(mask) - ip_res.append(im.header["pixdim"][1]) - tp_res.append(str(round(im.header["pixdim"][3], 1))) - if not fake_run: - imc = crop_path(im, m) - - maskc = crop_path(m, m) - imc = ni.Nifti1Image(imc.get_fdata() * maskc.get_fdata(), imc.affine) - ni.save(imc, cropped_im) + prepro_im = input_prepro_path / im_file + im = ni.load(image) + tp_res.append(round(im.header["pixdim"][3], 1)) + + if use_svrtk_auto: + if not fake_run: + shutil.copy(image, prepro_im) + else: + print(f"Processing {image} {mask}") + m = ni.load(mask) + if not fake_run: + imc = crop_path(im, m) + imc = ni.Nifti1Image(imc.get_fdata(), imc.affine) + ni.save(imc, prepro_im) # Define the file and path names inside the docker volume run_im = Path("/home/data") / im_file filename_data.append(str(run_im)) filename_data = " ".join(filename_data) - filename_masks = " ".join(filename_masks) - tp_str = " ".join(tp_res) + # filename_masks = " ".join(filename_masks) + tp = str(round(np.mean(tp_res), 2)) + tp_str = " ".join([str(t) for t in tp_res]) ## # Reconstruction stage ## @@ -154,19 +166,30 @@ def iterate_subject( os.makedirs(recon_path, exist_ok=True) # Replace input and mask path by preprocessed - input_path = input_cropped_path + input_path = input_prepro_path # , mask_path = , mask_cropped_path recon_file = f"{sub_path}_{ses_path}_{run_path}_rec-SR_T2w.nii.gz" - cmd = ( - "docker run " - f"-v {input_path}:/home/data " - # f"-v {mask_path}:/home/mask " - f"-v {recon_path}:/home/out/ " - "fetalsvrtk/svrtk mirtk reconstruct " - f"/home/out/{recon_file} {len(img_list)} " - f"{filename_data} " - f"-thickness {tp_str} -resolution {np.min(ip_res):.2f}" - ) + + if use_svrtk_auto: + cmd = ( + "docker run " + f"-v {input_path}:/home/data " + f"-v {recon_path}:/home/out/ " + f"fetalsvrtk/svrtk:{SVRTK_AUTO_VERSION} " + "bash /home/auto-proc-svrtk/auto-brain-reconstruction.sh " + f"/home/data /home/out/ 1 {tp} {target_res} 1" + ) + else: + cmd = ( + "docker run " + f"-v {input_path}:/home/data " + # f"-v {mask_path}:/home/mask " + f"-v {recon_path}:/home/out/ " + "fetalsvrtk/svrtk mirtk reconstruct " + f"/home/out/{recon_file} {len(img_list)} " + f"{filename_data} " + f"-thickness {tp_str} -resolution {target_res}" + ) print("RECONSTRUCTION STAGE") print(cmd) @@ -207,15 +230,29 @@ def main(argv=None): from .parser import get_default_parser p = get_default_parser("SVRTK") + p.add_argument( + "--use_svrtk_auto", + action="store_true", + help="Use svrtk:auto-2.20 instead of svrtk:latest. This features automated reorientation.", + ) + p.add_argument( + "--target_res", + type=float, + default=0.8, + help="Target resolution for the reconstruction.", + ) args = p.parse_args(argv) data_path = Path(args.data_path).resolve() config = Path(args.config).resolve() masks_folder = Path(args.masks_path).resolve() out_path = Path(args.out_path).resolve() + target_res = args.target_res participant_label = args.participant_label + use_svrtk_auto = args.use_svrtk_auto fake_run = args.fake_run - + if use_svrtk_auto: + assert masks_folder is None, "Cannot use masks with svrtk:auto-2.20" # Load a dictionary of subject-session-paths # sub_ses_dict = iter_dir(data_path, add_run_only=True) @@ -230,8 +267,10 @@ def main(argv=None): config_path=config, data_path=data_path, output_path=out_path, + target_res=target_res, masks_folder=masks_folder, participant_label=participant_label, + use_svrtk_auto=use_svrtk_auto, fake_run=fake_run, ) for sub, config_sub in params.items():