diff --git a/neuralmonkey/decoders/transformer.py b/neuralmonkey/decoders/transformer.py index 67ea0a7ab..b91bb3dac 100644 --- a/neuralmonkey/decoders/transformer.py +++ b/neuralmonkey/decoders/transformer.py @@ -214,6 +214,8 @@ def __init__(self, self._variable_scope.set_initializer(tf.variance_scaling_initializer( mode="fan_avg", distribution="uniform")) + if reuse: + self._variable_scope.reuse_variables() log("Decoder cost op: {}".format(self.cost)) self._variable_scope.reuse_variables() log("Runtime logits: {}".format(self.runtime_logits)) diff --git a/scripts/compute_fisher.py b/scripts/compute_fisher.py new file mode 100755 index 000000000..2690bd924 --- /dev/null +++ b/scripts/compute_fisher.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 +"""Compute the Empirical Fisher matrix using a list of gradients. + +The gradient tensors can be spread over multiple npz files. The mean +is computed over the first dimension (supposed to be a batch). + +""" + +import argparse +import os +import re +import glob + +import numpy as np + +from neuralmonkey.logging import log as _log + + +def log(message: str, color: str = "blue") -> None: + _log(message, color) + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--file_prefix", type=str, + help="prefix of the npz files containing the gradients") + parser.add_argument("--output_path", type=str, + help="Path to output the Empirical Fisher to.") + args = parser.parse_args() + + output_dict = {} + n = 0 + for file in glob.glob("{}.*npz".format(args.file_prefix)): + log("Processing {}".format(file)) + tensors = np.load(file) + + # first dimension must be equal for all tensors (batch) + shapes = [tensors[f].shape for f in tensors.files] + assert all([x[0] == shapes[0][0] for x in shapes]) + + for varname in tensors.files: + res = np.sum(np.square(tensors[varname]), 0) + if varname in output_dict: + output_dict[varname] += res + else: + output_dict[varname] = res + n += shapes[0][0] + + for name in output_dict: + output_dict[name] /= n + + np.savez(args.output_path, **output_dict) + + +if __name__ == "__main__": + main() diff --git a/tests/vocab.ini b/tests/vocab.ini index 0f4883d44..ce180216d 100644 --- a/tests/vocab.ini +++ b/tests/vocab.ini @@ -4,7 +4,7 @@ tf_manager= output="tests/outputs/vocab" overwrite_output_dir=True batch_size=16 -epochs=0 +epochs=1 train_dataset= val_dataset= trainer= @@ -66,11 +66,20 @@ dropout_keep_prob=0.5 data_id="target" vocabulary= -[trainer] +[trainer1] class=trainers.cross_entropy_trainer.CrossEntropyTrainer decoders=[] regularizers=[] +[trainer2] +class=trainers.cross_entropy_trainer.CrossEntropyTrainer +decoders=[] +regularizers=[] + +[trainer] +class=trainers.multitask_trainer.MultitaskTrainer +trainers=[, ] + [train_l2] class=trainers.regularizers.L2Regularizer weight=1.0e-8