diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 8d5e25281829bc..5881411f7c17f8 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -129,21 +129,11 @@ def __init__(self, config: PhiConfig, layer_idx: int): self.scaling = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout self.is_causal = True - - self.q_proj = nn.Linear( - config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias - ) - self.k_proj = nn.Linear( - config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias - ) - self.v_proj = nn.Linear( - config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias - ) - self.o_proj = nn.Linear( - config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias - ) + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.dense = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=True) self.rotary_ndims = int(self.head_dim * config.partial_rotary_factor) - self.dense = self.o_proj self.qk_layernorm = config.qk_layernorm if self.qk_layernorm: self.q_layernorm = nn.LayerNorm( @@ -217,7 +207,7 @@ def forward( ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() - attn_output = self.o_proj(attn_output) + attn_output = self.dense(attn_output) return attn_output, attn_weights diff --git a/src/transformers/models/phi/modular_phi.py b/src/transformers/models/phi/modular_phi.py index 46fd51e9afcb0d..ee15fed8dfb84f 100644 --- a/src/transformers/models/phi/modular_phi.py +++ b/src/transformers/models/phi/modular_phi.py @@ -25,8 +25,12 @@ class PhiAttention(LlamaAttention): def __init__(self, config: PhiConfig, layer_idx: int): super().__init__(config, layer_idx) + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.dense = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=True) + del self.o_proj self.rotary_ndims = int(self.head_dim * config.partial_rotary_factor) - self.dense = self.o_proj self.qk_layernorm = config.qk_layernorm if self.qk_layernorm: self.q_layernorm = nn.LayerNorm( @@ -100,7 +104,7 @@ def forward( ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() - attn_output = self.o_proj(attn_output) + attn_output = self.dense(attn_output) return attn_output, attn_weights diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index bc03855846fe15..23c5f33b0802e2 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -43,6 +43,7 @@ logging, set_seed, ) +from transformers.cache_utils import DynamicCache from transformers.integrations import HfDeepSpeedConfig from transformers.integrations.deepspeed import ( is_deepspeed_available, @@ -1285,6 +1286,11 @@ def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=Fa ) for i in range(model.config.num_hidden_layers) ) + empty_pkv = ( + DynamicCache.from_legacy_cache(empty_pkv) + if model_class._supports_cache_class + else empty_pkv + ) cache_length = 9 cache_shape = (batch_size, num_heads, cache_length, head_dim) @@ -1295,6 +1301,11 @@ def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=Fa ) for i in range(model.config.num_hidden_layers) ) + non_empty_pkv = ( + DynamicCache.from_legacy_cache(empty_pkv) + if model_class._supports_cache_class + else non_empty_pkv + ) inps = copy.deepcopy(inputs_to_test[0])