diff --git a/.travis.yml b/.travis.yml index 5e5c3ea37..039b3a6bb 100644 --- a/.travis.yml +++ b/.travis.yml @@ -14,11 +14,7 @@ env: - TEST_SUITE=mypy python: - #- "2.7" - #- "3.4" - - "3.5" - #- "3.5-dev" # 3.5 development branch - #- "nightly" # currently points to 3.6-dev + - "3.6" # commands to install dependencies before_install: diff --git a/docs/requirements.txt b/docs/requirements.txt index 580608dd2..fe2ae61f1 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -10,4 +10,4 @@ python_speech_features pygments typeguard sacrebleu -tensorflow>=1.10.0,<1.11 +tensorflow>=1.12.0,<1.13 diff --git a/neuralmonkey/checking.py b/neuralmonkey/checking.py index d9ac7e2dc..dfa9a0167 100644 --- a/neuralmonkey/checking.py +++ b/neuralmonkey/checking.py @@ -4,64 +4,14 @@ constructing the computational graph. """ - -from typing import List, Optional, Iterable - +from typing import List, Optional import tensorflow as tf -from neuralmonkey.logging import log, debug -from neuralmonkey.dataset import Dataset -from neuralmonkey.runners.base_runner import BaseRunner - class CheckingException(Exception): pass -def check_dataset_and_coders(dataset: Dataset, - runners: Iterable[BaseRunner]) -> None: - # pylint: disable=protected-access - - data_list = [] - for runner in runners: - for c in runner.feedables: - if hasattr(c, "data_id"): - data_list.append((getattr(c, "data_id"), c)) - elif hasattr(c, "data_ids"): - data_list.extend([(d, c) for d in getattr(c, "data_ids")]) - elif hasattr(c, "input_sequence"): - inpseq = getattr(c, "input_sequence") - if hasattr(inpseq, "data_id"): - data_list.append((getattr(inpseq, "data_id"), c)) - elif hasattr(inpseq, "data_ids"): - data_list.extend( - [(d, c) for d in getattr(inpseq, "data_ids")]) - else: - log("Input sequence: {} does not have a data attribute" - .format(str(inpseq))) - else: - log(("Coder: {} has neither an input sequence attribute nor a " - "a data attribute.").format(c)) - - debug("Found series: {}".format(str(data_list)), "checking") - missing = [] - - for (serie, coder) in data_list: - if serie not in dataset: - log("dataset {} does not have serie {}".format( - dataset.name, serie)) - missing.append((coder, serie)) - - if missing: - formated = ["{} ({}, {}.{})" .format(serie, str(cod), - cod.__class__.__module__, - cod.__class__.__name__) - for cod, serie in missing] - - raise CheckingException("Dataset '{}' is mising series {}:" - .format(dataset.name, ", ".join(formated))) - - def assert_shape(tensor: tf.Tensor, expected_shape: List[Optional[int]]) -> None: """Check shape of a tensor. diff --git a/neuralmonkey/checkpython.py b/neuralmonkey/checkpython.py index e8841edff..979453fb7 100644 --- a/neuralmonkey/checkpython.py +++ b/neuralmonkey/checkpython.py @@ -1,7 +1,7 @@ import sys -if sys.version_info[0] < 3 or sys.version_info[1] < 5: +if sys.version_info[0] < 3 or sys.version_info[1] < 6: print("Error:", file=sys.stderr) - print("Neural Monkey must use Python >= 3.5", file=sys.stderr) + print("Neural Monkey must use Python >= 3.6", file=sys.stderr) print("Your Python is", sys.version, sys.executable, file=sys.stderr) sys.exit(1) diff --git a/neuralmonkey/config/normalize.py b/neuralmonkey/config/normalize.py index a8cbc8bf7..5e471a370 100644 --- a/neuralmonkey/config/normalize.py +++ b/neuralmonkey/config/normalize.py @@ -14,7 +14,6 @@ import numpy as np -from neuralmonkey.dataset import BatchingScheme from neuralmonkey.logging import warn from neuralmonkey.tf_manager import get_default_tf_manager from neuralmonkey.trainers.delayed_update_trainer import DelayedUpdateTrainer @@ -34,25 +33,6 @@ def normalize_configuration(cfg: Namespace, train_mode: bool) -> None: if cfg.tf_manager is None: cfg.tf_manager = get_default_tf_manager() - if (cfg.batch_size is None) == (cfg.batching_scheme is None): - raise ValueError("You must specify either batch_size or " - "batching_scheme (not both).") - - if cfg.batch_size is not None: - assert cfg.batching_scheme is None - cfg.batching_scheme = BatchingScheme(batch_size=cfg.batch_size) - else: - assert cfg.batching_scheme is not None - cfg.batch_size = cfg.batching_scheme.batch_size - - if cfg.runners_batch_size is None: - cfg.runners_batch_size = cfg.batching_scheme.batch_size - - cfg.runners_batching_scheme = BatchingScheme( - batch_size=cfg.runners_batch_size, - token_level_batching=cfg.batching_scheme.token_level_batching, - use_leftover_buckets=True) - cfg.evaluation = [(e[0], e[0], e[1]) if len(e) == 2 else e for e in cfg.evaluation] diff --git a/neuralmonkey/config/parsing.py b/neuralmonkey/config/parsing.py index c70866c00..edcafc019 100644 --- a/neuralmonkey/config/parsing.py +++ b/neuralmonkey/config/parsing.py @@ -150,12 +150,13 @@ def _parse_class_name(string: str, vars_dict: VarsDict) -> ClassSymbol: def _parse_value(string: str, vars_dict: VarsDict) -> Any: - """Parse the value recursively according to the Nerualmonkey grammar. + """Parse the value recursively according to the Nerual Monkey grammar. Arguments: string: the string to be parsed vars_dict: a dictionary of variables for substitution """ + string = string.strip() if string in CONSTANTS: return CONSTANTS[string] diff --git a/neuralmonkey/dataset.py b/neuralmonkey/dataset.py index 2627fa33e..ae963d75d 100644 --- a/neuralmonkey/dataset.py +++ b/neuralmonkey/dataset.py @@ -55,32 +55,43 @@ class BatchingScheme: def __init__(self, - batch_size: int, - batch_bucket_span: int = None, - token_level_batching: bool = False, - bucketing_ignore_series: List[str] = None, - use_leftover_buckets: bool = True) -> None: + batch_size: int = None, + drop_remainder: bool = False, + bucket_boundaries: List[int] = None, + bucket_batch_sizes: List[int] = None, + ignore_series: List[str] = None) -> None: """Construct the baching scheme. Attributes: batch_size: Number of examples in one mini-batch. - batch_bucket_span: The span of the bucket for bucketed batching. - token_level_batching: Count the batch_size per individual tokens - in the batch instead of examples. - bucketing_ignore_series: Series to ignore during bucketing. - use_leftover_buckets: Whether to throw out bucket contents at the - end of the epoch or to use them. + drop_remainder: Whether to throw out the last batch in the epoch + if it is not complete. + bucket_boundaries: Upper length boundaries of buckets. + bucket_batch_sizes: Batch size per bucket. Lenght should be + `len(bucket_boundaries) + 1` + ignore_series: Series to ignore during bucketing. """ check_argument_types() self.batch_size = batch_size - self.batch_bucket_span = batch_bucket_span - self.token_level_batching = token_level_batching - self.use_leftover_buckets = use_leftover_buckets - - self.bucketing_ignore_series = [] # type: List[str] - if bucketing_ignore_series is not None: - self.bucketing_ignore_series = bucketing_ignore_series + self.drop_remainder = drop_remainder + self.bucket_boundaries = bucket_boundaries + self.bucket_batch_sizes = bucket_batch_sizes + + self.ignore_series = [] # type: List[str] + if ignore_series is not None: + self.ignore_series = ignore_series + + if (self.batch_size is None) == (self.bucket_boundaries is None): + raise ValueError("You must specify either batch_size or " + "bucket_boundaries, not both") + + if self.bucket_boundaries is not None: + if self.bucket_batch_sizes is None: + raise ValueError("You must specify bucket_batch_sizes") + if len(self.bucket_batch_sizes) != len(self.bucket_boundaries) + 1: + raise ValueError( + "There should be N+1 batch sizes for N bucket boundaries") # pylint: enable=too-few-public-methods @@ -192,79 +203,11 @@ def _get_series_outputs(series_config: SeriesConfig) -> List[OutputSpec]: return [(key, val, AutoWriter) for key, val in outputs.items()] -# pylint: disable=too-many-locals -# This is a deprecated function, no point in removing one local var from it -def load_dataset_from_files( - name: str, - lazy: bool = False, - preprocessors: List[Tuple[str, str, Callable]] = None, - **kwargs) -> "Dataset": - """Load a dataset from the files specified by the provided arguments. - - Paths to the data are provided in a form of dictionary. - - Keyword arguments: - name: The name of the dataset to use. If None (default), the name will - be inferred from the file names. - lazy: Boolean flag specifying whether to use lazy loading (useful for - large files). Note that the lazy dataset cannot be shuffled. - Defaults to False. - preprocessor: A callable used for preprocessing of the input sentences. - kwargs: Dataset keyword argument specs. These parameters should begin - with 's_' prefix and may end with '_out' suffix. For example, - a data series 'source' which specify the source sentences - should be initialized with the 's_source' parameter, which - specifies the path and optinally reader of the source file. If - runners generate data of the 'target' series, the output file - should be initialized with the 's_target_out' parameter. - Series identifiers should not contain underscores. - Dataset-level preprocessors are defined with 'pre_' prefix - followed by a new series name. In case of the pre-processed - series, a callable taking the dataset and returning a new - series is expected as a value. - - Returns: - The newly created dataset. - """ - warn("Use of deprecated function. Consider using dataset.load instead.") - check_argument_types() - - series_paths_and_readers = _get_series_paths_and_readers(kwargs) - outputs = _get_series_outputs(kwargs) - - if not series_paths_and_readers: - raise ValueError("No input files were provided.") - - series, data = [list(x) for x in zip(*series_paths_and_readers.items())] - - # Series-level preprocessors - if preprocessors: - for src, tgt, fun in preprocessors: - series.append(tgt) - data.append((fun, src)) - - # Dataset-level preprocessors - keys = [key for key in kwargs if PREPROCESSED_SERIES.match(key)] - - for key in keys: - s_name = get_first_match(PREPROCESSED_SERIES, key) - preprocessor = cast(DatasetPreprocess, kwargs[key]) - series.append(s_name) - data.append(preprocessor) - - buffer_size = None if not lazy else 5000 - return load(name, series, data, outputs, buffer_size, False) -# pylint: enable=too-many-locals - - -def from_files(*args, **kwargs): - return load_dataset_from_files(*args, **kwargs) - - # pylint: disable=too-many-locals,too-many-branches def load(name: str, series: List[str], data: List[SourceSpec], + batching: BatchingScheme = None, outputs: List[OutputSpec] = None, buffer_size: int = None, shuffled: bool = False) -> "Dataset": @@ -288,11 +231,20 @@ def load(name: str, (much) larger than the batch size. Note that the buffer gets refilled each time its size is less than half the `buffer_size`. When refilling, the buffer gets refilled to the specified size. - shuffled: Whether to shuffle the dataset buffer (done upon refill). - """ check_argument_types() + if batching is None: + from neuralmonkey.experiment import Experiment + log("Using default batching scheme for dataset {}.".format(name)) + # pylint: disable=no-member + batch_size = Experiment.get_current().config.args.batch_size + # pylint: enable=no-member + if batch_size is None: + raise ValueError("Argument main.batch_size is not specified, " + "cannot use default batching scheme.") + batching = BatchingScheme(batch_size=batch_size) + if not series: raise ValueError("No dataset series specified.") @@ -374,10 +326,10 @@ def itergen(): in [_normalize_outputspec(out) for out in outputs]} if buffer_size is not None: - return Dataset(name, iterators, output_dict, + return Dataset(name, iterators, batching, output_dict, (buffer_size // 2, buffer_size), shuffled) - return Dataset(name, iterators, output_dict, None, shuffled) + return Dataset(name, iterators, batching, output_dict, None, shuffled) # pylint: enable=too-many-locals,too-many-branches @@ -398,6 +350,7 @@ class Dataset: def __init__(self, name: str, iterators: Dict[str, Callable[[], Iterator]], + batching: BatchingScheme, outputs: Dict[str, Tuple[str, Writer]] = None, buffer_size: Tuple[int, int] = None, shuffled: bool = False) -> None: @@ -424,6 +377,7 @@ def __init__(self, """ self.name = name self.iterators = iterators + self.batching = batching self.outputs = outputs if buffer_size is not None: @@ -509,21 +463,23 @@ def maybe_get_series(self, name: str) -> Optional[Iterator]: return self.get_series(name) return None - # pylint: disable=too-many-locals,too-many-branches - def batches(self, - scheme: BatchingScheme) -> Iterator["Dataset"]: + # pylint: disable=too-many-locals,too-many-branches,too-many-statements + def batches(self) -> Iterator["Dataset"]: """Split the dataset into batches. - Arguments: - scheme: `BatchingScheme` configuration object. - Returns: Generator yielding the batches. """ - if self.lazy and self.buffer_min_size < scheme.batch_size: + if self.batching.batch_size is not None: + max_bs = self.batching.batch_size + else: + assert self.batching.bucket_batch_sizes is not None + max_bs = max(self.batching.bucket_batch_sizes) + + if self.lazy and self.buffer_min_size < max_bs: warn("Minimum buffer size ({}) lower than batch size ({}). " "It is recommended to use large buffer size." - .format(self.buffer_min_size, scheme.batch_size)) + .format(self.buffer_min_size, max_bs)) # Initialize iterators iterators = {s: it() for s, it in self.iterators.items()} @@ -551,28 +507,39 @@ def itergen(): # Iterate over the rest of the data until buffer is empty batch_index = 0 - buckets = {} \ - # type: Dict[int, List[DataExample]] + buckets = [[]] # type: List[List[DataExample]] + + if self.batching.bucket_boundaries is not None: + buckets += [[] for _ in self.batching.bucket_boundaries] + while buf: row = buf.popleft() - if scheme.batch_bucket_span is None: + if self.batching.bucket_boundaries is None: bucket_id = 0 else: # TODO: use only specific series to determine the bucket number - bucket_id = (max(len(row[key]) for key in row) - // scheme.batch_bucket_span) + length = max(len(row[key]) for key in row) + + bucket_id = -1 + for b_id, limit in enumerate(self.batching.bucket_boundaries): + fits_in = length <= limit + tighter_fit = ( + bucket_id == -1 + or limit < self.batching.bucket_boundaries[ + bucket_id]) + + if fits_in and tighter_fit: + bucket_id = b_id - if bucket_id not in buckets: - buckets[bucket_id] = [] buckets[bucket_id].append(row) - is_full = (len(buckets[bucket_id]) >= scheme.batch_size) - if scheme.token_level_batching: - bucket_width = max(max(len(row[key]) for key in row) - for row in buckets[bucket_id]) - is_full = (bucket_width * len(buckets[bucket_id]) - >= scheme.batch_size) + if self.batching.bucket_batch_sizes is None: + assert self.batching.batch_size is not None + is_full = len(buckets[bucket_id]) >= self.batching.batch_size + else: + is_full = (len(buckets[bucket_id]) + >= self.batching.bucket_batch_sizes[bucket_id]) if is_full: # Create the batch @@ -580,7 +547,8 @@ def itergen(): data = {key: _make_datagen(buckets[bucket_id], key) for key in buckets[bucket_id][0]} - yield Dataset(name=name, iterators=data) + yield Dataset( + name=name, iterators=data, batching=self.batching) batch_index += 1 buckets[bucket_id] = [] @@ -598,14 +566,15 @@ def itergen(): random.shuffle(lbuf) buf = deque(lbuf) - if scheme.use_leftover_buckets: - for bucket_id in buckets: - if buckets[bucket_id]: + if not self.batching.drop_remainder: + for bucket in buckets: + if bucket: name = "{}.batch.{}".format(self.name, batch_index) - data = {key: _make_datagen(buckets[bucket_id], key) - for key in buckets[bucket_id][0]} + data = {key: _make_datagen(bucket, key) + for key in bucket[0]} - yield Dataset(name=name, iterators=data) + yield Dataset( + name=name, iterators=data, batching=self.batching) batch_index += 1 # pylint: enable=too-many-locals,too-many-branches @@ -639,6 +608,7 @@ def subset(self, start: int, length: int) -> "Dataset": return Dataset( # type: ignore name=name, iterators=slices, + batching=self.batching, outputs=outputs, buffer_size=self.buffer_size, shuffled=self.shuffled) diff --git a/neuralmonkey/decoders/autoregressive.py b/neuralmonkey/decoders/autoregressive.py index 15cd47e7d..46fdffa7f 100644 --- a/neuralmonkey/decoders/autoregressive.py +++ b/neuralmonkey/decoders/autoregressive.py @@ -7,7 +7,6 @@ """ from typing import NamedTuple, Callable, Tuple, Optional, Any, List, Dict -import numpy as np import tensorflow as tf from neuralmonkey.dataset import Dataset @@ -20,8 +19,7 @@ from neuralmonkey.nn.utils import dropout from neuralmonkey.tf_utils import get_variable, get_state_shape_invariants from neuralmonkey.vocabulary import ( - Vocabulary, START_TOKEN, UNK_TOKEN_INDEX, START_TOKEN_INDEX, - PAD_TOKEN_INDEX) + Vocabulary, pad_batch, sentence_mask, UNK_TOKEN_INDEX, START_TOKEN_INDEX) class LoopState(NamedTuple( @@ -181,23 +179,29 @@ def embedding_size(self) -> int: @tensor def go_symbols(self) -> tf.Tensor: - return tf.fill([self.batch_size], START_TOKEN_INDEX) + return tf.fill([self.batch_size], + tf.constant(START_TOKEN_INDEX, dtype=tf.int64)) @property def input_types(self) -> Dict[str, tf.DType]: - return {self.data_id: tf.int32} + return {self.data_id: tf.string} @property def input_shapes(self) -> Dict[str, tf.TensorShape]: return {self.data_id: tf.TensorShape([None, None])} @tensor - def train_inputs(self) -> tf.Tensor: + def train_tokens(self) -> tf.Tensor: return self.dataset[self.data_id] + @tensor + def train_inputs(self) -> tf.Tensor: + return tf.transpose( + self.vocabulary.strings_to_indices(self.train_tokens)) + @tensor def train_mask(self) -> tf.Tensor: - return tf.to_float(tf.not_equal(self.train_inputs, PAD_TOKEN_INDEX)) + return sentence_mask(self.train_inputs) @tensor def decoding_w(self) -> tf.Variable: @@ -373,7 +377,7 @@ def get_initial_loop_state(self) -> LoopState: outputs = tf.zeros( shape=[0, self.batch_size], - dtype=tf.int32, + dtype=tf.int64, name="outputs") feedables = DecoderFeedables( @@ -480,18 +484,9 @@ def feed_dict(self, dataset: Dataset, train: bool = False) -> FeedDict: raise ValueError("When training, you must feed " "reference sentences") - go_symbol_idx = self.vocabulary.get_word_index(START_TOKEN) - fd[self.go_symbols] = np.full([len(dataset)], go_symbol_idx, - dtype=np.int32) - if sentences is not None: - sentences_list = list(sentences) - # train_mode=False, since we don't want to ize target words! - inputs, _ = self.vocabulary.sentences_to_tensor( - sentences_list, self.max_output_len, train_mode=False, - add_start_symbol=False, add_end_symbol=True, - pad_to_max_len=False) - - fd[self.train_inputs] = inputs + fd[self.train_tokens] = pad_batch( + list(sentences), self.max_output_len, add_start_symbol=False, + add_end_symbol=True) return fd diff --git a/neuralmonkey/decoders/beam_search_decoder.py b/neuralmonkey/decoders/beam_search_decoder.py index c4e39bd0b..de8a5fbae 100644 --- a/neuralmonkey/decoders/beam_search_decoder.py +++ b/neuralmonkey/decoders/beam_search_decoder.py @@ -476,7 +476,8 @@ def body(*args: Any) -> BeamSearchLoopState: topk_indices.set_shape([None, self.beam_size]) topk_scores.set_shape([None, self.beam_size]) - next_word_ids = tf.mod(topk_indices, len(self.vocabulary)) + next_word_ids = tf.to_int64( + tf.mod(topk_indices, len(self.vocabulary))) next_beam_ids = tf.div(topk_indices, len(self.vocabulary)) # batch offset for tf.gather_nd diff --git a/neuralmonkey/decoders/classifier.py b/neuralmonkey/decoders/classifier.py index bfbb19489..234471e2f 100644 --- a/neuralmonkey/decoders/classifier.py +++ b/neuralmonkey/decoders/classifier.py @@ -10,7 +10,7 @@ from neuralmonkey.model.model_part import ModelPart from neuralmonkey.model.stateful import Stateful from neuralmonkey.nn.mlp import MultilayerPerceptron -from neuralmonkey.vocabulary import Vocabulary +from neuralmonkey.vocabulary import Vocabulary, pad_batch class Classifier(ModelPart): @@ -64,7 +64,7 @@ def __init__(self, @property def input_types(self) -> Dict[str, tf.DType]: - return {self.data_id: tf.int32} + return {self.data_id: tf.string} @property def input_shapes(self) -> Dict[str, tf.TensorShape]: @@ -72,6 +72,10 @@ def input_shapes(self) -> Dict[str, tf.TensorShape]: @tensor def gt_inputs(self) -> tf.Tensor: + return self.vocabulary.strings_to_indices(self.targets) + + @tensor + def targets(self) -> tf.Tensor: return self.dataset[self.data_id] @tensor @@ -132,8 +136,8 @@ def feed_dict(self, dataset: Dataset, train: bool = False) -> FeedDict: sentences = dataset.maybe_get_series(self.data_id) if sentences is not None: - label_tensors, _ = self.vocabulary.sentences_to_tensor( - list(sentences), self.max_output_len) - fd[self.gt_inputs] = label_tensors[0] + labels = [l[0] for l in pad_batch(list(sentences), + self.max_output_len)] + fd[self.targets] = labels return fd diff --git a/neuralmonkey/decoders/ctc_decoder.py b/neuralmonkey/decoders/ctc_decoder.py index b75bd1a0a..44a525825 100644 --- a/neuralmonkey/decoders/ctc_decoder.py +++ b/neuralmonkey/decoders/ctc_decoder.py @@ -1,6 +1,5 @@ -from typing import cast, Iterable, List +from typing import Dict -import numpy as np import tensorflow as tf from typeguard import check_argument_types @@ -11,7 +10,8 @@ from neuralmonkey.model.model_part import ModelPart from neuralmonkey.model.stateful import TemporalStateful from neuralmonkey.tf_utils import get_variable -from neuralmonkey.vocabulary import Vocabulary, END_TOKEN +from neuralmonkey.vocabulary import (Vocabulary, pad_batch, END_TOKEN_INDEX, + PAD_TOKEN_INDEX) class CTCDecoder(ModelPart): @@ -48,11 +48,29 @@ def __init__(self, self.beam_width = beam_width # pylint: enable=too-many-arguments - # pylint: disable=no-self-use + @property + def input_types(self) -> Dict[str, tf.DType]: + return {self.data_id: tf.string} + + @property + def input_shapes(self) -> Dict[str, tf.TensorShape]: + return {self.data_id: tf.TensorShape([None, None])} + @tensor - def train_targets(self) -> tf.Tensor: - return tf.sparse_placeholder(tf.int32, name="targets") - # pylint: disable=no-self-use + def target_tokens(self) -> tf.Tensor: + return self.dataset[self.data_id] + + @tensor + def train_targets(self) -> tf.SparseTensor: + params = self.vocabulary.strings_to_indices(self.target_tokens) + + indices = tf.where(tf.not_equal(params, PAD_TOKEN_INDEX)) + values = tf.gather_nd(params, indices) + + return tf.cast( + tf.SparseTensor( + indices, values, tf.shape(params, out_type=tf.int64)), + tf.int32) @tensor def decoded(self) -> tf.Tensor: @@ -68,7 +86,7 @@ def decoded(self) -> tf.Tensor: return tf.sparse_tensor_to_dense( tf.sparse_transpose(decoded[0]), - default_value=self.vocabulary.get_word_index(END_TOKEN)) + default_value=END_TOKEN_INDEX) @property def train_loss(self) -> tf.Tensor: @@ -114,7 +132,7 @@ def logits(self) -> tf.Tensor: multiplication = tf.nn.conv2d( encoder_states, weights_4d, [1, 1, 1, 1], "SAME") - multiplication_3d = tf.squeeze(multiplication, squeeze_dims=[2]) + multiplication_3d = tf.squeeze(multiplication, axis=2) biases_3d = tf.expand_dims(tf.expand_dims(biases, 0), 0) @@ -124,29 +142,13 @@ def logits(self) -> tf.Tensor: def feed_dict(self, dataset: Dataset, train: bool = False) -> FeedDict: fd = ModelPart.feed_dict(self, dataset, train) - sentences = cast(Iterable[List[str]], - dataset.maybe_get_series(self.data_id)) + sentences = dataset.maybe_get_series(self.data_id) if sentences is None and train: - raise ValueError("When training, you must feed " - "reference sentences") + raise ValueError("You must feed reference sentences when training") if sentences is not None: - vectors, paddings = self.vocabulary.sentences_to_tensor( - list(sentences), train_mode=train, max_len=self.max_length) - - # sentences_to_tensor returns time-major tensors, targets need to - # be batch-major - vectors = vectors.T - paddings = paddings.T - - # Need to convert the data to a sparse representation - bool_mask = (paddings > 0.5) - indices = np.stack(np.where(bool_mask), axis=1) - values = vectors[bool_mask] - - fd[self.train_targets] = tf.SparseTensorValue( - indices=indices, values=values, - dense_shape=vectors.shape) + fd[self.target_tokens] = pad_batch(list(sentences), + self.max_length) return fd diff --git a/neuralmonkey/decoders/decoder.py b/neuralmonkey/decoders/decoder.py index 0677d29bc..1d6db2c49 100644 --- a/neuralmonkey/decoders/decoder.py +++ b/neuralmonkey/decoders/decoder.py @@ -367,13 +367,13 @@ def body(*args) -> LoopState: self.step_scope.reuse_variables() if sample: - next_symbols = tf.to_int32( - tf.squeeze(tf.multinomial(logits, num_samples=1), axis=1)) + next_symbols = tf.squeeze(tf.multinomial( + logits, num_samples=1), axis=1) elif train_mode: next_symbols = loop_state.constants.train_inputs[step] else: - next_symbols = tf.to_int32(tf.argmax(logits, axis=1)) - int_unfinished_mask = tf.to_int32( + next_symbols = tf.argmax(logits, axis=1) + int_unfinished_mask = tf.to_int64( tf.logical_not(loop_state.feedables.finished)) # Note this works only when PAD_TOKEN_INDEX is 0. Otherwise diff --git a/neuralmonkey/decoders/sequence_labeler.py b/neuralmonkey/decoders/sequence_labeler.py index 4f48aa8e1..71b504a0f 100644 --- a/neuralmonkey/decoders/sequence_labeler.py +++ b/neuralmonkey/decoders/sequence_labeler.py @@ -11,7 +11,7 @@ from neuralmonkey.model.parameterized import InitializerSpecs from neuralmonkey.model.model_part import ModelPart from neuralmonkey.tf_utils import get_variable -from neuralmonkey.vocabulary import Vocabulary, PAD_TOKEN_INDEX +from neuralmonkey.vocabulary import Vocabulary, pad_batch, sentence_mask class SequenceLabeler(ModelPart): @@ -40,19 +40,23 @@ def __init__(self, @property def input_types(self) -> Dict[str, tf.DType]: - return {self.data_id: tf.int32} + return {self.data_id: tf.string} @property def input_shapes(self) -> Dict[str, tf.TensorShape]: return {self.data_id: tf.TensorShape([None, None])} @tensor - def train_targets(self) -> tf.Tensor: + def target_tokens(self) -> tf.Tensor: return self.dataset[self.data_id] + @tensor + def train_targets(self) -> tf.Tensor: + return self.vocabulary.strings_to_indices(self.target_tokens) + @tensor def train_mask(self) -> tf.Tensor: - return tf.to_float(tf.not_equal(self.train_targets, PAD_TOKEN_INDEX)) + return sentence_mask(self.train_targets) @property def rnn_size(self) -> int: @@ -91,7 +95,7 @@ def logits(self) -> tf.Tensor: multiplication = tf.nn.conv2d( encoder_states, weights_4d, [1, 1, 1, 1], "SAME") - multiplication_3d = tf.squeeze(multiplication, squeeze_dims=[2]) + multiplication_3d = tf.squeeze(multiplication, axis=[2]) biases_3d = tf.expand_dims(tf.expand_dims(self.decoding_b, 0), 0) @@ -102,7 +106,7 @@ def logits(self) -> tf.Tensor: dmultiplication = tf.nn.conv2d( embedded_inputs, dweights_4d, [1, 1, 1, 1], "SAME") - dmultiplication_3d = tf.squeeze(dmultiplication, squeeze_dims=[2]) + dmultiplication_3d = tf.squeeze(dmultiplication, axis=[2]) logits = multiplication_3d + dmultiplication_3d + biases_3d return logits @@ -138,8 +142,6 @@ def feed_dict(self, dataset: Dataset, train: bool = False) -> FeedDict: sentences = dataset.maybe_get_series(self.data_id) if sentences is not None: - vectors, _ = self.vocabulary.sentences_to_tensor( - list(sentences), pad_to_max_len=False, train_mode=train) + fd[self.target_tokens] = pad_batch(list(sentences)) - fd[self.train_targets] = vectors.T return fd diff --git a/neuralmonkey/decoders/transformer.py b/neuralmonkey/decoders/transformer.py index b6cefb976..5e911b418 100644 --- a/neuralmonkey/decoders/transformer.py +++ b/neuralmonkey/decoders/transformer.py @@ -421,7 +421,7 @@ def get_initial_loop_state(self) -> LoopState: histories["decoded_symbols"] = tf.zeros( shape=[0, self.batch_size], - dtype=tf.int32, + dtype=tf.int64, name="decoded_symbols") histories["input_mask"] = tf.zeros( @@ -483,10 +483,9 @@ def body(*args) -> LoopState: if sample: next_symbols = tf.squeeze( tf.multinomial(logits, num_samples=1), axis=1) - next_symbols = tf.to_int32(next_symbols) else: - next_symbols = tf.to_int32(tf.argmax(logits, axis=1)) - int_unfinished_mask = tf.to_int32( + next_symbols = tf.argmax(logits, axis=1) + int_unfinished_mask = tf.to_int64( tf.logical_not(loop_state.feedables.finished)) # Note this works only when PAD_TOKEN_INDEX is 0. Otherwise diff --git a/neuralmonkey/encoders/attentive.py b/neuralmonkey/encoders/attentive.py index f293b62bf..0c902b9d3 100644 --- a/neuralmonkey/encoders/attentive.py +++ b/neuralmonkey/encoders/attentive.py @@ -67,7 +67,7 @@ def attention_weights(self) -> tf.Tensor: energies = tf.layers.dense(hidden, units=self.num_heads, use_bias=False, name="S2") # shape: [batch_size, max_time, num_heads] - weights = tf.nn.softmax(energies, dim=1) + weights = tf.nn.softmax(energies, axis=1) if mask is not None: weights *= tf.expand_dims(mask, -1) weights /= tf.reduce_sum(weights, axis=1, keepdims=True) + 1e-8 diff --git a/neuralmonkey/encoders/numpy_stateful_filler.py b/neuralmonkey/encoders/numpy_stateful_filler.py index 4eae2ff55..e7abd4841 100644 --- a/neuralmonkey/encoders/numpy_stateful_filler.py +++ b/neuralmonkey/encoders/numpy_stateful_filler.py @@ -164,5 +164,5 @@ def spatial_mask(self) -> tf.Tensor: def feed_dict(self, dataset: Dataset, train: bool = False) -> FeedDict: fd = ModelPart.feed_dict(self, dataset, train) - fd[self.spatial_input] = dataset.get_series(self.data_id) + fd[self.spatial_input] = list(dataset.get_series(self.data_id)) return fd diff --git a/neuralmonkey/encoders/sequence_cnn_encoder.py b/neuralmonkey/encoders/sequence_cnn_encoder.py index c08ffaff2..81aad8e4a 100644 --- a/neuralmonkey/encoders/sequence_cnn_encoder.py +++ b/neuralmonkey/encoders/sequence_cnn_encoder.py @@ -12,7 +12,7 @@ from neuralmonkey.model.model_part import ModelPart from neuralmonkey.model.stateful import Stateful from neuralmonkey.nn.utils import dropout -from neuralmonkey.vocabulary import Vocabulary, PAD_TOKEN_INDEX +from neuralmonkey.vocabulary import Vocabulary, pad_batch, sentence_mask from neuralmonkey.tf_utils import get_variable @@ -62,7 +62,7 @@ def __init__(self, @property def input_types(self) -> Dict[str, tf.DType]: - return {self.data_id: tf.int32} + return {self.data_id: tf.string} @property def input_shapes(self) -> Dict[str, tf.TensorShape]: @@ -70,11 +70,15 @@ def input_shapes(self) -> Dict[str, tf.TensorShape]: @tensor def inputs(self) -> tf.Tensor: + return self.vocabulary.strings_to_indices(self.input_tokens) + + @tensor + def input_tokens(self) -> tf.Tensor: return self.dataset[self.data_id] @tensor def input_mask(self) -> tf.Tensor: - return tf.to_float(tf.not_equal(self.inputs, PAD_TOKEN_INDEX)) + return sentence_mask(self.inputs) @tensor def embedded_inputs(self) -> tf.Tensor: @@ -128,14 +132,6 @@ def feed_dict(self, dataset: Dataset, train: bool = False) -> FeedDict: train: Boolean flag telling whether it is training time """ fd = ModelPart.feed_dict(self, dataset, train) - sentences = dataset.get_series(self.data_id) - vectors, _ = self.vocabulary.sentences_to_tensor( - list(sentences), self.max_input_len, pad_to_max_len=False, - train_mode=train) - - # as sentences_to_tensor returns lists of shape (time, batch), - # we need to transpose - fd[self.inputs] = list(zip(*vectors)) - + fd[self.input_tokens] = pad_batch(list(sentences), self.max_input_len) return fd diff --git a/neuralmonkey/experiment.py b/neuralmonkey/experiment.py index 3a9a3ee20..4b65b1f68 100644 --- a/neuralmonkey/experiment.py +++ b/neuralmonkey/experiment.py @@ -12,18 +12,15 @@ import numpy as np import tensorflow as tf from tensorflow.contrib.tensorboard.plugins import projector -from typeguard import check_argument_types -from neuralmonkey.checking import (check_dataset_and_coders, - CheckingException) -from neuralmonkey.dataset import BatchingScheme, Dataset +from neuralmonkey.checking import CheckingException +from neuralmonkey.dataset import Dataset from neuralmonkey.logging import Logging, log, debug, warn from neuralmonkey.config.configuration import Configuration from neuralmonkey.config.normalize import normalize_configuration from neuralmonkey.learning_utils import (training_loop, evaluation, run_on_dataset, print_final_evaluation) -from neuralmonkey.model.sequence import EmbeddedFactorSequence from neuralmonkey.runners.base_runner import ExecutionResult from neuralmonkey.runners.dataset_runner import DatasetRunner @@ -158,10 +155,23 @@ def register_inputs(self) -> None: feedables |= set.union( *[ex.feedables for ex in self.model.trainers]) + # collect input shapes and types + input_types = {} # type: Dict[str, tf.DType] + input_shapes = {} # type: Dict[str, tf.TensorShape] + + for feedable in feedables: + input_types.update(feedable.input_types) + input_shapes.update(feedable.input_shapes) + + dataset = {} # type: Dict[str, tf.Tensor] + for s_id, dtype in input_types.items(): + shape = input_shapes[s_id] + dataset[s_id] = tf.placeholder(dtype, shape, s_id) + for feedable in feedables: - feedable.register_input() + feedable.register_input(dataset) - self.model.dataset_runner.register_input() + self.model.dataset_runner.register_input(dataset) def build_model(self) -> None: """Build the configuration and the computational graph. @@ -211,20 +221,8 @@ def build_model(self) -> None: type(self)._current_experiment = None - if self.train_mode: - check_dataset_and_coders(self.model.train_dataset, - self.model.runners) - if isinstance(self.model.val_dataset, Dataset): - check_dataset_and_coders(self.model.val_dataset, - self.model.runners) - else: - for val_dataset in self.model.val_dataset: - check_dataset_and_coders(val_dataset, - self.model.runners) - - if self.train_mode and self.model.visualize_embeddings: - visualize_embeddings(self.model.visualize_embeddings, - self.model.output) + if self.train_mode and self.model.visualize_embeddings is not None: + self.visualize_embeddings() self._check_unused_initializers() @@ -263,17 +261,31 @@ def train(self) -> None: log("Saving final variables in {}".format(final_variables)) self.model.tf_manager.save(final_variables) + if self.model.test_datasets: + if self.model.tf_manager.best_score_index is not None: + self.model.tf_manager.restore_best_vars() + + for test_id, dataset in enumerate(self.model.test_datasets): + self.evaluate(dataset, write_out=True, + name="test_{}".format(test_id)) + log("Finished.") self._vars_loaded = True def load_variables(self, variable_files: List[str] = None) -> None: - """Load variables from files. + """Load variables of the built model from file(s). + + When variable files are not provided, Neural Monkey will try to infer + the name of a default checkpoint file using the following key: + 1. Look for the averaged checkpoints named `variables.data.avg` or + `variables.data.avg-0`. + 2. Look for file `variables.data.best` file which usually contains the + best scoring checkpoint from the run. + 3. Look for the final checkpoint saved in `variables.data.final`. Arguments: - variable_files: A list of checkpoint file prefixes. A TF checkpoint - is usually three files with a common prefix. This list should - have the same number of files as there are sessions in the - `tf_manager` object. + variable_files: A list of variable files to load. The length of + this list should match the number of sessions. """ if not self._model_built: self.build_model() @@ -283,12 +295,16 @@ def load_variables(self, variable_files: List[str] = None) -> None: variable_files = [self.get_path("variables.data.avg-0")] elif os.path.exists(self.get_path("variables.data.avg.index")): variable_files = [self.get_path("variables.data.avg")] - else: + elif os.path.exists(self.get_path("variables.data.best")): best_var_file = self.get_path("variables.data.best") with open(best_var_file, "r") as f_best: var_path = f_best.read().rstrip() variable_files = [os.path.join(self.config.args.output, var_path)] + elif os.path.exists(self.get_path("variables.data.final.index")): + variable_files = [self.get_path("variables.data.final")] + else: + raise RuntimeError("Cannot infer default variables file") log("Default variable file '{}' will be used for loading " "variables.".format(variable_files[0])) @@ -305,7 +321,6 @@ def load_variables(self, variable_files: List[str] = None) -> None: def run_model(self, dataset: Dataset, write_out: bool = False, - batch_size: int = None, log_progress: int = 0) -> Tuple[ List[ExecutionResult], Dict[str, List], Dict[str, List]]: """Run the model on a given dataset. @@ -314,7 +329,6 @@ def run_model(self, dataset: The dataset on which the model will be executed. write_out: Flag whether the outputs should be printed to a file defined in the dataset object. - batch_size: size of the minibatch log_progress: log progress every X seconds Returns: @@ -325,17 +339,7 @@ def run_model(self, if not self._vars_loaded: self.load_variables() - toklevel = self.model.runners_batching_scheme.token_level_batching - assert self.model.runners_batching_scheme.batch_bucket_span is None - - batching_scheme = BatchingScheme( - batch_size=batch_size or self.model.runners_batch_size, - batch_bucket_span=None, - token_level_batching=toklevel, - bucketing_ignore_series=[]) - with self.graph.as_default(): - # TODO: check_dataset_and_coders(dataset, self.model.runners) return run_on_dataset( self.model.tf_manager, self.model.runners, @@ -343,13 +347,11 @@ def run_model(self, dataset, self.model.postprocess, write_out=write_out, - log_progress=log_progress, - batching_scheme=batching_scheme) + log_progress=log_progress) def evaluate(self, dataset: Dataset, write_out: bool = False, - batch_size: int = None, log_progress: int = 0, name: str = None) -> Dict[str, Any]: """Run the model on a given dataset and evaluate the outputs. @@ -358,7 +360,6 @@ def evaluate(self, dataset: The dataset on which the model will be executed. write_out: Flag whether the outputs should be printed to a file defined in the dataset object. - batch_size: size of the minibatch log_progress: log progress every X seconds name: The name of the evaluated dataset @@ -368,14 +369,13 @@ def evaluate(self, run. """ execution_results, output_data, f_dataset = self.run_model( - dataset, write_out, batch_size, log_progress) + dataset, write_out, log_progress) evaluators = [(e[0], e[0], e[1]) if len(e) == 2 else e for e in self.model.evaluation] with self.graph.as_default(): eval_result = evaluation( - evaluators, f_dataset, self.model.runners, - execution_results, output_data) + evaluators, f_dataset, execution_results, output_data) if eval_result: print_final_evaluation(eval_result, name) @@ -419,6 +419,31 @@ def _check_unused_initializers(self) -> None: "Initializers were specified for the following non-existent " "variables: " + ", ".join(unused_initializers)) + def visualize_embeddings(self) -> None: + """Insert visualization of embeddings in TensorBoard. + + Visualize the embeddings of `EmbeddedFactorSequence` objects specified + in the `main.visualize_embeddings` config attribute. + """ + tb_projector = projector.ProjectorConfig() + + for sequence in self.model.visualize_embeddings: + for i, (vocabulary, emb_matrix) in enumerate( + zip(sequence.vocabularies, sequence.embedding_matrices)): + + # TODO when vocabularies will have name parameter, change it + path = self.get_path("seq.{}-{}.tsv".format(sequence.name, i)) + vocabulary.save_wordlist(path) + + embedding = tb_projector.embeddings.add() + # pylint: disable=unsubscriptable-object + embedding.tensor_name = emb_matrix.name + embedding.metadata_path = path + # pylint: enable=unsubscriptable-object + + summary_writer = tf.summary.FileWriter(self.model.output) + projector.visualize_embeddings(summary_writer, tb_projector) + @classmethod def get_current(cls) -> "Experiment": """Return the experiment that is currently being built.""" @@ -430,11 +455,9 @@ def create_config(train_mode: bool = True) -> Configuration: config.add_argument("tf_manager", required=False, default=None) config.add_argument("batch_size", required=False, default=None, cond=lambda x: x is None or x > 0) - config.add_argument("batching_scheme", required=False, default=None) config.add_argument("output") config.add_argument("postprocess", required=False, default=None) config.add_argument("runners") - config.add_argument("runners_batch_size", required=False, default=None) if train_mode: config.add_argument("epochs", cond=lambda x: x >= 0) @@ -522,16 +545,3 @@ def save_git_info(git_commit_file: str, git_diff_file: str, ) else: warn("No git executable found. Not storing git commit and diffs") - - -def visualize_embeddings(sequences: List[EmbeddedFactorSequence], - output_dir: str) -> None: - check_argument_types() - - tb_projector = projector.ProjectorConfig() - - for sequence in sequences: - sequence.tb_embedding_visualization(output_dir, tb_projector) - - summary_writer = tf.summary.FileWriter(output_dir) - projector.visualize_embeddings(summary_writer, tb_projector) diff --git a/neuralmonkey/learning_utils.py b/neuralmonkey/learning_utils.py index d2b810ef8..50e0e0711 100644 --- a/neuralmonkey/learning_utils.py +++ b/neuralmonkey/learning_utils.py @@ -13,10 +13,10 @@ from termcolor import colored from neuralmonkey.logging import log, log_print, warn -from neuralmonkey.dataset import Dataset, BatchingScheme +from neuralmonkey.dataset import Dataset from neuralmonkey.tf_manager import TensorFlowManager from neuralmonkey.runners.base_runner import ( - BaseRunner, ExecutionResult, reduce_execution_results, GraphExecutor) + BaseRunner, ExecutionResult, GraphExecutor, OutputSeries) from neuralmonkey.runners.dataset_runner import DatasetRunner from neuralmonkey.trainers.generic_trainer import GenericTrainer from neuralmonkey.trainers.multitask_trainer import MultitaskTrainer @@ -42,11 +42,9 @@ def training_loop(cfg: Namespace) -> None: cfg: Experiment configuration namespace. """ _check_series_collisions(cfg.runners, cfg.postprocess) - - log_model_variables(cfg.trainers) - - initialize_model(cfg.tf_manager, cfg.initial_variables, - cfg.runners + cfg.trainers) + _log_model_variables(cfg.trainers) + _initialize_model(cfg.tf_manager, cfg.initial_variables, + cfg.runners + cfg.trainers) log("Initializing TensorBoard summary writer.") tb_writer = tf.summary.FileWriter(cfg.output, @@ -66,7 +64,7 @@ def training_loop(cfg: Namespace) -> None: try: for epoch_n in range(1, cfg.epochs + 1): - train_batches = cfg.train_dataset.batches(cfg.batching_scheme) + train_batches = cfg.train_dataset.batches() if epoch_n == 1 and cfg.train_start_offset: if cfg.train_dataset.shuffled and not cfg.train_dataset.lazy: @@ -89,14 +87,13 @@ def training_loop(cfg: Namespace) -> None: summaries=True) train_results, train_outputs, f_batch = run_on_dataset( cfg.tf_manager, cfg.runners, cfg.dataset_runner, batch, - cfg.postprocess, write_out=False, - batching_scheme=cfg.runners_batching_scheme) + cfg.postprocess, write_out=False) # ensure train outputs are iterable more than once train_outputs = { k: list(v) for k, v in train_outputs.items()} + train_evaluation = evaluation( - cfg.evaluation, f_batch, cfg.runners, train_results, - train_outputs) + cfg.evaluation, f_batch, train_results, train_outputs) _log_continuous_evaluation( tb_writer, cfg.main_metric, train_evaluation, @@ -121,14 +118,12 @@ def training_loop(cfg: Namespace) -> None: val_results, val_outputs, f_valset = run_on_dataset( cfg.tf_manager, cfg.runners, cfg.dataset_runner, - valset, cfg.postprocess, write_out=False, - batching_scheme=cfg.runners_batching_scheme) + valset, cfg.postprocess, write_out=False) # ensure val outputs are iterable more than once val_outputs = {k: list(v) for k, v in val_outputs.items()} val_evaluation = evaluation( - cfg.evaluation, f_valset, cfg.runners, val_results, - val_outputs) + cfg.evaluation, f_valset, val_results, val_outputs) valheader = ("Validation (epoch {}, batch number {}):" .format(epoch_n, batch_n)) @@ -194,27 +189,11 @@ def training_loop(cfg: Namespace) -> None: .format(cfg.main_metric, cfg.tf_manager.best_score, cfg.tf_manager.best_score_epoch)) - if cfg.test_datasets: - cfg.tf_manager.restore_best_vars() - - for test_id, dataset in enumerate(cfg.test_datasets): - test_results, test_outputs, f_testset = run_on_dataset( - cfg.tf_manager, cfg.runners, cfg.dataset_runner, dataset, - cfg.postprocess, write_out=True, - batching_scheme=cfg.runners_batching_scheme) - # ensure test outputs are iterable more than once - test_outputs = {k: list(v) for k, v in test_outputs.items()} - eval_result = evaluation(cfg.evaluation, f_testset, cfg.runners, - test_results, test_outputs) - print_final_evaluation(eval_result, "test_{}".format(test_id)) - - log("Finished.") - if interrupt is not None: raise interrupt # pylint: disable=raising-bad-type -def log_model_variables(trainers: List[Trainer]) -> None: +def _log_model_variables(trainers: List[Trainer]) -> None: var_list = list(set().union(*[t.var_list for t in trainers])) \ # type: List[tf.Variable] @@ -257,14 +236,14 @@ def log_model_variables(trainers: List[Trainer]) -> None: log("Total number of all parameters: {}".format(total_params)) -def initialize_model(tf_manager: TensorFlowManager, - initial_variables: Optional[List[str]], - executables: List[GraphExecutor]): +def _initialize_model(tf_manager: TensorFlowManager, + initial_variables: Optional[List[str]], + executables: List[GraphExecutor]): if initial_variables is None: # Assume we don't look at coder checkpoints when global # initial variables are supplied - tf_manager.initialize_model_parts(executables, save=True) + tf_manager.initialize_model_parts(executables) else: try: tf_manager.restore(initial_variables) @@ -297,7 +276,6 @@ def run_on_dataset(tf_manager: TensorFlowManager, dataset_runner: DatasetRunner, dataset: Dataset, postprocess: Postprocess, - batching_scheme: BatchingScheme, write_out: bool = False, log_progress: int = 0) -> Tuple[ List[ExecutionResult], @@ -318,7 +296,6 @@ def run_on_dataset(tf_manager: TensorFlowManager, postprocess: Dataset-level postprocessors write_out: Flag whether the outputs should be printed to a file defined in the dataset object. - batching_scheme: Scheme used for batching. log_progress: log progress every X seconds extra_fetches: Extra tensors to evaluate for each batch. @@ -340,8 +317,10 @@ def run_on_dataset(tf_manager: TensorFlowManager, feedables = set.union(*[runner.feedables for runner in runners]) feedables |= dataset_runner.feedables + fetched_input = {s: [] for s in dataset.series} # type: Dict[str, List] + processed_examples = 0 - for batch in dataset.batches(batching_scheme): + for batch in dataset.batches(): if 0 < log_progress < time.process_time() - last_log_time: log("Processed {} examples.".format(processed_examples)) last_log_time = time.process_time() @@ -358,15 +337,17 @@ def run_on_dataset(tf_manager: TensorFlowManager, for script_list, ex_result in zip(batch_results, execution_results): script_list.append(ex_result) + for s_id in batch.series: + fetched_input[s_id].extend(batch.get_series(s_id)) + # Transpose runner interim results. - all_results = [reduce_execution_results(res) for res in batch_results[:-1]] + all_results = [join_execution_results(res) for res in batch_results[:-1]] # TODO uncomment this when dataset runner starts outputting the dataset - # input_transposed = reduce_execution_results(batch_results[-1]).outputs + # input_transposed = join_execution_results(batch_results[-1]).outputs # fetched_input = { # k: [dic[k] for dic in input_transposed] for k in input_transposed[0]} - fetched_input = {s: list(dataset.get_series(s)) for s in dataset.series} fetched_input_lengths = {s: len(fetched_input[s]) for s in dataset.series} if len(set(fetched_input_lengths.values())) != 1: @@ -376,8 +357,12 @@ def run_on_dataset(tf_manager: TensorFlowManager, dataset_len = fetched_input_lengths[dataset.series[0]] # Convert execution results to dictionary. - result_data = {runner.output_series: result.outputs - for runner, result in zip(runners, all_results)} + result_data = {} # type: Dict[str, Union[List, np.ndarray]] + for s_id, data in ( + pair for res in all_results for pair in res.outputs.items()): + if s_id in result_data: + raise ValueError("Overwriting output series forbidden.") + result_data[s_id] = data # Run dataset-level postprocessing. if postprocess is not None: @@ -410,13 +395,50 @@ def run_on_dataset(tf_manager: TensorFlowManager, return all_results, result_data, fetched_input -def evaluation(evaluators, batch, runners, execution_results, result_data): +def join_execution_results( + execution_results: List[ExecutionResult]) -> ExecutionResult: + """Aggregate batch of execution results from a single runner.""" + + losses_sum = {loss: 0. for loss in execution_results[0].losses} + + def join(output_series: List[OutputSeries]) -> OutputSeries: + """Join a list of batches of results into a flat list of outputs.""" + joined = [] # type: List[Any] + + for item in output_series: + joined.extend(item) + + # If the list is a list of np.arrays, concatenate the list along first + # dimension (batch). Otherwise, return the list. + if joined and isinstance(joined[0], np.ndarray): + return np.array(joined) + + return joined + + outputs = {} # type: Dict[str, Any] + for key in execution_results[0].outputs.keys(): + outputs[key] = join([res.outputs[key] for res in execution_results]) + + for result in execution_results: + for l_id, loss in result.losses.items(): + losses_sum[l_id] += loss * result.size + + total_size = sum(res.size for res in execution_results) + losses = {l_id: loss / total_size for l_id, loss in losses_sum.items()} + + all_summaries = [ + summ for res in execution_results if res.summaries is not None + for summ in res.summaries] + + return ExecutionResult(outputs, losses, total_size, all_summaries) + + +def evaluation(evaluators, batch, execution_results, result_data): """Evaluate the model outputs. Args: evaluators: List of tuples of series and evaluation functions. batch: Batch of data against which the evaluation is done. - runners: List of runners (contains series ids and loss names). execution_results: Execution results that include the loss values. result_data: Dictionary from series names to list of outputs. @@ -427,9 +449,12 @@ def evaluation(evaluators, batch, runners, execution_results, result_data): eval_result = {} # losses - for runner, result in zip(runners, execution_results): - for name, value in zip(runner.loss_names, result.losses): - eval_result["{}/{}".format(runner.output_series, name)] = value + for result in execution_results: + if any(l in eval_result for l in result.losses): + # TODO(tf-data) this will go away with further exec_res refactor + raise ValueError("Duplicate loss result keys found.") + + eval_result.update(result.losses) # evaluation metrics for hypothesis_id, reference_id, function in evaluators: @@ -468,11 +493,8 @@ def _log_continuous_evaluation(tb_writer: tf.summary.FileWriter, if tb_writer: for result in execution_results: - for summaries in [result.scalar_summaries, - result.histogram_summaries, - result.image_summaries]: - if summaries is not None: - tb_writer.add_summary(summaries, seen_instances) + for summaries in result.summaries: + tb_writer.add_summary(summaries, seen_instances) external_str = \ tf.Summary(value=[tf.Summary.Value(tag=prefix + "_" + name, diff --git a/neuralmonkey/model/feedable.py b/neuralmonkey/model/feedable.py index 42652a4b6..604c72655 100644 --- a/neuralmonkey/model/feedable.py +++ b/neuralmonkey/model/feedable.py @@ -1,6 +1,8 @@ from abc import ABCMeta + +from typing import Any, Dict, List # pylint: disable=unused-import -from typing import Any, Dict, List, Optional +from typing import Optional # pylint: enable=unused-import import tensorflow as tf @@ -60,10 +62,5 @@ def dataset(self) -> Dict[str, tf.Tensor]: raise RuntimeError("Getting dataset before registering it.") return self._dataset - def register_input(self) -> None: - assert self.input_types.keys() == self.input_shapes.keys() - self._dataset = {} - - for s_id, dtype in self.input_types.items(): - shape = self.input_shapes[s_id] - self.dataset[s_id] = tf.placeholder(dtype, shape, s_id) + def register_input(self, dataset: Dict[str, tf.Tensor]) -> None: + self._dataset = dataset diff --git a/neuralmonkey/model/sequence.py b/neuralmonkey/model/sequence.py index ad0b4ca59..9e66aeb98 100644 --- a/neuralmonkey/model/sequence.py +++ b/neuralmonkey/model/sequence.py @@ -1,10 +1,8 @@ """Module which impements the sequence class and a few of its subclasses.""" -import os from typing import List, Dict import tensorflow as tf -from tensorflow.contrib.tensorboard.plugins import projector from typeguard import check_argument_types from neuralmonkey.dataset import Dataset @@ -14,7 +12,7 @@ from neuralmonkey.model.parameterized import InitializerSpecs from neuralmonkey.model.stateful import TemporalStateful from neuralmonkey.tf_utils import get_variable -from neuralmonkey.vocabulary import Vocabulary, PAD_TOKEN_INDEX +from neuralmonkey.vocabulary import Vocabulary, pad_batch, sentence_mask # pylint: disable=abstract-method @@ -133,40 +131,21 @@ def __init__(self, @property def input_types(self) -> Dict[str, tf.DType]: - return {d_id: tf.int32 for d_id in self.data_ids} + return {d_id: tf.string for d_id in self.data_ids} @property def input_shapes(self) -> Dict[str, tf.TensorShape]: return {d_id: tf.TensorShape([None, None]) for d_id in self.data_ids} + @tensor + def input_factor_indices(self) -> List[tf.Tensor]: + return [vocab.strings_to_indices(factor) for + vocab, factor in zip(self.vocabularies, self.input_factors)] + @tensor def input_factors(self) -> List[tf.Tensor]: return [self.dataset[s_id] for s_id in self.data_ids] - # TODO this should be placed into the abstract embedding class - def tb_embedding_visualization(self, logdir: str, - prj: projector): - """Link embeddings with vocabulary wordlist. - - Used for tensorboard visualization. - - Arguments: - logdir: directory where model is stored - projector: TensorBoard projector for storing linking info. - """ - for i in range(len(self.vocabularies)): - # the overriding is turned to true, because if the model would not - # be allowed to override the output folder it would failed earlier. - # TODO when vocabularies will have name parameter, change it - metadata_path = self.name + "_" + str(i) + ".tsv" - self.vocabularies[i].save_wordlist( - os.path.join(logdir, metadata_path), True, True) - - embedding = prj.embeddings.add() - # pylint: disable=unsubscriptable-object - embedding.tensor_name = self.embedding_matrices[i].name - embedding.metadata_path = metadata_path - @tensor def embedding_matrices(self) -> List[tf.Tensor]: """Return a list of embedding matrices for each factor.""" @@ -195,7 +174,7 @@ def temporal_states(self) -> tf.Tensor: """ embedded_factors = [] for (factor, embedding_matrix) in zip( - self.input_factors, self.embedding_matrices): + self.input_factor_indices, self.embedding_matrices): emb_factor = tf.nn.embedding_lookup(embedding_matrix, factor) # github.com/tensorflow/tensor2tensor/blob/v1.5.6/tensor2tensor/ @@ -214,8 +193,7 @@ def temporal_states(self) -> tf.Tensor: # pylint: disable=unsubscriptable-object @tensor def temporal_mask(self) -> tf.Tensor: - return tf.to_float(tf.not_equal( - self.input_factors[0], PAD_TOKEN_INDEX)) + return sentence_mask(self.input_factor_indices[0]) # pylint: enable=unsubscriptable-object def feed_dict(self, dataset: Dataset, train: bool = False) -> FeedDict: @@ -232,27 +210,11 @@ def feed_dict(self, dataset: Dataset, train: bool = False) -> FeedDict: fd = ModelPart.feed_dict(self, dataset, train) # for checking the lengths of individual factors - arr_strings = [] - last_paddings = None - - for factor_plc, name, vocabulary in zip( - self.input_factors, self.data_ids, self.vocabularies): - factors = dataset.get_series(name) - vectors, paddings = vocabulary.sentences_to_tensor( - list(factors), self.max_length, pad_to_max_len=False, - train_mode=train, add_start_symbol=self.add_start_symbol, - add_end_symbol=self.add_end_symbol) - - fd[factor_plc] = list(zip(*vectors)) - - arr_strings.append(paddings.tostring()) - last_paddings = paddings - - if len(set(arr_strings)) > 1: - raise ValueError("The lenghts of factors do not match") - - assert last_paddings is not None - # fd[self.mask] = list(zip(*last_paddings)) + for factor_plc, name in zip(self.input_factors, self.data_ids): + sentences = dataset.get_series(name) + fd[factor_plc] = pad_batch( + list(sentences), self.max_length, self.add_start_symbol, + self.add_end_symbol) return fd @@ -314,7 +276,7 @@ def __init__(self, @property def inputs(self) -> tf.Tensor: """Return a 2D placeholder for the sequence inputs.""" - return self.input_factors[0] + return self.input_factor_indices[0] @property def embedding_matrix(self) -> tf.Tensor: diff --git a/neuralmonkey/processors/wordpiece.py b/neuralmonkey/processors/wordpiece.py index fedf9974d..4a8fe4185 100644 --- a/neuralmonkey/processors/wordpiece.py +++ b/neuralmonkey/processors/wordpiece.py @@ -81,7 +81,7 @@ def wordpiece_encode(sentence: List[str], vocabulary: Vocabulary) -> List[str]: for end in range(token_len, current_subtoken_start, -1): subtoken = esc_token[current_subtoken_start:end] - if subtoken in vocabulary.word_to_index: + if subtoken in vocabulary: subtokens.append(subtoken) current_subtoken_start = end break diff --git a/neuralmonkey/readers/image_reader.py b/neuralmonkey/readers/image_reader.py index 3d92e4833..1eab16932 100644 --- a/neuralmonkey/readers/image_reader.py +++ b/neuralmonkey/readers/image_reader.py @@ -1,13 +1,19 @@ from typing import Callable, Iterable, List import os -from typeguard import check_argument_types + import numpy as np +from typeguard import check_argument_types from PIL import Image, ImageFile + +from neuralmonkey.logging import warn + + ImageFile.LOAD_TRUNCATED_IMAGES = True def image_reader(pad_w: int, pad_h: int, + channels: int = 3, prefix: str = "", rescale_w: bool = False, rescale_h: bool = False, @@ -17,7 +23,8 @@ def image_reader(pad_w: int, Args: pad_w: Width to which the images will be padded/cropped/resized. - pad_h: Height to with the images will be padded/corpped/resized. + pad_h: Height to which the images will be padded/cropped/resized. + channels: Number of channels in each image (default 3 for RGB) prefix: Prefix of the paths that are listed in a image files. rescale_w: If true, image is rescaled to have given width. It is cropped/padded otherwise. @@ -57,6 +64,8 @@ def load(list_files: List[str]) -> Iterable[np.ndarray]: try: image = Image.open(path).convert(mode) except IOError: + warn("Skipping image from file '{}' no. '{}'.".format( + path, i + 1)) image = Image.new(mode, (pad_w, pad_h)) image = _rescale_or_crop(image, pad_w, pad_h, @@ -65,16 +74,22 @@ def load(list_files: List[str]) -> Iterable[np.ndarray]: image_np = np.array(image) if len(image_np.shape) == 2: - channels = 1 + img_channels = 1 image_np = np.expand_dims(image_np, 2) elif len(image_np.shape) == 3: - channels = image_np.shape[2] + img_channels = image_np.shape[2] else: raise ValueError( ("Image should have either 2 (black and white) " "or three dimensions (color channels), has {} " "dimension.").format(len(image_np.shape))) + if channels != img_channels: + raise ValueError( + "Image does not have the pre-declared number of " + "channels {}, but {}.".format( + channels, img_channels)) + yield _pad(image_np, pad_w, pad_h, channels) return load diff --git a/neuralmonkey/readers/numpy_reader.py b/neuralmonkey/readers/numpy_reader.py index 8db8784a0..3406735ea 100644 --- a/neuralmonkey/readers/numpy_reader.py +++ b/neuralmonkey/readers/numpy_reader.py @@ -15,12 +15,14 @@ def single_tensor(files: List[str]) -> np.ndarray: def from_file_list(prefix: str, + shape: List[int], suffix: str = "", default_tensor_name: str = "arr_0") -> Callable: """Load a list of numpy arrays from a list of .npz numpy files. Args: prefix: A common prefix for the files in the list. + shape: The shape of the numpy arrays stored in the referenced files. suffix: An optional suffix that will be appended to each path default_tensor_name: Key of the tensors to load from the npz files. @@ -35,10 +37,11 @@ def load(files: List[str]) -> Iterable[np.ndarray]: for line in f_list: path = os.path.join(prefix, line.rstrip()) + suffix with np.load(path) as npz: - yield npz[default_tensor_name] - + arr = npz[default_tensor_name] + arr_shape = list(arr.shape) + if arr_shape != shape: + raise ValueError( + "Shapes do not match: expected {}, found {}" + .format(shape, arr_shape)) + yield arr return load - - -# pylint: disable=invalid-name -numpy_file_list_reader = from_file_list(prefix="") diff --git a/neuralmonkey/run.py b/neuralmonkey/run.py index f50711bd5..79ac7a05c 100644 --- a/neuralmonkey/run.py +++ b/neuralmonkey/run.py @@ -15,7 +15,6 @@ def load_runtime_config(config_path: str) -> argparse.Namespace: """Load a runtime configuration file.""" cfg = Configuration() cfg.add_argument("test_datasets") - cfg.add_argument("batch_size", cond=lambda x: x > 0) cfg.add_argument("variables", cond=lambda x: isinstance(x, list)) cfg.load_file(config_path) @@ -68,13 +67,9 @@ def main() -> None: dataset = dataset.subset(start, length) if exp.config.args.evaluation is None: - exp.run_model(dataset, - write_out=True, - batch_size=datasets_model.batch_size) + exp.run_model(dataset, write_out=True) else: - eval_result = exp.evaluate(dataset, - write_out=True, - batch_size=datasets_model.batch_size) + eval_result = exp.evaluate(dataset, write_out=True) results.append(eval_result) if args.json: diff --git a/neuralmonkey/runners/base_runner.py b/neuralmonkey/runners/base_runner.py index 8075563c2..e132331fb 100644 --- a/neuralmonkey/runners/base_runner.py +++ b/neuralmonkey/runners/base_runner.py @@ -1,5 +1,5 @@ from abc import abstractmethod, abstractproperty -from typing import (Any, Dict, Tuple, List, NamedTuple, Union, Set, TypeVar, +from typing import (Dict, Tuple, List, NamedTuple, Union, Set, TypeVar, Generic, Optional) import numpy as np import tensorflow as tf @@ -14,27 +14,28 @@ MP = TypeVar("MP", bound=GenericModelPart) Executor = TypeVar("Executor", bound="GraphExecutor") Runner = TypeVar("Runner", bound="BaseRunner") +OutputSeries = Union[List, np.ndarray] # pylint: enable=invalid-name class ExecutionResult(NamedTuple( "ExecutionResult", - [("outputs", List[Any]), - ("losses", List[float]), - ("scalar_summaries", tf.Summary), - ("histogram_summaries", tf.Summary), - ("image_summaries", tf.Summary)])): + [("outputs", Dict[str, OutputSeries]), + ("losses", Dict[str, float]), + ("size", int), + ("summaries", List[tf.Summary])])): """A data structure that represents the result of a graph execution. - The goal of each runner is to populate this structure and set it as its - ``self._result``. + The goal of each graph executor is to populate this structure using its + ``set_result`` function. Attributes: - outputs: A batch of outputs of the runner. + outputs: A dictionary mapping an output series to the batch of + outputs of the graph executor. losses: A (possibly empty) list of loss values computed during the run. - scalar_summaries: A TensorFlow summary object with scalar values. - histogram_summaries: A TensorFlow summary object with histograms. - image_summaries: A TensorFlow summary object with images. + size: The length of the output batch. + summaries: A list of TensorFlow summary objects fetched by the graph + executor """ @@ -87,13 +88,12 @@ def __init__(self, self._result = None # type: Optional[ExecutionResult] - def set_result(self, outputs: List[Any], losses: List[float], - scalar_summaries: tf.Summary, - histogram_summaries: tf.Summary, - image_summaries: tf.Summary) -> None: - self._result = ExecutionResult( - outputs, losses, scalar_summaries, histogram_summaries, - image_summaries) + def set_result(self, + outputs: Dict[str, OutputSeries], + losses: Dict[str, float], + size: int, + summaries: List[tf.Summary]) -> None: + self._result = ExecutionResult(outputs, losses, size, summaries) @property def result(self) -> Optional[ExecutionResult]: @@ -161,6 +161,21 @@ def next_to_execute(self) -> NextExecute: fetches[loss] = tf.zeros([]) return fetches, [] + + def set_runner_result(self, outputs: OutputSeries, + losses: List[float], size: int = None, + summaries: List[tf.Summary] = None) -> None: + if summaries is None: + summaries = [] + + if size is None: + size = len(outputs) + + loss_names = ["{}/{}".format(self.executor.output_series, loss) + for loss in self.executor.loss_names] + + self.set_result({self.executor.output_series: outputs}, + dict(zip(loss_names, losses)), size, summaries) # pylint: enable=too-few-public-methods def __init__(self, @@ -178,22 +193,3 @@ def decoder_data_id(self) -> Optional[str]: @property def loss_names(self) -> List[str]: raise NotImplementedError() - - -def reduce_execution_results( - execution_results: List[ExecutionResult]) -> ExecutionResult: - """Aggregate execution results into one.""" - outputs = [] # type: List[Any] - losses_sum = [0. for _ in execution_results[0].losses] - for result in execution_results: - outputs.extend(result.outputs) - for i, loss in enumerate(result.losses): - losses_sum[i] += loss - # TODO aggregate TensorBoard summaries - if outputs and isinstance(outputs[0], np.ndarray): - outputs = np.array(outputs) - losses = [l / max(len(outputs), 1) for l in losses_sum] - return ExecutionResult(outputs, losses, - execution_results[0].scalar_summaries, - execution_results[0].histogram_summaries, - execution_results[0].image_summaries) diff --git a/neuralmonkey/runners/beamsearch_runner.py b/neuralmonkey/runners/beamsearch_runner.py index 92a384afc..14a6fe486 100644 --- a/neuralmonkey/runners/beamsearch_runner.py +++ b/neuralmonkey/runners/beamsearch_runner.py @@ -101,11 +101,9 @@ def prepare_results(self, output): # TODO: provide better summaries in case (issue #599) # we want to use the runner during training. - self.set_result(outputs=decoded_tokens, - losses=[np.mean(bs_scores) * len(bs_scores)], - scalar_summaries=None, - histogram_summaries=None, - image_summaries=None) + self.set_runner_result( + outputs=decoded_tokens, + losses=[np.mean(bs_scores) * len(bs_scores)]) def _is_finished(self, results): finished = [ diff --git a/neuralmonkey/runners/ctc_debug_runner.py b/neuralmonkey/runners/ctc_debug_runner.py index 4c1e63e93..1e393619e 100644 --- a/neuralmonkey/runners/ctc_debug_runner.py +++ b/neuralmonkey/runners/ctc_debug_runner.py @@ -36,9 +36,7 @@ def collect_results(self, results: List[Dict]) -> None: decoded_instance.append(symbol) decoded_batch.append(decoded_instance) - self.set_result(outputs=decoded_batch, losses=[], - scalar_summaries=None, histogram_summaries=None, - image_summaries=None) + self.set_runner_result(outputs=decoded_batch, losses=[]) # pylint: enable=too-few-public-methods def __init__(self, diff --git a/neuralmonkey/runners/dataset_runner.py b/neuralmonkey/runners/dataset_runner.py index 64088c48d..812058dc1 100644 --- a/neuralmonkey/runners/dataset_runner.py +++ b/neuralmonkey/runners/dataset_runner.py @@ -13,11 +13,8 @@ class Executable(GraphExecutor.Executable["DatasetRunner"]): def collect_results(self, results: List[Dict]) -> None: res = results[0] - # convert bytes to str, probably here. - - data = [dict(zip(res, series)) for series in zip(*res.values())] - - self.set_result(data, [], None, None, None) + size = res["batch"] + self.set_result(res, {}, size, []) # pylint: enable=too-few-public-methods def __init__(self) -> None: @@ -27,4 +24,5 @@ def __init__(self) -> None: @tensor def fetches(self) -> Dict[str, tf.Tensor]: assert self.dataset is not None - return self.dataset + # TODO(tf-data) this will change to fetch real data + return {"batch": self.batch_size} diff --git a/neuralmonkey/runners/label_runner.py b/neuralmonkey/runners/label_runner.py index 90cdaac78..0ab0303b8 100644 --- a/neuralmonkey/runners/label_runner.py +++ b/neuralmonkey/runners/label_runner.py @@ -45,9 +45,7 @@ def collect_results(self, results: List[Dict]) -> None: if self.executor.postprocess is not None: decoded_labels = self.executor.postprocess(decoded_labels) - self.set_result(outputs=decoded_labels, losses=[loss], - scalar_summaries=None, histogram_summaries=None, - image_summaries=None) + self.set_runner_result(outputs=decoded_labels, losses=[loss]) # pylint: enable=too-few-public-methods def __init__(self, diff --git a/neuralmonkey/runners/logits_runner.py b/neuralmonkey/runners/logits_runner.py index 9d19a988e..9868fa3a0 100644 --- a/neuralmonkey/runners/logits_runner.py +++ b/neuralmonkey/runners/logits_runner.py @@ -52,11 +52,8 @@ def collect_results(self, results: List[Dict]) -> None: str_outputs = [["\t".join(l)] for l in outputs] - self.set_result(outputs=str_outputs, - losses=[train_loss, runtime_loss], - scalar_summaries=None, - histogram_summaries=None, - image_summaries=None) + self.set_runner_result(outputs=str_outputs, + losses=[train_loss, runtime_loss]) def __init__(self, output_series: str, @@ -88,8 +85,8 @@ def __init__(self, self.normalize = normalize if pick_value is not None: if pick_value in self.decoder.vocabulary: - vocab_map = self.decoder.vocabulary.word_to_index - self.pick_index = vocab_map[pick_value] + self.pick_index = self.decoder.vocabulary.index_to_word.index( + pick_value) else: raise ValueError( "Value '{}' is not in vocabulary of decoder '{}'".format( diff --git a/neuralmonkey/runners/perplexity_runner.py b/neuralmonkey/runners/perplexity_runner.py index eadbb9f90..89797cac1 100644 --- a/neuralmonkey/runners/perplexity_runner.py +++ b/neuralmonkey/runners/perplexity_runner.py @@ -19,11 +19,8 @@ def collect_results(self, results: List[Dict]) -> None: perplexities = np.mean( [2 ** res["xents"] for res in results], axis=0) xent = float(np.mean([res["xents"] for res in results])) - self.set_result(outputs=perplexities.tolist(), - losses=[xent], - scalar_summaries=None, - histogram_summaries=None, - image_summaries=None) + self.set_runner_result(outputs=perplexities.tolist(), + losses=[xent]) # pylint: enable=too-few-public-methods def __init__(self, @@ -39,5 +36,4 @@ def fetches(self) -> Dict[str, tf.Tensor]: @property def loss_names(self) -> List[str]: - # TODO(tf-data) Shouldn't be "xents" here? return ["xent"] diff --git a/neuralmonkey/runners/plain_runner.py b/neuralmonkey/runners/plain_runner.py index 215a46994..e2ef389d0 100644 --- a/neuralmonkey/runners/plain_runner.py +++ b/neuralmonkey/runners/plain_runner.py @@ -40,11 +40,8 @@ def collect_results(self, results: List[Dict]) -> None: if self.executor.postprocess is not None: decoded_tokens = self.executor.postprocess(decoded_tokens) - self.set_result(outputs=decoded_tokens, - losses=[train_loss, runtime_loss], - scalar_summaries=None, - histogram_summaries=None, - image_summaries=None) + self.set_runner_result(outputs=decoded_tokens, + losses=[train_loss, runtime_loss]) # pylint: enable=too-few-public-methods def __init__(self, diff --git a/neuralmonkey/runners/regression_runner.py b/neuralmonkey/runners/regression_runner.py index c285a18f0..788f50b34 100644 --- a/neuralmonkey/runners/regression_runner.py +++ b/neuralmonkey/runners/regression_runner.py @@ -30,16 +30,12 @@ def collect_results(self, results: List[Dict]) -> None: predictions_sum += sess_result["prediction"] - predictions = predictions_sum / len(results) + predictions = (predictions_sum / len(results)).tolist() if self.executor.postprocess is not None: predictions = self.executor.postprocess(predictions) - self.set_result(outputs=predictions.tolist(), - losses=[mse_loss], - scalar_summaries=None, - histogram_summaries=None, - image_summaries=None) + self.set_runner_result(outputs=predictions, losses=[mse_loss]) # pylint: enable=too-few-public-methods def __init__(self, diff --git a/neuralmonkey/runners/runner.py b/neuralmonkey/runners/runner.py index 3106ad9b5..dc3982fe4 100644 --- a/neuralmonkey/runners/runner.py +++ b/neuralmonkey/runners/runner.py @@ -54,13 +54,13 @@ def collect_results(self, results: List[Dict]) -> None: if self.executor.postprocess is not None: decoded_tokens = self.executor.postprocess(decoded_tokens) - image_summaries = results[0].get("image_summaries") + summaries = None + if "image_summaries" in results[0]: + summaries = [results[0]["image_summaries"]] - self.set_result(outputs=decoded_tokens, - losses=[train_loss, runtime_loss], - scalar_summaries=None, - histogram_summaries=None, - image_summaries=image_summaries) + self.set_runner_result( + outputs=decoded_tokens, losses=[train_loss, runtime_loss], + summaries=summaries) def __init__(self, output_series: str, diff --git a/neuralmonkey/runners/tensor_runner.py b/neuralmonkey/runners/tensor_runner.py index 6adf2fc67..b5404e83d 100644 --- a/neuralmonkey/runners/tensor_runner.py +++ b/neuralmonkey/runners/tensor_runner.py @@ -38,9 +38,7 @@ def collect_results(self, results: List[Dict]) -> None: else: batched = self._fetch_values_from_session(results[0]) - self.set_result(outputs=batched, losses=[], - scalar_summaries=None, histogram_summaries=None, - image_summaries=None) + self.set_runner_result(outputs=batched, losses=[]) def _fetch_values_from_session(self, sess_results: Dict) -> List: diff --git a/neuralmonkey/runners/word_alignment_runner.py b/neuralmonkey/runners/word_alignment_runner.py index 96e93239a..7a157b412 100644 --- a/neuralmonkey/runners/word_alignment_runner.py +++ b/neuralmonkey/runners/word_alignment_runner.py @@ -16,9 +16,7 @@ class WordAlignmentRunner(BaseRunner[BaseAttention]): class Executable(BaseRunner.Executable["WordAlignmentRunner"]): def collect_results(self, results: List[Dict]) -> None: - self.set_result(outputs=results[0]["alignment"], losses=[], - scalar_summaries=None, histogram_summaries=None, - image_summaries=None) + self.set_runner_result(outputs=results[0]["alignment"], losses=[]) # pylint: enable=too-few-public-methods def __init__(self, diff --git a/neuralmonkey/server/server.py b/neuralmonkey/server/server.py index 6c35162bd..745e78f8d 100644 --- a/neuralmonkey/server/server.py +++ b/neuralmonkey/server/server.py @@ -12,7 +12,7 @@ import numpy as np from neuralmonkey.config.configuration import Configuration -from neuralmonkey.dataset import Dataset +from neuralmonkey.dataset import Dataset, BatchingScheme from neuralmonkey.experiment import Experiment @@ -33,7 +33,8 @@ def get_file(filename): # pragma: no cover def run(data): # pragma: no cover exp = APP.config["experiment"] dataset = Dataset( - "request", data, {}, preprocessors=APP.config["preprocess"]) + "request", data, BatchingScheme(batch_size=1), {}, + preprocessors=APP.config["preprocess"]) _, response_data, _ = exp.run_model(dataset, write_out=False) diff --git a/neuralmonkey/tests/test_dataset.py b/neuralmonkey/tests/test_dataset.py index 07371a6c4..e8e1cdd1f 100644 --- a/neuralmonkey/tests/test_dataset.py +++ b/neuralmonkey/tests/test_dataset.py @@ -5,33 +5,23 @@ import tempfile import unittest -from neuralmonkey.dataset import Dataset, from_files, load, BatchingScheme +from neuralmonkey.dataset import Dataset, load, BatchingScheme from neuralmonkey.readers.plain_text_reader import UtfPlainTextReader -DEFAULT_BATCHING_SCHEME = BatchingScheme( - batch_size=3, - batch_bucket_span=None, - token_level_batching=False, - bucketing_ignore_series=[]) +DEFAULT_BATCHING_SCHEME = BatchingScheme(batch_size=3) class TestDataset(unittest.TestCase): - def test_nonexistent_file(self): + def test_nonexistent_file(self) -> None: with self.assertRaises(FileNotFoundError): load(name="name", series=["source"], data=[(["some_nonexistent_file"], UtfPlainTextReader)], + batching=DEFAULT_BATCHING_SCHEME, buffer_size=5) - def test_nonexistent_file_deprec(self): - with self.assertRaises(FileNotFoundError): - from_files( - name="name", - s_source=(["some_nonexistent_file"], UtfPlainTextReader), - lazy=True) - - def test_lazy_dataset(self): + def test_lazy_dataset(self) -> None: i = 0 # iteration counter def reader(files: List[str]) -> Iterable[List[str]]: @@ -43,7 +33,9 @@ def reader(files: List[str]) -> Iterable[List[str]]: dataset = load( name="data", series=["source", "source_prep"], - data=[([], reader), (lambda x: x, "source")], + data=[(["tests/data/train.tc.en"], reader), + (lambda x: x, "source")], + batching=DEFAULT_BATCHING_SCHEME, buffer_size=5) series = dataset.get_series("source_prep") @@ -53,28 +45,6 @@ def reader(files: List[str]) -> Iterable[List[str]]: self.assertEqual(i, j) self.assertEqual(i, 9) - def test_lazy_dataset_deprec(self): - i = 0 # iteration counter - - def reader(files: List[str]) -> Iterable[List[str]]: - del files - nonlocal i - for i in range(10): # pylint: disable=unused-variable - yield ["foo"] - - dataset = from_files( - name="data", - s_source=([], reader), - preprocessors=[("source", "source_prep", lambda x: x)], - lazy=True) - - series = dataset.get_series("source_prep") - - # Check that the reader is being iterated lazily - for j, _ in enumerate(series): - self.assertEqual(i, j) - self.assertEqual(i, 9) - def test_glob(self): filenames = sorted(["abc1", "abc2", "abcxx", "xyz"]) contents = ["a", "b", "c", "d"] @@ -87,23 +57,8 @@ def test_glob(self): name="dataset", series=["data"], data=[[os.path.join(tmp_dir, "abc?"), - os.path.join(tmp_dir, "xyz*")]]) - - series_iterator = dataset.get_series("data") - self.assertEqual(list(series_iterator), [["a"], ["b"], ["d"]]) - - def test_glob_deprec(self): - filenames = sorted(["abc1", "abc2", "abcxx", "xyz"]) - contents = ["a", "b", "c", "d"] - with tempfile.TemporaryDirectory() as tmp_dir: - for fname, text in zip(filenames, contents): - with open(os.path.join(tmp_dir, fname), "w") as file: - print(text, file=file) - - dataset = from_files( - name="dataset", - s_data=[os.path.join(tmp_dir, "abc?"), - os.path.join(tmp_dir, "xyz*")]) + os.path.join(tmp_dir, "xyz*")]], + batching=DEFAULT_BATCHING_SCHEME) series_iterator = dataset.get_series("data") self.assertEqual(list(series_iterator), [["a"], ["b"], ["d"]]) @@ -115,12 +70,13 @@ def test_batching_eager_noshuffle(self): } dataset = Dataset( - "dataset", iterators=iterators, shuffled=False) + "dataset", iterators=iterators, batching=DEFAULT_BATCHING_SCHEME, + shuffled=False) batches = [] for epoch in range(2): epoch = [] - for batch in dataset.batches(DEFAULT_BATCHING_SCHEME): + for batch in dataset.batches(): epoch.append({s: list(batch.get_series(s)) for s in iterators}) batches.append(epoch) @@ -138,12 +94,13 @@ def test_batching_lazy_noshuffle(self): } dataset = Dataset( - "dataset", iterators=iterators, shuffled=False, buffer_size=(3, 5)) + "dataset", iterators=iterators, batching=DEFAULT_BATCHING_SCHEME, + shuffled=False, buffer_size=(3, 5)) batches = [] for epoch in range(2): epoch = [] - for batch in dataset.batches(DEFAULT_BATCHING_SCHEME): + for batch in dataset.batches(): epoch.append({s: list(batch.get_series(s)) for s in iterators}) batches.append(epoch) @@ -160,12 +117,13 @@ def test_batching_eager_shuffle(self): "b": lambda: range(5, 10) } - dataset = Dataset("dataset", iterators=iterators, shuffled=True) + dataset = Dataset("dataset", iterators=iterators, + batching=DEFAULT_BATCHING_SCHEME, shuffled=True) batches = [] for epoch in range(2): epoch = [] - for batch in dataset.batches(DEFAULT_BATCHING_SCHEME): + for batch in dataset.batches(): epoch.append({s: list(batch.get_series(s)) for s in iterators}) batches.append(epoch) @@ -187,12 +145,13 @@ def test_batching_lazy_shuffle(self): } dataset = Dataset( - "dataset", iterators=iterators, shuffled=True, buffer_size=(3, 5)) + "dataset", iterators=iterators, batching=DEFAULT_BATCHING_SCHEME, + shuffled=True, buffer_size=(3, 5)) batches = [] for epoch in range(2): epoch = [] - for batch in dataset.batches(DEFAULT_BATCHING_SCHEME): + for batch in dataset.batches(): epoch.append({s: list(batch.get_series(s)) for s in iterators}) batches.append(epoch) @@ -215,15 +174,17 @@ def test_bucketing(self): for l in range(1, 50)) } - dataset = Dataset("dataset", iterators=iterators, shuffled=False) - # we use batch size 7 and bucket span 10 - scheme = BatchingScheme(7, 10, False, None, True) + scheme = BatchingScheme(bucket_boundaries=[9, 19, 29, 39, 49], + bucket_batch_sizes=[7, 7, 7, 7, 7, 7]) + + dataset = Dataset("dataset", iterators=iterators, + batching=scheme, shuffled=False) # we process the dataset in two epochs and save what did the batches # look like batches = [] - for batch in dataset.batches(scheme): + for batch in dataset.batches(): batches.append(list(batch.get_series("sentences"))) ref_batches = [ @@ -248,15 +209,17 @@ def test_bucketing_no_leftovers(self): for l in range(1, 50)) } - dataset = Dataset("dataset", iterators=iterators, shuffled=False) - # we use batch size 7 and bucket span 10 - scheme = BatchingScheme(7, 10, False, None, False) + scheme = BatchingScheme(bucket_boundaries=[9, 19, 29, 39, 49], + bucket_batch_sizes=[7, 7, 7, 7, 7, 7], + drop_remainder=True) + dataset = Dataset("dataset", iterators=iterators, batching=scheme, + shuffled=False) # we process the dataset in two epochs and save what did the batches # look like batches = [] - for batch in dataset.batches(scheme): + for batch in dataset.batches(): batches.append(list(batch.get_series("sentences"))) ref_batches = [ @@ -275,15 +238,16 @@ def test_buckets_similar_size(self): for l in range(6)] * 3 } - dataset = Dataset("dataset", iterators=iterators, shuffled=True) - # we use batch size 6 and bucket span 2 - scheme = BatchingScheme(6, 2, False, None) + scheme = BatchingScheme(bucket_boundaries=[1, 3, 5], + bucket_batch_sizes=[6, 6, 6, 6]) + dataset = Dataset("dataset", iterators=iterators, batching=scheme, + shuffled=True) # we process the dataset in two epochs and save what did the batches # look like batches = [] - for batch in dataset.batches(scheme): + for batch in dataset.batches(): batches.append(list(batch.get_series("sentences"))) # this setup should divide the data to 3 batches diff --git a/neuralmonkey/tests/test_decoder.py b/neuralmonkey/tests/test_decoder.py index 7c89865b1..cb7564c0f 100644 --- a/neuralmonkey/tests/test_decoder.py +++ b/neuralmonkey/tests/test_decoder.py @@ -3,38 +3,45 @@ """ Unit tests for the decoder. (Tests only initialization so far) """ import unittest -import copy +import tensorflow as tf from neuralmonkey.decoders.decoder import Decoder from neuralmonkey.vocabulary import Vocabulary -DECODER_PARAMS = dict( - encoders=[], - vocabulary=Vocabulary(), - data_id="foo", - name="test-decoder", - max_output_len=5, - dropout_keep_prob=1.0, - embedding_size=10, - rnn_size=10) - class TestDecoder(unittest.TestCase): + @classmethod + def setUpClass(cls): + tf.reset_default_graph() + + def setUp(self): + self.decoder_params = dict( + encoders=[], + vocabulary=Vocabulary(["a", "b", "c"]), + data_id="foo", + name="test-decoder", + max_output_len=5, + dropout_keep_prob=1.0, + embedding_size=10, + rnn_size=10) + + @classmethod + def tearDownClass(cls): + tf.reset_default_graph() + def test_init(self): - decoder = Decoder(**DECODER_PARAMS) + decoder = Decoder(**self.decoder_params) self.assertIsNotNone(decoder) def test_max_output_len(self): - dparams = copy.deepcopy(DECODER_PARAMS) - + dparams = self.decoder_params dparams["max_output_len"] = -10 with self.assertRaises(ValueError): Decoder(**dparams) def test_dropout(self): - dparams = copy.deepcopy(DECODER_PARAMS) - + dparams = self.decoder_params dparams["dropout_keep_prob"] = -0.5 with self.assertRaises(ValueError): Decoder(**dparams) @@ -44,8 +51,7 @@ def test_dropout(self): Decoder(**dparams) def test_embedding_size(self): - dparams = copy.deepcopy(DECODER_PARAMS) - + dparams = self.decoder_params dparams["embedding_size"] = None with self.assertRaises(ValueError): dec = Decoder(**dparams) @@ -56,7 +62,7 @@ def test_embedding_size(self): Decoder(**dparams) def test_cell_type(self): - dparams = copy.deepcopy(DECODER_PARAMS) + dparams = self.decoder_params dparams.update({"rnn_cell": "bogus_cell"}) with self.assertRaises(ValueError): diff --git a/neuralmonkey/tests/test_encoders_init.py b/neuralmonkey/tests/test_encoders_init.py index fe39373c6..b23197d24 100755 --- a/neuralmonkey/tests/test_encoders_init.py +++ b/neuralmonkey/tests/test_encoders_init.py @@ -9,9 +9,10 @@ from neuralmonkey.encoders.recurrent import SentenceEncoder from neuralmonkey.encoders.sentence_cnn_encoder import SentenceCNNEncoder from neuralmonkey.model.sequence import EmbeddedSequence -from neuralmonkey.tests.test_vocabulary import VOCABULARY +from neuralmonkey.vocabulary import Vocabulary +VOCABULARY = Vocabulary(["ich", "bin", "der", "walrus"]) INPUT_SEQUENCE = EmbeddedSequence("seq", VOCABULARY, "marmelade", 300) SENTENCE_ENCODER_GOOD = { diff --git a/neuralmonkey/tests/test_model_part.py b/neuralmonkey/tests/test_model_part.py index 7cfe717d6..303882cf7 100644 --- a/neuralmonkey/tests/test_model_part.py +++ b/neuralmonkey/tests/test_model_part.py @@ -17,24 +17,29 @@ class Test(unittest.TestCase): """Test capabilities of model part.""" + @classmethod + def setUpClass(cls): + tf.reset_default_graph() + cls.dataset = { + "id": tf.constant([["hello", "world"], ["test", "this"]]), + "data_id": tf.constant([["A", "B", "C"], ["D", "E", "F"]])} + def test_reuse(self): - vocabulary = Vocabulary() - vocabulary.add_word("a") - vocabulary.add_word("b") + vocabulary = Vocabulary(["a", "b"]) seq1 = EmbeddedSequence( name="seq1", vocabulary=vocabulary, data_id="id", embedding_size=10) - seq1.register_input() + seq1.register_input(self.dataset) seq2 = EmbeddedSequence( name="seq2", vocabulary=vocabulary, embedding_size=10, data_id="id") - seq2.register_input() + seq2.register_input(self.dataset) seq3 = EmbeddedSequence( name="seq3", @@ -42,7 +47,7 @@ def test_reuse(self): data_id="id", embedding_size=10, reuse=seq1) - seq3.register_input() + seq3.register_input(self.dataset) # blessing self.assertIsNotNone(seq1.embedding_matrix) @@ -62,20 +67,18 @@ def test_reuse(self): def test_save_and_load(self): """Try to save and load encoder.""" - vocabulary = Vocabulary() - vocabulary.add_word("a") - vocabulary.add_word("b") + vocabulary = Vocabulary(["a", "b"]) checkpoint_file = tempfile.NamedTemporaryFile(delete=False) checkpoint_file.close() encoder = SentenceEncoder( - name="enc", vocabulary=Vocabulary(), data_id="data_id", + name="enc", vocabulary=vocabulary, data_id="data_id", embedding_size=10, rnn_size=20, max_input_len=30, save_checkpoint=checkpoint_file.name, load_checkpoint=checkpoint_file.name) - encoder.input_sequence.register_input() + encoder.input_sequence.register_input(self.dataset) # NOTE: This assert needs to be here otherwise the model has # no parameters since the sentence encoder is initialized lazily diff --git a/neuralmonkey/tests/test_vocabulary.py b/neuralmonkey/tests/test_vocabulary.py index 3d37e5322..99e2cdbee 100755 --- a/neuralmonkey/tests/test_vocabulary.py +++ b/neuralmonkey/tests/test_vocabulary.py @@ -1,73 +1,69 @@ #!/usr/bin/env python3.5 import unittest +import tensorflow as tf +from neuralmonkey.vocabulary import Vocabulary, pad_batch -from neuralmonkey.vocabulary import Vocabulary -CORPUS = [ - "the colorless ideas slept furiously", - "pooh slept all night", - "working class hero is something to be", - "I am the working class walrus", - "walrus for president" -] +class TestVocabulary(tf.test.TestCase): -TOKENIZED_CORPUS = [s.split(" ") for s in CORPUS] + @classmethod + def setUpClass(cls): + tf.reset_default_graph() -VOCABULARY = Vocabulary() + cls.corpus = [ + "the colorless ideas slept furiously", + "pooh slept all night", + "working class hero is something to be", + "I am the working class walrus", + "walrus for president" + ] -for s in TOKENIZED_CORPUS: - VOCABULARY.add_tokenized_text(s) + cls.graph = tf.Graph() + with cls.graph.as_default(): + cls.tokenized_corpus = [s.split(" ") for s in cls.corpus] + words = [w for sent in cls.tokenized_corpus for w in sent] + cls.vocabulary = Vocabulary(list(set(words))) -class TestVocabulary(unittest.TestCase): + @classmethod + def tearDownClass(cls): + tf.reset_default_graph() def test_all_words_in(self): - for sentence in TOKENIZED_CORPUS: + for sentence in self.tokenized_corpus: for word in sentence: - self.assertTrue(word in VOCABULARY) + self.assertTrue(word in self.vocabulary) def test_unknown_word(self): - self.assertFalse("jindrisek" in VOCABULARY) + self.assertFalse("jindrisek" in self.vocabulary) def test_padding(self): - pass + padded = pad_batch(self.tokenized_corpus) + self.assertTrue(all(len(p) == 7 for p in padded)) def test_weights(self): pass def test_there_and_back_self(self): - vectors, _ = VOCABULARY.sentences_to_tensor(TOKENIZED_CORPUS, 20, - add_start_symbol=True, - add_end_symbol=True) - senteces_again = VOCABULARY.vectors_to_sentences(vectors[1:]) - - for orig_sentence, reconstructed_sentence in \ - zip(TOKENIZED_CORPUS, senteces_again): - self.assertSequenceEqual(orig_sentence, reconstructed_sentence) - - def test_min_freq(self): - vocabulary = Vocabulary() - vocabulary.correct_counts = True + with self.graph.as_default(): + with self.test_session() as sess: + sess.run(tf.tables_initializer()) - for sentence in TOKENIZED_CORPUS: - vocabulary.add_tokenized_text(sentence) + padded = tf.constant( + pad_batch(self.tokenized_corpus, max_length=20, + add_start_symbol=False, add_end_symbol=True)) - vocabulary.truncate_by_min_freq(2) + vectors = tf.transpose( + self.vocabulary.strings_to_indices(padded)) + f_vectors = sess.run(vectors) - self.assertTrue("walrus" in vocabulary) - self.assertFalse("colorless" in vocabulary) + sentences_again = self.vocabulary.vectors_to_sentences(f_vectors) - def test_count_fail(self): - - vocabulary = Vocabulary() - - for sentence in TOKENIZED_CORPUS: - vocabulary.add_tokenized_text(sentence) - - with self.assertRaises(ValueError): - vocabulary.truncate_by_min_freq(2) + for orig_sentence, reconstructed_sentence in \ + zip(self.tokenized_corpus, sentences_again): + self.assertSequenceEqual(orig_sentence, reconstructed_sentence) if __name__ == "__main__": diff --git a/neuralmonkey/tests/test_wordpiece.py b/neuralmonkey/tests/test_wordpiece.py index 5f9916359..5bdd9328d 100644 --- a/neuralmonkey/tests/test_wordpiece.py +++ b/neuralmonkey/tests/test_wordpiece.py @@ -5,37 +5,31 @@ from neuralmonkey.processors.wordpiece import ( WordpiecePreprocessor, WordpiecePostprocessor) -CORPUS = [ - "the colorless ideas slept furiously", - "pooh slept all night", - "working class hero is something to be", - "I am the working class walrus", - "walrus for president" -] - -TOKENIZED_CORPUS = [[a + "_" for a in s.split()] for s in CORPUS] - -# Create list of characters required to process the CORPUS with wordpieces -CORPUS_CHARS = [x for c in set("".join(CORPUS)) for x in [c, c + "_"]] -ESCAPE_CHARS = "\\_u0987654321;" -C_CARON = "\\269;" -A_ACUTE = "225" - class TestWordpieces(unittest.TestCase): @classmethod def setUpClass(cls): - vocabulary = Vocabulary() - - for c in CORPUS_CHARS + list(ESCAPE_CHARS): - vocabulary.add_word(c) - - for sent in TOKENIZED_CORPUS: - vocabulary.add_tokenized_text(sent) - - vocabulary.add_word(C_CARON) - vocabulary.add_word(A_ACUTE) + corpus = [ + "the colorless ideas slept furiously", + "pooh slept all night", + "working class hero is something to be", + "I am the working class walrus", + "walrus for president" + ] + + tokenized_corpus = [[a + "_" for a in s.split()] for s in corpus] + vocab_from_corpus = {w for sent in tokenized_corpus for w in sent} + + # Create list of characters required to process the CORPUS with + # wordpieces + corpus_chars = {x for c in set("".join(corpus)) for x in [c, c + "_"]} + escape_chars = "\\_u0987654321;" + c_caron = "\\269;" + a_acute = "225" + + words = corpus_chars | set(escape_chars) | vocab_from_corpus + vocabulary = Vocabulary(list(words) + [c_caron, a_acute]) cls.preprocessor = WordpiecePreprocessor(vocabulary) cls.postprocessor = WordpiecePostprocessor diff --git a/neuralmonkey/tf_manager.py b/neuralmonkey/tf_manager.py index e150d23da..056579f4e 100644 --- a/neuralmonkey/tf_manager.py +++ b/neuralmonkey/tf_manager.py @@ -84,7 +84,7 @@ def __init__(self, self.saver = None - self.best_score_index = 0 + self.best_score_index = None # type: Optional[int] self.best_score_epoch = 0 self.best_score_batch = 0 @@ -129,7 +129,6 @@ def init_saving(self, vars_prefix: str) -> None: for i in range(self.saver_max_to_keep)] self._best_vars_file = "{}.best".format(vars_prefix) - self._update_best_vars(var_index=0) def validation_hook(self, score: float, epoch: int, batch: int) -> None: if self._is_better(score, self.best_score): @@ -262,24 +261,23 @@ def restore(self, variable_files: Union[str, List[str]]) -> None: log("Variables loaded from {}".format(file_name)) def restore_best_vars(self) -> None: - # TODO warn when link does not exist + assert self.best_score_index is not None self.restore(self.variables_files[self.best_score_index]) def initialize_sessions(self) -> None: log("Initializing variables") init_op = tf.global_variables_initializer() + init_tables = tf.tables_initializer() for sess in self.sessions: - sess.run(init_op) + sess.run([init_op, init_tables]) log("Initializing tf.train.Saver") self.saver = tf.train.Saver(max_to_keep=None, var_list=[g for g in tf.global_variables() if "reward_" not in g.name]) - def initialize_model_parts(self, runners: Sequence[GraphExecutor], - save: bool = False) -> None: + def initialize_model_parts(self, runners: Sequence[GraphExecutor]) -> None: """Initialize model parts variables from their checkpoints.""" - if any(not hasattr(r, "parameterizeds") for r in runners): raise TypeError( "Args to initialize_model_parts must be trainers or runners") @@ -289,9 +287,6 @@ def initialize_model_parts(self, runners: Sequence[GraphExecutor], for session in self.sessions: coder.load(session) - if save: - self.save(self.variables_files[0]) - def _feed_dicts(dataset: Dataset, coders: Set[Feedable], train: bool = False): """Feed the coders with data from dataset. diff --git a/neuralmonkey/trainers/delayed_update_trainer.py b/neuralmonkey/trainers/delayed_update_trainer.py index 6af121323..dacea68e6 100644 --- a/neuralmonkey/trainers/delayed_update_trainer.py +++ b/neuralmonkey/trainers/delayed_update_trainer.py @@ -1,4 +1,6 @@ -from typing import Dict, List, Tuple +# pylint: disable=unused-import +from typing import Dict, List, Tuple, Optional +# pylint: enable=unused-import import tensorflow as tf from typeguard import check_argument_types @@ -24,15 +26,16 @@ def __init__(self, executor: "DelayedUpdateTrainer", super().__init__(executor, compute_losses, summaries, num_sessions) self.state = 0 - self.res_hist_sums = None - self.res_scal_sums = None - self.res_losses = None + self.res_sums = [] # type: List[tf.Summary] + self.res_losses = None # type: Optional[List[float]] + self.res_batch = None # type: Optional[int] def next_to_execute(self) -> NextExecute: if self.state == 0: # ACCUMULATING fetches = {"accumulators": self.executor.accumulate_ops, "counter": self.executor.cumulator_counter, + "batch_size": self.executor.batch_size, "losses": self.executor.objective_values} elif self.state == 1: # UPDATING @@ -54,6 +57,7 @@ def collect_results(self, results: List[Dict]) -> None: if self.state == 0: # ACCUMULATING self.res_losses = result["losses"] + self.res_batch = result["batch_size"] # Are we updating? counter = result["counter"] @@ -63,17 +67,19 @@ def collect_results(self, results: List[Dict]) -> None: return elif self.state == 1: if self.summaries: - self.res_scal_sums = result["scalar_summaries"] - self.res_hist_sums = result["histogram_summaries"] - + self.res_sums = [result["scalar_summaries"], + result["histogram_summaries"]] self.state = 2 return assert self.res_losses is not None - self.set_result([], losses=self.res_losses, - scalar_summaries=self.res_scal_sums, - histogram_summaries=self.res_hist_sums, - image_summaries=None) + assert self.res_batch is not None + + objective_names = [obj.name for obj in self.executor.objectives] + objective_names += ["L1", "L2"] + losses = dict(zip(objective_names, self.res_losses)) + + self.set_result({}, losses, self.res_batch, self.res_sums) # pylint: disable=too-many-arguments def __init__(self, @@ -129,7 +135,7 @@ def diff_buffer(self) -> tf.Variable: @tensor def cumulator_counter(self) -> tf.Variable: - return tf.Variable(0, trainable=False, name="self.cumulator_counter") + return tf.Variable(0, trainable=False, name="cumulator_counter") # pylint: enable=no-self-use @tensor diff --git a/neuralmonkey/trainers/generic_trainer.py b/neuralmonkey/trainers/generic_trainer.py index aef5a6dd5..7dea52623 100644 --- a/neuralmonkey/trainers/generic_trainer.py +++ b/neuralmonkey/trainers/generic_trainer.py @@ -6,6 +6,7 @@ from neuralmonkey.decorators import tensor from neuralmonkey.logging import warn +from neuralmonkey.model.feedable import Feedable from neuralmonkey.runners.base_runner import GraphExecutor, NextExecute from neuralmonkey.trainers.objective import ( Objective, Gradients, ObjectiveWeight) @@ -14,7 +15,7 @@ # pylint: disable=too-few-public-methods,too-many-locals,too-many-arguments -class GenericTrainer(GraphExecutor): +class GenericTrainer(GraphExecutor, Feedable): class Executable(GraphExecutor.Executable["GenericTrainer"]): @@ -39,15 +40,17 @@ def collect_results(self, results: List[Dict]) -> None: assert len(results) == 1 result = results[0] - scalar_summaries = ( - result["scalar_summaries"] if self.summaries else None) - histogram_summaries = ( - result["histogram_summaries"] if self.summaries else None) + summaries = [] + if self.summaries: + summaries.extend([result["scalar_summaries"], + result["histogram_summaries"]]) + + objective_names = [obj.name for obj in self.executor.objectives] + objective_names += ["L1", "L2"] + + losses = dict(zip(objective_names, result["losses"])) - self.set_result([], losses=result["losses"], - scalar_summaries=scalar_summaries, - histogram_summaries=histogram_summaries, - image_summaries=None) + self.set_result({}, losses, result["batch_size"], summaries) @staticmethod def default_optimizer(): @@ -63,6 +66,7 @@ def __init__(self, var_collection: str = None) -> None: check_argument_types() GraphExecutor.__init__(self, {obj.decoder for obj in objectives}) + Feedable.__init__(self) self.objectives = objectives self.l1_weight = l1_weight @@ -242,4 +246,5 @@ def summaries(self) -> Dict[str, tf.Tensor]: def fetches(self) -> Dict[str, tf.Tensor]: return {"train_op": self.train_op, "losses": self.objective_values, + "batch_size": self.batch_size, "_update_ops": tf.get_collection(tf.GraphKeys.UPDATE_OPS)} diff --git a/neuralmonkey/trainers/test_multitask_trainer.py b/neuralmonkey/trainers/test_multitask_trainer.py index ca7c4c2fa..2fa6194e5 100644 --- a/neuralmonkey/trainers/test_multitask_trainer.py +++ b/neuralmonkey/trainers/test_multitask_trainer.py @@ -42,6 +42,10 @@ class TestMultitaskTrainer(unittest.TestCase): def setUpClass(cls): tf.reset_default_graph() + cls.dataset = { + "id": tf.constant([["hello", "world"], ["test", "this"]]), + "data_id": tf.constant([["A", "B", "C"], ["D", "E", "F"]])} + def setUp(self): self.mpart = TestMP("dummy_model_part") self.mpart_2 = TestMP("dummy_model_part_2") @@ -53,14 +57,18 @@ def setUp(self): self.trainer2 = GenericTrainer([objective_2], clip_norm=1.0) def test_mt_trainer(self): - # TODO multitask trainer is likely broken by changes in tf-data branch + # TODO(tf-data) multitask trainer is likely broken by the changes trainer = MultitaskTrainer( [self.trainer1, self.trainer2, self.trainer1]) + feedables = {self.mpart, self.mpart_2, self.trainer1, self.trainer2} + for feedable in feedables: + feedable.register_input(self.dataset) + log("Blessing trainer fetches: {}".format(trainer.fetches)) - self.assertSetEqual(trainer.feedables, {self.mpart, self.mpart_2}) + self.assertSetEqual(trainer.feedables, feedables) self.assertSetEqual(trainer.parameterizeds, {self.mpart, self.mpart_2}) self.assertSetEqual( diff --git a/neuralmonkey/util/word2vec.py b/neuralmonkey/util/word2vec.py index 16739bcc8..51a7efdb8 100644 --- a/neuralmonkey/util/word2vec.py +++ b/neuralmonkey/util/word2vec.py @@ -8,8 +8,7 @@ import numpy as np from typeguard import check_argument_types -from neuralmonkey.vocabulary import ( - Vocabulary, is_special_token, SPECIAL_TOKENS) +from neuralmonkey.vocabulary import Vocabulary, SPECIAL_TOKENS class Word2Vec: @@ -21,7 +20,7 @@ def __init__(self, path: str, encoding: str = "utf-8") -> None: # Create the vocabulary object, load the words and vectors from the # file - self.vocab = Vocabulary() + words = [] # List[str] embedding_vectors = [] # type: List[np.ndarray] with open(path, encoding=encoding) as f_data: @@ -46,12 +45,14 @@ def __init__(self, path: str, encoding: str = "utf-8") -> None: # Embedding of unknown token should be at index 3 to match the # vocabulary implementation - if is_special_token(word): + if word in SPECIAL_TOKENS: embedding_vectors[SPECIAL_TOKENS.index(word)] = vector else: - self.vocab.add_word(word) + words.append(word) embedding_vectors.append(vector) + self.vocab = Vocabulary(words) + assert embedding_vectors[3] is not None assert emb_size is not None diff --git a/neuralmonkey/vocabulary.py b/neuralmonkey/vocabulary.py index 7055ab7e6..bb9306d8b 100644 --- a/neuralmonkey/vocabulary.py +++ b/neuralmonkey/vocabulary.py @@ -3,22 +3,18 @@ This module implements the Vocabulary class and the helper functions that can be used to obtain a Vocabulary instance. """ -# pylint: disable=too-many-lines import collections import json import os -import random -# pylint: disable=unused-import -from typing import List, Optional, Tuple, Dict, Union -# pylint: enable=unused-import +from typing import List, Set, Union import numpy as np +import tensorflow as tf from typeguard import check_argument_types -from neuralmonkey.logging import log, warn -from neuralmonkey.dataset import Dataset +from neuralmonkey.logging import log, warn, notice PAD_TOKEN = "" START_TOKEN = "" @@ -33,24 +29,6 @@ UNK_TOKEN_INDEX = 3 -def is_special_token(word: str) -> bool: - """Check whether word is a special token (such as or ). - - Arguments: - word: The word to check - - Returns: - True if the word is special, False otherwise. - """ - return word in SPECIAL_TOKENS - - -# pylint: disable=unused-argument -def from_file(*args, **kwargs) -> "Vocabulary": - raise NotImplementedError("Use loading by from_wordlist") -# pylint: enable=unused-argument - - def from_wordlist(path: str, encoding: str = "utf-8", contains_header: bool = True, @@ -65,12 +43,13 @@ def from_wordlist(path: str, path: The path to the wordlist file encoding: The encoding of the wordlist file (defaults to UTF-8) contains_header: if the file have a header on first line - contains_frequencies: if the file contains frequencies in second column + contains_frequencies: if the file contains a second column Returns: The new Vocabulary instance. """ - vocabulary = Vocabulary() + check_argument_types() + vocabulary = [] # type: List[str] with open(path, encoding=encoding) as wordlist: line_number = 1 @@ -94,18 +73,30 @@ def from_wordlist(path: str, raise ValueError( "Vocabulary file {}:{}: line does not have two columns" .format(path, line_number)) - vocabulary.add_word(info[0], int(info[1])) + word = info[0] else: if "\t" in line: warn("Vocabulary file {}:{}: line contains a tabulator" .format(path, line_number)) - vocabulary.add_word(line) + word = line + + if line_number <= len(SPECIAL_TOKENS) + int(contains_header): + should_be = SPECIAL_TOKENS[ + line_number - 1 - int(contains_header)] + if word != should_be: + notice("Expected special token {} but encountered a " + "different word: {}".format(should_be, word)) + vocabulary.append(word) + line_number += 1 + continue + + vocabulary.append(word) line_number += 1 log("Vocabulary from wordlist loaded, containing {} words" .format(len(vocabulary))) - vocabulary.log_sample() - return vocabulary + log_sample(vocabulary) + return Vocabulary(vocabulary) def from_t2t_vocabulary(path: str, @@ -119,7 +110,8 @@ def from_t2t_vocabulary(path: str, Returns: The new Vocabulary instantce. """ - vocabulary = Vocabulary() + check_argument_types() + vocabulary = [] # type: List[str] with open(path, encoding=encoding) as wordlist: for line in wordlist: @@ -133,12 +125,13 @@ def from_t2t_vocabulary(path: str, if line in ["", ""]: continue - vocabulary.add_word(line) + vocabulary.append(line) log("Vocabulary form wordlist loaded, containing {} words" .format(len(vocabulary))) - vocabulary.log_sample() - return vocabulary + log_sample(vocabulary) + + return Vocabulary(vocabulary) def from_nematus_json(path: str, max_size: int = None, @@ -155,14 +148,15 @@ def from_nematus_json(path: str, max_size: int = None, pad_to_max_size: If specified, the vocabulary is padded with dummy symbols up to the specified maximum size. """ + check_argument_types() with open(path, "r", encoding="utf-8") as f_json: contents = json.load(f_json) - vocabulary = Vocabulary() + vocabulary = [] # type: List[str] for word in sorted(contents.keys(), key=lambda x: contents[x]): if contents[word] < 2: continue - vocabulary.add_word(word) + vocabulary.append(word) if max_size is not None and len(vocabulary) == max_size: break @@ -173,142 +167,32 @@ def from_nematus_json(path: str, max_size: int = None, current_length = len(vocabulary) for i in range(max_size - current_length + 2): # the "2" is ugly HACK word = "".format(i) - vocabulary.add_word(word) - - return vocabulary - - -# pylint: disable=too-many-arguments -# helper function, this number of parameters is needed -def from_dataset(datasets: List[Dataset], series_ids: List[str], max_size: int, - save_file: str = None, overwrite: bool = False, - min_freq: Optional[int] = None, - unk_sample_prob: float = 0.5) -> "Vocabulary": - """Load a vocabulary from a dataset with an option to save it. - - Arguments: - datasets: A list of datasets from which to create the vocabulary - series_ids: A list of ids of series of the datasets that should be used - producing the vocabulary - max_size: The maximum size of the vocabulary - save_file: A file to save the vocabulary to. If None (default), - the vocabulary will not be saved. - overwrite: Overwrite existing file. - min_freq: Do not include words with frequency smaller than this. - unk_sample_prob: The probability with which to sample unks out of - words with frequency 1. Defaults to 0.5. - - Returns: - The new Vocabulary instance. - """ - check_argument_types() - - vocabulary = Vocabulary(unk_sample_prob=unk_sample_prob) - vocabulary.correct_counts = True + vocabulary.append(word) - for dataset in datasets: - if dataset.lazy: - warn("Inferring vocabulary from lazy dataset!") - - for series_id in series_ids: - if series_id not in dataset: - warn("Data series '{}' not present in the dataset" - .format(series_id)) - - series = dataset.maybe_get_series(series_id) - if series is not None: - vocabulary.add_tokenized_text( - [token for sent in series for token in sent]) - - vocabulary.truncate(max_size) - - if min_freq is not None: - if min_freq > 1: - vocabulary.truncate_by_min_freq(min_freq) - - log("Vocabulary for series {} initialized, containing {} words" - .format(series_ids, len(vocabulary))) - - vocabulary.log_sample() - - if save_file is not None: - directory = os.path.dirname(save_file) - if directory and not os.path.exists(directory): - os.makedirs(directory) - vocabulary.save_wordlist(save_file, overwrite, True) - - return vocabulary - - -def initialize_vocabulary(directory: str, name: str, - datasets: List[Dataset] = None, - series_ids: List[str] = None, - max_size: int = None) -> "Vocabulary": - """Initialize a vocabulary. - - This function is supposed to initialize vocabulary when called from - the configuration file. It first checks whether the vocabulary is already - loaded on the provided path and if not, it tries to generate it from - the provided dataset. - - Args: - directory: Directory where the vocabulary should be stored. - - name: Name of the vocabulary which is also the name of the file - it is stored it. - - datasets: A a list of datasets from which the vocabulary can be - created. - - series_ids: A list of ids of series of the datasets that should be used - for producing the vocabulary. - - max_size: The maximum size of the vocabulary - - Returns: - The new vocabulary - """ - warn("Use of deprecated initialize_vocabulary method. " - "Did you think this through?") - - file_name = os.path.join(directory, name + ".pickle") - if os.path.exists(file_name): - return from_wordlist(file_name) - - if datasets is None or series_ids is None or max_size is None: - raise Exception("Vocabulary does not exist in '{}', " - "neither dataset and series_id were provided.") - - return from_dataset(datasets, series_ids, max_size, - save_file=file_name, overwrite=False) + return Vocabulary(vocabulary) class Vocabulary(collections.Sized): - def __init__(self, tokenized_text: List[str] = None, - unk_sample_prob: float = 0.0) -> None: + def __init__(self, words: List[str], num_oov_buckets: int = 0) -> None: """Create a new instance of a vocabulary. Arguments: - tokenized_text: The initial list of words to add. + words: The mapping of indices to words. """ - self.word_to_index = {} # type: Dict[str, int] - self.index_to_word = [] # type: List[str] - self.word_count = {} # type: Dict[str, int] - self.alphabet = {tok for tok in SPECIAL_TOKENS} - # flag if the word count are in use - self.correct_counts = False + self._vocabulary = SPECIAL_TOKENS + words + self._alphabet = {c for word in words for c in word} - self.unk_sample_prob = unk_sample_prob + self._index_to_string = ( + tf.contrib.lookup.index_to_string_table_from_tensor( + mapping=self._vocabulary, + default_value=UNK_TOKEN)) - self.add_word(PAD_TOKEN) - self.add_word(START_TOKEN) - self.add_word(END_TOKEN) - self.add_word(UNK_TOKEN) - - if tokenized_text: - self.add_tokenized_text(tokenized_text) + self._string_to_index = tf.contrib.lookup.index_table_from_tensor( + mapping=self._vocabulary, + num_oov_buckets=num_oov_buckets, + default_value=UNK_TOKEN_INDEX) def __len__(self) -> int: """Get the size of the vocabulary. @@ -316,7 +200,7 @@ def __len__(self) -> int: Returns: The number of distinct words in the vocabulary. """ - return len(self.index_to_word) + return len(self._vocabulary) def __contains__(self, word: str) -> bool: """Check if a word is in the vocabulary. @@ -327,157 +211,24 @@ def __contains__(self, word: str) -> bool: Returns: True if the word was added to the vocabulary, False otherwise. """ - return word in self.word_to_index + return word in self._vocabulary - def add_word(self, word: str, occurences: int = 1) -> None: - """Add a word to the vocablulary. + @property + def alphabet(self) -> Set[str]: + return self._alphabet - Arguments: - word: The word to add. If it's already there, increment the count. - occurences: increment the count of word by the number of occurences - """ - if word not in self: - self.word_to_index[word] = len(self.index_to_word) - self.index_to_word.append(word) - self.word_count[word] = 0 - if not is_special_token(word): - self.add_characters(word) - self.word_count[word] += occurences + @property + def index_to_word(self) -> List[str]: + return self._vocabulary - def add_characters(self, word: str) -> None: - self.alphabet |= {c for c in word} - - def add_tokenized_text(self, tokenized_text: List[str]) -> None: - """Add words from a list to the vocabulary. - - Arguments: - tokenized_text: The list of words to add. - """ - for word in tokenized_text: - self.add_word(word) - - def get_word_index(self, word: str) -> int: - """Return index of the specified word. - - Arguments: - word: The word to look up. - - Returns: - Index of the word or index of the unknown token if the word is not - present in the vocabulary. - """ - if word not in self: - return self.get_word_index(UNK_TOKEN) - return self.word_to_index[word] - - def get_unk_sampled_word_index(self, word): - """Return index of the specified word with sampling of unknown words. - - This method returns the index of the specified word in the vocabulary. - If the frequency of the word in the vocabulary is 1 (the word was only - seen once in the whole training dataset), with probability of - self.unk_sample_prob, generate the index of the unknown token instead. - - Arguments: - word: The word to look up. - - Returns: - Index of the word, index of the unknown token if sampled, or index - of the unknown token if the word is not present in the vocabulary. - """ - idx = self.word_to_index.get(word, self.get_word_index(UNK_TOKEN)) - freq = self.word_count.get(word, 0) - - if freq <= 1 and random.random() < self.unk_sample_prob: - if not self.correct_counts: - raise ValueError("The vocabulary does not have correct " - "word_counts to use with unknown sampling") - return self.get_word_index(UNK_TOKEN) - - return idx - - def truncate(self, size: int) -> None: - """Truncate the vocabulary to the requested size. - - The infrequent tokens are discarded. - - Arguments: - size: The final size of the vocabulary - """ - - if not self.correct_counts: - raise ValueError("The vocabulary does not have correct " - "word_counts to use for vocabulary truncate") - - # sort by frequency - # sorting words first makes vocabulary generation deterministic - words_by_freq = sorted(list(sorted(self.word_count.keys())), - key=lambda w: self.word_count[w]) - - # keep the least frequent words which are not special symbols - to_delete = len(self) - size - if to_delete < 0: - to_delete = 0 - warn("Actual vocabulary size ({}) is smaller than max_size ({})" - .format(len(self), size)) - words_to_delete = [] # type: List[str] - for word in words_by_freq: - if len(words_to_delete) == to_delete: - break - if not is_special_token(word): - words_to_delete.append(word) - - # sort by index ... bigger indices needs to be removed first - # to keep the lists propertly shaped - delete_words_by_index = sorted( - [(w, self.word_to_index[w]) for w in words_to_delete], - key=lambda p: -p[1]) - - for word, index in delete_words_by_index: - del self.word_count[word] - del self.index_to_word[index] - - self.word_to_index = {} - for index, word in enumerate(self.index_to_word): - self.word_to_index[word] = index - - def truncate_by_min_freq(self, min_freq: int) -> None: - """Truncate the vocabulary only keeping words with a minimum frequency. - - Arguments: - min_freq: The minimum frequency of included words. - """ - if min_freq > 1: - # count how many words there are with frequency < min_freq - # ignoring special tokens - infreq_word_count = sum([1 for w in self.word_count - if self.word_count[w] < min_freq - and not is_special_token(w)]) - log("Removing {} infrequent (<{}) words from vocabulary".format( - infreq_word_count, min_freq)) - new_size = len(self) - infreq_word_count - self.truncate(new_size) - - def sentences_to_tensor( - self, - sentences: List[List[str]], - max_len: int = None, - pad_to_max_len: bool = True, - train_mode: bool = False, - add_start_symbol: bool = False, - add_end_symbol: bool = False) -> Tuple[np.ndarray, np.ndarray]: + def strings_to_indices(self, + # add_start_symbol: bool = False, + # add_end_symbol: bool = False + sentences: tf.Tensor) -> tf.Tensor: """Generate the tensor representation for the provided sentences. Arguments: sentences: List of sentences as lists of tokens. - max_len: If specified, all sentences will be truncated to this - length. - pad_to_max_len: If True, the tensor will be padded to `max_len`, - even if all of the sentences are shorter. If False, the shape - of the tensor will be determined by the maximum length of the - sentences in the batch. - train_mode: Flag whether we are training or not - (enables/disables unk sampling). add_start_symbol: If True, the `` token will be added to the beginning of each sentence vector. Enabling this option extends the maximum length by one. @@ -488,50 +239,20 @@ def sentences_to_tensor( the maximum length. Returns: - A tuple of a sentence tensor and a padding weight vector. - - The shape of the tensor representing the sentences is either - `(batch_max_len, batch_size)` or `(batch_max_len+1, batch_size)`, - depending on the value of the `add_start_symbol` argument. - `batch_max_len` is the length of the longest sentence in the - batch (including the optional `` token), limited by `max_len` - (if specified). - - The shape of the padding vector is the same as of the sentence - vector. + Tensor of indices of the words. """ - if pad_to_max_len and max_len is not None: - batch_max_len = max_len - else: - batch_max_len = max(len(s) for s in sentences) - if add_end_symbol: - batch_max_len += 1 - if max_len is not None: - batch_max_len = min(max_len, batch_max_len) - - word_indices = np.full( - [batch_max_len, len(sentences)], self.get_word_index(PAD_TOKEN), - dtype=np.int32) - weights = np.zeros([batch_max_len, len(sentences)]) - - for i in range(batch_max_len): - for j, sent in enumerate(sentences): - if i < len(sent): - w_idx = (self.get_unk_sampled_word_index(sent[i]) - if train_mode else self.get_word_index(sent[i])) - word_indices[i, j] = w_idx - weights[i, j] = 1 - - elif i == len(sent) and add_end_symbol: - word_indices[i, j] = self.get_word_index(END_TOKEN) - weights[i, j] = 1 + return self._string_to_index.lookup(sentences) - if add_start_symbol: - word_indices = np.insert(word_indices, 0, - self.get_word_index(START_TOKEN), axis=0) - weights = np.insert(weights, 0, 1, axis=0) + def indices_to_strings(self, vectors: tf.Tensor) -> tf.Tensor: + """Convert tensors of indexes of vocabulary items to lists of words. + + Arguments: + vectors: An int Tensor with indices to the vocabulary. - return word_indices, weights + Returns: + A string Tensor with the corresponding words. + """ + return self._index_to_string.lookup(vectors) def vectors_to_sentences( self, @@ -539,7 +260,7 @@ def vectors_to_sentences( """Convert vectors of indexes of vocabulary items to lists of words. Arguments: - vectors: List of vectors of vocabulary indices. + vectors: TIME-MAJOR List of vectors of vocabulary indices. Returns: List of lists of words. @@ -549,6 +270,7 @@ def vectors_to_sentences( raise ValueError( "Cannot infer batch size because decoder returned an " "empty output.") + batch_size = vectors[0].shape[0] elif isinstance(vectors, np.ndarray): batch_size = vectors.shape[1] @@ -566,7 +288,6 @@ def vectors_to_sentences( return [s[:-1] if s and s[-1] == END_TOKEN else s for s in sentences] def save_wordlist(self, path: str, overwrite: bool = False, - save_frequencies: bool = False, encoding: str = "utf-8") -> None: """Save the vocabulary as a wordlist. @@ -577,8 +298,6 @@ def save_wordlist(self, path: str, overwrite: bool = False, path: The path to save the file to. overwrite: Flag whether to overwrite existing file. Defaults to False. - save_frequencies: flag if frequencies should be stored. This - parameter adds header into the output file. Raises: FileExistsError if the file exists and overwrite flag is @@ -589,35 +308,51 @@ def save_wordlist(self, path: str, overwrite: bool = False, "overwrite is disabled. {}".format(path)) with open(path, "w", encoding=encoding) as output_file: - if save_frequencies and self.correct_counts: - # this header is important for the TensorBoard to properly - # handle the frequencies. - # - # IMPORTANT NOTICE: when saving only wordlist without - # frequencies it MUST NOT contain the header. It is an - # exception from Tensorboard. More at - # https://www.tensorflow.org/get_started/embedding_viz - output_file.write("Word\tWord counts\n") - elif save_frequencies and not self.correct_counts: - log("Storing vocabulary without frequencies.") - - for i in range(len(self.index_to_word)): - output_file.write(self.index_to_word[i]) - if save_frequencies and self.correct_counts: - output_file.write( - "\t" + str(self.word_count[self.index_to_word[i]])) - - output_file.write("\n") - - def log_sample(self, size: int = 5) -> None: - """Log a sample of the vocabulary. + log("Storing vocabulary without frequencies.") - Arguments: - size: How many sample words to log. - """ - if size > len(self): - log("Vocabulary: {}".format(self.index_to_word)) + for word in self._vocabulary: + output_file.write("{}\n".format(word)) + + +def log_sample(vocabulary: List[str], size: int = 5) -> None: + """Log a sample of the vocabulary. + + Arguments: + size: How many sample words to log. + """ + if size > len(vocabulary): + log("Vocabulary: {}".format(vocabulary)) + else: + sample_ids = np.random.permutation(np.arange(len(vocabulary)))[:size] + log("Sample of the vocabulary: {}".format( + [vocabulary[i] for i in sample_ids])) + + +def pad_batch(sentences: List[List[str]], + max_length: int = None, + add_start_symbol: bool = False, + add_end_symbol: bool = False) -> List[List[str]]: + + max_len = max(len(s) for s in sentences) + if add_end_symbol: + max_len += 1 + + if max_length is not None: + max_len = min(max_length, max_len) + + padded_sentences = [] + for sent in sentences: + if add_end_symbol: + padded = (sent + [END_TOKEN] + [PAD_TOKEN] * max_len)[:max_len] else: - sample_ids = np.random.permutation(np.arange(len(self)))[:size] - log("Sample of the vocabulary: {}".format( - [self.index_to_word[i] for i in sample_ids])) + padded = (sent + [PAD_TOKEN] * max_len)[:max_len] + + if add_start_symbol: + padded.insert(0, START_TOKEN) + padded_sentences.append(padded) + + return padded_sentences + + +def sentence_mask(sentences: tf.Tensor) -> tf.Tensor: + return tf.to_float(tf.not_equal(sentences, PAD_TOKEN_INDEX)) diff --git a/requirements-gpu.txt b/requirements-gpu.txt index 1c69f4d75..ebc6da3ab 100644 --- a/requirements-gpu.txt +++ b/requirements-gpu.txt @@ -10,4 +10,4 @@ pygments rouge==0.2.1 typeguard sacrebleu -tensorflow-gpu>=1.10.0,<1.11 +tensorflow-gpu>=1.12.0,<1.13 diff --git a/requirements.txt b/requirements.txt index 7df3e5453..27341cc05 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,4 +10,4 @@ pygments rouge==0.2.1 typeguard sacrebleu -tensorflow>=1.10.0,<1.11 +tensorflow>=1.12.0,<1.13 diff --git a/scripts/imagenet_features.py b/scripts/imagenet_features.py index 6b9eb756b..205a20a86 100755 --- a/scripts/imagenet_features.py +++ b/scripts/imagenet_features.py @@ -15,7 +15,7 @@ import numpy as np import tensorflow as tf -from neuralmonkey.dataset import Dataset +from neuralmonkey.dataset import Dataset, BatchingScheme from neuralmonkey.encoders.imagenet_encoder import ImageNet from neuralmonkey.logging import log from neuralmonkey.readers.image_reader import single_image_for_imagenet @@ -91,7 +91,8 @@ def main(): image_paths = [] def process_images(): - dataset = Dataset("dataset", {"images": np.array(images)}, {}) + dataset = Dataset("dataset", {"images": np.array(images)}, + BatchingScheme(batch_size=1), {}) feed_dict = imagenet.feed_dict(dataset) fetch = imagenet.encoded if args.vector else imagenet.spatial_states diff --git a/tests/audio-classifier.ini b/tests/audio-classifier.ini index a0edf754f..255078174 100644 --- a/tests/audio-classifier.ini +++ b/tests/audio-classifier.ini @@ -3,7 +3,7 @@ name="audio classification" tf_manager= output="tests/outputs/audio-classifier" overwrite_output_dir=True -batch_size=5 +batch_size=2 epochs=1 train_dataset= diff --git a/tests/bahdanau.ini b/tests/bahdanau.ini index c1b2cb51d..bafa39157 100644 --- a/tests/bahdanau.ini +++ b/tests/bahdanau.ini @@ -23,21 +23,29 @@ class=tf_manager.TensorFlowManager num_threads=4 num_sessions=1 +[batching] +class=dataset.BatchingScheme +bucket_boundaries=[5, 10, 15, 20] +bucket_batch_sizes=[20, 15, 10, 5, 2] + [train_data] class=dataset.load series=["source", "target"] data=["tests/data/train.tc.en", "tests/data/train.tc.de"] +batching= [val_data] class=dataset.load series=["source", "target"] data=["tests/data/val.tc.en", "tests/data/val.tc.de"] +batching= [val_data_no_target] class=dataset.load series=["source"] data=["tests/data/val.tc.en"] outputs=[("encoded", "tests/outputs/bahdanau/encoded"), ("debugtensors", "tests/outputs/bahdanau/debugtensors")] +batching= [encoder_vocabulary] class=vocabulary.from_wordlist diff --git a/tests/beamsearch.ini b/tests/beamsearch.ini index 04d5182fa..aca06df04 100644 --- a/tests/beamsearch.ini +++ b/tests/beamsearch.ini @@ -13,7 +13,6 @@ postprocess=None evaluation=[("target_beam.rank001", "target", evaluators.BLEU)] logging_period=20 validation_period=60 -runners_batch_size=5 random_seed=1234 [tf_manager] diff --git a/tests/beamsearch_ensembles.ini b/tests/beamsearch_ensembles.ini index 7633aedcc..e25fdb779 100644 --- a/tests/beamsearch_ensembles.ini +++ b/tests/beamsearch_ensembles.ini @@ -13,7 +13,6 @@ postprocess=None evaluation=[("target_beam.rank001", "target", evaluators.BLEU)] logging_period=20 validation_period=60 -runners_batch_size=5 random_seed=1234 [tf_manager] diff --git a/tests/bpe.ini b/tests/bpe.ini index 1a1af63c4..c8a2062b1 100644 --- a/tests/bpe.ini +++ b/tests/bpe.ini @@ -17,7 +17,6 @@ val_preview_input_series=["source", "target", "target_bpe"] val_preview_output_series=["target_greedy"] logging_period=20 validation_period=60 -runners_batch_size=5 test_datasets=[,] [tf_manager] diff --git a/tests/captioning.ini b/tests/captioning.ini index 33ba6af96..2b9ddffb8 100644 --- a/tests/captioning.ini +++ b/tests/captioning.ini @@ -13,7 +13,6 @@ postprocess=None evaluation=[("target", evaluators.BLEU)] logging_period=1 validation_period=2 -runners_batch_size=1 test_datasets=[,] random_seed=1234 diff --git a/tests/classifier.ini b/tests/classifier.ini index d927bd960..82d30bf5f 100644 --- a/tests/classifier.ini +++ b/tests/classifier.ini @@ -13,7 +13,6 @@ postprocess=None evaluation=[("classification", evaluators.Accuracy)] logging_period=50 validation_period=100 -runners_batch_size=1 random_seed=1234 [tf_manager] diff --git a/tests/ctc.ini b/tests/ctc.ini index 4589f8b4c..34d6f689d 100644 --- a/tests/ctc.ini +++ b/tests/ctc.ini @@ -4,8 +4,8 @@ tf_manager= output="tests/outputs/ctc" overwrite_output_dir=True -batch_size=5 -epochs=1 +batch_size=4 +epochs=5 train_dataset= val_dataset= @@ -16,8 +16,8 @@ runners=[] evaluation=[("target", evaluators.WER)] -logging_period=2 -validation_period=2 +logging_period=1 +validation_period="5s" random_seed=123485 diff --git a/tests/data/flickr30k/1000092795.jpg.npz b/tests/data/flickr30k/1000092795.jpg.npz new file mode 100644 index 000000000..f24174411 Binary files /dev/null and b/tests/data/flickr30k/1000092795.jpg.npz differ diff --git a/tests/data/flickr30k/10002456.jpg.npz b/tests/data/flickr30k/10002456.jpg.npz new file mode 100644 index 000000000..24c0c50d7 Binary files /dev/null and b/tests/data/flickr30k/10002456.jpg.npz differ diff --git a/tests/data/flickr30k/1000268201.jpg.npz b/tests/data/flickr30k/1000268201.jpg.npz new file mode 100644 index 000000000..c9dfc42af Binary files /dev/null and b/tests/data/flickr30k/1000268201.jpg.npz differ diff --git a/tests/data/flickr30k/1000344755.jpg.npz b/tests/data/flickr30k/1000344755.jpg.npz new file mode 100644 index 000000000..527d214f8 Binary files /dev/null and b/tests/data/flickr30k/1000344755.jpg.npz differ diff --git a/tests/data/flickr30k/1000366164.jpg.npz b/tests/data/flickr30k/1000366164.jpg.npz new file mode 100644 index 000000000..289595f2c Binary files /dev/null and b/tests/data/flickr30k/1000366164.jpg.npz differ diff --git a/tests/data/flickr30k/1000523639.jpg.npz b/tests/data/flickr30k/1000523639.jpg.npz new file mode 100644 index 000000000..8e85776b9 Binary files /dev/null and b/tests/data/flickr30k/1000523639.jpg.npz differ diff --git a/tests/data/flickr30k/1000919630.jpg.npz b/tests/data/flickr30k/1000919630.jpg.npz new file mode 100644 index 000000000..53183dc90 Binary files /dev/null and b/tests/data/flickr30k/1000919630.jpg.npz differ diff --git a/tests/data/flickr30k/10010052.jpg.npz b/tests/data/flickr30k/10010052.jpg.npz new file mode 100644 index 000000000..cb31cca16 Binary files /dev/null and b/tests/data/flickr30k/10010052.jpg.npz differ diff --git a/tests/data/flickr30k/1001465944.jpg.npz b/tests/data/flickr30k/1001465944.jpg.npz new file mode 100644 index 000000000..28b0132b2 Binary files /dev/null and b/tests/data/flickr30k/1001465944.jpg.npz differ diff --git a/tests/data/flickr30k/1001545525.jpg.npz b/tests/data/flickr30k/1001545525.jpg.npz new file mode 100644 index 000000000..ca326a061 Binary files /dev/null and b/tests/data/flickr30k/1001545525.jpg.npz differ diff --git a/tests/data/flickr30k/1018148011.jpg.npz b/tests/data/flickr30k/1018148011.jpg.npz new file mode 100644 index 000000000..d0ac9d161 Binary files /dev/null and b/tests/data/flickr30k/1018148011.jpg.npz differ diff --git a/tests/data/flickr30k/1029450589.jpg.npz b/tests/data/flickr30k/1029450589.jpg.npz new file mode 100644 index 000000000..bfdbcd277 Binary files /dev/null and b/tests/data/flickr30k/1029450589.jpg.npz differ diff --git a/tests/data/flickr30k/1029737941.jpg.npz b/tests/data/flickr30k/1029737941.jpg.npz new file mode 100644 index 000000000..48e898474 Binary files /dev/null and b/tests/data/flickr30k/1029737941.jpg.npz differ diff --git a/tests/data/flickr30k/train_images.npz.txt b/tests/data/flickr30k/train_images.npz.txt new file mode 100644 index 000000000..cab358269 --- /dev/null +++ b/tests/data/flickr30k/train_images.npz.txt @@ -0,0 +1,10 @@ +1000092795.jpg.npz +10002456.jpg.npz +1000268201.jpg.npz +1000344755.jpg.npz +1000366164.jpg.npz +1000523639.jpg.npz +1000919630.jpg.npz +10010052.jpg.npz +1001465944.jpg.npz +1001545525.jpg.npz diff --git a/tests/data/flickr30k/val_images.npz.txt b/tests/data/flickr30k/val_images.npz.txt new file mode 100644 index 000000000..e6f72bdab --- /dev/null +++ b/tests/data/flickr30k/val_images.npz.txt @@ -0,0 +1,3 @@ +1018148011.jpg.npz +1029450589.jpg.npz +1029737941.jpg.npz diff --git a/tests/flat-multiattention.ini b/tests/flat-multiattention.ini index 15ae53ba9..ca637582b 100644 --- a/tests/flat-multiattention.ini +++ b/tests/flat-multiattention.ini @@ -12,8 +12,7 @@ runners=[, , ] random_seed=1234 @@ -22,32 +21,25 @@ class=tf_manager.TensorFlowManager num_threads=4 num_sessions=1 -[image_reader] -class=readers.image_reader.image_reader +[numpy_reader] +class=readers.numpy_reader.from_file_list prefix="tests/data/flickr30k" -pad_h=32 -pad_w=32 -mode="RGB" +shape=[8, 8, 2048] [train_data] class=dataset.load series=["source", "target", "images"] -data=["tests/data/flickr30k/train.en", "tests/data/flickr30k/train.de", ("tests/data/flickr30k/train_images.txt", )] +data=["tests/data/flickr30k/train.en", "tests/data/flickr30k/train.de", ("tests/data/flickr30k/train_images.npz.txt", )] [val_data] class=dataset.load series=["source", "target", "images"] -data=["tests/data/flickr30k/val.en", "tests/data/flickr30k/val.de", ("tests/data/flickr30k/val_images.txt", )] +data=["tests/data/flickr30k/val.en", "tests/data/flickr30k/val.de", ("tests/data/flickr30k/val_images.npz.txt", )] [imagenet] -class=encoders.cnn_encoder.CNNEncoder -name="cnn" +class=encoders.numpy_stateful_filler.SpatialFiller +input_shape=[8, 8, 2048] data_id="images" -batch_normalize=True -image_height=32 -image_width=32 -pixel_dim=3 -convolutions=[("C", 3, 1, "valid", 4), ("M", 2, 2, "same"), ("M", 2, 2, "same")] [encoder_vocabulary] class=vocabulary.from_wordlist diff --git a/tests/hier-multiattention.ini b/tests/hier-multiattention.ini index db8887d5f..f4a4b5c68 100644 --- a/tests/hier-multiattention.ini +++ b/tests/hier-multiattention.ini @@ -3,7 +3,6 @@ name="Configurations of Hierarchical Attention Captioning" tf_manager= output="tests/outputs/hier-multiattention" overwrite_output_dir=True -batching_scheme= epochs=1 train_dataset= val_dataset= @@ -12,17 +11,15 @@ runners=[, , ] random_seed=1234 [batch_scheme] class=dataset.BatchingScheme -batch_size=100 -token_level_batching=True -batch_bucket_span=5 -bucketing_ignore_series=["images"] +bucket_boundaries=[5, 10, 15, 20, 25] +bucket_batch_sizes=[5, 4, 3, 2, 1, 1] +ignore_series=["images"] [tf_manager] class=tf_manager.TensorFlowManager @@ -35,16 +32,19 @@ prefix="tests/data/flickr30k" pad_h=32 pad_w=32 mode="RGB" +channels=3 [train_data] class=dataset.load series=["source", "target", "images"] data=["tests/data/flickr30k/train.en", "tests/data/flickr30k/train.de", ("tests/data/flickr30k/train_images.txt", )] +batching= [val_data] class=dataset.load series=["source", "target", "images"] data=["tests/data/flickr30k/val.en", "tests/data/flickr30k/val.de", ("tests/data/flickr30k/val_images.txt", )] +batching= [imagenet] class=encoders.cnn_encoder.CNNEncoder diff --git a/tests/labeler.ini b/tests/labeler.ini index b7fdf0f75..9d2835a0b 100644 --- a/tests/labeler.ini +++ b/tests/labeler.ini @@ -12,7 +12,6 @@ trainer= evaluation=[("tags", evaluators.Accuracy)] batch_size=10 -runners_batch_size=10 epochs=2 validation_period="10s" diff --git a/tests/language-model.ini b/tests/language-model.ini index f246b3d71..8038ffae0 100644 --- a/tests/language-model.ini +++ b/tests/language-model.ini @@ -13,7 +13,6 @@ postprocess=None evaluation=[("perplexity", "target", )] logging_period=20 validation_period=60 -runners_batch_size=1 random_seed=1234 [tf_manager] @@ -30,8 +29,7 @@ name="perplexity" class=dataset.load series=["target"] data=["tests/data/train.tc.en"] -buffer_size=48 -shuffled=True +buffer_size=100 [val_data] class=dataset.load diff --git a/tests/nematus.ini b/tests/nematus.ini index 9d77bf1f1..cbb10809f 100644 --- a/tests/nematus.ini +++ b/tests/nematus.ini @@ -13,7 +13,6 @@ postprocess=None evaluation=[("target", evaluators.BLEU), ("target", evaluators.TER), ("target", evaluators.ChrF3)] logging_period=20 validation_period=60 -runners_batch_size=1 random_seed=4321 [tf_manager] diff --git a/tests/post-edit.ini b/tests/post-edit.ini index ae0dfd80f..5faa94adc 100644 --- a/tests/post-edit.ini +++ b/tests/post-edit.ini @@ -108,7 +108,6 @@ val_dataset= test_datasets=[] evaluation=[("target", ), ("target", evaluators.TER)] batch_size=2 -runners_batch_size=5 epochs=2 validation_period=2 logging_period=1 diff --git a/tests/regressor.ini b/tests/regressor.ini index 805982f04..0ec9c767a 100644 --- a/tests/regressor.ini +++ b/tests/regressor.ini @@ -13,7 +13,6 @@ postprocess=None evaluation=[("regression", evaluators.MSE)] logging_period=50 validation_period=100 -runners_batch_size=1 random_seed=1234 [tf_manager] diff --git a/tests/rl.ini b/tests/rl.ini index 6b685847e..c90ac2c18 100644 --- a/tests/rl.ini +++ b/tests/rl.ini @@ -13,7 +13,6 @@ postprocess=None evaluation=[("target", evaluators.BLEU)] logging_period=20 validation_period=60 -runners_batch_size=1 random_seed=1234 [tf_manager] @@ -25,8 +24,6 @@ num_sessions=1 class=dataset.load series=["source", "target"] data=["tests/data/train.tc.en", "tests/data/train.tc.de"] -buffer_size=48 -shuffled=False [val_data] class=dataset.load diff --git a/tests/self-critical.ini b/tests/self-critical.ini index 9bcd359e0..088e879dc 100644 --- a/tests/self-critical.ini +++ b/tests/self-critical.ini @@ -13,7 +13,6 @@ postprocess=None evaluation=[("target", evaluators.BLEU), ("target", evaluators.TER)] logging_period=20 validation_period=60 -runners_batch_size=1 random_seed=1234 [tf_manager] diff --git a/tests/small.ini b/tests/small.ini index 87d57e1f3..147eaf406 100644 --- a/tests/small.ini +++ b/tests/small.ini @@ -22,23 +22,29 @@ postprocess=None evaluation=[("target", $bleu), ("target", evaluators.TER), ("target", evaluators.ChrF3)] logging_period=20 validation_period=60 -runners_batch_size=1 random_seed=4321 [bleu] class=evaluators.BLEUEvaluator name="bleu" +[batching] +class=dataset.BatchingScheme +bucket_boundaries=[5, 10, 15, 20] +bucket_batch_sizes=[20, 15, 10, 5, 2] + [train_data] class=dataset.load series=["source", "target"] data=["tests/data/train.tc.en", "tests/data/train.tc.de"] +batching= buffer_size=48 [val_data] class=dataset.load series=["source", "target"] data=["tests/data/val.tc.en", "tests/data/val.tc.de"] +batching= [encoder_vocabulary] class=vocabulary.from_wordlist diff --git a/tests/small_sent_cnn.ini b/tests/small_sent_cnn.ini index 845f64790..7bc0875c8 100644 --- a/tests/small_sent_cnn.ini +++ b/tests/small_sent_cnn.ini @@ -13,7 +13,6 @@ postprocess=None evaluation=[("target", evaluators.BLEU), ("target", evaluators.TER)] logging_period=20 validation_period=60 -runners_batch_size=1 random_seed=1234 [tf_manager] diff --git a/tests/str.ini b/tests/str.ini index daad156a4..7d5d0baa4 100644 --- a/tests/str.ini +++ b/tests/str.ini @@ -13,7 +13,6 @@ postprocess=None evaluation=[("target_chars", evaluators.EditDistance)] logging_period=1 validation_period=4 -runners_batch_size=5 test_datasets=[,] random_seed=1234 @@ -30,12 +29,13 @@ pad_w=256 rescale_w=True rescale_h=True mode="F" +channels=1 [train_data] class=dataset.load series=["images", "target", "target_chars"] data=[("tests/data/str/train_files.txt", ), "tests/data/str/train_words.txt", (processors.helpers.preprocess_char_based, "target")] -shuffled=False +buffer_size=10000 [val_data] class=dataset.load diff --git a/tests/test_data.ini b/tests/test_data.ini index 9b258bb79..3014f1d02 100644 --- a/tests/test_data.ini +++ b/tests/test_data.ini @@ -1,15 +1,21 @@ [main] test_datasets=[,] +[batching] +class=dataset.BatchingScheme +batch_size=10 + [val_data] class=dataset.load ; test wildcards series=["source", "target"] data=["tests/data/val10.part?.tc.en", "tests/data/val10.tc.de"] outputs=[("target", "tests/outputs/tmpout-val10.tc.de")] +batching= [val_data_no_target] class=dataset.load series=["source"] data=["tests/data/val10.tc.en"] outputs=[("target", "tests/outputs/tmpout-val10.tc.de")] +batching= diff --git a/tests/test_data_ensembles_all.ini b/tests/test_data_ensembles_all.ini index 700d3fb79..a5acf0376 100644 --- a/tests/test_data_ensembles_all.ini +++ b/tests/test_data_ensembles_all.ini @@ -4,9 +4,13 @@ test_datasets=[] variables=["tests/outputs/beamsearch/variables.data.0", "tests/outputs/beamsearch/variables.data.1", "tests/outputs/beamsearch/variables.data.2", "tests/outputs/beamsearch/variables.data.3"] +[batching] +class=dataset.BatchingScheme +batch_size=10 [val_data] class=dataset.load series=["source", "target"] data=["tests/data/val.tc.en", "tests/data/val.tc.de"] outputs=[("target", "tests/outputs/ensemble_out.txt")] +batching= diff --git a/tests/test_data_ensembles_duplicate.ini b/tests/test_data_ensembles_duplicate.ini index 8a3625160..03c526d85 100644 --- a/tests/test_data_ensembles_duplicate.ini +++ b/tests/test_data_ensembles_duplicate.ini @@ -5,9 +5,13 @@ test_datasets=[] variables=["tests/outputs/beamsearch/variables.data.0", "tests/outputs/beamsearch/variables.data.0", "tests/outputs/beamsearch/variables.data.0", "tests/outputs/beamsearch/variables.data.0"] +[batching] +class=dataset.BatchingScheme +batch_size=10 [val_data] class=dataset.load series=["source", "target"] data=["tests/data/val.tc.en", "tests/data/val.tc.de"] outputs=[("target", "tests/outputs/ensemble_out.txt")] +batching= diff --git a/tests/test_data_ensembles_single.ini b/tests/test_data_ensembles_single.ini index f6a231829..1c4d67e64 100644 --- a/tests/test_data_ensembles_single.ini +++ b/tests/test_data_ensembles_single.ini @@ -4,9 +4,13 @@ test_datasets=[] variables=["tests/outputs/beamsearch/variables.data.0"] +[batching] +class=dataset.BatchingScheme +batch_size=10 [val_data] class=dataset.load series=["source", "target"] data=["tests/data/val.tc.en", "tests/data/val.tc.de"] outputs=[("target", "tests/outputs/ensemble_out.txt")] +batching= diff --git a/tests/tests_run.sh b/tests/tests_run.sh index 9609abe5a..4c7c1ad03 100755 --- a/tests/tests_run.sh +++ b/tests/tests_run.sh @@ -24,6 +24,7 @@ bin/neuralmonkey-train tests/transformer.ini bin/neuralmonkey-train tests/str.ini bin/neuralmonkey-train tests/flat-multiattention.ini bin/neuralmonkey-train tests/hier-multiattention.ini +bin/neuralmonkey-train tests/small_sent_cnn.ini # Testing environment variable substitution in config file NM_EXPERIMENT_NAME=small bin/neuralmonkey-train tests/small.ini @@ -33,8 +34,6 @@ bin/neuralmonkey-run tests/small.ini tests/test_data.ini --json /dev/stdout \ | python -c 'import sys,json; print(json.load(sys.stdin)[0]["target/bleu"])' unset NM_EXPERIMENT_NAME -bin/neuralmonkey-train tests/small_sent_cnn.ini - # Ensembles testing score_single=$(bin/neuralmonkey-run tests/beamsearch.ini tests/test_data_ensembles_single.ini --json /dev/stdout | python -c 'import sys,json;print(json.load(sys.stdin)[0]["target_beam.rank001/beam_search_score"])') score_ensemble=$(bin/neuralmonkey-run tests/beamsearch_ensembles.ini tests/test_data_ensembles_duplicate.ini --json /dev/stdout | python -c 'import sys,json;print(json.load(sys.stdin)[0]["target_beam.rank001/beam_search_score"])') diff --git a/tests/transformer.ini b/tests/transformer.ini index e897c89d4..37de70100 100644 --- a/tests/transformer.ini +++ b/tests/transformer.ini @@ -14,8 +14,7 @@ runners=[] postprocess=None evaluation=[("target", evaluators.BLEU)] logging_period=10 -validation_period=60 -runners_batch_size=1 +validation_period=30 random_seed=1234 [tf_manager]