Skip to content

Commit

Permalink
rebased EWC to master
Browse files Browse the repository at this point in the history
  • Loading branch information
varisd committed Jan 30, 2019
1 parent dbe24ec commit c676d9b
Show file tree
Hide file tree
Showing 9 changed files with 86 additions and 116 deletions.
16 changes: 1 addition & 15 deletions neuralmonkey/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ def _add_preprocessed_series(iterators, s_name, prep_sl):
_add_preprocessed_series(iterators, source, prep_sl)
if source not in iterators:
raise ValueError(
"Source series {} for series-level preprocessor nonexistent: "
"Source series {} for series-level preprocessor nonexistent: "
"Preprocessed series '', source series ''".format(source))
iterators[s_name] = _make_sl_iterator(source, preprocessor)

Expand Down Expand Up @@ -469,7 +469,6 @@ def __init__(self,
self.buffer_min_size, self.buffer_size = buffer_size
else:
self.lazy = False
self.buffer_size = None

self.shuffled = shuffled
self.length = None
Expand Down Expand Up @@ -585,11 +584,6 @@ def batches(self) -> Iterator["Dataset"]:
random.shuffle(lbuf)
buf = deque(lbuf)

def _make_datagen(rows, key):
def itergen():
return (row[key] for row in rows)
return itergen

def _make_datagen(rows, key):
def itergen():
return (row[key] for row in rows)
Expand Down Expand Up @@ -653,14 +647,6 @@ def itergen():

if self.shuffled:
random.shuffle(buf) # type: ignore
for bucket_id in buckets:
if buckets[bucket_id]:
name = "{}.batch.{}".format(self.name, batch_index)
data = {key: _make_datagen(buckets[bucket_id], key)
for key in buckets[bucket_id][0]}

yield Dataset(name=name, iterators=data)
batch_index += 1

if not self.batching.drop_remainder:
for bucket in buckets:
Expand Down
1 change: 0 additions & 1 deletion neuralmonkey/evaluators/average.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# This evaluator here is just an ugly hack to work with perplexity runner
from neuralmonkey.evaluators.evaluator import Evaluator

import numpy as np

class AverageEvaluator(Evaluator[float]):
"""Just average the numeric output of a runner."""
Expand Down
2 changes: 1 addition & 1 deletion neuralmonkey/model/parameterized.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import tensorflow as tf

from neuralmonkey.tf_utils import update_initializers
from neuralmonkey.logging import log, warn
from neuralmonkey.logging import log

# pylint: enable=invalid-name
InitializerSpecs = List[Tuple[str, Callable]]
Expand Down
1 change: 0 additions & 1 deletion neuralmonkey/model/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,6 @@ def temporal_states(self) -> tf.Tensor:
emb_factor = emb_factor * tf.expand_dims(self.temporal_mask, -1)
embedded_factors.append(emb_factor)


return tf.concat(embedded_factors, 2)

# pylint: disable=unsubscriptable-object
Expand Down
3 changes: 1 addition & 2 deletions neuralmonkey/processors/helpers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Any, Callable, Generator, List

import numpy as np
from random import randint


def preprocess_char_based(sentence: List[str]) -> List[str]:
Expand Down
74 changes: 26 additions & 48 deletions neuralmonkey/runners/gradient_runner.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import Any, Dict, List, Set, Union, Optional
from typing import Dict, List, Union

import tensorflow as tf
from typeguard import check_argument_types

from neuralmonkey.runners.base_runner import (
BaseRunner, Executable, ExecutionResult, NextExecute)
from neuralmonkey.model.model_part import ModelPart
from neuralmonkey.runners.base_runner import BaseRunner
from neuralmonkey.model.model_part import GenericModelPart
from neuralmonkey.decoders.autoregressive import AutoregressiveDecoder
from neuralmonkey.decoders.classifier import Classifier
from neuralmonkey.trainers.generic_trainer import GenericTrainer
Expand All @@ -14,38 +14,7 @@
# pylint: enable=invalid-name


class GradientRunnerExecutable(Executable):

def __init__(self,
all_coders: Set[ModelPart],
fetches: Dict[str, List[Any]]) -> None:
self._all_coders = all_coders
self._fetches = fetches

self.result = None # type: Optional[ExecutionResult]

def next_to_execute(self) -> NextExecute:
"""Get the feedables and tensors to run."""
return self._all_coders, self._fetches, []

def collect_results(self, results: List[Dict]) -> None:
assert len(results) == 1

for sess_result in results:
gradient_dict = {}
tensor_names = [t.name for t in self._fetches["gradients"]]
for name, val in zip(tensor_names, sess_result["gradients"]):
gradient_dict[name] = val

self.result = ExecutionResult(
outputs=[gradient_dict],
losses=[],
scalar_summaries=None,
histogram_summaries=None,
image_summaries=None)


class GradientRunner(BaseRunner[SupportedDecoder]):
class GradientRunner(BaseRunner[GenericModelPart]):
"""Runner for fetching gradients computed over the dataset.
Gradient runner applies provided trainer on a desired dataset
Expand All @@ -55,25 +24,34 @@ class GradientRunner(BaseRunner[SupportedDecoder]):
(https://arxiv.org/pdf/1612.00796.pdf)
"""

# pylint: disable=too-few-public-methods
class Executable(BaseRunner.Executable["GradientRunner"]):

def collect_results(self, results: List[Dict]) -> None:
assert len(results) == 1

for sess_result in results:
gradient_dict = {}
tensor_names = [
t.name for t in self.executor.fetches()["gradients"]]
for name, val in zip(tensor_names, sess_result["gradients"]):
gradient_dict[name] = val

self.set_runner_result(outputs=gradient_dict, losses=[])
# pylint: enable=too-few-public-methods

def __init__(self,
output_series: str,
trainer: GenericTrainer,
decoder: SupportedDecoder) -> None:
decoder: SupportedDecoder,
trainer: GenericTrainer) -> None:
check_argument_types()
BaseRunner[AutoregressiveDecoder].__init__(
BaseRunner[GenericModelPart].__init__(
self, output_series, decoder)

self._gradients = trainer.gradients

# pylint: disable=unused-argument
def get_executable(self,
compute_losses: bool,
summaries: bool,
num_sessions: int) -> GradientRunnerExecutable:
fetches = {"gradients": [g[1] for g in self._gradients]}

return GradientRunnerExecutable(self.all_coders, fetches)
# pylint: enable=unused-argument
def fetches(self) -> Dict[str, tf.Tensor]:
return {"gradients": [g[1] for g in self._gradients]}

@property
def loss_names(self) -> List[str]:
Expand Down
52 changes: 30 additions & 22 deletions neuralmonkey/trainers/delayed_update_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from neuralmonkey.runners.base_runner import GraphExecutor, NextExecute
from neuralmonkey.trainers.generic_trainer import (GenericTrainer, Objective,
Gradients)
from neuralmonkey.trainers.regularizers import (Regularizer, L1Regularizer,
L2Regularizer)


class DelayedUpdateTrainer(GenericTrainer):
Expand Down Expand Up @@ -76,7 +78,7 @@ def collect_results(self, results: List[Dict]) -> None:
assert self.res_batch is not None

objective_names = [obj.name for obj in self.executor.objectives]
objective_names += ["L1", "L2"]
objective_names += [reg.name for reg in self.executor.regularizers]
losses = dict(zip(objective_names, self.res_losses))

self.set_result({}, losses, self.res_batch, self.res_sums)
Expand All @@ -85,16 +87,14 @@ def collect_results(self, results: List[Dict]) -> None:
def __init__(self,
batches_per_update: int,
objectives: List[Objective],
l1_weight: float = 0.0,
l2_weight: float = 0.0,
clip_norm: float = None,
optimizer: tf.train.Optimizer = None,
regularizers: List[Regularizer] = None,
var_scopes: List[str] = None,
var_collection: str = None) -> None:
check_argument_types()
GenericTrainer.__init__(self, objectives, l1_weight, l2_weight,
clip_norm, optimizer, var_scopes,
var_collection)
GenericTrainer.__init__(self, objectives, clip_norm, optimizer,
regularizers, var_scopes, var_collection)

self.batches_per_update = batches_per_update
# pylint: enable=too-many-arguments
Expand Down Expand Up @@ -186,17 +186,6 @@ def raw_gradients(self) -> Gradients:
for grad in self.gradient_buffers]
# pylint: enable=not-an-iterable

