Skip to content

Commit

Permalink
addressing PR reviews + fixed variable fetching in EWCRegularizer
Browse files Browse the repository at this point in the history
  • Loading branch information
varisd committed Aug 7, 2018
1 parent 5fb7fc9 commit 9c94e76
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 58 deletions.
15 changes: 5 additions & 10 deletions neuralmonkey/trainers/cross_entropy_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import tensorflow as tf
from typeguard import check_argument_types

from neuralmonkey.logging import warn
from neuralmonkey.trainers.generic_trainer import (
GenericTrainer, Objective, ObjectiveWeight)
from neuralmonkey.trainers.regularizers import (
Expand Down Expand Up @@ -42,17 +41,13 @@ def __init__(self,

if regularizers is None:
regularizers = []
if l1_weight > 0.:
if L1Regularizer in [type(r) for r in regularizers]:
warn("You specified both trainer l1_weight "
"and a L1Regularizer object in your config")
regularizers.append(L1Regularizer(weight=l1_weight))

if l1_weight > 0.:
regularizers.append(
L1Regularizer(name="train_l1", weight=l1_weight))
if l2_weight > 0.:
if L2Regularizer in [type(r) for r in regularizers]:
warn("You specified both trainer l2_weight "
"and a L2Regularizer object in your config")
regularizers.append(L2Regularizer(weight=l2_weight))
regularizers.append(
L2Regularizer(name="train_l2", weight=l2_weight))

if len(decoder_weights) != len(decoders):
raise ValueError(
Expand Down
23 changes: 12 additions & 11 deletions neuralmonkey/trainers/generic_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
from neuralmonkey.model.model_part import ModelPart
from neuralmonkey.runners.base_runner import (
Executable, ExecutionResult, NextExecute)
from neuralmonkey.trainers.regularizers import (
Regularizer, L2Regularizer)
from neuralmonkey.trainers.regularizers import (Regularizer, L2Regularizer)

# pylint: disable=invalid-name
Gradients = List[Tuple[tf.Tensor, tf.Variable]]
Expand Down Expand Up @@ -40,6 +39,7 @@ class Objective(NamedTuple(


# pylint: disable=too-few-public-methods,too-many-locals,too-many-branches
# pylint: disable=too-many-statements
class GenericTrainer:

def __init__(self,
Expand Down Expand Up @@ -102,7 +102,9 @@ def __init__(self,

# we always want to include l2 values in the summary
if L2Regularizer not in [type(r) for r in self.regularizers]:
reg_values.append(L2Regularizer().value(regularizable))
l2_reg = L2Regularizer(name="train_l2", weight=0.)
tf.summary.scalar(l2_reg.name, l2_reg.value(regularizable),
collections=["summary_train"])
for reg, reg_value in zip(self.regularizers, reg_values):
tf.summary.scalar(reg.name, reg_value,
collections=["summary_train"])
Expand All @@ -119,8 +121,8 @@ def __init__(self,
with tf.name_scope("gradient_collection"):
differentiable_loss_sum = sum(
[(o.weight if o.weight is not None else 1.) * o.loss
for o in objectives
if o.gradients is None] + reg_costs)
for o in objectives if o.gradients is None])
differentiable_loss_sum += sum(reg_costs)
implicit_gradients = self._get_gradients(
differentiable_loss_sum)

Expand All @@ -130,25 +132,24 @@ def __init__(self,
for o in objectives if o.gradients is not None]

if other_gradients:
gradients = _sum_gradients(
self.gradients = _sum_gradients(
[implicit_gradients] + other_gradients)
else:
gradients = implicit_gradients
self.gradients = implicit_gradients

tf.summary.scalar("train_opt_cost",
differentiable_loss_sum,
collections=["summary_train"])

if clip_norm:
assert clip_norm > 0.0
gradients = [(tf.clip_by_norm(grad, clip_norm), var)
for grad, var in gradients
if grad is not None]
self.gradients = [(tf.clip_by_norm(grad, clip_norm), var)
for grad, var in self.gradients
if grad is not None]

self.all_coders = set.union(*(obj.decoder.get_dependencies()
for obj in objectives))

self.gradients = gradients
self.train_op = self.optimizer.apply_gradients(
self.gradients, global_step=step)

Expand Down
87 changes: 50 additions & 37 deletions neuralmonkey/trainers/regularizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
This module contains classes that can be used as a variable regularizers
during training. All implementation should be derived from the Regularizer
class.
"""
from abc import ABCMeta, abstractmethod
from typing import List

import numpy as np
Expand All @@ -14,8 +14,14 @@
from neuralmonkey.logging import log


class Regularizer:
"""Base class for the regularizers."""
class Regularizer(metaclass=ABCMeta):
"""Base clas s for regularizers.
Regularizer objects are used to introduce additional loss terms to
the trainerthus constraining the model variable during training. These
loss terms have an adjustable weight allowing to set the ``importance''
of the term.
"""

def __init__(self,
name: str,
Expand All @@ -24,10 +30,9 @@ def __init__(self,
Arguments:
name: Regularizer name.
weight: Weight of the regularization term.
weight: Weight of the regularization term (usually expressed
as ``lambda'' in the literature).
"""
check_argument_types()

self._name = name
self._weight = weight

Expand All @@ -39,34 +44,40 @@ def name(self) -> str:
def weight(self) -> float:
return self._weight

def value(self, variables) -> float:
@abstractmethod
def value(self, variables: List[tf.Tensor]) -> tf.Tensor:
"""Compute the unweighted value of the regularization loss term.
Arguments:
variables: List of the regularizable model variables.
"""
raise NotImplementedError("Abstract method")


class L1Regularizer(Regularizer):
"""L1 regularizer."""

def __init__(self,
name: str = "train_l1",
weight: float = 1.0e-8) -> None:
name: str,
weight: float) -> None:
"""Create the regularizer.
Arguments:
name: Regularizer name.
weight: Weight of the regularization term.
weight: Weight of the regularization term (default=1.0e-8.
"""
Regularizer.__init__(self, name, weight)

def value(self, variables: List[tf.Tensor]) -> float:
def value(self, variables: List[tf.Tensor]) -> tf.Tensor:
return sum(tf.reduce_sum(abs(v)) for v in variables)


class L2Regularizer(Regularizer):
"""L2 regularizer."""

def __init__(self,
name: str = "train_l2",
weight: float = 1.0e-8) -> None:
name: str,
weight: float) -> None:
"""Create the regularizer.
Arguments:
Expand All @@ -75,7 +86,7 @@ def __init__(self,
"""
Regularizer.__init__(self, name, weight)

def value(self, variables: List[tf.Tensor]) -> float:
def value(self, variables: List[tf.Tensor]) -> tf.Tensor:
return sum(tf.reduce_sum(v ** 2) for v in variables)


Expand All @@ -84,15 +95,18 @@ class EWCRegularizer(Regularizer):
Implements Elastic Weight Consolidation from the "Overcoming catastrophic
forgetting in neural networks" paper.
The regularizer applies separate regularization weight to each trainable
variable based on how important the variable was for the previously
learned task.
https://arxiv.org/pdf/1612.00796.pdf
"""

def __init__(self,
name: str = "train_ewc",
weight: float = 0.,
gradients_file: str = None,
variables_file: str = None) -> None:
name: str,
weight: float,
gradients_file: str,
variables_file: str) -> None:
"""Create the regularizer.
Arguments:
Expand All @@ -104,36 +118,35 @@ def __init__(self,
on the previous task.
"""
check_argument_types()

Regularizer.__init__(self, name, weight)

if gradients_file is None:
raise ValueError("Missing gradients_file")
if variables_file is None:
raise ValueError("Missing variables_file")

log("Loading initial variables for EWC from {}".format(variables_file))
log("Loading initial variables for EWC from "
"{}.".format(variables_file))
self.init_vars = tf.contrib.framework.load_checkpoint(variables_file)
log("EWC initial variables loaded")
log("EWC initial variables loaded.")

log("Loading gradient estimates from {}".format(gradients_file))
log("Loading gradient estimates from {}.".format(gradients_file))
self.gradients = np.load(gradients_file)
log("Gradient estimates loaded")
log("Gradient estimates loaded.")

def value(self, variables: List[tf.Tensor]) -> float:
def value(self, variables: List[tf.Tensor]) -> tf.Tensor:
ewc_value = tf.constant(0.0)
for var in variables:
var_name = var.name.split(":")[0]
var_name = var.name
init_var_name = var_name.split(":")[0]
if (var_name in self.gradients.files
and self.init_vars.has_tensor(var_name)):
init_var = self.init_vars.get_tensor(var_name)
gradient = tf.constant(
self.gradients[var_name], name="ewc_gradients")
and self.init_vars.has_tensor(init_var_name)):
init_var = tf.constant(
self.init_vars.get_tensor(init_var_name),
name="{}_init_value".format(init_var_name))
grad_squared = tf.constant(
np.square(self.gradients[var_name]),
name="{}_ewc_weight".format(init_var_name))
ewc_value += tf.reduce_sum(tf.multiply(
tf.square(gradient), tf.square(var - init_var)))
grad_squared, tf.square(var - init_var)))

return ewc_value


L1 = L1Regularizer()
L2 = L2Regularizer()
L1 = L1Regularizer(name="train_l1", weight=1.0e-8)
L2 = L2Regularizer(name="train_l2", weight=1.0e-8)

0 comments on commit 9c94e76

Please sign in to comment.