From 60f4e65c6fc70ddb149a6cc33cea6d688bcd074a Mon Sep 17 00:00:00 2001 From: aecelaya Date: Thu, 19 Sep 2024 17:42:05 -0500 Subject: [PATCH] Update inference for changes in preprocessing. --- mist/inference/main_inference.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/mist/inference/main_inference.py b/mist/inference/main_inference.py index e5e8cfb..4c4fe53 100755 --- a/mist/inference/main_inference.py +++ b/mist/inference/main_inference.py @@ -337,7 +337,7 @@ def test_time_inference(df, no_preprocess=False, output_std=False): config = read_json_file(config_file) - + create_empty_dir(dest) # Set up rich progress bar @@ -367,18 +367,15 @@ def test_time_inference(df, og_ants_img = ants.image_read(image_list[0]) if no_preprocess: - torch_img, _, fg_bbox, _ = convert_nifti_to_numpy(image_list, None) + preprocessed_example = convert_nifti_to_numpy(image_list) else: - torch_img, _, fg_bbox, _ = preprocess_example( - config, - image_list, - None, - False, - False, - None + preprocessed_example = preprocess_example( + config, + image_list, ) # Make image channels first and add batch dimension + torch_img = preprocessed_example["image"] torch_img = np.transpose(torch_img, axes=(3, 0, 1, 2)) torch_img = np.expand_dims(torch_img, axis=0) @@ -394,7 +391,7 @@ def test_time_inference(df, blend_mode, tta, output_std, - fg_bbox + preprocessed_example["fg_bbox"] ) # Apply postprocessing if required