From d0da72f4a18bbd9fd8aa28dbdf72a7e9cc77f19f Mon Sep 17 00:00:00 2001 From: Jindrich Libovicky Date: Tue, 13 Mar 2018 10:40:30 +0100 Subject: [PATCH] fix striping --- neuralmonkey/decoders/autoregressive.py | 6 ++++-- neuralmonkey/decoders/sequence_labeler.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/neuralmonkey/decoders/autoregressive.py b/neuralmonkey/decoders/autoregressive.py index 666a8b053..3ffdc6356 100644 --- a/neuralmonkey/decoders/autoregressive.py +++ b/neuralmonkey/decoders/autoregressive.py @@ -479,14 +479,16 @@ def feed_dict(self, dataset: Dataset, train: bool = False) -> FeedDict: @tensor def temporal_states(self) -> tf.Tensor: + # strip the last symbol which is return tf.cond( self.train_mode, - lambda: tf.transpose(self.train_output_states, [1, 0, 2])[:, :-2], + lambda: tf.transpose(self.train_output_states, [1, 0, 2])[:, :-1], lambda: tf.transpose( - self.runtime_output_states, [1, 0, 2])[:, :-2]) + self.runtime_output_states, [1, 0, 2])[:, :-1]) @tensor def temporal_mask(self) -> tf.Tensor: + # strip the last symbol which is return tf.cond( self.train_mode, lambda: tf.transpose(self.train_mask, [1, 0])[:, :-1], diff --git a/neuralmonkey/decoders/sequence_labeler.py b/neuralmonkey/decoders/sequence_labeler.py index 510a7ca28..f0a9b7974 100644 --- a/neuralmonkey/decoders/sequence_labeler.py +++ b/neuralmonkey/decoders/sequence_labeler.py @@ -17,7 +17,7 @@ class SequenceLabeler(ModelPart): Note that when the labeler is stacked on an autoregressive decoder, it labels the symbol that is currently generated by the decoder, i.e., the - decoder's state has not yet been updated by putting the decoded symbol on + decoder state has not yet been updated by putting the decoded symbol on its input. """