From 2a4aea2a4db1b907f19bc656d3d8dce11f8da376 Mon Sep 17 00:00:00 2001 From: Jindrich Libovicky Date: Fri, 22 Feb 2019 17:04:36 +0100 Subject: [PATCH 1/6] generalize labeler and allow re-use embeddings for labeling --- neuralmonkey/decoders/sequence_labeler.py | 163 ++++++++++++++-------- neuralmonkey/runners/label_runner.py | 2 +- tests/labeler.ini | 2 +- 3 files changed, 106 insertions(+), 61 deletions(-) diff --git a/neuralmonkey/decoders/sequence_labeler.py b/neuralmonkey/decoders/sequence_labeler.py index 71b504a0f..7d92b4446 100644 --- a/neuralmonkey/decoders/sequence_labeler.py +++ b/neuralmonkey/decoders/sequence_labeler.py @@ -1,29 +1,34 @@ -from typing import Dict, Union +from typing import List, Dict, Callable import tensorflow as tf from typeguard import check_argument_types from neuralmonkey.dataset import Dataset from neuralmonkey.decorators import tensor -from neuralmonkey.encoders.recurrent import RecurrentEncoder -from neuralmonkey.encoders.facebook_conv import SentenceEncoder +from neuralmonkey.model.stateful import TemporalStateful from neuralmonkey.model.feedable import FeedDict from neuralmonkey.model.parameterized import InitializerSpecs from neuralmonkey.model.model_part import ModelPart -from neuralmonkey.tf_utils import get_variable +from neuralmonkey.model.sequence import EmbeddedSequence +from neuralmonkey.nn.utils import dropout from neuralmonkey.vocabulary import Vocabulary, pad_batch, sentence_mask class SequenceLabeler(ModelPart): """Classifier assing a label to each encoder's state.""" - # pylint: disable=too-many-arguments + # pylint: disable=too-many-arguments,too-many-locals def __init__(self, name: str, - encoder: Union[RecurrentEncoder, SentenceEncoder], + encoders: List[TemporalStateful], vocabulary: Vocabulary, data_id: str, + max_output_len: int, + hidden_dim: int = None, + activation: Callable = tf.nn.relu, dropout_keep_prob: float = 1.0, + add_start_symbol: bool = False, + add_end_symbol: bool = False, reuse: ModelPart = None, save_checkpoint: str = None, load_checkpoint: str = None, @@ -32,11 +37,16 @@ def __init__(self, ModelPart.__init__(self, name, reuse, save_checkpoint, load_checkpoint, initializers) - self.encoder = encoder + self.encoders = encoders self.vocabulary = vocabulary self.data_id = data_id + self.max_output_len = max_output_len + self.hidden_dim = hidden_dim + self.activation = activation self.dropout_keep_prob = dropout_keep_prob - # pylint: enable=too-many-arguments + self.add_start_symbol = add_start_symbol + self.add_end_symbol = add_end_symbol + # pylint: enable=too-many-arguments,too-many-locals @property def input_types(self) -> Dict[str, tf.DType]: @@ -52,64 +62,36 @@ def target_tokens(self) -> tf.Tensor: @tensor def train_targets(self) -> tf.Tensor: - return self.vocabulary.strings_to_indices(self.target_tokens) + return self.vocabulary.strings_to_indices( + self.dataset[self.data_id]) @tensor def train_mask(self) -> tf.Tensor: return sentence_mask(self.train_targets) - @property - def rnn_size(self) -> int: - return int(self.encoder.temporal_states.get_shape()[-1]) - - @tensor - def decoding_w(self) -> tf.Variable: - return get_variable( - name="state_to_word_W", - shape=[self.rnn_size, len(self.vocabulary)]) - @tensor - def decoding_b(self) -> tf.Variable: - return get_variable( - name="state_to_word_b", - shape=[len(self.vocabulary)], - initializer=tf.zeros_initializer()) + def concatenated_inputs(self) -> tf.Tensor: + return tf.concat( + [inp.temporal_states for inp in self.encoders], axis=2) @tensor - def decoding_residual_w(self) -> tf.Variable: - input_dim = self.encoder.input_sequence.dimension - return get_variable( - name="emb_to_word_W", - shape=[input_dim, len(self.vocabulary)]) + def states(self) -> tf.Tensor: + states = dropout( + self.concatenated_inputs, self.dropout_keep_prob, self.train_mode) + + if self.hidden_dim is not None: + hidden = tf.layers.dense( + states, self.hidden_dim, self.activation, + name="hidden_layer") + # pylint: disable=redefined-variable-type + states = dropout(hidden, self.dropout_keep_prob, self.train_mode) + # pylint: enable=redefined-variable-type + return states @tensor def logits(self) -> tf.Tensor: - # To multiply 3-D matrix (encoder hidden states) by a 2-D matrix - # (weights), we use 1-by-1 convolution (similar trick can be found in - # attention computation) - - # TODO dropout needs to be revisited - - encoder_states = tf.expand_dims(self.encoder.temporal_states, 2) - weights_4d = tf.expand_dims(tf.expand_dims(self.decoding_w, 0), 0) - - multiplication = tf.nn.conv2d( - encoder_states, weights_4d, [1, 1, 1, 1], "SAME") - multiplication_3d = tf.squeeze(multiplication, axis=[2]) - - biases_3d = tf.expand_dims(tf.expand_dims(self.decoding_b, 0), 0) - - embedded_inputs = tf.expand_dims( - self.encoder.input_sequence.temporal_states, 2) - dweights_4d = tf.expand_dims( - tf.expand_dims(self.decoding_residual_w, 0), 0) - - dmultiplication = tf.nn.conv2d( - embedded_inputs, dweights_4d, [1, 1, 1, 1], "SAME") - dmultiplication_3d = tf.squeeze(dmultiplication, axis=[2]) - - logits = multiplication_3d + dmultiplication_3d + biases_3d - return logits + return tf.layers.dense( + self.states, len(self.vocabulary), name="logits") @tensor def logprobs(self) -> tf.Tensor: @@ -120,14 +102,17 @@ def decoded(self) -> tf.Tensor: return tf.argmax(self.logits, 2) @tensor - def cost(self) -> tf.Tensor: + def train_xents(self) -> tf.Tensor: loss = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=self.train_targets, logits=self.logits) # loss is now of shape [batch, time]. Need to mask it now by # element-wise multiplication with weights placeholder - weighted_loss = loss * self.train_mask - return tf.reduce_sum(weighted_loss) + return loss * self.train_mask + + @tensor + def cost(self) -> tf.Tensor: + return tf.reduce_sum(self.train_xents) / tf.reduce_sum(self.train_mask) @property def train_loss(self) -> tf.Tensor: @@ -142,6 +127,66 @@ def feed_dict(self, dataset: Dataset, train: bool = False) -> FeedDict: sentences = dataset.maybe_get_series(self.data_id) if sentences is not None: - fd[self.target_tokens] = pad_batch(list(sentences)) + fd[self.target_tokens] = pad_batch( + list(sentences), self.max_output_len, self.add_start_symbol, + self.add_end_symbol) return fd + + +class EmbeddingsLabeler(SequenceLabeler): + """SequenceLabeler that uses an embedding matrix for output projection.""" + + # pylint: disable=too-many-arguments,too-many-locals + def __init__(self, + name: str, + encoders: List[TemporalStateful], + embedded_sequence: EmbeddedSequence, + data_id: str, + max_output_len: int, + hidden_dim: int = None, + activation: Callable = tf.nn.relu, + train_embeddings: bool = True, + dropout_keep_prob: float = 1.0, + add_start_symbol: bool = False, + add_end_symbol: bool = False, + reuse: ModelPart = None, + save_checkpoint: str = None, + load_checkpoint: str = None, + initializers: InitializerSpecs = None) -> None: + + check_argument_types() + SequenceLabeler.__init__( + self, name, encoders, embedded_sequence.vocabulary, data_id, + max_output_len, hidden_dim=hidden_dim, activation=activation, + dropout_keep_prob=dropout_keep_prob, + add_start_symbol=add_start_symbol, add_end_symbol=add_end_symbol, + reuse=reuse, save_checkpoint=save_checkpoint, + load_checkpoint=load_checkpoint, initializers=initializers) + + self.embedded_sequence = embedded_sequence + self.train_embeddings = train_embeddings + # pylint: enable=too-many-arguments,too-many-locals + + @tensor + def logits(self) -> tf.Tensor: + embeddings = self.embedded_sequence.embedding_matrix + if not self.train_embeddings: + embeddings = tf.stop_gradient(embeddings) + + states = self.states + # pylint: disable=no-member + states_dim = self.states.get_shape()[-1].value + # pylint: enable=no-member + embedding_dim = self.embedded_sequence.embedding_sizes[0] + # pylint: disable=redefined-variable-type + if states_dim != embedding_dim: + states = tf.layers.dense( + states, embedding_dim, name="project_for_embeddings") + # pylint: enable=redefined-variable-type + + reshaped_states = tf.reshape(states, [-1, embedding_dim]) + reshaped_logits = tf.matmul( + reshaped_states, embeddings, transpose_b=True, name="logits") + return tf.reshape( + reshaped_logits, [self.batch_size, -1, len(self.vocabulary)]) diff --git a/neuralmonkey/runners/label_runner.py b/neuralmonkey/runners/label_runner.py index 0ab0303b8..011710e7e 100644 --- a/neuralmonkey/runners/label_runner.py +++ b/neuralmonkey/runners/label_runner.py @@ -60,7 +60,7 @@ def __init__(self, def fetches(self) -> Dict[str, tf.Tensor]: return { "label_logprobs": self.decoder.logprobs, - "input_mask": self.decoder.encoder.input_sequence.temporal_mask, + "input_mask": self.decoder.encoders[0].temporal_mask, "loss": self.decoder.cost} @property diff --git a/tests/labeler.ini b/tests/labeler.ini index 9d2835a0b..768447eab 100644 --- a/tests/labeler.ini +++ b/tests/labeler.ini @@ -59,7 +59,7 @@ vocabulary= [decoder] class=decoders.sequence_labeler.SequenceLabeler name="tagger" -encoder= +encoders=[] data_id="tags" dropout_keep_prob=0.5 vocabulary= From 76c6f117dfb00e93f594a9e89e7b72582ab23d95 Mon Sep 17 00:00:00 2001 From: Jindra Helcl Date: Wed, 13 Mar 2019 16:06:40 +0100 Subject: [PATCH 2/6] Fixing tests and adressing reviews --- neuralmonkey/decoders/sequence_labeler.py | 27 +++++++++++++++++--- neuralmonkey/readers/string_vector_reader.py | 2 +- neuralmonkey/runners/label_runner.py | 2 +- 3 files changed, 25 insertions(+), 6 deletions(-) diff --git a/neuralmonkey/decoders/sequence_labeler.py b/neuralmonkey/decoders/sequence_labeler.py index 7d92b4446..60bbdfc71 100644 --- a/neuralmonkey/decoders/sequence_labeler.py +++ b/neuralmonkey/decoders/sequence_labeler.py @@ -23,7 +23,7 @@ def __init__(self, encoders: List[TemporalStateful], vocabulary: Vocabulary, data_id: str, - max_output_len: int, + max_output_len: int = None, hidden_dim: int = None, activation: Callable = tf.nn.relu, dropout_keep_prob: float = 1.0, @@ -56,6 +56,20 @@ def input_types(self) -> Dict[str, tf.DType]: def input_shapes(self) -> Dict[str, tf.TensorShape]: return {self.data_id: tf.TensorShape([None, None])} + @tensor + def input_mask(self) -> tf.Tensor: + mask_main = self.encoders[0].temporal_mask + + asserts = [ + tf.assert_equal( + mask_main, enc.temporal_mask, + message=("Encoders '{}' and '{}' does not have equal temporal " + "masks.".format(self.encoders[0].name, enc.name))) + for enc in self.encoders[1:]] + + with tf.control_dependencies(asserts): + return mask_main + @tensor def target_tokens(self) -> tf.Tensor: return self.dataset[self.data_id] @@ -71,8 +85,10 @@ def train_mask(self) -> tf.Tensor: @tensor def concatenated_inputs(self) -> tf.Tensor: - return tf.concat( - [inp.temporal_states for inp in self.encoders], axis=2) + # Validate shapes first + with tf.control_dependencies(self.input_mask): + return tf.concat( + [inp.temporal_states for inp in self.encoders], axis=2) @tensor def states(self) -> tf.Tensor: @@ -112,6 +128,8 @@ def train_xents(self) -> tf.Tensor: @tensor def cost(self) -> tf.Tensor: + # Cross entropy mean over all words in the batch + # (could also be done as a mean over sentences) return tf.reduce_sum(self.train_xents) / tf.reduce_sum(self.train_mask) @property @@ -143,7 +161,7 @@ def __init__(self, encoders: List[TemporalStateful], embedded_sequence: EmbeddedSequence, data_id: str, - max_output_len: int, + max_output_len: int = None, hidden_dim: int = None, activation: Callable = tf.nn.relu, train_embeddings: bool = True, @@ -183,6 +201,7 @@ def logits(self) -> tf.Tensor: if states_dim != embedding_dim: states = tf.layers.dense( states, embedding_dim, name="project_for_embeddings") + states = dropout(states, self.dropout_keep_prob, self.train_mode) # pylint: enable=redefined-variable-type reshaped_states = tf.reshape(states, [-1, embedding_dim]) diff --git a/neuralmonkey/readers/string_vector_reader.py b/neuralmonkey/readers/string_vector_reader.py index d6545b2a3..439a23838 100644 --- a/neuralmonkey/readers/string_vector_reader.py +++ b/neuralmonkey/readers/string_vector_reader.py @@ -13,7 +13,7 @@ def process_line(line: str, lineno: int, path: str) -> np.ndarray: return np.array(numbers, dtype=dtype) - def reader(files: List[str])-> Iterable[List[np.ndarray]]: + def reader(files: List[str]) -> Iterable[List[np.ndarray]]: for path in files: current_line = 0 diff --git a/neuralmonkey/runners/label_runner.py b/neuralmonkey/runners/label_runner.py index 011710e7e..8b21a1a28 100644 --- a/neuralmonkey/runners/label_runner.py +++ b/neuralmonkey/runners/label_runner.py @@ -60,7 +60,7 @@ def __init__(self, def fetches(self) -> Dict[str, tf.Tensor]: return { "label_logprobs": self.decoder.logprobs, - "input_mask": self.decoder.encoders[0].temporal_mask, + "input_mask": self.decoder.input_mask, "loss": self.decoder.cost} @property From 86bca5866b00008eec0c9356e379b3c49afcfbd3 Mon Sep 17 00:00:00 2001 From: Jindra Helcl Date: Wed, 13 Mar 2019 16:49:58 +0100 Subject: [PATCH 3/6] Update pylint, fix bugs introduced while fixing bugs --- neuralmonkey/config/builder.py | 3 +-- neuralmonkey/decoders/sequence_labeler.py | 4 ++-- neuralmonkey/learning_utils.py | 6 ++---- 3 files changed, 5 insertions(+), 8 deletions(-) diff --git a/neuralmonkey/config/builder.py b/neuralmonkey/config/builder.py index c7524a556..f2093df94 100644 --- a/neuralmonkey/config/builder.py +++ b/neuralmonkey/config/builder.py @@ -47,8 +47,7 @@ def create(self) -> Any: if exc.name == module_name: # type: ignore raise Exception( "Cannot import module {}.".format(module_name)) - else: - raise + raise try: clazz = getattr(module, class_name) diff --git a/neuralmonkey/decoders/sequence_labeler.py b/neuralmonkey/decoders/sequence_labeler.py index 60bbdfc71..07f915cab 100644 --- a/neuralmonkey/decoders/sequence_labeler.py +++ b/neuralmonkey/decoders/sequence_labeler.py @@ -64,7 +64,7 @@ def input_mask(self) -> tf.Tensor: tf.assert_equal( mask_main, enc.temporal_mask, message=("Encoders '{}' and '{}' does not have equal temporal " - "masks.".format(self.encoders[0].name, enc.name))) + "masks.".format(str(self.encoders[0]), str(enc)))) for enc in self.encoders[1:]] with tf.control_dependencies(asserts): @@ -86,7 +86,7 @@ def train_mask(self) -> tf.Tensor: @tensor def concatenated_inputs(self) -> tf.Tensor: # Validate shapes first - with tf.control_dependencies(self.input_mask): + with tf.control_dependencies([self.input_mask]): return tf.concat( [inp.temporal_states for inp in self.encoders], axis=2) diff --git a/neuralmonkey/learning_utils.py b/neuralmonkey/learning_utils.py index 50e0e0711..ed6a0dec4 100644 --- a/neuralmonkey/learning_utils.py +++ b/neuralmonkey/learning_utils.py @@ -260,15 +260,13 @@ def _check_series_collisions(runners: List[BaseRunner], if series in runners_outputs: raise Exception(("Output series '{}' is multiple times among the " "runners' outputs.").format(series)) - else: - runners_outputs.add(series) + runners_outputs.add(series) if postprocess is not None: for series, _ in postprocess: if series in runners_outputs: raise Exception(("Postprocess output series '{}' " "already exists.").format(series)) - else: - runners_outputs.add(series) + runners_outputs.add(series) def run_on_dataset(tf_manager: TensorFlowManager, From e51a98e2fd3f40a0b5d4fc6bf2f69aa55d4e2025 Mon Sep 17 00:00:00 2001 From: Jindra Helcl Date: Thu, 14 Mar 2019 14:58:53 +0100 Subject: [PATCH 4/6] refactoring xents in autoregressive decoder to (batch, time) --- neuralmonkey/decoders/autoregressive.py | 13 ++++++++++--- neuralmonkey/runners/perplexity_runner.py | 5 ++++- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/neuralmonkey/decoders/autoregressive.py b/neuralmonkey/decoders/autoregressive.py index 46fdffa7f..16da07142 100644 --- a/neuralmonkey/decoders/autoregressive.py +++ b/neuralmonkey/decoders/autoregressive.py @@ -288,16 +288,21 @@ def train_xents(self) -> tf.Tensor: tf.one_hot(labels, len(self.vocabulary)), logits, label_smoothing=self.label_smoothing)) + # Return losses of shape (batch, time). Losses on invalid positions + # are zero. return tf.contrib.seq2seq.sequence_loss( tf.transpose(self.train_logits, perm=[1, 0, 2]), train_targets, tf.transpose(self.train_mask), average_across_batch=False, + average_across_timesteps=False, softmax_loss_function=softmax_function) @tensor def train_loss(self) -> tf.Tensor: - return tf.reduce_mean(self.train_xents) + # Cross entropy mean over all words in the batch + # (could also be done as a mean over sentences) + return tf.reduce_sum(self.train_xents) / tf.reduce_sum(self.train_mask) @property def cost(self) -> tf.Tensor: @@ -344,11 +349,13 @@ def runtime_xents(self) -> tf.Tensor: logits=batch_major_logits[:, :min_time], targets=train_targets[:, :min_time], weights=tf.transpose(self.train_mask)[:, :min_time], - average_across_batch=False) + average_across_batch=False, + average_across_timesteps=False) @tensor def runtime_loss(self) -> tf.Tensor: - return tf.reduce_mean(self.runtime_xents) + return (tf.reduce_sum(self.runtime_xents) + / tf.reduce_sum(self.runtime_mask)) @tensor def runtime_logprobs(self) -> tf.Tensor: diff --git a/neuralmonkey/runners/perplexity_runner.py b/neuralmonkey/runners/perplexity_runner.py index 89797cac1..c6aa02797 100644 --- a/neuralmonkey/runners/perplexity_runner.py +++ b/neuralmonkey/runners/perplexity_runner.py @@ -32,7 +32,10 @@ def __init__(self, @tensor def fetches(self) -> Dict[str, tf.Tensor]: - return {"xents": self.decoder.train_xents} + # decoder.train_xents are (batch, time) + # we average xents over the time dimension + return {"xents": (tf.reduce_sum(self.decoder.train_xents, axis=1) + / tf.reduce_sum(self.decoder.train_mask, axis=1))} @property def loss_names(self) -> List[str]: From 1bdbfaf3431c6d9eb175e5226524c22f91ff5768 Mon Sep 17 00:00:00 2001 From: Jindra Helcl Date: Thu, 14 Mar 2019 15:07:28 +0100 Subject: [PATCH 5/6] fix: runtime mask is bool, need to convert --- neuralmonkey/decoders/autoregressive.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/neuralmonkey/decoders/autoregressive.py b/neuralmonkey/decoders/autoregressive.py index 16da07142..bdf80bcb1 100644 --- a/neuralmonkey/decoders/autoregressive.py +++ b/neuralmonkey/decoders/autoregressive.py @@ -355,7 +355,7 @@ def runtime_xents(self) -> tf.Tensor: @tensor def runtime_loss(self) -> tf.Tensor: return (tf.reduce_sum(self.runtime_xents) - / tf.reduce_sum(self.runtime_mask)) + / tf.reduce_sum(tf.to_float(self.runtime_mask))) @tensor def runtime_logprobs(self) -> tf.Tensor: From 1a9d54acf19622bcf8348e3008240153cfae6b3f Mon Sep 17 00:00:00 2001 From: Jindra Helcl Date: Thu, 14 Mar 2019 15:27:39 +0100 Subject: [PATCH 6/6] transposing decoder train mask for perplexity runner --- neuralmonkey/runners/perplexity_runner.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/neuralmonkey/runners/perplexity_runner.py b/neuralmonkey/runners/perplexity_runner.py index c6aa02797..27089e5a3 100644 --- a/neuralmonkey/runners/perplexity_runner.py +++ b/neuralmonkey/runners/perplexity_runner.py @@ -34,8 +34,9 @@ def __init__(self, def fetches(self) -> Dict[str, tf.Tensor]: # decoder.train_xents are (batch, time) # we average xents over the time dimension - return {"xents": (tf.reduce_sum(self.decoder.train_xents, axis=1) - / tf.reduce_sum(self.decoder.train_mask, axis=1))} + return {"xents": tf.reduce_sum( + self.decoder.train_xents, axis=1) / tf.reduce_sum( + tf.transpose(self.decoder.train_mask), axis=1)} @property def loss_names(self) -> List[str]: