Skip to content

Commit

Permalink
Update SMCOpt
Browse files Browse the repository at this point in the history
  • Loading branch information
Rolf Johan Lorentzen committed Dec 18, 2024
1 parent 1ae0992 commit 97aaf2c
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 17 deletions.
31 changes: 16 additions & 15 deletions popt/loop/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,9 @@ def __set__variable(var_name=None, defalut=None):
# Inflation factor used in SmcOpt
self.inflation_factor = None
self.survival_factor = None
self.particles = np.empty((self.cov.shape[0],0))
self.particle_values = np.empty((0))
self.particles = np.empty((self.cov.shape[0],self.ne))
self.particle_values = np.empty(self.ne)
self.resample_index = None

# Initialize variables for bias correction
if 'bias_file' in self.sim.input_dict: # use bias correction
Expand Down Expand Up @@ -388,24 +389,23 @@ def calc_ensemble_weights(self, x, *args, **kwargs):
initial_state = deepcopy(self.state) # store this to update current objective values

# Generate ensemble of states
if self.particles.shape[1] == 0:
if self.resample_index is None:
self.ne = self.num_samples
self.resample_index = []
else:
self.ne = int(np.round(self.num_samples*self.survival_factor))
self._aux_input()
self.state = self._gen_state_ensemble()

self._invert_scale_state() # ensure that state is in [lb,ub]
self.calc_prediction() # calculate flow data
self._scale_state() # scale back to [0, 1]

#self.ens_func_values = self.obj_func(self.pred_data, self.sim.input_dict, self.sim.true_order)
#self.ens_func_values = np.array(self.ens_func_values)
state_ens = at.aug_state(self.state, list(self.state.keys()))
self.function(state_ens, **kwargs)

self.particles = np.hstack((self.particles, state_ens))
self.particle_values = np.hstack((self.particle_values,self.ens_func_values))
self.particles[:,:(self.num_samples-self.ne)] = self.particles[:,self.resample_index]
#np.hstack((self.particles, state_ens))
self.particles[:,(self.num_samples-self.ne):] = deepcopy(state_ens)
self.particle_values[:(self.num_samples-self.ne)] = self.particle_values[self.resample_index]
# np.hstack((self.particle_values,self.ens_func_values))
self.particle_values[(self.num_samples - self.ne):] = deepcopy(self.ens_func_values)

# If bias correction is used we need to calculate the bias factors, J(u_j,m_j)/J(u_j,m)
if self.bias_file is not None: # use bias corrections
Expand All @@ -415,7 +415,8 @@ def calc_ensemble_weights(self, x, *args, **kwargs):
warnings.filterwarnings('ignore') # suppress warnings
weights = np.zeros(self.num_samples)
for i in np.arange(self.num_samples):
weights[i] = np.exp(-(self.particle_values[i]-np.min(self.particle_values))*self.inflation_factor)
weights[i] = np.exp(np.clip(-(self.particle_values[i] - np.min(
self.particle_values)) * self.inflation_factor, None, 10))

weights = weights + 0.000001
weights = weights/np.sum(weights) # TODO: Sjekke at disse er riktig
Expand All @@ -424,10 +425,10 @@ def calc_ensemble_weights(self, x, *args, **kwargs):
index = np.argmin(self.particle_values)
best_ens = self.particles[:, index]
best_func = self.particle_values[index]
resample_index = np.random.choice(self.num_samples,int(np.round(self.num_samples-
self.resample_index = np.random.choice(self.num_samples,int(np.round(self.num_samples-
self.num_samples*self.survival_factor)),replace=True,p=weights)
self.particles = self.particles[:, resample_index]
self.particle_values = self.particle_values[resample_index]
#self.particles = self.particles[:, resample_index]
#self.particle_values = self.particle_values[resample_index]
return sens_matrix, best_ens, best_func

def _gen_state_ensemble(self):
Expand Down
5 changes: 3 additions & 2 deletions popt/update_schemes/smcopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,9 @@ def __set__variable(var_name=None, defalut=None):
self.alpha_iter_max = __set__variable('alpha_maxiter', 5)
self.max_resample = __set__variable('resample', 0)
self.cov_factor = __set__variable('cov_factor', 0.5)
self.inflation_factor = __set__variable('inflation_factor', 1)
self.survival_factor = __set__variable('survival_factor', 0)
self.inflation_factor = __set__variable('inflation_factor', 1.0)
self.survival_factor = __set__variable('survival_factor', 1.0)
self.survival_factor = np.clip(self.survival_factor,0.1, 1.0)

# Calculate objective function of startpoint
if not self.restart:
Expand Down

0 comments on commit 97aaf2c

Please sign in to comment.