Skip to content

Commit

Permalink
fixing bug when using past kv caches
Browse files Browse the repository at this point in the history
  • Loading branch information
ShashankMosaicML committed Dec 4, 2024
1 parent 58760fc commit f4ad493
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 7 deletions.
6 changes: 5 additions & 1 deletion llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,9 +474,11 @@ def flex_attn_fn(
)

check_valid_inputs(query, key, value)

query_offset = 0
if past_key_value is not None:
if len(past_key_value) != 0:
assert past_key_value[0].shape[1] == past_key_value[1].shape[1]
query_offset = past_key_value[0].shape[1]
key = torch.cat([past_key_value[0], key], dim=1)
value = torch.cat([past_key_value[1], value], dim=1)

Expand Down Expand Up @@ -564,10 +566,12 @@ def _check_mod_list(mod_list: list[dict[str, Any]], name: str):
B=query.shape[0],
block_mask_list=block_mask_list, # type: ignore
compiled_create_block_mask=compiled_create_block_mask,
query_offset=query_offset,
sequence_id_info=sequence_id_info,
)
score_mod = generate_score_mod(
score_mod_list=score_mod_list, # type: ignore
query_offset=query_offset,
sequence_id_info=sequence_id_info,
)

Expand Down
38 changes: 32 additions & 6 deletions llmfoundry/models/layers/flex_attn_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@ def _mask_mod_fn(
h: torch.Tensor,
q_idx: torch.Tensor,
kv_idx: torch.Tensor,
query_offset: int,
sequence_id_info: Optional[dict[str, Any]],
) -> torch.Tensor:
del sequence_id_info, b, h, q_idx, kv_idx
del sequence_id_info, query_offset, b, h, q_idx, kv_idx
raise NotImplementedError

def _score_mod_fn(
Expand All @@ -31,9 +32,10 @@ def _score_mod_fn(
h: torch.Tensor,
q_idx: torch.Tensor,
kv_idx: torch.Tensor,
query_offset: int,
sequence_id_info: Optional[dict[str, Any]],
) -> torch.Tensor:
del sequence_id_info, score, b, h, q_idx, kv_idx
del sequence_id_info, query_offset, score, b, h, q_idx, kv_idx
raise NotImplementedError

def __init__(self, mod_type: str) -> None:
Expand All @@ -51,9 +53,11 @@ def _mask_mod_fn(
h: torch.Tensor,
q_idx: torch.Tensor,
kv_idx: torch.Tensor,
query_offset: int,
sequence_id_info: Optional[dict[str, Any]],
) -> torch.Tensor:
del sequence_id_info, b, h
q_idx = q_idx + query_offset
return q_idx >= kv_idx

def __init__(self) -> None:
Expand All @@ -69,9 +73,11 @@ def _mask_mod_fn(
h: torch.Tensor,
q_idx: torch.Tensor,
kv_idx: torch.Tensor,
query_offset: int,
sequence_id_info: Optional[dict[str, Any]],
) -> torch.Tensor:
del sequence_id_info, b, h
q_idx = q_idx + query_offset
return q_idx - kv_idx <= self.sliding_window_size

def __init__(self, sliding_window_size: int) -> None:
Expand All @@ -88,9 +94,11 @@ def _mask_mod_fn(
h: torch.Tensor,
q_idx: torch.Tensor,
kv_idx: torch.Tensor,
query_offset: int,
sequence_id_info: Optional[dict[str, Any]],
) -> torch.Tensor:
del h
q_idx = q_idx + query_offset
if sequence_id_info is None:
raise ValueError(
'sequence_id_info is required for SequenceIdMaskMod',
Expand All @@ -113,9 +121,10 @@ def _mask_mod_fn(
h: torch.Tensor,
q_idx: torch.Tensor,
kv_idx: torch.Tensor,
query_offset: int,
sequence_id_info: Optional[dict[str, Any]],
) -> torch.Tensor:
del h, q_idx
del h, q_idx, query_offset
if sequence_id_info is None:
raise ValueError(
'sequence_id_info is required for SequenceIdMaskMod',
Expand All @@ -137,9 +146,11 @@ def _mask_mod_fn(
h: torch.Tensor,
q_idx: torch.Tensor,
kv_idx: torch.Tensor,
query_offset: int,
sequence_id_info: Optional[dict[str, Any]],
) -> torch.Tensor:
del h
q_idx = q_idx + query_offset
if sequence_id_info is None:
raise ValueError(
'sequence_id_info is required for LocalGlobalMaskMod',
Expand Down Expand Up @@ -177,9 +188,11 @@ def _score_mod_fn(
h: torch.Tensor,
q_idx: torch.Tensor,
kv_idx: torch.Tensor,
query_offset: int,
sequence_id_info: Optional[dict[str, Any]],
) -> torch.Tensor:
del sequence_id_info, b
q_idx = q_idx + query_offset
bias = -self.alibi_slopes[h] * torch.abs(q_idx - kv_idx)
return score + bias

Expand All @@ -198,9 +211,10 @@ def _score_mod_fn(
h: torch.Tensor,
q_idx: torch.Tensor,
kv_idx: torch.Tensor,
query_offset: int,
sequence_id_info: Optional[dict[str, Any]],
) -> torch.Tensor:
del sequence_id_info, b, h, q_idx, kv_idx
del sequence_id_info, query_offset, b, h, q_idx, kv_idx
return self.attn_logit_softcapping * torch.tanh(
score / self.attn_logit_softcapping,
)
Expand All @@ -216,6 +230,7 @@ def generate_block_mask(
B: int,
block_mask_list: Optional[list[FlexAttentionMod]],
compiled_create_block_mask: Any,
query_offset: int,
sequence_id_info: Optional[dict[str, Any]],
):
if block_mask_list is None:
Expand All @@ -226,12 +241,17 @@ def generate_block_mask(
if i == 0:
block_mask_fn = partial(
block_mask.mod_fn,
query_offset=query_offset,
sequence_id_info=sequence_id_info,
)
else:
block_mask_fn = and_masks(
block_mask_fn,
partial(block_mask.mod_fn, sequence_id_info=sequence_id_info),
partial(
block_mask.mod_fn,
query_offset=query_offset,
sequence_id_info=sequence_id_info,
),
)

block_mask = compiled_create_block_mask(
Expand All @@ -247,6 +267,7 @@ def generate_block_mask(

def generate_score_mod(
score_mod_list: Optional[list[FlexAttentionMod]],
query_offset: int,
sequence_id_info: Optional[dict[str, Any]],
):
if score_mod_list is None:
Expand All @@ -256,12 +277,17 @@ def generate_score_mod(
if i == 0:
wrapped_score_mod = partial(
score_mod.mod_fn,
query_offset=query_offset,
sequence_id_info=sequence_id_info,
)
else:
wrapped_score_mod = _wrap_score_mod_fns(
wrapped_score_mod,
partial(score_mod.mod_fn, sequence_id_info=sequence_id_info),
partial(
score_mod.mod_fn,
query_offset=query_offset,
sequence_id_info=sequence_id_info,
),
)

return wrapped_score_mod
Expand Down

0 comments on commit f4ad493

Please sign in to comment.