Skip to content

Commit

Permalink
fix writing issues
Browse files Browse the repository at this point in the history
  • Loading branch information
martvanrijthoven committed Nov 2, 2023
1 parent d6eb1db commit 1788659
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 25 deletions.
27 changes: 11 additions & 16 deletions hooknet/inference/apply.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import time
import traceback
from pathlib import Path

import numpy as np
from hooknet.inference.utils import (
create_lock_file,
Expand All @@ -9,10 +10,10 @@
get_files,
release_lock_file,
)
from hooknet.inference.writing import create_writers, TILE_SIZE
from tqdm import tqdm
from wholeslidedata.interoperability.asap.masks import MaskType
from hooknet.inference.writing import create_writers

from wholeslidedata.samplers.utils import crop_data


def execute_inference_single(
Expand All @@ -23,7 +24,6 @@ def execute_inference_single(
output_folder,
tmp_folder,
):

print("Init writers...")
writers = create_writers(
image_path=image_path,
Expand All @@ -43,41 +43,37 @@ def execute_inference_single(
batch_time = -1
for x_batch, y_batch, info in tqdm(iterator):
if index > 0:
batch_times.append(time.time()-batch_time)
print("batch time", batch_times[-1])
batch_times.append(time.time() - batch_time)
x_batch = list(x_batch.transpose(1, 0, 2, 3, 4))
prediction_time = time.time()
predictions = model.predict_on_batch(x_batch, argmax=False)
if index > 0:
prediction_times.append(time.time()-prediction_time)
print("prediction_time", prediction_times[-1])

prediction_times.append(time.time() - prediction_time)

for idx, prediction in enumerate(predictions):
c, r = (
info["x"] - model.output_shape[1] // 4,
info["y"] - model.output_shape[0] // 4,
info["x"] - TILE_SIZE//2,
info["y"] - TILE_SIZE//2
)

mask = crop_data(y_batch[idx][0], model.output_shape[:2])
for writer in writers:
writer.write_tile(
tile=prediction,
coordinates=(int(c), int(r)),
mask=y_batch[idx][0],
mask=mask,
)
index += 1
batch_time = time.time()


print(f"average batch time: {np.mean(batch_times)}")
print(f"average prediction time: {np.mean(prediction_times)}")

# save predictions
print("Saving...")
for writer in writers:
writer.save()



def create_lock_file(lock_file_path):
print(f"Creating lock file: {lock_file_path}")
Path(lock_file_path).touch()
Expand Down Expand Up @@ -132,7 +128,6 @@ def execute_inference(
tmp_folder,
heatmaps,
):

image_path = Path(image_path)
output_folder = Path(output_folder)
tmp_folder = Path(tmp_folder)
Expand Down
18 changes: 9 additions & 9 deletions hooknet/inference/writing.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
from enum import Enum, auto
from shutil import copyfile
from pathlib import Path

from wholeslidedata.image.wholeslideimage import WholeSlideImage
from wholeslidedata.interoperability.asap.imagewriter import (HeatmapTileCallback,
PredictionTileCallback,
WholeSlideMaskWriter)
from wholeslidedata.interoperability.asap.imagewriter import (
HeatmapTileCallback,
PredictionTileCallback,
WholeSlideMaskWriter,
)
from wholeslidedata.interoperability.asap.masks import MaskType

SPACING = 0.5
TILE_SIZE = 1024


class TmpWholeSlideMaskWriter(WholeSlideMaskWriter):


def __init__(self, output_path: Path, callbacks=(), suffix='.tif'):
def __init__(self, output_path: Path, callbacks=(), suffix=".tif"):
"""Writes temp file and copies the tmp file to an output folder in the save method.
Args:
Expand Down Expand Up @@ -61,7 +61,7 @@ def _create_writer(
"""

if file["type"] == MaskType.HEATMAP:
callbacks = (HeatmapTileCallback(heatmap_index=file['heatmap_index']),)
callbacks = (HeatmapTileCallback(heatmap_index=file["heatmap_index"]),)
elif file["type"] == MaskType.PREDICTION:
callbacks = (PredictionTileCallback(),)
else:
Expand Down Expand Up @@ -101,7 +101,7 @@ def create_writers(
writers = []

# get info
with WholeSlideImage(image_path) as wsi:
with WholeSlideImage(image_path, backend="asap") as wsi:
shape = wsi.shapes[wsi.get_level_from_spacing(SPACING)]
real_spacing = wsi.get_real_spacing(SPACING)

Expand Down

0 comments on commit 1788659

Please sign in to comment.