tf.summary.scalar(
"train_opt_cost",
self.diff_buffer / tf.to_float(self.cumulator_counter),
collections=["summary_train"])

# log all objectives
for obj, objbuf in zip(self.objectives, self.objective_buffers):
tf.summary.scalar(
obj.name, objbuf / tf.to_float(self.cumulator_counter),
collections=["summary_train"])

# now, zip averaged grads with associated vars to a Gradients struct.
# pylint: disable=unpacking-non-sequence
_, existing_vars = self.existing_grads_and_vars
Expand All @@ -205,18 +194,37 @@ def raw_gradients(self) -> Gradients:

@tensor
def summaries(self) -> Dict[str, tf.Tensor]:

# pylint: disable=protected-access
if isinstance(self.optimizer._lr, tf.Tensor):
tf.summary.scalar("learning_rate", self.optimizer._lr,
collections=["summary_train"])
# pylint: enable=protected-access

# pylint: disable=unpacking-non-sequence
l1_norm, l2_norm = self.regularization_losses
# pylint: enable=unpacking-non-sequence
reg_values = self.regularization_losses
# we always want to include l2 values in the summary
if L1Regularizer not in [type(r) for r in self.regularizers]:
l1_reg = L1Regularizer(name="train_l1", weight=0.)
tf.summary.scalar(l1_reg.name, l1_reg.value(self.regularizable),
collections=["summary_train"])
if L2Regularizer not in [type(r) for r in self.regularizers]:
l2_reg = L2Regularizer(name="train_l2", weight=0.)
tf.summary.scalar(l2_reg.name, l2_reg.value(self.regularizable),
collections=["summary_train"])

