Skip to content

Commit

Permalink
removed squaring of gradients in EWCRegularizer
Browse files Browse the repository at this point in the history
  • Loading branch information
varisd committed Nov 16, 2018
1 parent 5adfd48 commit 6673ae7
Showing 1 changed file with 15 additions and 10 deletions.
25 changes: 15 additions & 10 deletions neuralmonkey/trainers/regularizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,15 +104,15 @@ class EWCRegularizer(Regularizer):
def __init__(self,
name: str,
weight: float,
gradients_file: str,
fisher_file: str,
variables_file: str) -> None:
"""Create the regularizer.
Arguments:
name: Regularizer name.
weight: Weight of the regularization term.
gradients_file: File containing the gradient estimates
from the previous task.
fisher_file: File containing the diagonal of the fisher information
matrix estimated on the previous task.
variables_files: File containing the variables learned
on the previous task.
"""
Expand All @@ -124,23 +124,28 @@ def __init__(self,
self.init_vars = tf.contrib.framework.load_checkpoint(variables_file)
log("EWC initial variables loaded.")

log("Loading gradient estimates from {}.".format(gradients_file))
self.gradients = np.load(gradients_file)
log("Loading gradient estimates from {}.".format(fisher_file))
self.fisher = np.load(fisher_file)
log("Gradient estimates loaded.")

def value(self, variables: List[tf.Tensor]) -> tf.Tensor:
r"""Compute the value of the regularization term.
value = \sum_{i} (λ * F_{i} * (θ_{i} - θ_{i}^{*})^2)
where λ is the regularizer weight and F is the diagonal
of the Fisher Information matrix.
"""

ewc_value = tf.constant(0.0)
for var in variables:
init_var_name = var.name.split(":")[0]
if (var.name in self.gradients.files
if (var.name in self.fisher.files
and self.init_vars.has_tensor(init_var_name)):
init_var = tf.constant(
self.init_vars.get_tensor(init_var_name),
name="{}_init_value".format(init_var_name))
grad_squared = tf.constant(
np.square(self.gradients[var.name]),
name="{}_ewc_weight".format(init_var_name))
ewc_value += tf.reduce_sum(tf.multiply(
grad_squared, tf.square(var - init_var)))
self.fisher[var.name], tf.square(var - init_var)))

return ewc_value

0 comments on commit 6673ae7

Please sign in to comment.