Skip to content

Commit

Permalink
Fused Multihead Attention for training (#985)
Browse files Browse the repository at this point in the history
* WIP: Support fused MHA during training

* Workaround for Tensor vs int during tracing

* fused MHA for self-attention

* fix unit tests

* fix error for kv input feature size affected by factors. Fixes remaining integration tests. Added docstrings

* minor version & changelog

* Call `train()` at the start of training. Print info for resolving error if it isn't called.

* Revert "Call `train()` at the start of training. Print info for resolving error if it isn't called."

This reverts commit ff1efd6.

* Separate key-value params after loading from disk when model is in training mode

Co-authored-by: Michael Denkowski <[email protected]>
  • Loading branch information
fhieber and mjdenkowski authored Dec 19, 2021
1 parent 71527d7 commit 6905c78
Show file tree
Hide file tree
Showing 11 changed files with 307 additions and 45 deletions.
19 changes: 19 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,25 @@ Note that Sockeye has checks in place to not translate with an old model that wa

Each version section may have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_.

## [3.0.7]

## Changed

- Improve training speed by using`torch.nn.functional.multi_head_attention_forward` for self- and encoder-attention
during training. Requires reorganization of the parameter layout of the key-value input projections,
as the current Sockeye attention interleaves for faster inference.
Attention masks (both for source masking and autoregressive masks need some shape adjustments as requirements
for the fused MHA op differ slightly).
- Non-interleaved format for joint key-value input projection parameters:
`in_features=hidden, out_features=2*hidden -> Shape: (2*hidden, hidden)`
- Interleaved format for joint-key-value input projection stores key and value parameters, grouped by heads:
`Shape: ((num_heads * 2 * hidden_per_head), hidden)`
- Models save and load key-value projection parameters in interleaved format.
- When `model.training == True` key-value projection parameters are put into
non-interleaved format for `torch.nn.functional.multi_head_attention_forward`
- When `model.training == False`, i.e. model.eval() is called, key-value projection
parameters are again converted into interleaved format in place.

## [3.0.6]

### Fixed
Expand Down
2 changes: 1 addition & 1 deletion sockeye/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

__version__ = '3.0.6'
__version__ = '3.0.7'
25 changes: 15 additions & 10 deletions sockeye/decoder_pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,19 +186,24 @@ def init_state_from_encoder(self,
:param target_embed: Target-side embedding layer output. Shape: (batch, target_length, target_embedding_dim).
:return: Initial states.
"""
source_max_len = encoder_outputs.size()[1]
if target_embed is None: # Inference: initial step = 0. Shape: (batch_size, 1)
steps = pt.zeros_like(encoder_valid_length).unsqueeze(1)
# (batch * heads, 1, source_max_len)
source_mask = layers_pt.prepare_source_length_mask(encoder_valid_length, self.config.attention_heads,
source_max_len)
# Shape: (batch, heads, 1, src_max_len)
source_mask = source_mask.view(-1, self.config.attention_heads, 1, source_max_len)
else: # Training: steps up to target length. Shape: (1, target_length)
target_length = target_embed.size()[1]
steps = pt.arange(0, target_length, device=target_embed.device).unsqueeze(0)
# (batch * heads, 1, source_max_len)
source_mask = layers_pt.prepare_source_length_mask(encoder_valid_length, self.config.attention_heads,
source_max_len)
source_mask = source_mask.repeat(1, target_length, 1) # Shape: (batch * heads, trg_max_len, src_max_len)

# inverted source_length_mask for attention masking, (batch_size * heads, 1, source_max_len)
source_max_len = encoder_outputs.size()[1]
source_mask = layers_pt.prepare_source_length_mask(encoder_valid_length,
self.config.attention_heads,
source_max_len).view(-1,
self.config.attention_heads,
source_max_len)
# Shape: (batch, heads, trg_max_len, src_max_len)
source_mask = source_mask.view(-1, self.config.attention_heads, target_length, source_max_len)

if self.inference_only:
# Encoder projection caching, therefore we don't pass the encoder_outputs
Expand Down Expand Up @@ -241,7 +246,7 @@ def forward(self, step_input: pt.Tensor, states: List[pt.Tensor]) -> Tuple[pt.Te
autoregr_states = other[self.config.num_layers:]
else:
if any(layer.needs_mask for layer in self.layers):
target_mask = self.autoregressive_mask(step_input) # mask: (1, length, length)
target_mask = self.autoregressive_mask(step_input) # mask: (length, length)
steps, source_encoded, source_mask, *autoregr_states = states
enc_att_kv = [None for _ in range(self.config.num_layers)]

Expand All @@ -250,8 +255,8 @@ def forward(self, step_input: pt.Tensor, states: List[pt.Tensor]) -> Tuple[pt.Te
states_iter = iter(autoregr_states)
autoregr_states = [list(islice(states_iter, 0, layer.num_state_tensors)) for layer in self.layers] # type: ignore

batch, heads, source_max_len = source_mask.size()
source_mask_view = source_mask.view(batch * heads, 1, source_max_len)
batch, heads, target_max_len, source_max_len = source_mask.size()
source_mask_view = source_mask.view(batch * heads, target_max_len, source_max_len)

# target: (batch_size, length, model_size)
target = self.pos_embedding(step_input, steps)
Expand Down
6 changes: 4 additions & 2 deletions sockeye/encoder_pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,10 @@ def forward(self, data: pt.Tensor, valid_length: pt.Tensor) -> Tuple[pt.Tensor,
if self.dropout is not None:
data = self.dropout(data)

# inverted length_mask for attention masking, (batch_size * heads, 1, max_len)
att_mask = layers_pt.prepare_source_length_mask(valid_length, self.config.attention_heads, data.size()[1])
_, max_len, __ = data.size()
# length_mask for source attention masking. Shape: (batch_size * heads, 1, max_len)
att_mask = layers_pt.prepare_source_length_mask(valid_length, self.config.attention_heads, max_length=max_len)
att_mask = att_mask.repeat(1, max_len, 1)

data = data.transpose(1, 0) # batch to time major
for layer in self.layers:
Expand Down
Loading

0 comments on commit 6905c78

Please sign in to comment.