for reg, reg_value in zip(self.regularizers, reg_values):
tf.summary.scalar(reg.name, reg_value,
collections=["summary_train"])

tf.summary.scalar("train_l1", l1_norm, collections=["summary_train"])
tf.summary.scalar("train_l2", l2_norm, collections=["summary_train"])
for obj, objbuf in zip(self.objectives, self.objective_buffers):
tf.summary.scalar(
obj.name, objbuf / tf.to_float(self.cumulator_counter),
collections=["summary_train"])

tf.summary.scalar(
"train_opt_cost",
self.diff_buffer / tf.to_float(self.cumulator_counter),
collections=["summary_train"])

# pylint: disable=not-an-iterable
# Pylint does not understand @tensor annotations
Expand Down
27 changes: 15 additions & 12 deletions neuralmonkey/trainers/generic_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def collect_results(self, results: List[Dict]) -> None:
result["histogram_summaries"]])

objective_names = [obj.name for obj in self.executor.objectives]
objective_names += ["L1", "L2"]
objective_names += [reg.name for reg in self.executor.regularizers]

losses = dict(zip(objective_names, result["losses"]))

Expand Down Expand Up @@ -86,25 +86,28 @@ def __init__(self,
optimizer if optimizer is not None else self.default_optimizer())

# pylint: disable=no-self-use
@tensor
def regularizable(self) -> List[tf.Tensor]:
return [v for v in tf.trainable_variables()
if not BIAS_REGEX.findall(v.name)
and not v.name.startswith("vgg")
and not v.name.startswith("Inception")
and not v.name.startswith("resnet")]
# pylint: enable=no-self-use

@tensor
def regularization_losses(self) -> List[tf.Tensor]:
"""Compute the regularization losses, e.g. L1 and L2."""
regularizable = [v for v in tf.trainable_variables()
if not BIAS_REGEX.findall(v.name)
and not v.name.startswith("vgg")
and not v.name.startswith("Inception")
and not v.name.startswith("resnet")]

regularizable = self.regularizable
if not regularizable:
warn("It seems that there are no trainable variables in the model")
return tf.zeros([]), tf.zeros([])
return [tf.zeros([]) for _ in self.regularizers]

with tf.name_scope("regularization"):
reg_values = [reg.value(regularizable)
for reg in self.regularizers]

return reg_values
# pylint: enable=no-self-use

@tensor
def objective_values(self) -> List[tf.Tensor]:
Expand All @@ -127,7 +130,7 @@ def differentiable_loss_sum(self) -> tf.Tensor:
else:
obj_weights.append(obj.weight)

obj_weights += [reg.weights for reg in self.regularizers]
obj_weights += [reg.weight for reg in self.regularizers]
diff_loss = sum(
o * w for o, w in zip(self.objective_values, obj_weights)
if w is not None)
Expand Down Expand Up @@ -219,11 +222,11 @@ def summaries(self) -> Dict[str, tf.Tensor]:
# we always want to include l2 values in the summary
if L1Regularizer not in [type(r) for r in self.regularizers]:
l1_reg = L1Regularizer(name="train_l1", weight=0.)
tf.summary.scalar(l1_reg.name, l1_reg.value(regularizable),
tf.summary.scalar(l1_reg.name, l1_reg.value(self.regularizable),
collections=["summary_train"])
if L2Regularizer not in [type(r) for r in self.regularizers]:
l2_reg = L2Regularizer(name="train_l2", weight=0.)
tf.summary.scalar(l2_reg.name, l2_reg.value(regularizable),
tf.summary.scalar(l2_reg.name, l2_reg.value(self.regularizable),
collections=["summary_train"])

for reg, reg_value in zip(self.regularizers, reg_values):
Expand Down
Loading

0 comments on commit c676d9b

Please sign in to comment.