Skip to content

Commit

Permalink
QualityMetrics.evaluate
Browse files Browse the repository at this point in the history
  • Loading branch information
casperdcl committed Jul 12, 2024
1 parent facebd1 commit f31d077
Showing 1 changed file with 21 additions and 22 deletions.
43 changes: 21 additions & 22 deletions petric.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __call__(self, algo: Algorithm):
algo.x.write(str(self.outdir / 'iter_final.hv'))


class TensorBoard(cbks.Callback):
class StatsLog(cbks.Callback):
"""Log image slices & objective value"""
def __init__(self, verbose=1, transverse_slice=None, coronal_slice=None, vmax=None, logdir=OUTDIR):
super().__init__(verbose)
Expand Down Expand Up @@ -98,27 +98,26 @@ def __init__(self, reference_image, whole_object_mask, background_mask, **kwargs
self.ref_im_arr = reference_image.as_array()
self.norm = self.ref_im_arr[self.background_indices].mean()

def __call__(self, algorithm):
iteration = algorithm.iteration
if iteration % algorithm.update_objective_interval != 0 and iteration != algorithm.max_iteration:
def __call__(self, algo: Algorithm):
iteration = algo.iteration
if iteration % algo.update_objective_interval != 0 and iteration != algo.max_iteration:
return
self.log(algorithm.x, iteration)
for tag, value in self.evaluate(algo.x).items():
self.tb_summary_writer.add_scalar(tag, value, iteration)

def log(self, test_im, iteration):
def evaluate(self, test_im: STIR.ImageData) -> dict[str, float]:
assert not any(self.filter.values()), "Filtering not implemented"
test_im_arr = test_im.as_array()
self.tb_summary_writer.add_scalar(
"RMSE_whole_object",
np.sqrt(mse(self.ref_im_arr[self.whole_object_indices], test_im_arr[self.whole_object_indices])) /
self.norm, iteration)
self.tb_summary_writer.add_scalar(
"RMSE_background",
np.sqrt(mse(self.ref_im_arr[self.background_indices], test_im_arr[self.background_indices])) / self.norm,
iteration)
for voi_name, voi_indices in sorted(self.voi_indices.items()):
self.tb_summary_writer.add_scalar(
f"AEM_VOI_{voi_name}",
np.abs(test_im_arr[voi_indices].mean() - self.ref_im_arr[voi_indices].mean()) / self.norm, iteration)
whole = {
"RMSE_whole_object": np.sqrt(
mse(self.ref_im_arr[self.whole_object_indices], test_im_arr[self.whole_object_indices])) / self.norm,
"RMSE_background": np.sqrt(
mse(self.ref_im_arr[self.background_indices], test_im_arr[self.background_indices])) / self.norm}
local = {
f"AEM_VOI_{voi_name}": np.abs(test_im_arr[voi_indices].mean() - self.ref_im_arr[voi_indices].mean()) /
self.norm
for voi_name, voi_indices in sorted(self.voi_indices.items())}
return {**whole, **local}


class MetricsWithTimeout(cbks.Callback):
Expand All @@ -129,20 +128,20 @@ def __init__(self, seconds=300, outdir=OUTDIR, transverse_slice=None, coronal_sl
self.callbacks = [
cbks.ProgressCallback(),
SaveIters(outdir=outdir),
(tb_cbk := TensorBoard(logdir=outdir, transverse_slice=transverse_slice, coronal_slice=coronal_slice))]
(tb_cbk := StatsLog(logdir=outdir, transverse_slice=transverse_slice, coronal_slice=coronal_slice))]
self.tb = tb_cbk.tb
self.reset()

def reset(self, seconds=None):
self.limit = time() + (self._seconds if seconds is None else seconds)

def __call__(self, algorithm: Algorithm):
def __call__(self, algo: Algorithm):
if (now := time()) > self.limit:
log.warning("Timeout reached. Stopping algorithm.")
raise StopIteration
if self.callbacks:
for c in self.callbacks:
c(algorithm)
c(algo)
self.limit += time() - now

@staticmethod
Expand Down Expand Up @@ -221,7 +220,7 @@ def get_image(fname):

if SRCDIR.is_dir():
# create list of existing data
# Note: as MetricsWithTimeout initialises Tensorboard, this will currently create directories in OUTDIR accordingly
# NB: `MetricsWithTimeout` initialises `SaveIters` which creates `outdir`
data_dirs_metrics = [(SRCDIR / "Siemens_mMR_NEMA_IQ", OUTDIR / "mMR_NEMA",
[MetricsWithTimeout(outdir=OUTDIR / "mMR_NEMA", transverse_slice=72, coronal_slice=109)]),
(SRCDIR / "NeuroLF_Hoffman_Dataset", OUTDIR / "NeuroLF_Hoffman",
Expand Down

0 comments on commit f31d077

Please sign in to comment.