Skip to content

Commit

Permalink
BERT data preparation script
Browse files Browse the repository at this point in the history
  • Loading branch information
varisd committed Feb 6, 2019
1 parent 63e19aa commit fbad9af
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 3 deletions.
10 changes: 7 additions & 3 deletions neuralmonkey/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,8 +389,8 @@ def itergen():
for s_name, (preprocessor, source) in prep_sl.items():
if source not in iterators:
raise ValueError(
"Source series {} for series-level preprocessor nonexistent: "
"Preprocessed series '', source series ''".format(source))
"Source series for series-level preprocessor nonexistent: "
"Preprocessed series '{}', source series '{}'")
iterators[s_name] = _make_sl_iterator(source, preprocessor)

# Finally, dataset-level preprocessors.
Expand Down Expand Up @@ -443,6 +443,8 @@ def __init__(self,
Arguments:
name: The name for the dataset.
iterators: A series-iterator generator mapping.
lazy: If False, load the data from iterators to a list and store
the list in memory.
buffer_size: Use this tuple as a minimum and maximum buffer size
for pre-loading data. This should be (a few times) larger than
the batch size used for mini-batching. When the buffer size
Expand Down Expand Up @@ -638,7 +640,9 @@ def itergen():
buf.append(item)

if self.shuffled:
random.shuffle(buf) # type: ignore
lbuf = list(buf)
random.shuffle(lbuf)
buf = deque(lbuf)

if not self.batching.drop_remainder:
for bucket in buckets:
Expand Down
79 changes: 79 additions & 0 deletions scripts/preprocess_bert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
#!/usr/bin/env python3
# Creates training data for the BERT network training
# (noisified + masked gold predictions) using the input corpus
# TODO: add support for other NM vocabularies (aside from t2t)

import argparse
import os

import numpy as np

from neuralmonkey.logging import log as _log
from neuralmonkey.vocabulary import (
Vocabulary, PAD_TOKEN, UNK_TOKEN, from_wordlist)


def log(message: str, color: str = "blue") -> None:
_log(message, color)


def main() -> None:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--input_file", type=str, default="/dev/stdin")
parser.add_argument("--vocabulary", type=str, required=True)
parser.add_argument("--output_prefix", type=str, default=None)
parser.add_argument("--mask_token", type=str, default=UNK_TOKEN)
parser.add_argument("--coverage", type=float, default=0.15)
parser.add_argument("--mask_prob", type=float, default=0.8)
parser.add_argument("--replace_prob", type=float, default=0.1)
parser.add_argument("--vocab_contains_header", type=bool, default=True)
parser.add_argument("--vocab_contains_frequencies",
type=bool, default=True)
args = parser.parse_args()

assert (args.coverage <= 1 and args.coverage >= 0)
assert (args.mask_prob <= 1 and args.mask_prob >= 0)
assert (args.replace_prob <= 1 and args.replace_prob >= 0)

log("Loading vocabulary.")
vocabulary = from_wordlist(
args.vocabulary,
contains_header=args.vocab_contains_header,
contains_frequencies=args.vocab_contains_freqeuencies)

# Tuple[keep_prob
mask_prob = args.mask_prob
replace_prob = args.replace_prob
keep_prob = 1 - mask_prob - replace_prob
sample_probs = (keep_prob, mask_prob, replace_prob)

output_prefix = args.output_prefix
if output_prefix is None:
output_prefix = args.input_file
out_f_noise = "{}.noisy".format(output_prefix)
out_f_mask = "{}.mask".format(output_prefix)

out_noise_h = open(out_f_noise, "w", encoding="utf-8")
out_mask_h = open(out_f_mask, "w", encoding="utf-8")
log("Processing data.")
with open(args.input_file, "r", encoding="utf-8") as input_h:
# TODO: performance optimizations
for line in input_h:
line = line.strip().split(" ")
num_samples = int(args.coverage * len(line))
sampled_indices = np.random.choice(len(line), num_samples, False)

output_noisy = list(line)
output_masked = [PAD_TOKEN] * len(line)
for i in sampled_indices:
random_token = np.random.choice(vocabulary.index_to_word[4:])
new_token = np.random.choice(
[line[i], args.mask_token, random_token], p=sample_probs)
output_noisy[i] = new_token
output_masked[i] = line[i]
out_noise_h.write(str(" ".join(output_noisy)) + "\n")
out_mask_h.write(str(" ".join(output_masked)) + "\n")


if __name__ == "__main__":
main()

0 comments on commit fbad9af

Please sign in to comment.