From df971738c88e3f9d98356c2d1e7d74a1ce4d953a Mon Sep 17 00:00:00 2001 From: Dusan Varis Date: Thu, 26 Jul 2018 00:10:58 +0200 Subject: [PATCH] implementation of EWC, gradient_runner and gradient averaging script --- neuralmonkey/learning_utils.py | 2 +- neuralmonkey/runners/gradient_runner.py | 72 +++++++++++++++ .../trainers/cross_entropy_trainer.py | 7 +- neuralmonkey/trainers/generic_trainer.py | 39 ++++---- neuralmonkey/trainers/regularizers.py | 91 +++++++++++++++++++ scripts/avg_tensors.py | 56 ++++++++++++ 6 files changed, 246 insertions(+), 21 deletions(-) create mode 100644 neuralmonkey/runners/gradient_runner.py create mode 100644 neuralmonkey/trainers/regularizers.py create mode 100755 scripts/avg_tensors.py diff --git a/neuralmonkey/learning_utils.py b/neuralmonkey/learning_utils.py index 7d4921821..9e94bd679 100644 --- a/neuralmonkey/learning_utils.py +++ b/neuralmonkey/learning_utils.py @@ -423,7 +423,7 @@ def _check_savable_dict(data): return False supported_type = Union[ - List[Dict[str, np.ndarray]], + List[Dict[str, Union[np.ndarray, np.float32]]], List[List[Dict[str, np.ndarray]]]] try: diff --git a/neuralmonkey/runners/gradient_runner.py b/neuralmonkey/runners/gradient_runner.py new file mode 100644 index 000000000..45d182ba0 --- /dev/null +++ b/neuralmonkey/runners/gradient_runner.py @@ -0,0 +1,72 @@ +from typing import Dict, List, Set, Union + +from typeguard import check_argument_types + +from neuralmonkey.runners.base_runner import ( + BaseRunner, Executable, FeedDict, ExecutionResult, NextExecute) +from neuralmonkey.model.model_part import ModelPart +from neuralmonkey.decoders.autoregressive import AutoregressiveDecoder +from neuralmonkey.decoders.classifier import Classifier +from neuralmonkey.trainers.generic_trainer import GenericTrainer + +# pylint: disable=invalid-name +SupportedDecoder = Union[AutoregressiveDecoder, Classifier] +# pylint: enable=invalid-name + + +class GradientRunExecutable(Executable): + + def __init__(self, + all_coders: Set[ModelPart], + fetches: FeedDict) -> None: + self._all_coders = all_coders + self._fetches = fetches + + self.result = None + + def next_to_execute(self) -> NextExecute: + """Get the feedables and tensors to run.""" + return self._all_coders, self._fetches, [] + + def collect_results(self, results: List[Dict]) -> None: + assert len(results) == 1 + + for sess_result in results: + gradient_dict = {} + tensor_names = [t.name for t in self._fetches["gradients"]] + for name, val in zip(tensor_names, sess_result["gradients"]): + gradient_dict[name] = val + + self.result = ExecutionResult( + outputs=[gradient_dict], + losses=[], + scalar_summaries=None, + histogram_summaries=None, + image_summaries=None) + + +class GradientRunner(BaseRunner[SupportedDecoder]): + + def __init__(self, + output_series: str, + trainer: GenericTrainer, + decoder: SupportedDecoder) -> None: + check_argument_types() + BaseRunner[AutoregressiveDecoder].__init__( + self, output_series, decoder) + + self._gradients = trainer.gradients + + # pylint: disable=unused-argument + def get_executable(self, + compute_losses: bool, + summaries: bool, + num_sessions: int) -> GradientRunExecutable: + fetches = {"gradients": [g[1] for g in self._gradients]} + + return GradientRunExecutable(self.all_coders, fetches) + # pylint: enable=unused-argument + + @property + def loss_names(self) -> List[str]: + return [] diff --git a/neuralmonkey/trainers/cross_entropy_trainer.py b/neuralmonkey/trainers/cross_entropy_trainer.py index 83083c39d..2f8282903 100644 --- a/neuralmonkey/trainers/cross_entropy_trainer.py +++ b/neuralmonkey/trainers/cross_entropy_trainer.py @@ -5,6 +5,7 @@ from neuralmonkey.trainers.generic_trainer import ( GenericTrainer, Objective, ObjectiveWeight) +from neuralmonkey.trainers.regularizers import BaseRegularizer def xent_objective(decoder, weight=None) -> Objective: @@ -25,10 +26,9 @@ class CrossEntropyTrainer(GenericTrainer): def __init__(self, decoders: List[Any], decoder_weights: List[ObjectiveWeight] = None, - l1_weight: float = 0., - l2_weight: float = 0., clip_norm: float = None, optimizer: tf.train.Optimizer = None, + regularizers: List[BaseRegularizer] = None, var_scopes: List[str] = None, var_collection: str = None) -> None: check_argument_types() @@ -47,9 +47,8 @@ def __init__(self, GenericTrainer.__init__( self, objectives=objectives, - l1_weight=l1_weight, - l2_weight=l2_weight, clip_norm=clip_norm, optimizer=optimizer, + regularizers=regularizers, var_scopes=var_scopes, var_collection=var_collection) diff --git a/neuralmonkey/trainers/generic_trainer.py b/neuralmonkey/trainers/generic_trainer.py index 5ea79b87e..d5b72b551 100644 --- a/neuralmonkey/trainers/generic_trainer.py +++ b/neuralmonkey/trainers/generic_trainer.py @@ -6,6 +6,7 @@ from neuralmonkey.model.model_part import ModelPart from neuralmonkey.runners.base_runner import ( Executable, ExecutionResult, NextExecute) +from neuralmonkey.trainers.regularizers import BaseRegularizer # pylint: disable=invalid-name Gradients = List[Tuple[tf.Tensor, tf.Variable]] @@ -41,13 +42,17 @@ class GenericTrainer: def __init__(self, objectives: List[Objective], - l1_weight: float = 0.0, - l2_weight: float = 0.0, clip_norm: float = None, optimizer: tf.train.Optimizer = None, + regularizers: List[BaseRegularizer] = None, var_scopes: List[str] = None, var_collection: str = None) -> None: + if regularizers is not None: + self.regularizers = regularizers + else: + self.regularizers = [] + if var_collection is None: var_collection = tf.GraphKeys.TRAINABLE_VARIABLES @@ -84,18 +89,15 @@ def __init__(self, and not v.name.startswith("vgg") and not v.name.startswith("Inception") and not v.name.startswith("resnet")] - l1_value = sum(tf.reduce_sum(abs(v)) for v in regularizable) - l1_cost = l1_weight * l1_value if l1_weight > 0 else 0.0 - - l2_value = sum(tf.reduce_sum(v ** 2) for v in regularizable) - l2_cost = l2_weight * l2_value if l2_weight > 0 else 0.0 + reg_values = [reg.value(regularizable) + for reg in self.regularizers] # unweighted losses for fetching - self.losses = [o.loss for o in objectives] + [l1_value, l2_value] - tf.summary.scalar("train_l1", l1_value, - collections=["summary_train"]) - tf.summary.scalar("train_l2", l2_value, - collections=["summary_train"]) + self.losses = [o.loss for o in objectives] + reg_values + + for reg, reg_value in zip(self.regularizers, reg_values): + tf.summary.scalar(reg.name, reg_value, + collections=["summary_train"]) # log all objectives for obj in objectives: @@ -108,9 +110,13 @@ def __init__(self, with tf.control_dependencies(update_ops): with tf.name_scope("gradient_collection"): differentiable_loss_sum = sum( - (o.weight if o.weight is not None else 1) * o.loss + (o.weight if o.weight is not None else 1.) * o.loss for o in objectives - if o.gradients is None) + l1_cost + l2_cost + if o.gradients is None) + differentiable_loss_sum += sum( + reg.weight * reg_value + for reg, reg_value in zip(self.regularizers, + reg_values)) implicit_gradients = self._get_gradients( differentiable_loss_sum) @@ -138,10 +144,11 @@ def __init__(self, self.all_coders = set.union(*(obj.decoder.get_dependencies() for obj in objectives)) + self.gradients = gradients self.train_op = self.optimizer.apply_gradients( - gradients, global_step=step) + self.gradients, global_step=step) - for grad, var in gradients: + for grad, var in self.gradients: if grad is not None: tf.summary.histogram( "gr_" + var.name, diff --git a/neuralmonkey/trainers/regularizers.py b/neuralmonkey/trainers/regularizers.py new file mode 100644 index 000000000..d21f47ba0 --- /dev/null +++ b/neuralmonkey/trainers/regularizers.py @@ -0,0 +1,91 @@ +"""Variable regularizers. + +This module contains classes that can be used as a variable regularizers +during training. All implementation should be derived from the BaseRegularizer +class. + +""" +from typing import List + +import numpy as np +import tensorflow as tf +from typeguard import check_argument_types + +from neuralmonkey.logging import log + + +# pylint: disable=too-few-public-methods +class BaseRegularizer: + """Base class for the regularizers.""" + + def __init__(self, + name: str, + weight: float) -> None: + check_argument_types() + self.name = name + self.weight = weight + + def value(self, variables) -> float: + raise NotImplementedError("Abstract method") + + +class L1Regularizer(BaseRegularizer): + + def __init__(self, + name: str = "train_l1", + weight: float = 0.) -> None: + BaseRegularizer.__init__(self, name, weight) + + def value(self, variables: List[tf.Tensor]) -> float: + return sum(tf.reduce_sum(abs(v)) for v in variables) + + +class L2Regularizer(BaseRegularizer): + + def __init__(self, + name: str = "train_l2", + weight: float = 0.) -> None: + BaseRegularizer.__init__(self, name, weight) + + def value(self, variables: List[tf.Tensor]) -> float: + return sum(tf.reduce_sum(v ** 2) for v in variables) + + +class EWCRegularizer(BaseRegularizer): + """Regularizer based on the Elastic Weight Consolidation. + + TODO description + """ + + def __init__(self, + name: str = "train_ewc", + weight: float = 0., + gradients_file: str = None, + variables_file: str = None) -> None: + check_argument_types() + + BaseRegularizer.__init__(self, name, weight) + + if gradients_file is None: + raise ValueError("Missing gradients_file") + if variables_file is None: + raise ValueError("Missing variables_file") + + log("Loading initial variables for EWC from {}".format(variables_file)) + self.init_vars = tf.contrib.framework.load_checkpoint(variables_file) + log("EWC initial variables loaded") + + log("Loading gradient estimates from {}".format(gradients_file)) + self.gradients = np.load(gradients_file) + log("Gradient estimates loaded") + + def value(self, variables: List[tf.Tensor]) -> float: + ewc_value = 0.0 + for var in variables: + var_name = var.name.split(":")[0] + init_var = self.init_vars.get_tensor(var_name) + gradient = self.gradients[var_name] + ewc_value += tf.reduce_sum(tf.multiply( + tf.square(gradient), tf.square(var - init_var))) + + return ewc_value diff --git a/scripts/avg_tensors.py b/scripts/avg_tensors.py new file mode 100755 index 000000000..ad99cc309 --- /dev/null +++ b/scripts/avg_tensors.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 +"""Compute the mean over a set of tensors. + +The tensors can be spread over multiple npz files. The mean is computed +over the first dimension (supposed to be a batch). + +""" + +import argparse +import os +import re +import glob + +import numpy as np + +from neuralmonkey.logging import log as _log + + +def log(message: str, color: str = "blue") -> None: + _log(message, color) + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--file_prefix", type=str, + help="prefix of the npz files to be averaged") + parser.add_argument("--output_path", type=str, + help="Path to output the averaged checkpoint to.") + args = parser.parse_args() + + output_dict = {} + n = 0 + for file in glob.glob("{}.*npz".format(args.file_prefix)): + log("Processing {}".format(file)) + tensors = np.load(file) + + # first dimension must be equal for all tensors (batch) + shapes = [tensors[f].shape for f in tensors.files] + assert all([x[0] == shapes[0][0] for x in shapes]) + + for varname in tensors.files: + res = np.sum(tensors[varname], 0) + if varname in output_dict: + output_dict[varname] += res + else: + output_dict[varname] = res + n += shapes[0][0] + + for name in output_dict: + output_dict[name] /= n + + np.savez(args.output_path, **output_dict) + + +if __name__ == "__main__": + main()