-
Notifications
You must be signed in to change notification settings - Fork 103
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
implementation of EWC, gradient_runner and gradient averaging script
- Loading branch information
Showing
6 changed files
with
246 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
from typing import Dict, List, Set, Union | ||
|
||
from typeguard import check_argument_types | ||
|
||
from neuralmonkey.runners.base_runner import ( | ||
BaseRunner, Executable, FeedDict, ExecutionResult, NextExecute) | ||
from neuralmonkey.model.model_part import ModelPart | ||
from neuralmonkey.decoders.autoregressive import AutoregressiveDecoder | ||
from neuralmonkey.decoders.classifier import Classifier | ||
from neuralmonkey.trainers.generic_trainer import GenericTrainer | ||
|
||
# pylint: disable=invalid-name | ||
SupportedDecoder = Union[AutoregressiveDecoder, Classifier] | ||
# pylint: enable=invalid-name | ||
|
||
|
||
class GradientRunExecutable(Executable): | ||
|
||
def __init__(self, | ||
all_coders: Set[ModelPart], | ||
fetches: FeedDict) -> None: | ||
self._all_coders = all_coders | ||
self._fetches = fetches | ||
|
||
self.result = None | ||
|
||
def next_to_execute(self) -> NextExecute: | ||
"""Get the feedables and tensors to run.""" | ||
return self._all_coders, self._fetches, [] | ||
|
||
def collect_results(self, results: List[Dict]) -> None: | ||
assert len(results) == 1 | ||
|
||
for sess_result in results: | ||
gradient_dict = {} | ||
tensor_names = [t.name for t in self._fetches["gradients"]] | ||
for name, val in zip(tensor_names, sess_result["gradients"]): | ||
gradient_dict[name] = val | ||
|
||
self.result = ExecutionResult( | ||
outputs=[gradient_dict], | ||
losses=[], | ||
scalar_summaries=None, | ||
histogram_summaries=None, | ||
image_summaries=None) | ||
|
||
|
||
class GradientRunner(BaseRunner[SupportedDecoder]): | ||
|
||
def __init__(self, | ||
output_series: str, | ||
trainer: GenericTrainer, | ||
decoder: SupportedDecoder) -> None: | ||
check_argument_types() | ||
BaseRunner[AutoregressiveDecoder].__init__( | ||
self, output_series, decoder) | ||
|
||
self._gradients = trainer.gradients | ||
|
||
# pylint: disable=unused-argument | ||
def get_executable(self, | ||
compute_losses: bool, | ||
summaries: bool, | ||
num_sessions: int) -> GradientRunExecutable: | ||
fetches = {"gradients": [g[1] for g in self._gradients]} | ||
|
||
return GradientRunExecutable(self.all_coders, fetches) | ||
# pylint: enable=unused-argument | ||
|
||
@property | ||
def loss_names(self) -> List[str]: | ||
return [] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
"""Variable regularizers. | ||
This module contains classes that can be used as a variable regularizers | ||
during training. All implementation should be derived from the BaseRegularizer | ||
class. | ||
""" | ||
from typing import List | ||
|
||
import numpy as np | ||
import tensorflow as tf | ||
from typeguard import check_argument_types | ||
|
||
from neuralmonkey.logging import log | ||
|
||
|
||
# pylint: disable=too-few-public-methods | ||
class BaseRegularizer: | ||
"""Base class for the regularizers.""" | ||
|
||
def __init__(self, | ||
name: str, | ||
weight: float) -> None: | ||
check_argument_types() | ||
self.name = name | ||
self.weight = weight | ||
|
||
def value(self, variables) -> float: | ||
raise NotImplementedError("Abstract method") | ||
|
||
|
||
class L1Regularizer(BaseRegularizer): | ||
|
||
def __init__(self, | ||
name: str = "train_l1", | ||
weight: float = 0.) -> None: | ||
BaseRegularizer.__init__(self, name, weight) | ||
|
||
def value(self, variables: List[tf.Tensor]) -> float: | ||
return sum(tf.reduce_sum(abs(v)) for v in variables) | ||
|
||
|
||
class L2Regularizer(BaseRegularizer): | ||
|
||
def __init__(self, | ||
name: str = "train_l2", | ||
weight: float = 0.) -> None: | ||
BaseRegularizer.__init__(self, name, weight) | ||
|
||
def value(self, variables: List[tf.Tensor]) -> float: | ||
return sum(tf.reduce_sum(v ** 2) for v in variables) | ||
|
||
|
||
class EWCRegularizer(BaseRegularizer): | ||
"""Regularizer based on the Elastic Weight Consolidation. | ||
TODO description | ||
""" | ||
|
||
def __init__(self, | ||
name: str = "train_ewc", | ||
weight: float = 0., | ||
gradients_file: str = None, | ||
variables_file: str = None) -> None: | ||
check_argument_types() | ||
|
||
BaseRegularizer.__init__(self, name, weight) | ||
|
||
if gradients_file is None: | ||
raise ValueError("Missing gradients_file") | ||
if variables_file is None: | ||
raise ValueError("Missing variables_file") | ||
|
||
log("Loading initial variables for EWC from {}".format(variables_file)) | ||
self.init_vars = tf.contrib.framework.load_checkpoint(variables_file) | ||
log("EWC initial variables loaded") | ||
|
||
log("Loading gradient estimates from {}".format(gradients_file)) | ||
self.gradients = np.load(gradients_file) | ||
log("Gradient estimates loaded") | ||
|
||
def value(self, variables: List[tf.Tensor]) -> float: | ||
ewc_value = 0.0 | ||
for var in variables: | ||
var_name = var.name.split(":")[0] | ||
init_var = self.init_vars.get_tensor(var_name) | ||
gradient = self.gradients[var_name] | ||
ewc_value += tf.reduce_sum(tf.multiply( | ||
tf.square(gradient), tf.square(var - init_var))) | ||
|
||
return ewc_value |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
#!/usr/bin/env python3 | ||
"""Compute the mean over a set of tensors. | ||
The tensors can be spread over multiple npz files. The mean is computed | ||
over the first dimension (supposed to be a batch). | ||
""" | ||
|
||
import argparse | ||
import os | ||
import re | ||
import glob | ||
|
||
import numpy as np | ||
|
||
from neuralmonkey.logging import log as _log | ||
|
||
|
||
def log(message: str, color: str = "blue") -> None: | ||
_log(message, color) | ||
|
||
|
||
def main() -> None: | ||
parser = argparse.ArgumentParser(description=__doc__) | ||
parser.add_argument("--file_prefix", type=str, | ||
help="prefix of the npz files to be averaged") | ||
parser.add_argument("--output_path", type=str, | ||
help="Path to output the averaged checkpoint to.") | ||
args = parser.parse_args() | ||
|
||
output_dict = {} | ||
n = 0 | ||
for file in glob.glob("{}.*npz".format(args.file_prefix)): | ||
log("Processing {}".format(file)) | ||
tensors = np.load(file) | ||
|
||
# first dimension must be equal for all tensors (batch) | ||
shapes = [tensors[f].shape for f in tensors.files] | ||
assert all([x[0] == shapes[0][0] for x in shapes]) | ||
|
||
for varname in tensors.files: | ||
res = np.sum(tensors[varname], 0) | ||
if varname in output_dict: | ||
output_dict[varname] += res | ||
else: | ||
output_dict[varname] = res | ||
n += shapes[0][0] | ||
|
||
for name in output_dict: | ||
output_dict[name] /= n | ||
|
||
np.savez(args.output_path, **output_dict) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |