Skip to content

Commit

Permalink
addressing PR reviews
Browse files Browse the repository at this point in the history
  • Loading branch information
varisd committed Jul 31, 2018
1 parent aaa8e8a commit f7e6795
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 19 deletions.
8 changes: 8 additions & 0 deletions neuralmonkey/runners/gradient_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,14 @@ def collect_results(self, results: List[Dict]) -> None:


class GradientRunner(BaseRunner[SupportedDecoder]):
"""Runner for fetching gradients computed over the dataset.
Gradient runner applies provided trainer on a desired dataset
and uses it to compute gradients over the gold data. It is currently
used to gather gradients for Elastic Weight Consolidation.
(https://arxiv.org/pdf/1612.00796.pdf)
"""

def __init__(self,
output_series: str,
Expand Down
22 changes: 20 additions & 2 deletions neuralmonkey/trainers/cross_entropy_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
import tensorflow as tf
from typeguard import check_argument_types

from neuralmonkey.logging import warn
from neuralmonkey.trainers.generic_trainer import (
GenericTrainer, Objective, ObjectiveWeight)
from neuralmonkey.trainers.regularizers import BaseRegularizer
from neuralmonkey.trainers.regularizers import (
Regularizer, L1Regularizer, L2Regularizer)


def xent_objective(decoder, weight=None) -> Objective:
Expand All @@ -28,14 +30,30 @@ def __init__(self,
decoder_weights: List[ObjectiveWeight] = None,
clip_norm: float = None,
optimizer: tf.train.Optimizer = None,
regularizers: List[BaseRegularizer] = None,
regularizers: List[Regularizer] = None,
l1_weight: float = 0.,
l2_weight: float = 0.,
var_scopes: List[str] = None,
var_collection: str = None) -> None:
check_argument_types()

if decoder_weights is None:
decoder_weights = [None for _ in decoders]

if regularizers is None:
regularizers = []
if l1_weight > 0.:
if L1Regularizer in [type(r) for r in regularizers]:
warn("You specified both trainer l1_weight "
"and a L1Regularizer object in your config")
regularizers.append(L1Regularizer(weight=l1_weight))

if l2_weight > 0.:
if L2Regularizer in [type(r) for r in regularizers]:
warn("You specified both trainer l2_weight "
"and a L2Regularizer object in your config")
regularizers.append(L2Regularizer(weight=l2_weight))

if len(decoder_weights) != len(decoders):
raise ValueError(
"decoder_weights (length {}) do not match decoders (length {})"
Expand Down
6 changes: 3 additions & 3 deletions neuralmonkey/trainers/generic_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from neuralmonkey.model.model_part import ModelPart
from neuralmonkey.runners.base_runner import (
Executable, ExecutionResult, NextExecute)
from neuralmonkey.trainers.regularizers import BaseRegularizer
from neuralmonkey.trainers.regularizers import Regularizer

# pylint: disable=invalid-name
Gradients = List[Tuple[tf.Tensor, tf.Variable]]
Expand Down Expand Up @@ -45,12 +45,12 @@ def __init__(self,
objectives: List[Objective],
clip_norm: float = None,
optimizer: tf.train.Optimizer = None,
regularizers: List[BaseRegularizer] = None,
regularizers: List[Regularizer] = None,
var_scopes: List[str] = None,
var_collection: str = None) -> None:
check_argument_types()

self.regularizers = [] # type: List[BaseRegularizer]
self.regularizers = [] # type: List[Regularizer]
if regularizers is not None:
self.regularizers = regularizers

Expand Down
64 changes: 50 additions & 14 deletions neuralmonkey/trainers/regularizers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Variable regularizers.
This module contains classes that can be used as a variable regularizers
during training. All implementation should be derived from the BaseRegularizer
during training. All implementation should be derived from the Regularizer
class.
"""
Expand All @@ -14,12 +14,18 @@
from neuralmonkey.logging import log


class BaseRegularizer:
class Regularizer:
"""Base class for the regularizers."""

def __init__(self,
name: str,
weight: float) -> None:
"""Create the regularizer.
Arguments:
name: Regularizer name.
weight: Weight of the regularization term.
"""
check_argument_types()

self._name = name
Expand All @@ -37,42 +43,69 @@ def value(self, variables) -> float:
raise NotImplementedError("Abstract method")


class L1Regularizer(BaseRegularizer):
class L1Regularizer(Regularizer):
"""L1 regularizer."""

def __init__(self,
name: str = "train_l1",
weight: float = 1.0e-8) -> None:
BaseRegularizer.__init__(self, name, weight)
"""Create the regularizer.
Arguments:
name: Regularizer name.
weight: Weight of the regularization term.
"""
Regularizer.__init__(self, name, weight)

def value(self, variables: List[tf.Tensor]) -> float:
return sum(tf.reduce_sum(abs(v)) for v in variables)


class L2Regularizer(BaseRegularizer):
class L2Regularizer(Regularizer):
"""L2 regularizer."""

def __init__(self,
name: str = "train_l2",
weight: float = 1.0e-8) -> None:
BaseRegularizer.__init__(self, name, weight)
"""Create the regularizer.
Arguments:
name: Regularizer name.
weight: Weight of the regularization term.
"""
Regularizer.__init__(self, name, weight)

def value(self, variables: List[tf.Tensor]) -> float:
return sum(tf.reduce_sum(v ** 2) for v in variables)


class EWCRegularizer(BaseRegularizer):
class EWCRegularizer(Regularizer):
"""Regularizer based on the Elastic Weight Consolidation.
TODO description
Implements Elastic Weight Consolidation from the "Overcoming catastrophic
forgetting in neural networks" paper.
https://arxiv.org/pdf/1612.00796.pdf
"""

def __init__(self,
name: str = "train_ewc",
weight: float = 0.,
gradients_file: str = None,
variables_file: str = None) -> None:
"""Create the regularizer.
Arguments:
name: Regularizer name.
weight: Weight of the regularization term.
gradients_file: File containing the gradient estimates
from the previous task.
variables_files: File containing the variables learned
on the previous task.
"""
check_argument_types()

BaseRegularizer.__init__(self, name, weight)
Regularizer.__init__(self, name, weight)

if gradients_file is None:
raise ValueError("Missing gradients_file")
Expand All @@ -88,13 +121,16 @@ def __init__(self,
log("Gradient estimates loaded")

def value(self, variables: List[tf.Tensor]) -> float:
ewc_value = 0.0
ewc_value = tf.constant(0.0)
for var in variables:
var_name = var.name.split(":")[0]
init_var = self.init_vars.get_tensor(var_name)
gradient = self.gradients[var_name]
ewc_value += tf.reduce_sum(tf.multiply(
tf.square(gradient), tf.square(var - init_var)))
if (var_name in self.gradients.files
and self.init_vars.has_tensor(var_name)):
init_var = self.init_vars.get_tensor(var_name)
gradient = tf.constant(
self.gradients[var_name], name="ewc_gradients")
ewc_value += tf.reduce_sum(tf.multiply(
tf.square(gradient), tf.square(var - init_var)))

return ewc_value

Expand Down

0 comments on commit f7e6795

Please sign in to comment.