diff --git a/main.py b/main.py index c310920..74486ac 100644 --- a/main.py +++ b/main.py @@ -107,6 +107,7 @@ def __init__(self, data: Dataset): sampler = Sampler.random_without_replacement(len(obj_funs)) f = -SVRGFunction(obj_funs, sampler=sampler, snapshot_update_interval=None, store_gradients=True) g = IndicatorBox(lower=0, accelerated=True) # non-negativity constraint + initial_step_size = initial_step_size_search_rule(data.OSEM_image, f, g, grad) step_size_rule = LinearDecayStepSizeRule(initial_step_size, decay=decay) super().__init__(initial=data.OSEM_image, f=f, g=g, step_size=step_size_rule, preconditioner=preconditioner,