diff --git a/neuralmonkey/decoders/autoregressive.py b/neuralmonkey/decoders/autoregressive.py index 666a8b053..3ffdc6356 100644 --- a/neuralmonkey/decoders/autoregressive.py +++ b/neuralmonkey/decoders/autoregressive.py @@ -479,14 +479,16 @@ def feed_dict(self, dataset: Dataset, train: bool = False) -> FeedDict: @tensor def temporal_states(self) -> tf.Tensor: + # strip the last symbol which is return tf.cond( self.train_mode, - lambda: tf.transpose(self.train_output_states, [1, 0, 2])[:, :-2], + lambda: tf.transpose(self.train_output_states, [1, 0, 2])[:, :-1], lambda: tf.transpose( - self.runtime_output_states, [1, 0, 2])[:, :-2]) + self.runtime_output_states, [1, 0, 2])[:, :-1]) @tensor def temporal_mask(self) -> tf.Tensor: + # strip the last symbol which is return tf.cond( self.train_mode, lambda: tf.transpose(self.train_mask, [1, 0])[:, :-1], diff --git a/neuralmonkey/decoders/sequence_labeler.py b/neuralmonkey/decoders/sequence_labeler.py index 510a7ca28..f0a9b7974 100644 --- a/neuralmonkey/decoders/sequence_labeler.py +++ b/neuralmonkey/decoders/sequence_labeler.py @@ -17,7 +17,7 @@ class SequenceLabeler(ModelPart): Note that when the labeler is stacked on an autoregressive decoder, it labels the symbol that is currently generated by the decoder, i.e., the - decoder's state has not yet been updated by putting the decoded symbol on + decoder state has not yet been updated by putting the decoded symbol on its input. """