From 560e4d6ef6c8cc0a83011759c208b49e5d1f36f9 Mon Sep 17 00:00:00 2001 From: Thomas Sanchez Date: Thu, 15 Aug 2024 13:21:19 +0200 Subject: [PATCH] Modify nesvor running script. --- fetal_brain_utils/cli/run_nesvor_docker.py | 57 +++++++++++++++++----- 1 file changed, 46 insertions(+), 11 deletions(-) diff --git a/fetal_brain_utils/cli/run_nesvor_docker.py b/fetal_brain_utils/cli/run_nesvor_docker.py index 4ae1fd4..073a89a 100644 --- a/fetal_brain_utils/cli/run_nesvor_docker.py +++ b/fetal_brain_utils/cli/run_nesvor_docker.py @@ -19,7 +19,7 @@ DATA_PATH = Path("/media/tsanchez/tsanchez_data/data/data") AUTO_MASK_PATH = "/media/tsanchez/tsanchez_data/data/out_anon/masks" -BATCH_SIZE = 8192 +BATCH_SIZE = 4096 NESVOR_VERSION = "v0.5.0" @@ -34,8 +34,11 @@ def iterate_subject( target_res, config, nesvor_version, + batch_size, + recon_type, mask_input, fake_run, + extra_commands, ): if participant_label: if sub not in participant_label: @@ -43,6 +46,9 @@ def iterate_subject( if sub not in bids_layout.get_subjects(): print(f"Subject {sub} not found in {data_path}") return + + if recon_type == "svr": + assert len(target_res) == 1, "Only one resolution can be used for SVR." # if sub not in sub_ses_dict: # print(f"Subject {sub} not found in {data_path}") # return @@ -122,33 +128,38 @@ def iterate_subject( output_file = str(out / out_base) + "_misaligned.nii.gz" output_json = str(output_sub_ses / out_base) + ".json" if i == 0: + nesvor_arg = "svr" if recon_type == "svr" else "reconstruct" cmd = ( f"docker run --gpus '\"device=0\"' " f"-v {mount_base}:/data " f"-v {output_sub_ses}:/out " - # f"-v /home/tsanchez/Documents/mial/repositories/NeSVoR/nesvor:/usr/local/NeSVoR/nesvor " - f"junshenxu/nesvor:{nesvor_version} nesvor reconstruct " + f"junshenxu/nesvor:{nesvor_version} nesvor {nesvor_arg} " f"--input-stacks {img_str} " f"--stack-masks {mask_str} " f"--output-volume {output_file} " - f"--output-resolution {res} " - f"--output-model {model} " - f"--n-levels-bias 1 " - f"--batch-size {BATCH_SIZE} " - " --n-proc-n4 1 " ) + if recon_type == "nesvor": + cmd += ( + f"--output-resolution {res} " + f"--output-model {model} " + f"--n-levels-bias 1 " + f"--batch-size {batch_size} " + " --n-proc-n4 1 " + ) if nesvor_version == "v0.5.0": cmd += "--bias-field-correction" + if extra_commands != "": + cmd += f" {extra_commands}" else: cmd = ( f"docker run --gpus '\"device=0\"' " f"-v {output_sub_ses}:/out " - f"junshenxu/nesvor:v0.1.0 nesvor sample-volume " + f"junshenxu/nesvor:{nesvor_version} nesvor sample-volume " f"--input-model {model} " f"--output-resolution {res} " f"--output-volume {output_file} " f"--output-resolution {res} " - f"--inference-batch-size 16384" + f"--inference-batch-size {batch_size*2}" ) conf["info"] = { "reconstruction": "NeSVoR", @@ -204,6 +215,28 @@ def main(argv=None): default=True, help="Whether the input stacks should be masked prior to computation.", ) + + p.add_argument( + "--batch_size", + type=int, + default=BATCH_SIZE, + help="Batch size for the NeSVoR pipeline.", + ) + + p.add_argument( + "--recon_type", + type=str, + default="nesvor", + choices=["nesvor", "svr"], + help="Types of reconstruction to be run: train a NeSVoR model or just run SVR.", + ) + + p.add_argument( + "--extra_commands", + type=str, + default="", + help="Extra commands to be added to the NeSVoR command.", + ) args = p.parse_args(argv) data_path = Path(args.data_path).resolve() @@ -219,7 +252,6 @@ def main(argv=None): iterate = partial( iterate_subject, bids_layout=bids_layout, - # sub_ses_dict=sub_ses_dict, data_path=data_path, output_path=out_path, mask_base_path=masks_folder, @@ -227,7 +259,10 @@ def main(argv=None): target_res=args.target_res, config=config, nesvor_version=args.version, + batch_size=args.batch_size, mask_input=args.mask_input, + recon_type=args.recon_type, + extra_commands=args.extra_commands, fake_run=args.fake_run, ) for sub, config_sub in params.items():