From fffc8bdefaf13fe984686c6eaedcb8ec66f4eb45 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca <14138589+paskino@users.noreply.github.com> Date: Thu, 4 Jul 2024 15:54:51 +0100 Subject: [PATCH] Apply suggestions from code review Co-authored-by: Kris Thielemans Co-authored-by: Casper da Costa-Luis --- main_ISTA.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/main_ISTA.py b/main_ISTA.py index 3aee1a9..eeda8a1 100644 --- a/main_ISTA.py +++ b/main_ISTA.py @@ -28,16 +28,17 @@ def __call__(self, algorithm: Algorithm): raise StopIteration class MyPreconditioner(Preconditioner): + # Use a preconditioner based on the row-sum of the Hessian of the log-likelihood as an example # See: # Tsai, Y.-J., Bousse, A., Ehrhardt, M.J., Stearns, C.W., Ahn, S., Hutton, B., Arridge, S., Thielemans, K., 2017. # Fast Quasi-Newton Algorithms for Penalized Reconstruction in Emission Tomography and Further Improvements via Preconditioning. # IEEE Transactions on Medical Imaging 1. https://doi.org/10.1109/tmi.2017.2786865 def __init__(self, kappa): - self.kappasq = kappa * kappa + 1e-5 + # add an epsilon to avoid division by zero. This eps value probably should be made dependent on kappa though. + self.kappasq = kappa * kappa + 1e-6 - def apply(self, algorithm, gradient, out): - out = gradient.divide(self.kappasq, out=out) - return out + def apply(self, algorithm, gradient, out=None): + return gradient.divide(self.kappasq, out=out) class Submission(ISTA): # note that `issubclass(ISTA, Algorithm) == True` @@ -57,7 +58,7 @@ def __init__(self, data: Dataset, num_subsets: int = 7, step_size: float = 1e-6, sampler = Sampler.random_without_replacement(len(obj_funs)) F = -SGFunction(obj_funs, sampler=sampler) # negative to turn minimiser into maximiser step_size_rule = ConstantStepSize(step_size) # ISTA default step_size is 0.99*2.0/F.L - g = IndicatorBox(lower=1e-6, accelerated=False) # "non-negativity" constraint + g = IndicatorBox(lower=0, accelerated=False) # non-negativity constraint my_preconditioner = MyPreconditioner(data.kappa) super().__init__(initial=data.OSEM_image,