Skip to content

Commit

Permalink
bug fix, adding test
Browse files Browse the repository at this point in the history
  • Loading branch information
ShashankMosaicML committed Jun 23, 2024
1 parent 06d03c1 commit ec42e72
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 2 deletions.
4 changes: 4 additions & 0 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,6 +750,7 @@ def __init__(
device: Optional[str] = None,
bias: bool = True,
sliding_window_size: int = -1,
reuse_kv_layer_idx: Optional[int] = None,
):
super().__init__(
d_model=d_model,
Expand All @@ -766,6 +767,7 @@ def __init__(
device=device,
bias=bias,
sliding_window_size=sliding_window_size,
reuse_kv_layer_idx=reuse_kv_layer_idx,
)


Expand All @@ -791,6 +793,7 @@ def __init__(
device: Optional[str] = None,
bias: bool = True,
sliding_window_size: int = -1,
reuse_kv_layer_idx: Optional[int] = None,
):
super().__init__(
d_model=d_model,
Expand All @@ -807,6 +810,7 @@ def __init__(
device=device,
bias=bias,
sliding_window_size=sliding_window_size,
reuse_kv_layer_idx=reuse_kv_layer_idx,
)


Expand Down
1 change: 1 addition & 0 deletions llmfoundry/models/layers/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ def forward(
needs_weights=output_attentions,
alibi_slopes=alibi_slopes,
flash_attn_padding_info=flash_attn_padding_info,
prev_layer_key_value=prev_layer_key_value,
)
x = x + self.resid_attn_dropout(b)
m = x
Expand Down
6 changes: 5 additions & 1 deletion llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,9 +500,13 @@ def _construct_blocks_with_overrides(

module_list = []
layer_description_list = []
if len(model_modules_order_expanded) != config.n_layers:
raise ValueError(
f'The specified block overrides do not match the number of layers: {len(model_modules_order_expanded)} vs {config.n_layers}.',
)

for i in range(config.n_layers):
module_name = model_modules_order_expanded[i]

override_config = {}
if module_name != 'default':
override_config = copy.deepcopy(
Expand Down
66 changes: 65 additions & 1 deletion tests/models/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,17 @@ def _load_tokenizer_cfg(cfg: Union[Dict[str, Any], DictConfig]) -> Dict:
def _get_objs(
request: pytest.FixtureRequest,
conf_path: str = 'scripts/train/yamls/pretrain/testing.yaml',
model_config_overrides: Optional[Dict] = None,
attn_impl: str = 'torch',
):
warnings.filterwarnings(
action='ignore',
message='Torchmetrics v0.9 introduced a new argument class property',
)
test_cfg = get_config(conf_path=conf_path)
if model_config_overrides is not None:
for k, v in model_config_overrides.items():
test_cfg.model[k] = v

# Read FSDP Config as a dict
fsdp_config = test_cfg.get('fsdp_config', None)
Expand All @@ -97,7 +102,7 @@ def _get_objs(
device = 'cuda' if is_gpu else 'cpu'
test_cfg.precision = 'amp_bf16' if is_gpu else 'fp32'
test_cfg.model.attn_config = {
'attn_impl': 'torch',
'attn_impl': attn_impl,
}
test_cfg.model.init_device = device
test_cfg.device = device
Expand Down Expand Up @@ -2724,3 +2729,62 @@ def test_construct_blocks(start: list, repeating_pattern: list, end: list):
assert block_list[4].attn.reuse_kv_layer_idx is None
assert block_list[5].attn.sliding_window_size == 512
assert block_list[5].attn.reuse_kv_layer_idx is None


@pytest.mark.gpu
@pytest.mark.parametrize(
'conf_path',
[
'scripts/train/yamls/pretrain/testing.yaml',
],
)
def test_reuse_prev_layer_kv_cache(
request: pytest.FixtureRequest,
conf_path: str,
batch_size: int = 2,
):
model_config_overrides = {
'block_overrides': {
'start': [
{
'name': 'default',
'repeat': 1,
},
{
'name': 'kv_reuse_layer',
'repeat': 1,
},
],
'overrides': {
'kv_reuse_layer': {
'attn_config': {
'reuse_kv_layer_idx': -1,
},
},
},
},
'use_cache': True,
}
test_cfg, model, _ = _get_objs(
request=request,
conf_path=conf_path,
model_config_overrides=model_config_overrides,
attn_impl='flash',
)

batch = gen_random_batch(batch_size, test_cfg)

assert batch['input_ids'].shape == torch.Size([
batch_size,
test_cfg.max_seq_len,
])
model.train()
with get_precision_context(test_cfg.precision):
outputs = model(batch)
len(outputs.past_key_values) == 2
assert torch.all(
outputs.past_key_values[0][0] == outputs.past_key_values[1][0],
)
assert torch.all(
outputs.past_key_values[0][1] == outputs.past_key_values[1][1],
)

0 comments on commit ec42e72

Please sign in to comment.