diff --git a/neuralmonkey/attention/stateful_context.py b/neuralmonkey/attention/stateful_context.py index df97fd906..b638e551b 100644 --- a/neuralmonkey/attention/stateful_context.py +++ b/neuralmonkey/attention/stateful_context.py @@ -63,7 +63,7 @@ def attention(self, AttentionLoopState]: context = tf.reshape(self.attention_states, [-1, self.context_vector_size]) - weights = tf.ones(shape=[tf.shape(context)[0]]) + weights = tf.ones(shape=[self.batch_size, 1]) next_contexts = tf.concat( [loop_state.contexts, tf.expand_dims(context, 0)], 0) @@ -77,8 +77,7 @@ def attention(self, def initial_loop_state(self) -> AttentionLoopState: return empty_attention_loop_state( - self.batch_size, - tf.shape(self.attention_states)[1], + self.batch_size, 1, self.context_vector_size) def finalize_loop(self, key: str,