Skip to content

Commit

Permalink
fix </s> striping
Browse files Browse the repository at this point in the history
  • Loading branch information
jlibovicky authored and jindrahelcl committed Jul 11, 2018
1 parent bc0caeb commit d0da72f
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
6 changes: 4 additions & 2 deletions neuralmonkey/decoders/autoregressive.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,14 +479,16 @@ def feed_dict(self, dataset: Dataset, train: bool = False) -> FeedDict:

@tensor
def temporal_states(self) -> tf.Tensor:
# strip the last symbol which is </s>
return tf.cond(
self.train_mode,
lambda: tf.transpose(self.train_output_states, [1, 0, 2])[:, :-2],
lambda: tf.transpose(self.train_output_states, [1, 0, 2])[:, :-1],
lambda: tf.transpose(
self.runtime_output_states, [1, 0, 2])[:, :-2])
self.runtime_output_states, [1, 0, 2])[:, :-1])

@tensor
def temporal_mask(self) -> tf.Tensor:
# strip the last symbol which is </s>
return tf.cond(
self.train_mode,
lambda: tf.transpose(self.train_mask, [1, 0])[:, :-1],
Expand Down
2 changes: 1 addition & 1 deletion neuralmonkey/decoders/sequence_labeler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class SequenceLabeler(ModelPart):
Note that when the labeler is stacked on an autoregressive decoder, it
labels the symbol that is currently generated by the decoder, i.e., the
decoder's state has not yet been updated by putting the decoded symbol on
decoder state has not yet been updated by putting the decoded symbol on
its input.
"""

Expand Down

0 comments on commit d0da72f

Please sign in to comment.