From 6673ae74accb84808610d8ebdd3ec8b84834945f Mon Sep 17 00:00:00 2001 From: Dusan Varis Date: Fri, 17 Aug 2018 16:25:54 +0200 Subject: [PATCH] removed squaring of gradients in EWCRegularizer --- neuralmonkey/trainers/regularizers.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/neuralmonkey/trainers/regularizers.py b/neuralmonkey/trainers/regularizers.py index 691326d36..4986a5ccf 100644 --- a/neuralmonkey/trainers/regularizers.py +++ b/neuralmonkey/trainers/regularizers.py @@ -104,15 +104,15 @@ class EWCRegularizer(Regularizer): def __init__(self, name: str, weight: float, - gradients_file: str, + fisher_file: str, variables_file: str) -> None: """Create the regularizer. Arguments: name: Regularizer name. weight: Weight of the regularization term. - gradients_file: File containing the gradient estimates - from the previous task. + fisher_file: File containing the diagonal of the fisher information + matrix estimated on the previous task. variables_files: File containing the variables learned on the previous task. """ @@ -124,23 +124,28 @@ def __init__(self, self.init_vars = tf.contrib.framework.load_checkpoint(variables_file) log("EWC initial variables loaded.") - log("Loading gradient estimates from {}.".format(gradients_file)) - self.gradients = np.load(gradients_file) + log("Loading gradient estimates from {}.".format(fisher_file)) + self.fisher = np.load(fisher_file) log("Gradient estimates loaded.") def value(self, variables: List[tf.Tensor]) -> tf.Tensor: + r"""Compute the value of the regularization term. + + value = \sum_{i} (λ * F_{i} * (θ_{i} - θ_{i}^{*})^2) + + where λ is the regularizer weight and F is the diagonal + of the Fisher Information matrix. + """ + ewc_value = tf.constant(0.0) for var in variables: init_var_name = var.name.split(":")[0] - if (var.name in self.gradients.files + if (var.name in self.fisher.files and self.init_vars.has_tensor(init_var_name)): init_var = tf.constant( self.init_vars.get_tensor(init_var_name), name="{}_init_value".format(init_var_name)) - grad_squared = tf.constant( - np.square(self.gradients[var.name]), - name="{}_ewc_weight".format(init_var_name)) ewc_value += tf.reduce_sum(tf.multiply( - grad_squared, tf.square(var - init_var))) + self.fisher[var.name], tf.square(var - init_var))) return ewc_value