Skip to content

Commit

Permalink
added script to compute Empirical Fisher
Browse files Browse the repository at this point in the history
  • Loading branch information
varisd committed Nov 16, 2018
1 parent 6673ae7 commit a7637d8
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 2 deletions.
2 changes: 2 additions & 0 deletions neuralmonkey/decoders/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,8 @@ def __init__(self,
self._variable_scope.set_initializer(tf.variance_scaling_initializer(
mode="fan_avg", distribution="uniform"))

if reuse:
self._variable_scope.reuse_variables()
log("Decoder cost op: {}".format(self.cost))
self._variable_scope.reuse_variables()
log("Runtime logits: {}".format(self.runtime_logits))
Expand Down
56 changes: 56 additions & 0 deletions scripts/compute_fisher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
#!/usr/bin/env python3
"""Compute the Empirical Fisher matrix using a list of gradients.
The gradient tensors can be spread over multiple npz files. The mean
is computed over the first dimension (supposed to be a batch).
"""

import argparse
import os
import re
import glob

import numpy as np

from neuralmonkey.logging import log as _log


def log(message: str, color: str = "blue") -> None:
_log(message, color)


def main() -> None:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--file_prefix", type=str,
help="prefix of the npz files containing the gradients")
parser.add_argument("--output_path", type=str,
help="Path to output the Empirical Fisher to.")
args = parser.parse_args()

output_dict = {}
n = 0
for file in glob.glob("{}.*npz".format(args.file_prefix)):
log("Processing {}".format(file))
tensors = np.load(file)

# first dimension must be equal for all tensors (batch)
shapes = [tensors[f].shape for f in tensors.files]
assert all([x[0] == shapes[0][0] for x in shapes])

for varname in tensors.files:
res = np.sum(np.square(tensors[varname]), 0)
if varname in output_dict:
output_dict[varname] += res
else:
output_dict[varname] = res
n += shapes[0][0]

for name in output_dict:
output_dict[name] /= n

np.savez(args.output_path, **output_dict)


if __name__ == "__main__":
main()
13 changes: 11 additions & 2 deletions tests/vocab.ini
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ tf_manager=<tf_manager>
output="tests/outputs/vocab"
overwrite_output_dir=True
batch_size=16
epochs=0
epochs=1
train_dataset=<train_data>
val_dataset=<val_data>
trainer=<trainer>
Expand Down Expand Up @@ -66,11 +66,20 @@ dropout_keep_prob=0.5
data_id="target"
vocabulary=<decoder_vocabulary>

[trainer]
[trainer1]
class=trainers.cross_entropy_trainer.CrossEntropyTrainer
decoders=[<decoder>]
regularizers=[<train_l2>]

[trainer2]
class=trainers.cross_entropy_trainer.CrossEntropyTrainer
decoders=[<decoder>]
regularizers=[<train_l2>]

[trainer]
class=trainers.multitask_trainer.MultitaskTrainer
trainers=[<trainer1>, <trainer2>]

[train_l2]
class=trainers.regularizers.L2Regularizer
weight=1.0e-8
Expand Down

0 comments on commit a7637d8

Please sign in to comment.