From e6732cef50b1399bf5d57d675fdbad66cf9b47a9 Mon Sep 17 00:00:00 2001 From: "Sam.Porter" Date: Mon, 30 Sep 2024 21:11:13 +0100 Subject: [PATCH] update SVRG gradient --- main.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/main.py b/main.py index 7ab8bdc..3d8e303 100644 --- a/main.py +++ b/main.py @@ -15,6 +15,7 @@ from petric import Dataset from sirf.contrib.partitioner import partitioner import numpy as np +import numbers assert issubclass(ISTA, Algorithm) @@ -24,7 +25,8 @@ def __init__(self, functions, sampler=None, snapshot_update_interval=None, store init_steps=0, **kwargs): super(FullGradientInitialiserFunction, self).__init__(functions, sampler=sampler, - snapshot_update_interval=snapshot_update_interval, store_gradients=store_gradients, **kwargs) + snapshot_update_interval=snapshot_update_interval, + store_gradients=store_gradients, **kwargs) self.counter = 0 self.init_steps = init_steps @@ -46,11 +48,19 @@ def gradient(self, x, out=None): self.counter += 1 return self.full_gradient(x, out=out) - self.function_num = self.sampler.next() - - self._update_data_passes_indices([self.function_num]) - - return self.approximate_gradient(x, self.function_num, out=out) + if ( (self.snapshot_update_interval != 0) and (self._svrg_iter_number % (self.snapshot_update_interval)) == 0): + + return self._update_full_gradient_and_return(x, out=out) + + else: + + self.function_num = self.sampler.next() + if not isinstance(self.function_num, numbers.Number): + raise ValueError("Batch gradient is not yet implemented") + if self.function_num >= self.num_functions or self.function_num < 0: + raise IndexError( + f"The sampler has produced the index {self.function_num} which does not match the expected range of available functions to sample from. Please ensure your sampler only selects from [0,1,...,len(functions)-1] ") + return self.approximate_gradient(x, self.function_num, out=out) class BSREMPreconditioner(Preconditioner): ''' @@ -251,4 +261,4 @@ def __init__(self, data: Dataset, update_objective_interval=10): super().__init__(initial=data.OSEM_image, f=f, g=g, step_size=step_size_rule, preconditioner=preconditioner, update_objective_interval=update_objective_interval) -submission_callbacks = [] \ No newline at end of file +submission_callbacks = [] # 10000 iterations max \ No newline at end of file