Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
Cyrilvallez committed Dec 17, 2024
1 parent 2f666b3 commit bafa020
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 17 deletions.
20 changes: 5 additions & 15 deletions src/transformers/models/phi/modeling_phi.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,21 +129,11 @@ def __init__(self, config: PhiConfig, layer_idx: int):
self.scaling = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout
self.is_causal = True

self.q_proj = nn.Linear(
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
)
self.k_proj = nn.Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.v_proj = nn.Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
self.dense = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=True)
self.rotary_ndims = int(self.head_dim * config.partial_rotary_factor)
self.dense = self.o_proj
self.qk_layernorm = config.qk_layernorm
if self.qk_layernorm:
self.q_layernorm = nn.LayerNorm(
Expand Down Expand Up @@ -217,7 +207,7 @@ def forward(
)

attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
attn_output = self.dense(attn_output)
return attn_output, attn_weights


Expand Down
8 changes: 6 additions & 2 deletions src/transformers/models/phi/modular_phi.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,12 @@
class PhiAttention(LlamaAttention):
def __init__(self, config: PhiConfig, layer_idx: int):
super().__init__(config, layer_idx)
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
self.dense = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=True)
del self.o_proj
self.rotary_ndims = int(self.head_dim * config.partial_rotary_factor)
self.dense = self.o_proj
self.qk_layernorm = config.qk_layernorm
if self.qk_layernorm:
self.q_layernorm = nn.LayerNorm(
Expand Down Expand Up @@ -100,7 +104,7 @@ def forward(
)

attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
attn_output = self.dense(attn_output)
return attn_output, attn_weights


Expand Down
11 changes: 11 additions & 0 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
logging,
set_seed,
)
from transformers.cache_utils import DynamicCache
from transformers.integrations import HfDeepSpeedConfig
from transformers.integrations.deepspeed import (
is_deepspeed_available,
Expand Down Expand Up @@ -1285,6 +1286,11 @@ def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=Fa
)
for i in range(model.config.num_hidden_layers)
)
empty_pkv = (
DynamicCache.from_legacy_cache(empty_pkv)
if model_class._supports_cache_class
else empty_pkv
)

cache_length = 9
cache_shape = (batch_size, num_heads, cache_length, head_dim)
Expand All @@ -1295,6 +1301,11 @@ def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=Fa
)
for i in range(model.config.num_hidden_layers)
)
non_empty_pkv = (
DynamicCache.from_legacy_cache(empty_pkv)
if model_class._supports_cache_class
else non_empty_pkv
)

inps = copy.deepcopy(inputs_to_test[0])

Expand Down

0 comments on commit bafa020

Please sign in to comment.