From f6c9e84e11ecdad667eb6ce92aeb603147573b06 Mon Sep 17 00:00:00 2001 From: tastelikefeet <58414341+tastelikefeet@users.noreply.github.com> Date: Sun, 16 Jun 2024 19:44:09 +0800 Subject: [PATCH] fix phi3-small (#1148) --- swift/llm/utils/model.py | 54 ++++++++++++++++++++++++++++++++-------- 1 file changed, 44 insertions(+), 10 deletions(-) diff --git a/swift/llm/utils/model.py b/swift/llm/utils/model.py index abebce951..ea6773c05 100644 --- a/swift/llm/utils/model.py +++ b/swift/llm/utils/model.py @@ -451,6 +451,7 @@ class LoRATM(NamedTuple): glm4v = ['self_attention.query_key_value'] phi = ['Wqkv'] phi3 = ['qkv_proj'] + phi3_small = ['query_key_value'] # what the hell??? internlm2 = ['wqkv'] mamba = ['in_proj', 'x_proj', 'embeddings', 'out_proj'] telechat = ['key_value', 'query'] @@ -1605,16 +1606,6 @@ def _output_device_map_hook(module, input, output): support_vllm=True, tags=['general'], hf_model_id='microsoft/Phi-3-mini-128k-instruct') -@register_model( - ModelType.phi3_small_128k_instruct, - 'LLM-Research/Phi-3-small-128k-instruct', - LoRATM.phi3, - TemplateType.phi3, - requires=['transformers>=4.36'], - support_flash_attn=True, - support_vllm=True, - tags=['general'], - hf_model_id='microsoft/Phi-3-small-128k-instruct') @register_model( ModelType.phi3_medium_128k_instruct, 'LLM-Research/Phi-3-medium-128k-instruct', @@ -2361,6 +2352,49 @@ def get_model_tokenizer_with_flash_attn(model_dir: str, model_dir, torch_dtype, model_kwargs, load_model, model_config=model_config, **kwargs) +@register_model( + ModelType.phi3_small_128k_instruct, + 'LLM-Research/Phi-3-small-128k-instruct', + LoRATM.phi3_small, + TemplateType.phi3, + requires=['transformers>=4.36'], + support_flash_attn=True, + support_gradient_checkpointing=False, + support_vllm=True, + tags=['general'], + hf_model_id='microsoft/Phi-3-small-128k-instruct') +def get_model_tokenizer_phi3_small(model_dir: str, + torch_dtype: Dtype, + model_kwargs: Dict[str, Any], + load_model: bool = True, + model_config=None, + **kwargs): + if model_config is None: + model_config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True) + use_flash_attn = kwargs.pop('use_flash_attn', False) + if version.parse(transformers.__version__) >= version.parse('4.36'): + if use_flash_attn: + model_config._attn_implementation = 'flash_attention_2' + else: + model_config._flash_attn_2_enabled = use_flash_attn + model, tokenizer = get_model_tokenizer_from_repo( + model_dir, torch_dtype, model_kwargs, load_model, model_config=model_config, **kwargs) + + def rotary_emb(self, query_states, key_states, **kwargs): + q_type = query_states.dtype + k_type = key_states.dtype + query_states, key_states = self.rotory_emb_origin(query_states, key_states, **kwargs) + query_states = query_states.to(q_type) + key_states = key_states.to(k_type) + return query_states, key_states + + for i in range(32): + re = model.model.layers[i].self_attn.rotary_emb + re.rotory_emb_origin = re.forward + re.forward = MethodType(rotary_emb, re) + return model, tokenizer + + @register_model( ModelType.qwen2_57b_a14b_instruct_int4, 'qwen/Qwen2-57B-A14B-Instruct-GPTQ-Int4',