diff --git a/main.py b/main.py index 55ac03c..7bd5e3a 100644 --- a/main.py +++ b/main.py @@ -60,6 +60,21 @@ def __init__(self, initial_step_size: float, decay: float): def get_step_size(self, algorithm): return self.initial_step_size / (1 + self.decay * algorithm.iteration) + +def initial_step_size_search_rule(x, f, g, grad, max_iter=100, tol=0.1): + """ + Simple line search for the initial step size. + """ + step_size = 1.0 + f_x = f(x) + g(x) + g_norm = grad.squared_norm() + for _ in range(max_iter): + x_new = g.proximal(x - step_size * grad, step_size) + f_x_new = f(x_new) + g(x_new) + if f_x_new <= f_x - tol * step_size * g_norm: + break + step_size /= 2 + return step_size #%% class Submission(ISTA): @@ -83,13 +98,16 @@ def __init__(self, data: Dataset): # WARNING: modifies prior strength with 1/num_subsets (as currently needed for BSREM implementations) data.prior.set_penalisation_factor(data.prior.get_penalisation_factor() / len(obj_funs)) data.prior.set_up(data.OSEM_image) + grad = data.OSEM_image.get_uniform_copy(0) + preconditioner = MyPreconditioner(data.kappa) for f, d in zip(obj_funs, data_subs): # add prior to every objective function f.set_prior(data.prior) + grad -= preconditioner.apply(self, f.gradient(data.OSEM_image)) sampler = Sampler.random_without_replacement(len(obj_funs)) - preconditioner = MyPreconditioner(data.kappa) f = -SAGAFunction(obj_funs, sampler=sampler) g = IndicatorBox(lower=0, accelerated=True) # non-negativity constraint + initial_step_size = initial_step_size_search_rule(data.OSEM_image, f, g, grad) step_size_rule = LinearDecayStepSizeRule(initial_step_size, decay=decay) super().__init__(initial=data.OSEM_image, f=f, g=g, step_size=step_size_rule, preconditioner=preconditioner,