Skip to content

Commit

Permalink
Merge pull request #57 from SyneRBI/metrics-log
Browse files Browse the repository at this point in the history
metrics: standalone `log()`
  • Loading branch information
casperdcl authored Jul 12, 2024
2 parents 35a6c4c + f31d077 commit 11112d2
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 27 deletions.
2 changes: 1 addition & 1 deletion main_OSEM.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def update(self):
def update_objective(self):
"""
NB: The objective value is not required by OSEM nor by PETRIC, so this returns `0`.
NB: In theory it should be `sum(prompts * log(acq_model.forward(self.x)) - self.x * sensitivity)` across all subsets.
NB: It should be `sum(prompts * log(acq_model.forward(self.x)) - self.x * sensitivity)` across all subsets.
"""
return 0

Expand Down
48 changes: 22 additions & 26 deletions petric.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __call__(self, algo: Algorithm):
algo.x.write(str(self.outdir / 'iter_final.hv'))


class TensorBoard(cbks.Callback):
class StatsLog(cbks.Callback):
"""Log image slices & objective value"""
def __init__(self, verbose=1, transverse_slice=None, coronal_slice=None, vmax=None, logdir=OUTDIR):
super().__init__(verbose)
Expand Down Expand Up @@ -98,30 +98,26 @@ def __init__(self, reference_image, whole_object_mask, background_mask, **kwargs
self.ref_im_arr = reference_image.as_array()
self.norm = self.ref_im_arr[self.background_indices].mean()

def __call__(self, algorithm):
iteration = algorithm.iteration
if iteration % algorithm.update_objective_interval != 0 and iteration != algorithm.max_iteration:
def __call__(self, algo: Algorithm):
iteration = algo.iteration
if iteration % algo.update_objective_interval != 0 and iteration != algo.max_iteration:
return
for tag, value in self.evaluate(algo.x).items():
self.tb_summary_writer.add_scalar(tag, value, iteration)

def evaluate(self, test_im: STIR.ImageData) -> dict[str, float]:
assert not any(self.filter.values()), "Filtering not implemented"
test_im_arr = algorithm.x.as_array()

# (1) global metrics & statistics
self.tb_summary_writer.add_scalar(
"RMSE_whole_object",
np.sqrt(mse(self.ref_im_arr[self.whole_object_indices], test_im_arr[self.whole_object_indices])) /
self.norm, iteration)
self.tb_summary_writer.add_scalar(
"RMSE_background",
np.sqrt(mse(self.ref_im_arr[self.background_indices], test_im_arr[self.background_indices])) / self.norm,
iteration)

# (2) local metrics & statistics
for voi_name, voi_indices in sorted(self.voi_indices.items()):
# AEM not to be confused with MAE
self.tb_summary_writer.add_scalar(
f"AEM_VOI_{voi_name}",
np.abs(test_im_arr[voi_indices].mean() - self.ref_im_arr[voi_indices].mean()) / self.norm, iteration)
test_im_arr = test_im.as_array()
whole = {
"RMSE_whole_object": np.sqrt(
mse(self.ref_im_arr[self.whole_object_indices], test_im_arr[self.whole_object_indices])) / self.norm,
"RMSE_background": np.sqrt(
mse(self.ref_im_arr[self.background_indices], test_im_arr[self.background_indices])) / self.norm}
local = {
f"AEM_VOI_{voi_name}": np.abs(test_im_arr[voi_indices].mean() - self.ref_im_arr[voi_indices].mean()) /
self.norm
for voi_name, voi_indices in sorted(self.voi_indices.items())}
return {**whole, **local}


class MetricsWithTimeout(cbks.Callback):
Expand All @@ -132,20 +128,20 @@ def __init__(self, seconds=300, outdir=OUTDIR, transverse_slice=None, coronal_sl
self.callbacks = [
cbks.ProgressCallback(),
SaveIters(outdir=outdir),
(tb_cbk := TensorBoard(logdir=outdir, transverse_slice=transverse_slice, coronal_slice=coronal_slice))]
(tb_cbk := StatsLog(logdir=outdir, transverse_slice=transverse_slice, coronal_slice=coronal_slice))]
self.tb = tb_cbk.tb
self.reset()

def reset(self, seconds=None):
self.limit = time() + (self._seconds if seconds is None else seconds)

def __call__(self, algorithm: Algorithm):
def __call__(self, algo: Algorithm):
if (now := time()) > self.limit:
log.warning("Timeout reached. Stopping algorithm.")
raise StopIteration
if self.callbacks:
for c in self.callbacks:
c(algorithm)
c(algo)
self.limit += time() - now

@staticmethod
Expand Down Expand Up @@ -224,7 +220,7 @@ def get_image(fname):

if SRCDIR.is_dir():
# create list of existing data
# Note: as MetricsWithTimeout initialises Tensorboard, this will currently create directories in OUTDIR accordingly
# NB: `MetricsWithTimeout` initialises `SaveIters` which creates `outdir`
data_dirs_metrics = [(SRCDIR / "Siemens_mMR_NEMA_IQ", OUTDIR / "mMR_NEMA",
[MetricsWithTimeout(outdir=OUTDIR / "mMR_NEMA", transverse_slice=72, coronal_slice=109)]),
(SRCDIR / "NeuroLF_Hoffman_Dataset", OUTDIR / "NeuroLF_Hoffman",
Expand Down

0 comments on commit 11112d2

Please sign in to comment.