Skip to content

Commit

Permalink
add some utilities for reading/plotting
Browse files Browse the repository at this point in the history
  • Loading branch information
KrisThielemans committed Jul 14, 2024
1 parent 85fac04 commit 2035dd0
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 1 deletion.
38 changes: 38 additions & 0 deletions SIRF_data_preparation/evaluation_utilities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""
Some utilities for plotting objectives and metrics
"""
import csv
import os
import numpy
from typing import Iterator
import matplotlib.pyplot as plt
from petric import QualityMetrics
import sirf.STIR as STIR

def read_objectives(datadir='.') -> numpy.array:
"""Reads objectives.csv and returns as 2d array"""
with open(os.path.join(datadir, 'objectives.csv'), newline='') as csvfile:
reader = csv.reader(csvfile)
next(reader) # skip first line
objs = numpy.array([(float(row[0]), float(row[1])) for row in reader])
return objs


def get_metrics(qm: QualityMetrics, iters: Iterator, srcdir='.') -> numpy.array :
"""Read 'iter*.hv' images from datadir, compute metrics and return as 2d array"""
m = []
for iter in iters:
im = STIR.ImageData(os.path.join(srcdir,f"iter_{iter:04d}.hv"))
m.append([*qm.evaluate(im).values()])
return numpy.array(m)
#%%
def plot_metrics(iters: Iterator, m: numpy.array, labels=[], suffix=""):
"""Make 2 subplots of metrics"""
ax = plt.subplot(121)
plt.plot(iters, m[:,0], label=labels[0]+suffix)
plt.plot(iters, m[:,1], label=labels[1]+suffix)
ax.legend()
ax = plt.subplot(122)
for i in range(2, m.shape[1]):
plt.plot(iters, m[:,i], label=labels[i]+suffix)
ax.legend()
2 changes: 1 addition & 1 deletion petric.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def evaluate(self, test_im: STIR.ImageData) -> dict[str, float]:
for voi_name, voi_indices in sorted(self.voi_indices.items())}
return {**whole, **local}

def keys():
def keys(self):
l = ["RMSE_whole_object", "RMSE_background"]
for voi_name in sorted(self.voi_indices.keys()):
l.append(f"AEM_VOI_{voi_name}")
Expand Down

0 comments on commit 2035dd0

Please sign in to comment.