-
Notifications
You must be signed in to change notification settings - Fork 11
/
custom_beam_search_decode.py
244 lines (207 loc) · 9.95 KB
/
custom_beam_search_decode.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
# Copyright 2018 The Texar Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Beam search decoding for RNN decoders.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.contrib.seq2seq import \
AttentionWrapperState, AttentionWrapper, tile_batch
from MyBeamSearchDecoder import BeamSearchDecoder
from MyDynamicDecode import dynamic_decode
from texar.modules.decoders.rnn_decoder_base import RNNDecoderBase
# pylint: disable=too-many-arguments, protected-access, too-many-locals
# pylint: disable=invalid-name
__all__ = [
"beam_search_decode"
]
def _get_initial_state(initial_state,
tiled_initial_state,
cell,
batch_size,
beam_width,
dtype):
if tiled_initial_state is None:
if isinstance(initial_state, AttentionWrapperState):
raise ValueError(
'`initial_state` must not be an AttentionWrapperState. Use '
'a plain cell state instead, which will be wrapped into an '
'AttentionWrapperState automatically.')
if initial_state is None:
tiled_initial_state = cell.zero_state(batch_size * beam_width,
dtype)
else:
tiled_initial_state = tile_batch(initial_state,
multiplier=beam_width)
if isinstance(cell, AttentionWrapper) and \
not isinstance(tiled_initial_state, AttentionWrapperState):
zero_state = cell.zero_state(batch_size * beam_width, dtype)
tiled_initial_state = zero_state.clone(cell_state=tiled_initial_state)
return tiled_initial_state
def beam_search_decode(decoder_or_cell,
embedding,
start_tokens,
end_token,
beam_width,
vocab_size,
initial_state=None,
tiled_initial_state=None,
output_layer=None,
length_penalty_weight=0.0,
max_decoding_length=None,
output_time_major=False,
**kwargs):
"""Performs beam search sampling decoding.
Args:
decoder_or_cell: An instance of
subclass of :class:`~texar.modules.RNNDecoderBase`,
or an instance of :tf_main:`RNNCell <contrib/rnn/RNNCell>`. The
decoder or RNN cell to perform decoding.
embedding: A callable that takes a vector tensor of indexes (e.g.,
an instance of subclass of :class:`~texar.modules.EmbedderBase`),
or the :attr:`params` argument for
:tf_main:`tf.nn.embedding_lookup <nn/embedding_lookup>`.
start_tokens: `int32` vector shaped `[batch_size]`, the start tokens.
end_token: `int32` scalar, the token that marks end of decoding.
beam_width (int): Python integer, the number of beams.
initial_state (optional): Initial state of decoding. If `None`
(default), zero state is used.
The state must **not** be tiled with
:tf_main:`tile_batch <contrib/seq2seq/tile_batch>`.
If you have an already-tiled initial state, use
:attr:`tiled_initial_state` instead.
In the case of attention RNN decoder, `initial_state` must
**not** be an :tf_main:`AttentionWrapperState
<contrib/seq2seq/AttentionWrapperState>`. Instead, it must be a
state of the wrapped `RNNCell`, which state will be wrapped into
`AttentionWrapperState` automatically.
Ignored if :attr:`tiled_initial_state` is given.
tiled_initial_state (optional): Initial state that has been tiled
(typicaly with :tf_main:`tile_batch <contrib/seq2seq/tile_batch>`)
so that the batch dimension has size `batch_size * beam_width`.
In the case of attention RNN decoder, this can be either a state
of the wrapped `RNNCell`, or an `AttentionWrapperState`.
If not given, :attr:`initial_state` is used.
output_layer (optional): A :tf_main:`Layer <layers/Layer>` instance to
apply to the RNN output prior to storing the result or sampling. If
`None` and :attr:`decoder_or_cell` is a decoder, the decoder's
output layer will be used.
length_penalty_weight: Float weight to penalize length.
Disabled with `0.0` (default).
max_decoding_length (optional): A int scalar Tensor indicating the
maximum allowed number of decoding steps. If `None` (default),
decoding will continue until the end token is encountered.
output_time_major (bool): If `True`, outputs are returned as
time major tensors. If `False` (default), outputs are returned
as batch major tensors.
**kwargs: Other keyword arguments for :tf_main:`dynamic_decode
<contrib/seq2seq/dynamic_decode>` except argument
`maximum_iterations` which is set to :attr:`max_decoding_length`.
Returns:
A tuple `(outputs, final_state, sequence_length)`, where
- outputs: An instance of :tf_main:`FinalBeamSearchDecoderOutput \
<contrib/seq2seq/FinalBeamSearchDecoderOutput>`.
- final_state: An instance of :tf_main:`BeamSearchDecoderState \
<contrib/seq2seq/BeamSearchDecoderState>`.
- sequence_length: A Tensor of shape `[batch_size]` containing \
the lengths of samples.
Example:
.. code-block:: python
## Beam search with basic RNN decoder
embedder = WordEmbedder(vocab_size=data.vocab.size)
decoder = BasicRNNDecoder(vocab_size=data.vocab.size)
outputs, _, _, = beam_search_decode(
decoder_or_cell=decoder,
embedding=embedder,
start_tokens=[data.vocab.bos_token_id] * 100,
end_token=data.vocab.eos_token_id,
beam_width=5,
max_decoding_length=60)
sample_ids = sess.run(outputs.predicted_ids)
sample_text = tx.utils.map_ids_to_strs(sample_id[:,:,0], data.vocab)
print(sample_text)
# [
# the first sequence sample .
# the second sequence sample .
# ...
# ]
.. code-block:: python
## Beam search with attention RNN decoder
# Encodes the source
enc_embedder = WordEmbedder(data.source_vocab.size, ...)
encoder = UnidirectionalRNNEncoder(...)
enc_outputs, enc_state = encoder(
inputs=enc_embedder(data_batch['source_text_ids']),
sequence_length=data_batch['source_length'])
# Decodes while attending to the source
dec_embedder = WordEmbedder(vocab_size=data.target_vocab.size, ...)
decoder = AttentionRNNDecoder(
memory=enc_outputs,
memory_sequence_length=data_batch['source_length'],
vocab_size=data.target_vocab.size)
# Beam search
outputs, _, _, = beam_search_decode(
decoder_or_cell=decoder,
embedding=dec_embedder,
start_tokens=[data.vocab.bos_token_id] * 100,
end_token=data.vocab.eos_token_id,
beam_width=5,
initial_state=enc_state,
max_decoding_length=60)
"""
if isinstance(decoder_or_cell, RNNDecoderBase):
cell = decoder_or_cell._get_beam_search_cell(beam_width=beam_width)
elif isinstance(decoder_or_cell, tf.contrib.rnn.RNNCell):
cell = decoder_or_cell
else:
raise ValueError("`decoder` must be an instance of a subclass of "
"either `RNNDecoderBase` or `RNNCell`.")
start_tokens = tf.convert_to_tensor(
start_tokens, dtype=tf.int32, name="start_tokens")
if start_tokens.get_shape().ndims != 1:
raise ValueError("`start_tokens` must be a vector")
batch_size = tf.size(start_tokens)
initial_state = _get_initial_state(
initial_state, tiled_initial_state, cell,
batch_size, beam_width, tf.float32)
if output_layer is None and isinstance(decoder_or_cell, RNNDecoderBase):
output_layer = decoder_or_cell.output_layer
def _decode():
beam_docoder = BeamSearchDecoder(
cell=cell,
embedding=embedding,
start_tokens=start_tokens,
end_token=end_token,
vocab_size=vocab_size,
initial_state=initial_state,
beam_width=beam_width,
output_layer=output_layer,
length_penalty_weight=length_penalty_weight)
if 'maximum_iterations' in kwargs:
raise ValueError('Use `max_decoding_length` to set the maximum '
'allowed number of decoding steps.')
outputs, final_state, _ = dynamic_decode(
decoder=beam_docoder,
output_time_major=output_time_major,
maximum_iterations=max_decoding_length,
**kwargs)
return outputs, final_state, final_state.lengths
if isinstance(decoder_or_cell, RNNDecoderBase):
vs = decoder_or_cell.variable_scope
with tf.variable_scope(vs, reuse=tf.AUTO_REUSE):
return _decode()
else:
return _decode()