From c60103369ae0774e30e4eb86c461f72bd3d85b7b Mon Sep 17 00:00:00 2001 From: KexinFeng Date: Sat, 21 Oct 2023 09:12:14 -0700 Subject: [PATCH] llama_dtype_fix --- src/transformers/models/llama/modeling_llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index b67719ac327162..02aed6d5e07630 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -496,7 +496,7 @@ def forward( ) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() - attn_output = self.o_proj(attn_output) + attn_output = self.o_proj(attn_output.to(self.o_proj.weight.dtype)) if not output_attentions: attn_weights = None