diff --git a/main.py b/main.py index a8f7292..3152815 120000 --- a/main.py +++ b/main.py @@ -1 +1 @@ -main_SGD.py \ No newline at end of file +main_ISTA.py \ No newline at end of file diff --git a/main_SGD.py b/main_ISTA.py similarity index 76% rename from main_SGD.py rename to main_ISTA.py index 070cdaa..eeda971 100644 --- a/main_SGD.py +++ b/main_ISTA.py @@ -5,13 +5,13 @@ >>> algorithm = Submission(data) >>> algorithm.run(np.inf, callbacks=metrics + submission_callbacks) """ -from cil.optimisation.algorithms import GD, Algorithm -from cil.optimisation.functions import SGFunction +from cil.optimisation.algorithms import ISTA, Algorithm +from cil.optimisation.functions import IndicatorBox, SGFunction from cil.optimisation.utilities import ConstantStepSize, Sampler, callbacks from petric import Dataset from sirf.contrib.partitioner import partitioner -assert issubclass(GD, Algorithm) +assert issubclass(ISTA, Algorithm) class MaxIteration(callbacks.Callback): @@ -28,9 +28,9 @@ def __call__(self, algorithm: Algorithm): raise StopIteration -class Submission(GD): - # note that `issubclass(GD, Algorithm) == True` - def __init__(self, data: Dataset, num_subsets: int = 7, step_size: float = 1e-10, +class Submission(ISTA): + # note that `issubclass(ISTA, Algorithm) == True` + def __init__(self, data: Dataset, num_subsets: int = 7, step_size: float = 1e-6, update_objective_interval: int = 10): """ Initialisation function, setting up data & (hyper)parameters. @@ -45,10 +45,11 @@ def __init__(self, data: Dataset, num_subsets: int = 7, step_size: float = 1e-10 f.set_prior(data.prior) sampler = Sampler.random_without_replacement(len(obj_funs)) - F = -SGFunction(obj_funs, sampler=sampler) # negative to turn minimiser into maximiser - step_size_rule = ConstantStepSize(step_size) # ISTA default step_size is 0.99*2.0/F.L + F = -SGFunction(obj_funs, sampler=sampler) # negative to turn minimiser into maximiser + step_size_rule = ConstantStepSize(step_size) # ISTA default step_size is 0.99*2.0/F.L + g = IndicatorBox(lower=1e-6, accelerated=False) # "non-negativity" constraint - super().__init__(initial=data.OSEM_image, objective_function=F, step_size=step_size_rule, + super().__init__(initial=data.OSEM_image, f=F, g=g, step_size=step_size_rule, update_objective_interval=update_objective_interval)