From d6eb1db52296abb5d828e406d3b17044cad19802 Mon Sep 17 00:00:00 2001 From: martvanrijthoven Date: Mon, 23 Oct 2023 22:31:01 +0200 Subject: [PATCH] update apply to include mask --- hooknet/inference/apply.py | 44 +++++++++++++++++++------------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/hooknet/inference/apply.py b/hooknet/inference/apply.py index d293af3..860b67c 100644 --- a/hooknet/inference/apply.py +++ b/hooknet/inference/apply.py @@ -11,7 +11,7 @@ ) from tqdm import tqdm from wholeslidedata.interoperability.asap.masks import MaskType -# from hooknet.inference.writing import create_writers +from hooknet.inference.writing import create_writers @@ -24,24 +24,24 @@ def execute_inference_single( tmp_folder, ): - # print("Init writers...") - # writers = create_writers( - # image_path=image_path, - # files=files, - # output_folder=output_folder, - # tmp_folder=tmp_folder, - # ) + print("Init writers...") + writers = create_writers( + image_path=image_path, + files=files, + output_folder=output_folder, + tmp_folder=tmp_folder, + ) - # if not writers: - # print(f"Nothing to process for image {image_path}") - # return + if not writers: + print(f"Nothing to process for image {image_path}") + return prediction_times = [] batch_times = [] print("Applying...") index = 0 batch_time = -1 - for x_batch, info in tqdm(iterator): + for x_batch, y_batch, info in tqdm(iterator): if index > 0: batch_times.append(time.time()-batch_time) print("batch time", batch_times[-1]) @@ -58,12 +58,12 @@ def execute_inference_single( info["y"] - model.output_shape[0] // 4, ) - # for writer in writers: - # writer.write_tile( - # tile=prediction, - # coordinates=(int(c), int(r)), - # mask=y_batch[idx][0], - # ) + for writer in writers: + writer.write_tile( + tile=prediction, + coordinates=(int(c), int(r)), + mask=y_batch[idx][0], + ) index += 1 batch_time = time.time() @@ -71,10 +71,10 @@ def execute_inference_single( print(f"average batch time: {np.mean(batch_times)}") print(f"average prediction time: {np.mean(prediction_times)}") - # # save predictions - # print("Saving...") - # for writer in writers: - # writer.save() + # save predictions + print("Saving...") + for writer in writers: + writer.save()