Skip to content

Commit

Permalink
Merge pull request #1 from whyNLP/dev-lckv-dialog
Browse files Browse the repository at this point in the history
feat: support dialog attention
  • Loading branch information
why-in-Shanghaitech authored May 29, 2024
2 parents 0ac6b7f + da7aaea commit 503c82f
Show file tree
Hide file tree
Showing 5 changed files with 353 additions and 23 deletions.
3 changes: 0 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,6 @@ bash run_streaming.sh

See the script for more details. The [codes](test_streaming.py) follow the [official implementation](https://github.com/mit-han-lab/streaming-llm/blob/main/examples/eval_long_ppl.py) with minimal modification.

> [!WARNING]
> The script `run_streaming.py` is not supported yet.
### Evaluation

We use [LM-Harness](https://github.com/EleutherAI/lm-evaluation-harness) to evaluate the model. You may run the following command:
Expand Down
213 changes: 196 additions & 17 deletions models/modeling_llama_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,7 @@ def forward(
)

bsz, q_len, _ = hidden_states.size()
if encoder_outputs is not None:
kv_seq_len = encoder_outputs[0].shape[-2]
elif past_key_value is not None:
if past_key_value is not None:
kv_seq_len = q_len + past_key_value[0].shape[-2]
else:
kv_seq_len = q_len
Expand All @@ -163,10 +161,8 @@ def forward(
forward_func = self._forward_dummy
elif q_len == kv_seq_len:
forward_func = self._forward_training
elif q_len == 1:
forward_func = self._forward_decoding
else:
raise ValueError(f"Invalid q_len: {q_len} and kv_seq_len: {kv_seq_len}")
forward_func = self._forward_decoding

return forward_func(
hidden_states,
Expand Down Expand Up @@ -310,9 +306,11 @@ def _forward_decoding(
)

if attention_mask is not None:
warnings.warn(
"Attention mask is not supported for decoding. We just ignore it."
)
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask[:, :, :, 1:]

# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
Expand Down Expand Up @@ -414,9 +412,7 @@ def forward(
output_attentions = False

bsz, q_len, _ = hidden_states.size()
if encoder_outputs is not None:
kv_seq_len = encoder_outputs[0].shape[-2]
elif past_key_value is not None:
if past_key_value is not None:
kv_seq_len = q_len + past_key_value[0].shape[-2]
else:
kv_seq_len = q_len
Expand All @@ -425,10 +421,8 @@ def forward(
forward_func = self._forward_dummy
elif q_len == kv_seq_len:
forward_func = self._forward_training
elif q_len == 1:
forward_func = self._forward_decoding
else:
raise ValueError(f"Invalid q_len: {q_len} and kv_seq_len: {kv_seq_len}")
forward_func = self._forward_decoding

return forward_func(
hidden_states,
Expand Down Expand Up @@ -619,7 +613,16 @@ def _get_qkv(
use_cache: bool = False,
**kwargs,
):
if encoder_outputs is not None:
if encoder_outputs is not None and past_key_value is not None:
output = self._get_qkv_encoder_and_cache(
hidden_states,
position_ids,
past_key_value,
encoder_outputs,
use_cache,
**kwargs
)
elif encoder_outputs is not None:
output = self._get_qkv_encoder(
hidden_states,
position_ids,
Expand Down Expand Up @@ -694,6 +697,51 @@ def _get_qkv_encoder(
value_states = value_states[:, :, :-1, :]

return query_states, key_states, value_states, kv_seq_len, past_key_value

def _get_qkv_encoder_and_cache(
self,
hidden_states: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
encoder_outputs: Optional[List[torch.FloatTensor]] = None,
use_cache: bool = False,
**kwargs,
):
"""Combine the kv from cache and encoder outputs"""
bsz, q_len, _ = hidden_states.size()

query_states = self.q_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)

key_states, value_states = encoder_outputs

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states = apply_rotary_pos_emb_q(query_states, cos, sin, position_ids)

if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)

if use_cache:
_key_states = self.k_proj(hidden_states)
_value_states = self.v_proj(hidden_states)
_key_states = _key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
_value_states = _value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(_value_states, seq_len=kv_seq_len)
_key_states = apply_rotary_pos_emb_q(_key_states, cos, sin, position_ids)
past_key_value = (_key_states, _value_states)
else:
past_key_value = None

# remove the last token
key_states = key_states[:, :, :-1, :]
value_states = value_states[:, :, :-1, :]

return query_states, key_states, value_states, kv_seq_len, past_key_value


class LlamaFlashAttention2(LlamaFlashAttention2Base):
Expand All @@ -706,7 +754,16 @@ def _get_qkv(
use_cache: bool = False,
**kwargs,
):
if encoder_outputs is not None:
if encoder_outputs is not None and past_key_value is not None:
output = self._get_qkv_encoder_and_cache(
hidden_states,
position_ids,
past_key_value,
encoder_outputs,
use_cache,
**kwargs
)
elif encoder_outputs is not None:
output = self._get_qkv_encoder(
hidden_states,
position_ids,
Expand Down Expand Up @@ -781,6 +838,51 @@ def _get_qkv_encoder(
value_states = value_states[:, :, :-1, :]

return query_states, key_states, value_states, kv_seq_len, past_key_value

def _get_qkv_encoder_and_cache(
self,
hidden_states: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
encoder_outputs: Optional[List[torch.FloatTensor]] = None,
use_cache: bool = False,
**kwargs,
):
"""Combine the kv from cache and encoder outputs"""
bsz, q_len, _ = hidden_states.size()

query_states = self.q_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)

key_states, value_states = encoder_outputs

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states = apply_rotary_pos_emb_q(query_states, cos, sin, position_ids)

if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)

if use_cache:
_key_states = self.k_proj(hidden_states)
_value_states = self.v_proj(hidden_states)
_key_states = _key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
_value_states = _value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(_value_states, seq_len=kv_seq_len)
_key_states = apply_rotary_pos_emb_q(_key_states, cos, sin, position_ids)
past_key_value = (_key_states, _value_states)
else:
past_key_value = None

# remove the last token
key_states = key_states[:, :, :-1, :]
value_states = value_states[:, :, :-1, :]

return query_states, key_states, value_states, kv_seq_len, past_key_value


class LlamaAttentionMiddle(LlamaAttention):
Expand Down Expand Up @@ -865,6 +967,42 @@ def _get_qkv_encoder(
value_states = value_states[:, :, :-1, :]

return query_states, key_states, value_states, kv_seq_len, past_key_value

def _get_qkv_encoder_and_cache(
self,
hidden_states: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
encoder_outputs: Optional[List[torch.FloatTensor]] = None,
use_cache: bool = False,
**kwargs,
):
"""Combine the kv from cache and encoder outputs"""
bsz, q_len, _ = hidden_states.size()

query_states = self.q_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)

key_states, value_states = encoder_outputs

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states = apply_rotary_pos_emb_q(query_states, cos, sin, position_ids)

if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)

past_key_value = None

# remove the last token
key_states = key_states[:, :, :-1, :]
value_states = value_states[:, :, :-1, :]

return query_states, key_states, value_states, kv_seq_len, past_key_value


class LlamaFlashAttention2Middle(LlamaFlashAttention2):
Expand Down Expand Up @@ -949,6 +1087,42 @@ def _get_qkv_encoder(
value_states = value_states[:, :, :-1, :]

return query_states, key_states, value_states, kv_seq_len, past_key_value

def _get_qkv_encoder_and_cache(
self,
hidden_states: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
encoder_outputs: Optional[List[torch.FloatTensor]] = None,
use_cache: bool = False,
**kwargs,
):
"""Combine the kv from cache and encoder outputs"""
bsz, q_len, _ = hidden_states.size()

query_states = self.q_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)

key_states, value_states = encoder_outputs

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states = apply_rotary_pos_emb_q(query_states, cos, sin, position_ids)

if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)

past_key_value = None

# remove the last token
key_states = key_states[:, :, :-1, :]
value_states = value_states[:, :, :-1, :]

return query_states, key_states, value_states, kv_seq_len, past_key_value


class LlamaDecoderLayer(_LlamaDecoderLayer):
Expand Down Expand Up @@ -1711,6 +1885,11 @@ def forward_predict_prompt(
if use_cache:
layer_types = [int(x) for x in self.config.layer_types.split("_")]
memory = outputs[1][self.config.target_layer]
if past_key_values is not None:
key_states, value_states = memory
key_states = torch.cat([past_key_values[self.config.target_layer][0], key_states], dim=-2)
value_states = torch.cat([past_key_values[self.config.target_layer][1], value_states], dim=-2)
memory = (key_states, value_states)
new_past_key_values = tuple(
outputs[1][idx] if tp == 0 else memory
for idx, tp in enumerate(layer_types)
Expand Down
Loading

0 comments on commit 503c82f

Please sign in to comment.