Skip to content

Commit

Permalink
tests and slicing
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg committed Jun 20, 2024
1 parent d5e5fed commit 834066d
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 0 deletions.
19 changes: 19 additions & 0 deletions llmfoundry/models/layers/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,13 +212,32 @@ def apply_ffn(
indices = None
if not self.use_pad_tok_in_ffn and attention_mask is not None:
assert unpad_input is not None
attention_mask = self.slice_attention_mask(attention_mask, seq_len)
m, indices, _, _ = unpad_input(m, attention_mask)
n = self.ffn(m)
if not self.use_pad_tok_in_ffn and attention_mask is not None:
assert pad_input is not None
n = pad_input(n, indices, batch_size, seq_len)
return n

def slice_attention_mask(
self,
attention_mask: torch.ByteTensor,
seq_len: int,
) -> torch.ByteTensor:
"""Slice attention mask to the correct size.
Can be overridden by subclasses to apply different slicing logic.
Args:
attention_mask (torch.ByteTensor): The attention mask.
seq_len (int): The sequence length.
Returns:
torch.ByteTensor: The sliced attention mask.
"""
return attention_mask


class FusedNormAttentionNorm(nn.Module):

Expand Down
62 changes: 62 additions & 0 deletions tests/models/layers/test_blocks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from typing import Optional

import pytest
import torch
from unittest.mock import MagicMock

from llmfoundry.models.layers import blocks
from llmfoundry.models.layers.blocks import MPTBlock

def test_default_attention_mask_slicing():
attention_mask = torch.tensor([1, 1, 0, 1]).byte()
assert isinstance(attention_mask, torch.ByteTensor)

block = MPTBlock(
d_model=4,
n_heads=1,
expansion_ratio=1,
)

output_mask = block.slice_attention_mask(
attention_mask=attention_mask,
seq_len=4,
)

assert torch.equal(output_mask, attention_mask)

def test_attention_mask_slicing_called(monkeypatch: pytest.MonkeyPatch):
m = torch.randn(2, 4, 4)
attention_mask = torch.tensor([1, 1, 1, 1]).byte()
dummy_return_mask = torch.tensor([1, 1, 1, 0]).byte()
assert isinstance(attention_mask, torch.ByteTensor)
assert isinstance(dummy_return_mask, torch.ByteTensor)
indices = torch.arange(4)

unpad_mock = MagicMock(return_value=(m, indices, None, None))
pad_mock = MagicMock(return_value=m)
monkeypatch.setattr(blocks, 'unpad_input', unpad_mock)
monkeypatch.setattr(blocks, 'pad_input', pad_mock)
class MPTBlockTest(MPTBlock):
def slice_attention_mask(
self,
attention_mask: Optional[torch.ByteTensor],
seq_len: int,
) -> Optional[torch.ByteTensor]:
del seq_len
del attention_mask
return dummy_return_mask # type: ignore

block = MPTBlockTest(
d_model=4,
n_heads=1,
expansion_ratio=1,
use_pad_tok_in_ffn=False,
)

block.apply_ffn(
attention_mask=attention_mask,
m=m,
)

assert unpad_mock.call_count == 1
unpad_mock.assert_called_with(m, dummy_return_mask)

0 comments on commit 834066d

Please sign in to comment.