Skip to content

Commit

Permalink
Update SVRTK script to also run with the auto-2.20 docker that does r…
Browse files Browse the repository at this point in the history
…eorientation.
  • Loading branch information
t-sanchez committed Jul 13, 2023
1 parent 12e9130 commit aa752d7
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 46 deletions.
1 change: 0 additions & 1 deletion fetal_brain_utils/cli/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ def get_default_parser(srr):
p.add_argument(
"--masks_path",
default=None,
required=True,
help="Path to the brain masks.",
)

Expand Down
129 changes: 84 additions & 45 deletions fetal_brain_utils/cli/run_svrtk.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@
import json
import traceback
from bids import BIDSLayout
import shutil

# Default data path
DATA_PATH = Path("/media/tsanchez/tsanchez_data/data/data")
SVRTK_AUTO_VERSION = "custom_2.20"
# Default parameters for the parameter sweep
# Relative paths in the docker image to various scripts.

Expand All @@ -40,8 +42,10 @@ def iterate_subject(
bids_layout,
data_path,
output_path,
target_res,
masks_folder,
participant_label,
use_svrtk_auto,
fake_run,
):
output_path = Path(output_path)
Expand All @@ -52,16 +56,11 @@ def iterate_subject(
print(f"Subject {sub} not found in {data_path}")
return

# Prepare the output path, and locate the
# pre-computed masks
# output_path = data_path / output_path
# Prepare the output path
docker_out_path = output_path / "run_files"
recon_out_path = output_path / "svrtk"
masks_layout = BIDSLayout(masks_folder, validate=False)
# Pre-processing: mask (and crop) the low-resolution stacks
cropped_path_base = docker_out_path / "preprocess"
prepro_path_base = docker_out_path / "preprocess"

# Output for the parameter study.
if not fake_run:
os.makedirs(docker_out_path, exist_ok=True)

Expand All @@ -71,6 +70,8 @@ def iterate_subject(
config_sub = [config_sub]
failure_list = []

if not use_svrtk_auto:
masks_layout = BIDSLayout(masks_folder, validate=False)
for conf in config_sub:
try:
ses = conf["session"] if "session" in conf else None
Expand All @@ -80,23 +81,26 @@ def iterate_subject(
extension="nii.gz",
return_type="filename",
)
mask_list = masks_layout.get_runs(
subject=sub,
session=ses,
extension="nii.gz",
return_type="filename",
)

stacks = conf["stacks"] if "stacks" in conf else find_run_id(img_list)
run_id = conf["sr-id"] if "sr-id" in conf else "1"

run_path = f"run-{run_id}"
ses_path = f"ses-{ses}" if ses is not None else ""

mask_list, auto_masks = filter_mask_list(stacks, sub, ses, mask_list)
img_list = filter_run_list(stacks, img_list)
conf["im_path"] = img_list
conf["mask_path"] = mask_list
if not use_svrtk_auto:
mask_list = masks_layout.get_runs(
subject=sub,
session=ses,
extension="nii.gz",
return_type="filename",
)
else:
mask_list = None

conf["mask_path"] = mask_list
conf["config_path"] = str(config_path)
if ses_path != "":
sub_ses_anat = f"{sub_path}/{ses_path}/anat"
Expand All @@ -107,44 +111,52 @@ def iterate_subject(
# base paths
input_path = data_path / sub_ses_anat

input_cropped_path = cropped_path_base / sub_ses_anat / run_path
input_prepro_path = prepro_path_base / sub_ses_anat / run_path
if not fake_run:
os.makedirs(input_cropped_path, exist_ok=True)
os.makedirs(input_prepro_path, exist_ok=True)
# os.makedirs(mask_cropped_path, exist_ok=True)

# Get in-plane resolution to be set as target resolution.
ip_res = []
tp_res = []
# Construct the path to each data point and mask in
# the filesystem of the docker image
filename_data, filename_masks = [], []
# filename_data, filename_masks = [], []
filename_data = []
boundary_mm = 15
crop_path = partial(
get_cropped_stack_based_on_mask,
boundary_i=boundary_mm,
boundary_j=boundary_mm,
boundary_k=0,
)
for image, mask in zip(img_list, mask_list):
print(f"Processing {image} {mask}")
it_list = zip(img_list, mask_list) if not use_svrtk_auto else img_list
for o in it_list: # mask_list
image = o if use_svrtk_auto else o[0]
mask = None if use_svrtk_auto else o[1]
# Copy the image file to input_prepro_path
im_file = Path(image).name
cropped_im = input_cropped_path / im_file
im, m = ni.load(image), ni.load(mask)
ip_res.append(im.header["pixdim"][1])
tp_res.append(str(round(im.header["pixdim"][3], 1)))
if not fake_run:
imc = crop_path(im, m)

maskc = crop_path(m, m)
imc = ni.Nifti1Image(imc.get_fdata() * maskc.get_fdata(), imc.affine)
ni.save(imc, cropped_im)
prepro_im = input_prepro_path / im_file
im = ni.load(image)
tp_res.append(round(im.header["pixdim"][3], 1))

if use_svrtk_auto:
if not fake_run:
shutil.copy(image, prepro_im)
else:
print(f"Processing {image} {mask}")
m = ni.load(mask)
if not fake_run:
imc = crop_path(im, m)
imc = ni.Nifti1Image(imc.get_fdata(), imc.affine)
ni.save(imc, prepro_im)

# Define the file and path names inside the docker volume
run_im = Path("/home/data") / im_file
filename_data.append(str(run_im))
filename_data = " ".join(filename_data)
filename_masks = " ".join(filename_masks)
tp_str = " ".join(tp_res)
# filename_masks = " ".join(filename_masks)
tp = str(round(np.mean(tp_res), 2))
tp_str = " ".join([str(t) for t in tp_res])
##
# Reconstruction stage
##
Expand All @@ -154,19 +166,30 @@ def iterate_subject(
os.makedirs(recon_path, exist_ok=True)

# Replace input and mask path by preprocessed
input_path = input_cropped_path
input_path = input_prepro_path
# , mask_path = , mask_cropped_path
recon_file = f"{sub_path}_{ses_path}_{run_path}_rec-SR_T2w.nii.gz"
cmd = (
"docker run "
f"-v {input_path}:/home/data "
# f"-v {mask_path}:/home/mask "
f"-v {recon_path}:/home/out/ "
"fetalsvrtk/svrtk mirtk reconstruct "
f"/home/out/{recon_file} {len(img_list)} "
f"{filename_data} "
f"-thickness {tp_str} -resolution {np.min(ip_res):.2f}"
)

if use_svrtk_auto:
cmd = (
"docker run "
f"-v {input_path}:/home/data "
f"-v {recon_path}:/home/out/ "
f"fetalsvrtk/svrtk:{SVRTK_AUTO_VERSION} "
"bash /home/auto-proc-svrtk/auto-brain-reconstruction.sh "
f"/home/data /home/out/ 1 {tp} {target_res} 1"
)
else:
cmd = (
"docker run "
f"-v {input_path}:/home/data "
# f"-v {mask_path}:/home/mask "
f"-v {recon_path}:/home/out/ "
"fetalsvrtk/svrtk mirtk reconstruct "
f"/home/out/{recon_file} {len(img_list)} "
f"{filename_data} "
f"-thickness {tp_str} -resolution {target_res}"
)

print("RECONSTRUCTION STAGE")
print(cmd)
Expand Down Expand Up @@ -207,15 +230,29 @@ def main(argv=None):
from .parser import get_default_parser

p = get_default_parser("SVRTK")
p.add_argument(
"--use_svrtk_auto",
action="store_true",
help="Use svrtk:auto-2.20 instead of svrtk:latest. This features automated reorientation.",
)
p.add_argument(
"--target_res",
type=float,
default=0.8,
help="Target resolution for the reconstruction.",
)

args = p.parse_args(argv)
data_path = Path(args.data_path).resolve()
config = Path(args.config).resolve()
masks_folder = Path(args.masks_path).resolve()
out_path = Path(args.out_path).resolve()
target_res = args.target_res
participant_label = args.participant_label
use_svrtk_auto = args.use_svrtk_auto
fake_run = args.fake_run

if use_svrtk_auto:
assert masks_folder is None, "Cannot use masks with svrtk:auto-2.20"
# Load a dictionary of subject-session-paths
# sub_ses_dict = iter_dir(data_path, add_run_only=True)

Expand All @@ -230,8 +267,10 @@ def main(argv=None):
config_path=config,
data_path=data_path,
output_path=out_path,
target_res=target_res,
masks_folder=masks_folder,
participant_label=participant_label,
use_svrtk_auto=use_svrtk_auto,
fake_run=fake_run,
)
for sub, config_sub in params.items():
Expand Down

0 comments on commit aa752d7

Please sign in to comment.