Skip to content

Commit

Permalink
Merge pull request #110 from LucasBoTang/gradproj
Browse files Browse the repository at this point in the history
New feat: Update Violation Energy for Projected Gradient
  • Loading branch information
drgona authored Oct 25, 2023
2 parents bf4bc62 + 0e56621 commit d9c162c
Showing 1 changed file with 25 additions and 15 deletions.
40 changes: 25 additions & 15 deletions src/neuromancer/modules/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,15 @@ class GradientProjection(Solver):
DC3 paper: https://arxiv.org/abs/2104.12225
"""
def __init__(self, constraints, input_keys, output_keys=[], decay=0.1,
num_steps=1, step_size=0.01, name=None):
num_steps=1, step_size=0.01, energy_update=True, name=None):
"""
:param constraints:
:param input_keys: (List of str) List of input variable names
:param constraints: list of objects which implement the Loss interface (e.g. Objective, Loss, or Constraint)
:param num_steps: (int) number of iteration steps for the projected gradient method
:param step_size: (float) scaling factor for gradient update
:param decay: (float) decay factor of the step_size
:param energy_update: (bool) flag to update energy
:param name:
"""
super().__init__(constraints=constraints,
Expand All @@ -72,6 +73,7 @@ def __init__(self, constraints, input_keys, output_keys=[], decay=0.1,
self.step_size = step_size
self.input_keys = input_keys
self.decay = decay
self.energy_update = energy_update

def _constraints_check(self):
"""
Expand All @@ -96,23 +98,31 @@ def con_viol_energy(self, input_dict):

def forward(self, data):
"""
foward pass of the projected gradient solver
forward pass of the projected gradient solver
:param data: (dict: {str: Tensor})
:return: (dict: {str: Tensor})
"""
energy = self.con_viol_energy(data)
output_data = {}
for in_key, out_key in zip(self.input_keys, self.output_keys):
x = data[in_key]
step = gradient(energy, x)
assert step.shape == x.shape, \
f'Dimensions of gradient step {step.shape} should be equal to dimensions ' \
f'{x.shape} of a single variable {in_key}'
d = 1.
for k in range(self.num_steps):
x = x - d*self.step_size*step
d = d - self.decay*d
output_data[out_key] = x
# init output
output_data = data.copy()
if self.energy_update:
data = output_data
# init decay rate
d = 1
# projected gradient
for k in range(self.num_steps):
# update energy
energy = self.con_viol_energy(data)
for in_key, out_key in zip(self.input_keys, self.output_keys):
# get grad
x = data[in_key]
step = gradient(energy, x)
assert step.shape == x.shape, \
f'Dimensions of gradient step {step.shape} should be equal to dimensions ' \
f'{x.shape} of a single variable {in_key}'
# update
x = x - d * self.step_size*step
d = d - self.decay * d
output_data[out_key] = x
return output_data


Expand Down

0 comments on commit d9c162c

Please sign in to comment.