Skip to content

Commit

Permalink
update SVRG gradient
Browse files Browse the repository at this point in the history
  • Loading branch information
samdporter committed Sep 30, 2024
1 parent b1084a9 commit e6732ce
Showing 1 changed file with 17 additions and 7 deletions.
24 changes: 17 additions & 7 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from petric import Dataset
from sirf.contrib.partitioner import partitioner
import numpy as np
import numbers

assert issubclass(ISTA, Algorithm)

Expand All @@ -24,7 +25,8 @@ def __init__(self, functions, sampler=None, snapshot_update_interval=None, store
init_steps=0, **kwargs):

super(FullGradientInitialiserFunction, self).__init__(functions, sampler=sampler,
snapshot_update_interval=snapshot_update_interval, store_gradients=store_gradients, **kwargs)
snapshot_update_interval=snapshot_update_interval,
store_gradients=store_gradients, **kwargs)
self.counter = 0
self.init_steps = init_steps

Expand All @@ -46,11 +48,19 @@ def gradient(self, x, out=None):
self.counter += 1
return self.full_gradient(x, out=out)

self.function_num = self.sampler.next()

self._update_data_passes_indices([self.function_num])

return self.approximate_gradient(x, self.function_num, out=out)
if ( (self.snapshot_update_interval != 0) and (self._svrg_iter_number % (self.snapshot_update_interval)) == 0):

return self._update_full_gradient_and_return(x, out=out)

else:

self.function_num = self.sampler.next()
if not isinstance(self.function_num, numbers.Number):
raise ValueError("Batch gradient is not yet implemented")
if self.function_num >= self.num_functions or self.function_num < 0:
raise IndexError(
f"The sampler has produced the index {self.function_num} which does not match the expected range of available functions to sample from. Please ensure your sampler only selects from [0,1,...,len(functions)-1] ")
return self.approximate_gradient(x, self.function_num, out=out)

class BSREMPreconditioner(Preconditioner):
'''
Expand Down Expand Up @@ -251,4 +261,4 @@ def __init__(self, data: Dataset, update_objective_interval=10):
super().__init__(initial=data.OSEM_image, f=f, g=g, step_size=step_size_rule,
preconditioner=preconditioner, update_objective_interval=update_objective_interval)

submission_callbacks = []
submission_callbacks = [] # 10000 iterations max

0 comments on commit e6732ce

Please sign in to comment.