Skip to content

Commit

Permalink
addressing PR reviews
Browse files Browse the repository at this point in the history
  • Loading branch information
varisd committed Jul 31, 2018
1 parent aaa8e8a commit 3a19ab7
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 7 deletions.
8 changes: 8 additions & 0 deletions neuralmonkey/runners/gradient_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,14 @@ def collect_results(self, results: List[Dict]) -> None:


class GradientRunner(BaseRunner[SupportedDecoder]):
"""Runner for fetching gradients computed over the dataset.
Gradient runner applies provided trainer on a desired dataset
and uses it to compute gradients over the gold data. It is currently
used to gather gradients for Elastic Weight Consolidation.
(https://arxiv.org/pdf/1612.00796.pdf)
"""

def __init__(self,
output_series: str,
Expand Down
20 changes: 19 additions & 1 deletion neuralmonkey/trainers/cross_entropy_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
import tensorflow as tf
from typeguard import check_argument_types

from neuralmonkey.logging import warn
from neuralmonkey.trainers.generic_trainer import (
GenericTrainer, Objective, ObjectiveWeight)
from neuralmonkey.trainers.regularizers import BaseRegularizer
from neuralmonkey.trainers.regularizers import (
BaseRegularizer, L1Regularizer, L2Regularizer)


def xent_objective(decoder, weight=None) -> Objective:
Expand All @@ -29,13 +31,29 @@ def __init__(self,
clip_norm: float = None,
optimizer: tf.train.Optimizer = None,
regularizers: List[BaseRegularizer] = None,
l1_weight: float = 0.,
l2_weight: float = 0.,
var_scopes: List[str] = None,
var_collection: str = None) -> None:
check_argument_types()

if decoder_weights is None:
decoder_weights = [None for _ in decoders]

if regularizers is None:
regularizers = []
if l1_weight > 0.:
if L1Regularizer in [type(r) for r in regularizers]:
warn("You specified both trainer l1_weight "
"and a L1Regularizer object in your config")
regularizers.append(L1Regularizer(weight=l1_weight))

if l2_weight > 0.:
if L2Regularizer in [type(r) for r in regularizers]:
warn("You specified both trainer l2_weight "
"and a L2Regularizer object in your config")
regularizers.append(L2Regularizer(weight=l2_weight))

if len(decoder_weights) != len(decoders):
raise ValueError(
"decoder_weights (length {}) do not match decoders (length {})"
Expand Down
48 changes: 42 additions & 6 deletions neuralmonkey/trainers/regularizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ class BaseRegularizer:
def __init__(self,
name: str,
weight: float) -> None:
"""Create the regularizer.
Arguments:
name: Regularizer name.
weight: Weight of the regularization term.
"""
check_argument_types()

self._name = name
Expand All @@ -38,21 +44,35 @@ def value(self, variables) -> float:


class L1Regularizer(BaseRegularizer):
"""L1 regularizer."""

def __init__(self,
name: str = "train_l1",
weight: float = 1.0e-8) -> None:
"""Create the regularizer.
Arguments:
name: Regularizer name.
weight: Weight of the regularization term.
"""
BaseRegularizer.__init__(self, name, weight)

def value(self, variables: List[tf.Tensor]) -> float:
return sum(tf.reduce_sum(abs(v)) for v in variables)


class L2Regularizer(BaseRegularizer):
"""L2 regularizer."""

def __init__(self,
name: str = "train_l2",
weight: float = 1.0e-8) -> None:
"""Create the regularizer.
Arguments:
name: Regularizer name.
weight: Weight of the regularization term.
"""
BaseRegularizer.__init__(self, name, weight)

def value(self, variables: List[tf.Tensor]) -> float:
Expand All @@ -62,14 +82,27 @@ def value(self, variables: List[tf.Tensor]) -> float:
class EWCRegularizer(BaseRegularizer):
"""Regularizer based on the Elastic Weight Consolidation.
TODO description
Implements Elastic Weight Consolidation from the "Overcoming catastrophic
forgetting in neural networks" paper.
https://arxiv.org/pdf/1612.00796.pdf
"""

def __init__(self,
name: str = "train_ewc",
weight: float = 0.,
gradients_file: str = None,
variables_file: str = None) -> None:
"""Create the regularizer.
Arguments:
name: Regularizer name.
weight: Weight of the regularization term.
gradients_file: File containing the gradient estimates
from the previous task.
variables_files: File containing the variables learned
on the previous task.
"""
check_argument_types()

BaseRegularizer.__init__(self, name, weight)
Expand All @@ -88,13 +121,16 @@ def __init__(self,
log("Gradient estimates loaded")

def value(self, variables: List[tf.Tensor]) -> float:
ewc_value = 0.0
ewc_value = tf.constant(0.0)
for var in variables:
var_name = var.name.split(":")[0]
init_var = self.init_vars.get_tensor(var_name)
gradient = self.gradients[var_name]
ewc_value += tf.reduce_sum(tf.multiply(
tf.square(gradient), tf.square(var - init_var)))
if (var_name in self.gradients.files
and self.init_vars.has_tensor(var_name)):
init_var = self.init_vars.get_tensor(var_name)
gradient = tf.constant(
self.gradients[var_name], name="ewc_gradients")
ewc_value += tf.reduce_sum(tf.multiply(
tf.square(gradient), tf.square(var - init_var)))

return ewc_value

Expand Down

0 comments on commit 3a19ab7

Please sign in to comment.