diff --git a/neuralmonkey/decoders/autoregressive.py b/neuralmonkey/decoders/autoregressive.py index 8753b700f..53e0d526e 100644 --- a/neuralmonkey/decoders/autoregressive.py +++ b/neuralmonkey/decoders/autoregressive.py @@ -422,14 +422,16 @@ def feed_dict(self, dataset: Dataset, train: bool = False) -> FeedDict: @tensor def temporal_states(self) -> tf.Tensor: + # strip the last symbol which is return tf.cond( self.train_mode, - lambda: tf.transpose(self.train_output_states, [1, 0, 2])[:, :-2], + lambda: tf.transpose(self.train_output_states, [1, 0, 2])[:, :-1], lambda: tf.transpose( - self.runtime_output_states, [1, 0, 2])[:, :-2]) + self.runtime_output_states, [1, 0, 2])[:, :-1]) @tensor def temporal_mask(self) -> tf.Tensor: + # strip the last symbol which is return tf.cond( self.train_mode, lambda: tf.transpose(self.train_mask, [1, 0])[:, :-1], diff --git a/neuralmonkey/decoders/sequence_labeler.py b/neuralmonkey/decoders/sequence_labeler.py index 6c94ef7bd..15d72f35b 100644 --- a/neuralmonkey/decoders/sequence_labeler.py +++ b/neuralmonkey/decoders/sequence_labeler.py @@ -18,7 +18,7 @@ class SequenceLabeler(ModelPart): Note that when the labeler is stacked on an autoregressive decoder, it labels the symbol that is currently generated by the decoder, i.e., the - decoder's state has not yet been updated by putting the decoded symbol on + decoder state has not yet been updated by putting the decoded symbol on its input. """