Skip to content

Commit

Permalink
fix </s> striping
Browse files Browse the repository at this point in the history
  • Loading branch information
jlibovicky committed Mar 13, 2018
1 parent 9b42a2c commit f535288
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
6 changes: 4 additions & 2 deletions neuralmonkey/decoders/autoregressive.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,14 +422,16 @@ def feed_dict(self, dataset: Dataset, train: bool = False) -> FeedDict:

@tensor
def temporal_states(self) -> tf.Tensor:
# strip the last symbol which is </s>
return tf.cond(
self.train_mode,
lambda: tf.transpose(self.train_output_states, [1, 0, 2])[:, :-2],
lambda: tf.transpose(self.train_output_states, [1, 0, 2])[:, :-1],
lambda: tf.transpose(
self.runtime_output_states, [1, 0, 2])[:, :-2])
self.runtime_output_states, [1, 0, 2])[:, :-1])

@tensor
def temporal_mask(self) -> tf.Tensor:
# strip the last symbol which is </s>
return tf.cond(
self.train_mode,
lambda: tf.transpose(self.train_mask, [1, 0])[:, :-1],
Expand Down
2 changes: 1 addition & 1 deletion neuralmonkey/decoders/sequence_labeler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class SequenceLabeler(ModelPart):
Note that when the labeler is stacked on an autoregressive decoder, it
labels the symbol that is currently generated by the decoder, i.e., the
decoder's state has not yet been updated by putting the decoded symbol on
decoder state has not yet been updated by putting the decoded symbol on
its input.
"""

Expand Down

0 comments on commit f535288

Please sign in to comment.