Skip to content

Commit

Permalink
add Diff Attention and other but still with errors
Browse files Browse the repository at this point in the history
  • Loading branch information
weak-kajuma committed Oct 11, 2024
1 parent 3bd9e34 commit 269055e
Showing 1 changed file with 34 additions and 8 deletions.
42 changes: 34 additions & 8 deletions src/transformers/models/diffllama/modeling_diffllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,10 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


# Copied from transformers.models.llama.modeling_llama.LlamaAttention with Llama->DiffLlama
def lambda_init_fn(layer_idx):
return 0.8 - 0.6 * math.exp(-0.3 * layer_idx)

# Copied from transformers.models.llama.modeling_llama.LlamaMLP with Llama->DiffLlama
class DiffLlamaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

Expand All @@ -294,18 +297,26 @@ def __init__(self, config: DiffLlamaConfig, layer_idx: Optional[int] = None):
self.attention_dropout = config.attention_dropout
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads) // 2
self.scaling = self.head_dim ** -0.5
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
# under this are not used
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.is_causal = True

self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
self.k_proj = nn.Linear(self.hidden_size, self.hidden_size // self.num_key_value_groups, bias=config.attention_bias)
self.v_proj = nn.Linear(self.hidden_size, self.hidden_size // self.num_key_value_groups, bias=config.attention_bias)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)

self.lambda_init = lambda_init_fn(layer_idx)
self.lambda_q1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
self.lambda_k1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
self.lambda_q2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
self.lambda_k2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))

# TODO (joao): remove in v4.46 (RoPE is computed in the model, not in the decoder layers)
self.rotary_emb = DiffLlamaRotaryEmbedding(config=self.config)

Expand All @@ -322,6 +333,7 @@ def forward(
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
target_len = q_len

if self.config.pretraining_tp > 1:
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
Expand All @@ -345,9 +357,9 @@ def forward(
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
query_states = query_states.view(bsz, q_len, 2 * self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, 2 * self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, 2 * self.head_dim).transpose(1, 2)

if position_embeddings is None:
logger.warning_once(
Expand All @@ -359,6 +371,9 @@ def forward(
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
print(query_states.size())
print(cos.size())
print(sin.size())
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

if past_key_value is not None:
Expand All @@ -368,6 +383,7 @@ def forward(

key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
query_states = query_states * self.scaling
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

if attention_mask is not None: # no matter the length, we just slice it
Expand All @@ -377,6 +393,12 @@ def forward(
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)

lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()).type_as(query_states)
lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()).type_as(query_states)
lambda_full = lambda_1 - lambda_2 + self.lambda_init
attn_weights = attn_weights.view(bsz, self.num_heads, 2, q_len, target_len)
attn_weights = attn_weights[:, :, 0] - lambda_full * attn_weights[:, :, 1]
attn_output = torch.matmul(attn_weights, value_states)

if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
Expand Down Expand Up @@ -635,6 +657,7 @@ class DiffLlamaDecoderLayer(nn.Module):
def __init__(self, config: DiffLlamaConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.layer_idx = layer_idx

self.self_attn = DIFFLLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)

Expand Down Expand Up @@ -696,7 +719,10 @@ def forward(

# Fully Connected
residual = hidden_states
# GroupNorm with scale
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = hidden_states * (1 - lambda_init_fn(self.layer_idx))

hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states

Expand Down

0 comments on commit 269055e

Please sign in to comment.