Skip to content

Commit

Permalink
fix phi3-small (#1148)
Browse files Browse the repository at this point in the history
  • Loading branch information
tastelikefeet authored Jun 16, 2024
1 parent 962de3b commit f6c9e84
Showing 1 changed file with 44 additions and 10 deletions.
54 changes: 44 additions & 10 deletions swift/llm/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,7 @@ class LoRATM(NamedTuple):
glm4v = ['self_attention.query_key_value']
phi = ['Wqkv']
phi3 = ['qkv_proj']
phi3_small = ['query_key_value'] # what the hell???
internlm2 = ['wqkv']
mamba = ['in_proj', 'x_proj', 'embeddings', 'out_proj']
telechat = ['key_value', 'query']
Expand Down Expand Up @@ -1605,16 +1606,6 @@ def _output_device_map_hook(module, input, output):
support_vllm=True,
tags=['general'],
hf_model_id='microsoft/Phi-3-mini-128k-instruct')
@register_model(
ModelType.phi3_small_128k_instruct,
'LLM-Research/Phi-3-small-128k-instruct',
LoRATM.phi3,
TemplateType.phi3,
requires=['transformers>=4.36'],
support_flash_attn=True,
support_vllm=True,
tags=['general'],
hf_model_id='microsoft/Phi-3-small-128k-instruct')
@register_model(
ModelType.phi3_medium_128k_instruct,
'LLM-Research/Phi-3-medium-128k-instruct',
Expand Down Expand Up @@ -2361,6 +2352,49 @@ def get_model_tokenizer_with_flash_attn(model_dir: str,
model_dir, torch_dtype, model_kwargs, load_model, model_config=model_config, **kwargs)


@register_model(
ModelType.phi3_small_128k_instruct,
'LLM-Research/Phi-3-small-128k-instruct',
LoRATM.phi3_small,
TemplateType.phi3,
requires=['transformers>=4.36'],
support_flash_attn=True,
support_gradient_checkpointing=False,
support_vllm=True,
tags=['general'],
hf_model_id='microsoft/Phi-3-small-128k-instruct')
def get_model_tokenizer_phi3_small(model_dir: str,
torch_dtype: Dtype,
model_kwargs: Dict[str, Any],
load_model: bool = True,
model_config=None,
**kwargs):
if model_config is None:
model_config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
use_flash_attn = kwargs.pop('use_flash_attn', False)
if version.parse(transformers.__version__) >= version.parse('4.36'):
if use_flash_attn:
model_config._attn_implementation = 'flash_attention_2'
else:
model_config._flash_attn_2_enabled = use_flash_attn
model, tokenizer = get_model_tokenizer_from_repo(
model_dir, torch_dtype, model_kwargs, load_model, model_config=model_config, **kwargs)

def rotary_emb(self, query_states, key_states, **kwargs):
q_type = query_states.dtype
k_type = key_states.dtype
query_states, key_states = self.rotory_emb_origin(query_states, key_states, **kwargs)
query_states = query_states.to(q_type)
key_states = key_states.to(k_type)
return query_states, key_states

for i in range(32):
re = model.model.layers[i].self_attn.rotary_emb
re.rotory_emb_origin = re.forward
re.forward = MethodType(rotary_emb, re)
return model, tokenizer


@register_model(
ModelType.qwen2_57b_a14b_instruct_int4,
'qwen/Qwen2-57B-A14B-Instruct-GPTQ-Int4',
Expand Down

0 comments on commit f6c9e84

Please sign in to comment.