Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decoder stateful #663

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
21 changes: 20 additions & 1 deletion neuralmonkey/decoders/autoregressive.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from neuralmonkey.model.model_part import ModelPart, FeedDict, InitializerSpecs
from neuralmonkey.logging import log, warn
from neuralmonkey.model.sequence import EmbeddedSequence
from neuralmonkey.model.stateful import TemporalStateful
from neuralmonkey.nn.utils import dropout
from neuralmonkey.tf_utils import get_variable, get_state_shape_invariants
from neuralmonkey.vocabulary import Vocabulary, START_TOKEN, UNK_TOKEN_INDEX
Expand Down Expand Up @@ -93,7 +94,7 @@ class DecoderFeedables(NamedTuple(


# pylint: disable=too-many-public-methods,too-many-instance-attributes
class AutoregressiveDecoder(ModelPart):
class AutoregressiveDecoder(ModelPart, TemporalStateful):

# pylint: disable=too-many-arguments
def __init__(self,
Expand Down Expand Up @@ -475,3 +476,21 @@ def feed_dict(self, dataset: Dataset, train: bool = False) -> FeedDict:
fd[self.train_mask] = weights

return fd

@tensor
def temporal_states(self) -> tf.Tensor:
# strip the last symbol which is </s>
return tf.cond(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tohle takhle nefunguje. tf.cond potřebuje všechny placeholdery, i když se nakonec vyhodnotí opačně.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Jo, to už jsem jednou věděl. Myslím, že se to mělo spravit v nějaké další verzi TF. Jak jsme na tom s kompatibilitou s novými verzemi?

self.train_mode,
lambda: tf.transpose(self.train_output_states, [1, 0, 2])[:, :-1],
lambda: tf.transpose(
self.runtime_output_states, [1, 0, 2])[:, :-1])

@tensor
def temporal_mask(self) -> tf.Tensor:
# strip the last symbol which is </s>
return tf.cond(
self.train_mode,
lambda: tf.transpose(self.train_mask, [1, 0])[:, :-1],
lambda: tf.to_float(tf.transpose(
self.runtime_mask, [1, 0])[:, :-1]))
77 changes: 48 additions & 29 deletions neuralmonkey/decoders/sequence_labeler.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,47 @@
from typing import Optional, Union

import tensorflow as tf
from typeguard import check_argument_types

from neuralmonkey.dataset import Dataset
from neuralmonkey.model.model_part import ModelPart, FeedDict, InitializerSpecs
from neuralmonkey.encoders.recurrent import RecurrentEncoder
from neuralmonkey.encoders.facebook_conv import SentenceEncoder
from neuralmonkey.model.stateful import TemporalStateful
from neuralmonkey.vocabulary import Vocabulary
from neuralmonkey.decorators import tensor
from neuralmonkey.tf_utils import get_variable


class SequenceLabeler(ModelPart):
"""Classifier assing a label to each encoder's state."""
"""Classifier assigning a label to each input state.

If the labeler output has an input sequence with embeddings, these are used
as additional input to the labeler.

Note that when the labeler is stacked on an autoregressive decoder, it
labels the symbol that is currently generated by the decoder, i.e., the
decoder state has not yet been updated by putting the decoded symbol on
its input.
"""

# pylint: disable=too-many-arguments
def __init__(self,
name: str,
encoder: Union[RecurrentEncoder, SentenceEncoder],
input_sequence: TemporalStateful,
vocabulary: Vocabulary,
data_id: str,
dropout_keep_prob: float = 1.0,
save_checkpoint: Optional[str] = None,
load_checkpoint: Optional[str] = None,
save_checkpoint: str = None,
load_checkpoint: str = None,
initializers: InitializerSpecs = None) -> None:
check_argument_types()
ModelPart.__init__(self, name, save_checkpoint, load_checkpoint,
initializers)

self.encoder = encoder
self.input_sequence = input_sequence
self.vocabulary = vocabulary
self.data_id = data_id
self.dropout_keep_prob = dropout_keep_prob

self.rnn_size = int(self.encoder.temporal_states.get_shape()[-1])
self.input_size = int(
self.input_sequence.temporal_states.get_shape()[-1])

with self.use_scope():
self.train_targets = tf.placeholder(
Expand All @@ -45,7 +54,7 @@ def __init__(self,
def decoding_w(self) -> tf.Variable:
return get_variable(
name="state_to_word_W",
shape=[self.rnn_size, len(self.vocabulary)],
shape=[self.input_size, len(self.vocabulary)],
initializer=tf.glorot_normal_initializer())

@tensor
Expand All @@ -57,7 +66,8 @@ def decoding_b(self) -> tf.Variable:

@tensor
def decoding_residual_w(self) -> tf.Variable:
input_dim = self.encoder.input_sequence.dimension
input_dim = (
self.input_sequence.input_sequence.dimension) # type: ignore
return get_variable(
name="emb_to_word_W",
shape=[input_dim, len(self.vocabulary)],
Expand All @@ -71,25 +81,27 @@ def logits(self) -> tf.Tensor:

# TODO dropout needs to be revisited

encoder_states = tf.expand_dims(self.encoder.temporal_states, 2)
intpus_states = tf.expand_dims(self.input_sequence.temporal_states, 2)
weights_4d = tf.expand_dims(tf.expand_dims(self.decoding_w, 0), 0)

multiplication = tf.nn.conv2d(
encoder_states, weights_4d, [1, 1, 1, 1], "SAME")
intpus_states, weights_4d, [1, 1, 1, 1], "SAME")
multiplication_3d = tf.squeeze(multiplication, squeeze_dims=[2])

biases_3d = tf.expand_dims(tf.expand_dims(self.decoding_b, 0), 0)
logits = multiplication_3d + biases_3d

embedded_inputs = tf.expand_dims(
self.encoder.input_sequence.temporal_states, 2)
dweights_4d = tf.expand_dims(
tf.expand_dims(self.decoding_residual_w, 0), 0)
if hasattr(self.input_sequence, "input_sequence"):
inputs_input = self.input_sequence.input_sequence # type: ignore
embedded_inputs = tf.expand_dims(inputs_input.temporal_states, 2)
dweights_4d = tf.expand_dims(
tf.expand_dims(self.decoding_residual_w, 0), 0)

dmultiplication = tf.nn.conv2d(
embedded_inputs, dweights_4d, [1, 1, 1, 1], "SAME")
dmultiplication_3d = tf.squeeze(dmultiplication, squeeze_dims=[2])
dmultiplication = tf.nn.conv2d(
embedded_inputs, dweights_4d, [1, 1, 1, 1], "SAME")
dmultiplication_3d = tf.squeeze(dmultiplication, squeeze_dims=[2])

logits = multiplication_3d + dmultiplication_3d + biases_3d
logits += dmultiplication_3d
return logits

@tensor
Expand All @@ -102,13 +114,20 @@ def decoded(self) -> tf.Tensor:

@tensor
def cost(self) -> tf.Tensor:
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=self.train_targets, logits=self.logits)

# loss is now of shape [batch, time]. Need to mask it now by
# element-wise multiplication with weights placeholder
weighted_loss = loss * self.train_weights
return tf.reduce_sum(weighted_loss)
min_time = tf.minimum(tf.shape(self.train_targets)[1],
tf.shape(self.logits)[1])

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Skutecne tu chceme vybirat minimum behem trenovani. Nehrozi treba, ze se sit nauci spravne generovat (dekoderovou vrstvou) napr. pouze sekvence delky 1, ktere ale bude spravne labelovat.

Nemely by se logits "paddovat" na train_targets length, nebo nejakym jinym zpusobem penalizovat kratsi sekvence?

# In case the labeler is stacked on a decoder which emits also an end
# symbol (or for some reason emits more symbol than we have in the
# ground truth labels), we trim the sequences to the length of a
# shorter one.

# pylint: disable=unsubscriptable-object
return tf.contrib.seq2seq.sequence_loss(
logits=self.logits[:, :min_time],
targets=self.train_targets[:, :min_time],
weights=self.input_sequence.temporal_mask[:, :min_time])
# pylint: enable=unsubscriptable-object
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Je skutecne v poradku tady nahradit sumu pres logity meanem? Minimalne z hlediska zpetne kompatibility (trenovacich hyperparametru) to uplne koser nebude.


@property
def train_loss(self) -> tf.Tensor:
Expand Down
18 changes: 12 additions & 6 deletions neuralmonkey/decoders/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,15 @@ def __init__(self,

self.encoder_states = get_attention_states(self.encoder)
self.encoder_mask = get_attention_mask(self.encoder)
self.dimension = (
self.encoder_states.get_shape()[2].value) # type: ignore

if self.embedding_size != self.dimension:
# This assertion (and the "int" type declaration below) here is because
# of mypy not being able to handle the tf.Tensor type.
assert self.encoder_states is not None

self.model_dimension = (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nebylo by lepsi misto prejmenovavani radej presunout radky 130-136 do overridnute property dimension? Takhle vznika zmatek v kodu

self.encoder_states.get_shape()[2].value) # type: int

if self.embedding_size != self.model_dimension:
raise ValueError("Model dimension and input embedding size"
"do not match")

Expand All @@ -140,7 +145,7 @@ def __init__(self,

@property
def output_dimension(self) -> int:
return self.dimension
return self.model_dimension

def embed_inputs(self, inputs: tf.Tensor) -> tf.Tensor:
embedded = tf.nn.embedding_lookup(self.embedding_matrix, inputs)
Expand All @@ -156,7 +161,7 @@ def embed_inputs(self, inputs: tf.Tensor) -> tf.Tensor:
embedded *= math.sqrt(embedding_size)

length = tf.shape(inputs)[1]
return embedded + position_signal(self.dimension, length)
return embedded + position_signal(self.model_dimension, length)

@tensor
def embedded_train_inputs(self) -> tf.Tensor:
Expand Down Expand Up @@ -241,7 +246,8 @@ def feedforward_sublayer(self, layer_input: tf.Tensor) -> tf.Tensor:
ff_hidden = dropout(ff_hidden, self.dropout_keep_prob, self.train_mode)

# Feed-forward output projection
ff_output = tf.layers.dense(ff_hidden, self.dimension, name="output")
ff_output = tf.layers.dense(
ff_hidden, self.model_dimension, name="output")

# Apply dropout on the output projection
ff_output = dropout(ff_output, self.dropout_keep_prob, self.train_mode)
Expand Down
2 changes: 1 addition & 1 deletion neuralmonkey/runners/label_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def get_executable(self,
num_sessions: int) -> LabelRunExecutable:
fetches = {
"label_logprobs": self._decoder.logprobs,
"input_mask": self._decoder.encoder.input_sequence.temporal_mask}
"input_mask": self._decoder.input_sequence.temporal_mask}

if compute_losses:
fetches["loss"] = self._decoder.cost
Expand Down
17 changes: 15 additions & 2 deletions tests/bpe.ini
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ epochs=2
train_dataset=<train_data>
val_dataset=<val_data>
trainer=<trainer>
runners=[<runner>]
runners=[<runner>,<lab_runner>]
evaluation=[("target", evaluators.BLEU), ("target_greedy", "target", evaluators.BLEU)]
val_preview_num_examples=5
val_preview_input_series=["source", "target", "target_bpe"]
Expand Down Expand Up @@ -94,10 +94,18 @@ data_id="target_bpe"
max_output_len=10
vocabulary=<bpe_vocabulary>

[labeler]
class=decoders.sequence_labeler.SequenceLabeler
name="tagger"
input_sequence=<decoder>
data_id="target_bpe"
dropout_keep_prob=0.5
vocabulary=<bpe_vocabulary>

[trainer]
; This block just fills the arguments of the trainer __init__ method.
class=trainers.cross_entropy_trainer.CrossEntropyTrainer
decoders=[<decoder>]
decoders=[<decoder>,<labeler>]
l2_weight=1.0e-8
clip_norm=1.0
optimizer=<adadelta>
Expand All @@ -114,3 +122,8 @@ class=runners.GreedyRunner
decoder=<decoder>
postprocess=<bpe_postprocess>
output_series="target_greedy"

[lab_runner]
class=runners.LabelRunner
decoder=<labeler>
output_series="tags"
2 changes: 1 addition & 1 deletion tests/labeler.ini
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ vocabulary=<source_vocabulary>
[decoder]
class=decoders.sequence_labeler.SequenceLabeler
name="tagger"
encoder=<encoder>
input_sequence=<encoder>
data_id="tags"
dropout_keep_prob=0.5
vocabulary=<tags_vocabulary>
Expand Down