From c676d9bb5d1d0f8474890484f808d53a732fea8f Mon Sep 17 00:00:00 2001 From: varisd Date: Wed, 30 Jan 2019 18:59:10 +0100 Subject: [PATCH] rebased EWC to master --- neuralmonkey/dataset.py | 16 +--- neuralmonkey/evaluators/average.py | 1 - neuralmonkey/model/parameterized.py | 2 +- neuralmonkey/model/sequence.py | 1 - neuralmonkey/processors/helpers.py | 3 +- neuralmonkey/runners/gradient_runner.py | 74 +++++++------------ .../trainers/delayed_update_trainer.py | 52 +++++++------ neuralmonkey/trainers/generic_trainer.py | 27 ++++--- neuralmonkey/trainers/regularizers.py | 26 +++---- 9 files changed, 86 insertions(+), 116 deletions(-) diff --git a/neuralmonkey/dataset.py b/neuralmonkey/dataset.py index dbad8bc67..f0c9e63de 100644 --- a/neuralmonkey/dataset.py +++ b/neuralmonkey/dataset.py @@ -372,7 +372,7 @@ def _add_preprocessed_series(iterators, s_name, prep_sl): _add_preprocessed_series(iterators, source, prep_sl) if source not in iterators: raise ValueError( - "Source series {} for series-level preprocessor nonexistent: " + "Source series {} for series-level preprocessor nonexistent: " "Preprocessed series '', source series ''".format(source)) iterators[s_name] = _make_sl_iterator(source, preprocessor) @@ -469,7 +469,6 @@ def __init__(self, self.buffer_min_size, self.buffer_size = buffer_size else: self.lazy = False - self.buffer_size = None self.shuffled = shuffled self.length = None @@ -585,11 +584,6 @@ def batches(self) -> Iterator["Dataset"]: random.shuffle(lbuf) buf = deque(lbuf) - def _make_datagen(rows, key): - def itergen(): - return (row[key] for row in rows) - return itergen - def _make_datagen(rows, key): def itergen(): return (row[key] for row in rows) @@ -653,14 +647,6 @@ def itergen(): if self.shuffled: random.shuffle(buf) # type: ignore - for bucket_id in buckets: - if buckets[bucket_id]: - name = "{}.batch.{}".format(self.name, batch_index) - data = {key: _make_datagen(buckets[bucket_id], key) - for key in buckets[bucket_id][0]} - - yield Dataset(name=name, iterators=data) - batch_index += 1 if not self.batching.drop_remainder: for bucket in buckets: diff --git a/neuralmonkey/evaluators/average.py b/neuralmonkey/evaluators/average.py index 3b1db59b6..d146c8e66 100644 --- a/neuralmonkey/evaluators/average.py +++ b/neuralmonkey/evaluators/average.py @@ -2,7 +2,6 @@ # This evaluator here is just an ugly hack to work with perplexity runner from neuralmonkey.evaluators.evaluator import Evaluator -import numpy as np class AverageEvaluator(Evaluator[float]): """Just average the numeric output of a runner.""" diff --git a/neuralmonkey/model/parameterized.py b/neuralmonkey/model/parameterized.py index 3eab213d7..2cba03a42 100644 --- a/neuralmonkey/model/parameterized.py +++ b/neuralmonkey/model/parameterized.py @@ -5,7 +5,7 @@ import tensorflow as tf from neuralmonkey.tf_utils import update_initializers -from neuralmonkey.logging import log, warn +from neuralmonkey.logging import log # pylint: enable=invalid-name InitializerSpecs = List[Tuple[str, Callable]] diff --git a/neuralmonkey/model/sequence.py b/neuralmonkey/model/sequence.py index fb0e72a80..9e66aeb98 100644 --- a/neuralmonkey/model/sequence.py +++ b/neuralmonkey/model/sequence.py @@ -188,7 +188,6 @@ def temporal_states(self) -> tf.Tensor: emb_factor = emb_factor * tf.expand_dims(self.temporal_mask, -1) embedded_factors.append(emb_factor) - return tf.concat(embedded_factors, 2) # pylint: disable=unsubscriptable-object diff --git a/neuralmonkey/processors/helpers.py b/neuralmonkey/processors/helpers.py index 142b05a32..49b3f0c78 100644 --- a/neuralmonkey/processors/helpers.py +++ b/neuralmonkey/processors/helpers.py @@ -1,6 +1,5 @@ from typing import Any, Callable, Generator, List - -import numpy as np +from random import randint def preprocess_char_based(sentence: List[str]) -> List[str]: diff --git a/neuralmonkey/runners/gradient_runner.py b/neuralmonkey/runners/gradient_runner.py index b44c233ff..cd4605e17 100644 --- a/neuralmonkey/runners/gradient_runner.py +++ b/neuralmonkey/runners/gradient_runner.py @@ -1,10 +1,10 @@ -from typing import Any, Dict, List, Set, Union, Optional +from typing import Dict, List, Union +import tensorflow as tf from typeguard import check_argument_types -from neuralmonkey.runners.base_runner import ( - BaseRunner, Executable, ExecutionResult, NextExecute) -from neuralmonkey.model.model_part import ModelPart +from neuralmonkey.runners.base_runner import BaseRunner +from neuralmonkey.model.model_part import GenericModelPart from neuralmonkey.decoders.autoregressive import AutoregressiveDecoder from neuralmonkey.decoders.classifier import Classifier from neuralmonkey.trainers.generic_trainer import GenericTrainer @@ -14,38 +14,7 @@ # pylint: enable=invalid-name -class GradientRunnerExecutable(Executable): - - def __init__(self, - all_coders: Set[ModelPart], - fetches: Dict[str, List[Any]]) -> None: - self._all_coders = all_coders - self._fetches = fetches - - self.result = None # type: Optional[ExecutionResult] - - def next_to_execute(self) -> NextExecute: - """Get the feedables and tensors to run.""" - return self._all_coders, self._fetches, [] - - def collect_results(self, results: List[Dict]) -> None: - assert len(results) == 1 - - for sess_result in results: - gradient_dict = {} - tensor_names = [t.name for t in self._fetches["gradients"]] - for name, val in zip(tensor_names, sess_result["gradients"]): - gradient_dict[name] = val - - self.result = ExecutionResult( - outputs=[gradient_dict], - losses=[], - scalar_summaries=None, - histogram_summaries=None, - image_summaries=None) - - -class GradientRunner(BaseRunner[SupportedDecoder]): +class GradientRunner(BaseRunner[GenericModelPart]): """Runner for fetching gradients computed over the dataset. Gradient runner applies provided trainer on a desired dataset @@ -55,25 +24,34 @@ class GradientRunner(BaseRunner[SupportedDecoder]): (https://arxiv.org/pdf/1612.00796.pdf) """ + # pylint: disable=too-few-public-methods + class Executable(BaseRunner.Executable["GradientRunner"]): + + def collect_results(self, results: List[Dict]) -> None: + assert len(results) == 1 + + for sess_result in results: + gradient_dict = {} + tensor_names = [ + t.name for t in self.executor.fetches()["gradients"]] + for name, val in zip(tensor_names, sess_result["gradients"]): + gradient_dict[name] = val + + self.set_runner_result(outputs=gradient_dict, losses=[]) + # pylint: enable=too-few-public-methods + def __init__(self, output_series: str, - trainer: GenericTrainer, - decoder: SupportedDecoder) -> None: + decoder: SupportedDecoder, + trainer: GenericTrainer) -> None: check_argument_types() - BaseRunner[AutoregressiveDecoder].__init__( + BaseRunner[GenericModelPart].__init__( self, output_series, decoder) self._gradients = trainer.gradients - # pylint: disable=unused-argument - def get_executable(self, - compute_losses: bool, - summaries: bool, - num_sessions: int) -> GradientRunnerExecutable: - fetches = {"gradients": [g[1] for g in self._gradients]} - - return GradientRunnerExecutable(self.all_coders, fetches) - # pylint: enable=unused-argument + def fetches(self) -> Dict[str, tf.Tensor]: + return {"gradients": [g[1] for g in self._gradients]} @property def loss_names(self) -> List[str]: diff --git a/neuralmonkey/trainers/delayed_update_trainer.py b/neuralmonkey/trainers/delayed_update_trainer.py index dacea68e6..3dac99750 100644 --- a/neuralmonkey/trainers/delayed_update_trainer.py +++ b/neuralmonkey/trainers/delayed_update_trainer.py @@ -9,6 +9,8 @@ from neuralmonkey.runners.base_runner import GraphExecutor, NextExecute from neuralmonkey.trainers.generic_trainer import (GenericTrainer, Objective, Gradients) +from neuralmonkey.trainers.regularizers import (Regularizer, L1Regularizer, + L2Regularizer) class DelayedUpdateTrainer(GenericTrainer): @@ -76,7 +78,7 @@ def collect_results(self, results: List[Dict]) -> None: assert self.res_batch is not None objective_names = [obj.name for obj in self.executor.objectives] - objective_names += ["L1", "L2"] + objective_names += [reg.name for reg in self.executor.regularizers] losses = dict(zip(objective_names, self.res_losses)) self.set_result({}, losses, self.res_batch, self.res_sums) @@ -85,16 +87,14 @@ def collect_results(self, results: List[Dict]) -> None: def __init__(self, batches_per_update: int, objectives: List[Objective], - l1_weight: float = 0.0, - l2_weight: float = 0.0, clip_norm: float = None, optimizer: tf.train.Optimizer = None, + regularizers: List[Regularizer] = None, var_scopes: List[str] = None, var_collection: str = None) -> None: check_argument_types() - GenericTrainer.__init__(self, objectives, l1_weight, l2_weight, - clip_norm, optimizer, var_scopes, - var_collection) + GenericTrainer.__init__(self, objectives, clip_norm, optimizer, + regularizers, var_scopes, var_collection) self.batches_per_update = batches_per_update # pylint: enable=too-many-arguments @@ -186,17 +186,6 @@ def raw_gradients(self) -> Gradients: for grad in self.gradient_buffers] # pylint: enable=not-an-iterable - tf.summary.scalar( - "train_opt_cost", - self.diff_buffer / tf.to_float(self.cumulator_counter), - collections=["summary_train"]) - - # log all objectives - for obj, objbuf in zip(self.objectives, self.objective_buffers): - tf.summary.scalar( - obj.name, objbuf / tf.to_float(self.cumulator_counter), - collections=["summary_train"]) - # now, zip averaged grads with associated vars to a Gradients struct. # pylint: disable=unpacking-non-sequence _, existing_vars = self.existing_grads_and_vars @@ -205,18 +194,37 @@ def raw_gradients(self) -> Gradients: @tensor def summaries(self) -> Dict[str, tf.Tensor]: + # pylint: disable=protected-access if isinstance(self.optimizer._lr, tf.Tensor): tf.summary.scalar("learning_rate", self.optimizer._lr, collections=["summary_train"]) # pylint: enable=protected-access - # pylint: disable=unpacking-non-sequence - l1_norm, l2_norm = self.regularization_losses - # pylint: enable=unpacking-non-sequence + reg_values = self.regularization_losses + # we always want to include l2 values in the summary + if L1Regularizer not in [type(r) for r in self.regularizers]: + l1_reg = L1Regularizer(name="train_l1", weight=0.) + tf.summary.scalar(l1_reg.name, l1_reg.value(self.regularizable), + collections=["summary_train"]) + if L2Regularizer not in [type(r) for r in self.regularizers]: + l2_reg = L2Regularizer(name="train_l2", weight=0.) + tf.summary.scalar(l2_reg.name, l2_reg.value(self.regularizable), + collections=["summary_train"]) + + for reg, reg_value in zip(self.regularizers, reg_values): + tf.summary.scalar(reg.name, reg_value, + collections=["summary_train"]) - tf.summary.scalar("train_l1", l1_norm, collections=["summary_train"]) - tf.summary.scalar("train_l2", l2_norm, collections=["summary_train"]) + for obj, objbuf in zip(self.objectives, self.objective_buffers): + tf.summary.scalar( + obj.name, objbuf / tf.to_float(self.cumulator_counter), + collections=["summary_train"]) + + tf.summary.scalar( + "train_opt_cost", + self.diff_buffer / tf.to_float(self.cumulator_counter), + collections=["summary_train"]) # pylint: disable=not-an-iterable # Pylint does not understand @tensor annotations diff --git a/neuralmonkey/trainers/generic_trainer.py b/neuralmonkey/trainers/generic_trainer.py index 2770a5550..05d056c10 100644 --- a/neuralmonkey/trainers/generic_trainer.py +++ b/neuralmonkey/trainers/generic_trainer.py @@ -48,7 +48,7 @@ def collect_results(self, results: List[Dict]) -> None: result["histogram_summaries"]]) objective_names = [obj.name for obj in self.executor.objectives] - objective_names += ["L1", "L2"] + objective_names += [reg.name for reg in self.executor.regularizers] losses = dict(zip(objective_names, result["losses"])) @@ -86,25 +86,28 @@ def __init__(self, optimizer if optimizer is not None else self.default_optimizer()) # pylint: disable=no-self-use + @tensor + def regularizable(self) -> List[tf.Tensor]: + return [v for v in tf.trainable_variables() + if not BIAS_REGEX.findall(v.name) + and not v.name.startswith("vgg") + and not v.name.startswith("Inception") + and not v.name.startswith("resnet")] + # pylint: enable=no-self-use + @tensor def regularization_losses(self) -> List[tf.Tensor]: """Compute the regularization losses, e.g. L1 and L2.""" - regularizable = [v for v in tf.trainable_variables() - if not BIAS_REGEX.findall(v.name) - and not v.name.startswith("vgg") - and not v.name.startswith("Inception") - and not v.name.startswith("resnet")] - + regularizable = self.regularizable if not regularizable: warn("It seems that there are no trainable variables in the model") - return tf.zeros([]), tf.zeros([]) + return [tf.zeros([]) for _ in self.regularizers] with tf.name_scope("regularization"): reg_values = [reg.value(regularizable) for reg in self.regularizers] return reg_values - # pylint: enable=no-self-use @tensor def objective_values(self) -> List[tf.Tensor]: @@ -127,7 +130,7 @@ def differentiable_loss_sum(self) -> tf.Tensor: else: obj_weights.append(obj.weight) - obj_weights += [reg.weights for reg in self.regularizers] + obj_weights += [reg.weight for reg in self.regularizers] diff_loss = sum( o * w for o, w in zip(self.objective_values, obj_weights) if w is not None) @@ -219,11 +222,11 @@ def summaries(self) -> Dict[str, tf.Tensor]: # we always want to include l2 values in the summary if L1Regularizer not in [type(r) for r in self.regularizers]: l1_reg = L1Regularizer(name="train_l1", weight=0.) - tf.summary.scalar(l1_reg.name, l1_reg.value(regularizable), + tf.summary.scalar(l1_reg.name, l1_reg.value(self.regularizable), collections=["summary_train"]) if L2Regularizer not in [type(r) for r in self.regularizers]: l2_reg = L2Regularizer(name="train_l2", weight=0.) - tf.summary.scalar(l2_reg.name, l2_reg.value(regularizable), + tf.summary.scalar(l2_reg.name, l2_reg.value(self.regularizable), collections=["summary_train"]) for reg, reg_value in zip(self.regularizers, reg_values): diff --git a/neuralmonkey/trainers/regularizers.py b/neuralmonkey/trainers/regularizers.py index b9f7b1b67..527febd96 100644 --- a/neuralmonkey/trainers/regularizers.py +++ b/neuralmonkey/trainers/regularizers.py @@ -5,7 +5,7 @@ class. """ from abc import ABCMeta, abstractmethod -from typing import List, Union +from typing import List import numpy as np import tensorflow as tf @@ -104,34 +104,30 @@ class EWCRegularizer(Regularizer): def __init__(self, name: str, weight: float, - fisher_file: Union[str, List[str]], - variables_file: Union[str, List[str]]) -> None: + fisher_files: List[str], + variables_files: List[str]) -> None: # TODO: change *file -> *files """Create the regularizer. Arguments: name: Regularizer name. weight: Weight of the regularization term. - fisher_file: File containing the diagonal of the fisher information - matrix estimated on the previous task. + fisher_files: File containing the diagonal of the fisher + information matrix estimated on the previous task. variables_files: File containing the variables learned on the previous task. """ - if isinstance(fisher_file, str): - fisher_file = [fisher_file] - if isinstance(variables_file, str): - variables_file = [variables_file] - check_argument_types() Regularizer.__init__(self, name, weight) log("Loading initial variables for EWC from {}." - .format(variables_file)) - self.init_vars = [tf.contrib.framework.load_checkpoint(f) for f in variables_file] + .format(variables_files)) + self.init_vars = [tf.contrib.framework.load_checkpoint(f) + for f in variables_files] log("EWC initial variables loaded.") - log("Loading gradient estimates from {}.".format(fisher_file)) - self.fisher = [np.load(f) for f in fisher_file] + log("Loading gradient estimates from {}.".format(fisher_files)) + self.fisher = [np.load(f) for f in fisher_files] log("Gradient estimates loaded.") def value(self, variables: List[tf.Tensor]) -> tf.Tensor: @@ -145,6 +141,7 @@ def value(self, variables: List[tf.Tensor]) -> tf.Tensor: ewc_value = tf.constant(0.0) for var in variables: + # pylint: disable=invalid-name for f, v in zip(self.fisher, self.init_vars): init_var_name = var.name.split(":")[0] if (var.name in f.files @@ -154,5 +151,6 @@ def value(self, variables: List[tf.Tensor]) -> tf.Tensor: name="{}_init_value".format(init_var_name)) ewc_value += tf.reduce_sum(tf.multiply( f[var.name], tf.square(var - init_var))) + # pylint: enable=invalid-name return ewc_value