Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support qwen1.5 megatron #1564

Merged
merged 35 commits into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
9a69546
update
Jintao-Huang Aug 1, 2024
bddd3e2
Merge branch 'main' into megatron_model
Jintao-Huang Aug 1, 2024
3953146
update
Jintao-Huang Aug 1, 2024
8cc7660
update
Jintao-Huang Aug 1, 2024
20f8e32
update
Jintao-Huang Aug 1, 2024
2c54114
update
Jintao-Huang Aug 1, 2024
3bc2c66
update
Jintao-Huang Aug 1, 2024
12308c6
Merge branch 'main' into megatron_model
Jintao-Huang Aug 1, 2024
73db1d7
Merge branch 'main' into megatron_model
Jintao-Huang Aug 2, 2024
80deef2
update
Jintao-Huang Aug 2, 2024
49f4629
Merge branch 'main' into megatron_model
Jintao-Huang Aug 2, 2024
f8a505d
support megatron
Jintao-Huang Aug 2, 2024
057e95f
update docs
Jintao-Huang Aug 3, 2024
eb8c7bd
support qwen1.5
Jintao-Huang Aug 3, 2024
fec2685
support chat
Jintao-Huang Aug 4, 2024
0d20b2d
support deploy info
Jintao-Huang Aug 4, 2024
5ff020e
fix lmdeploy bug
Jintao-Huang Aug 4, 2024
6c2173e
Merge branch 'main' into megatron_model
Jintao-Huang Aug 5, 2024
4a0e29e
Merge branch 'main' into megatron_model
Jintao-Huang Aug 5, 2024
20e02ef
Merge branch 'main' into megatron_model
Jintao-Huang Aug 5, 2024
f4d484d
update
Jintao-Huang Aug 5, 2024
3362738
update
Jintao-Huang Aug 5, 2024
05f7f86
lint pass
Jintao-Huang Aug 5, 2024
325b919
update
Jintao-Huang Aug 5, 2024
f5a0251
update
Jintao-Huang Aug 5, 2024
b46dc66
support qwen1.5
Jintao-Huang Aug 5, 2024
b606952
update docs
Jintao-Huang Aug 5, 2024
0b18a54
fix bugs
Jintao-Huang Aug 5, 2024
8b093d0
update
Jintao-Huang Aug 5, 2024
d77f148
update
Jintao-Huang Aug 6, 2024
d9cb8e0
Merge branch 'main' into megatron_model
Jintao-Huang Aug 6, 2024
a50d001
update
Jintao-Huang Aug 6, 2024
46e0c5c
update
Jintao-Huang Aug 6, 2024
42d5fd4
update
Jintao-Huang Aug 6, 2024
d40603e
update docs
Jintao-Huang Aug 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/LLM/Megatron训练文档.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Megatron训练文档 (测试版)
# Megatron训练文档

## 目录
- [环境准备](#环境准备)
Expand Down
776 changes: 388 additions & 388 deletions docs/source/LLM/支持的模型和数据集.md

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions docs/source/LLM/自定义与拓展.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,21 +64,21 @@ system,query,response
```jsonl
{"system": "00000", "query": "11111", "response": "22222"}
{"query": "aaaaa", "response": "bbbbb"}
{"query": "AAAAA", "response": "BBBBB"}
{"system": "00001", "query": "AAAAA", "response": "BBBBB"}
```

多轮对话:

```jsonl
{"system": "00000", "query": "55555", "response": "66666"}
{"query": "eeeee", "response": "fffff", "history": []}
{"query": "EEEEE", "response": "FFFFF", "history": [["query1", "response1"], ["query2", "response2"]]}
{"system": "00001", "query": "EEEEE", "response": "FFFFF", "history": [["query1", "response1"], ["query2", "response2"]]}
```

```json
[{"system": "00000", "query": "55555", "response": "66666"},
{"query": "eeeee", "response": "fffff", "history": []},
{"query": "EEEEE", "response": "FFFFF", "history": [["query1", "response1"], ["query2", "response2"]]}]
{"system": "00001", "query": "EEEEE", "response": "FFFFF", "history": [["query1", "response1"], ["query2", "response2"]]}]
```

**格式2:**
Expand All @@ -102,7 +102,7 @@ system,query,response
```jsonl
{"system": "00000", "conversation": [{"human": "11111", "assistant": "22222"}]}
{"conversation": [{"human": "aaaaa", "assistant": "bbbbb"}]}
{"conversation": [{"human": "AAAAA", "assistant": "BBBBB"}, {"human": "CCCCC", "assistant": "DDDDD"}, {"human": "EEEEE", "assistant": "FFFFF"}]}
{"system": "00001", "conversation": [{"human": "AAAAA", "assistant": "BBBBB"}, {"human": "CCCCC", "assistant": "DDDDD"}, {"human": "EEEEE", "assistant": "FFFFF"}]}
```

**格式5:**
Expand Down
8 changes: 4 additions & 4 deletions docs/source_en/LLM/Customization.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,21 +58,21 @@ system,query,response
```jsonl
{"system": "00000", "query": "11111", "response": "22222"}
{"query": "aaaaa", "response": "bbbbb"}
{"query": "AAAAA", "response": "BBBBB"}
{"system": "00001", "query": "AAAAA", "response": "BBBBB"}
```

Multi-Round Dialogue

```jsonl
{"system": "00000", "query": "55555", "response": "66666"}
{"query": "eeeee", "response": "fffff", "history": []}
{"query": "EEEEE", "response": "FFFFF", "history": [["query1", "response1"], ["query2", "response2"]]}
{"system": "00001", "query": "EEEEE", "response": "FFFFF", "history": [["query1", "response1"], ["query2", "response2"]]}
```

```json
[{"system": "00000", "query": "55555", "response": "66666"},
{"query": "eeeee", "response": "fffff", "history": []},
{"query": "EEEEE", "response": "FFFFF", "history": [["query1", "response1"], ["query2", "response2"]]}]
{"system": "00001", "query": "EEEEE", "response": "FFFFF", "history": [["query1", "response1"], ["query2", "response2"]]}]
```

**Format 2:**
Expand All @@ -96,7 +96,7 @@ Multi-Round Dialogue
```jsonl
{"system": "00000", "conversation": [{"human": "11111", "assistant": "22222"}]}
{"conversation": [{"human": "aaaaa", "assistant": "bbbbb"}]}
{"conversation": [{"human": "AAAAA", "assistant": "BBBBB"}, {"human": "CCCCC", "assistant": "DDDDD"}, {"human": "EEEEE", "assistant": "FFFFF"}]}
{"system": "00001", "conversation": [{"human": "AAAAA", "assistant": "BBBBB"}, {"human": "CCCCC", "assistant": "DDDDD"}, {"human": "EEEEE", "assistant": "FFFFF"}]}
```

**Format 5:**
Expand Down
2 changes: 1 addition & 1 deletion docs/source_en/LLM/Megatron-training.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Megatron Training Documentation (Beta)
# Megatron Training Documentation

## Table of Contents
- [Environment Preparation](#Environment-Preparation)
Expand Down
776 changes: 388 additions & 388 deletions docs/source_en/LLM/Supported-models-datasets.md

Large diffs are not rendered by default.

18 changes: 10 additions & 8 deletions scripts/utils/run_model_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ def get_model_info_table():
model_name_list = ModelType.get_model_name_list()
result = [
'| Model Type | Model ID | Default Lora Target Modules | Default Template |'
' Support Flash Attn | Support vLLM | Support LMDeploy | Requires | Tags | HF Model ID |\n'
' Support Flash Attn | Support vLLM | Support LMDeploy | Support Megatron | Requires | Tags | HF Model ID |\n'
'| --------- | -------- | --------------------------- | ---------------- |'
' ------------------ | ------------ | ----------------- | -------- | ---- | ----------- |\n'
' ------------------ | ------------ | ---------------- | ---------------- | -------- | ---- | ----------- |\n'
] * 2
res_llm: List[Any] = []
res_mllm: List[Any] = []
Expand All @@ -31,6 +31,8 @@ def get_model_info_table():
support_vllm = bool_mapping[support_vllm]
support_lmdeploy = model_info.get('support_lmdeploy', False)
support_lmdeploy = bool_mapping[support_lmdeploy]
support_megatron = model_info.get('support_megatron', False)
support_megatron = bool_mapping[support_megatron]
requires = ', '.join(model_info['requires'])
tags = model_info.get('tags', [])
if 'multi-modal' in tags:
Expand All @@ -46,7 +48,7 @@ def get_model_info_table():
hf_model_id = '-'
r = [
model_name, model_id, lora_target_modules, template, support_flash_attn, support_vllm, support_lmdeploy,
requires, tags_str, hf_model_id
support_megatron, requires, tags_str, hf_model_id
]
if is_multi_modal:
res_mllm.append(r)
Expand All @@ -57,13 +59,13 @@ def get_model_info_table():
for i, res in enumerate([res_llm, res_mllm]):
for r in res:
ms_url = f'https://modelscope.cn/models/{r[1]}/summary'
if r[9] != '-':
hf_url = f'https://huggingface.co/{r[9]}'
hf_model_id_str = f'[{r[9]}]({hf_url})'
if r[10] != '-':
hf_url = f'https://huggingface.co/{r[10]}'
hf_model_id_str = f'[{r[10]}]({hf_url})'
else:
hf_model_id_str = '-'
text[i] += (
f'|{r[0]}|[{r[1]}]({ms_url})|{r[2]}|{r[3]}|{r[4]}|{r[5]}|{r[6]}|{r[7]}|{r[8]}|{hf_model_id_str}|\n')
text[i] += (f'|{r[0]}|[{r[1]}]({ms_url})|{r[2]}|{r[3]}|{r[4]}|{r[5]}|{r[6]}|{r[7]}|{r[8]}'
f'|{r[9]}|{hf_model_id_str}|\n')
result[i] += text[i]

for i, fpath in enumerate(fpaths):
Expand Down
2 changes: 1 addition & 1 deletion swift/llm/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ def llm_export(args: ExportArguments) -> None:
megatron_args = MegatronArguments(**res)
extra_args = megatron_args.parse_to_megatron()
patch_megatron(tokenizer)
convert_hf_to_megatron(model, extra_args, args.check_model_forward, args.torch_dtype)
convert_hf_to_megatron(model, extra_args, args.torch_dtype)
fpath = os.path.join(args.megatron_output_dir, 'export_args.json')
with open(fpath, 'w', encoding='utf-8') as f:
json.dump(check_json_format(args.__dict__), f, ensure_ascii=False, indent=2)
Expand Down
1 change: 0 additions & 1 deletion swift/llm/megatron/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ class MegatronMixin:
tensor_model_parallel_size: int = 1
pipeline_model_parallel_size: int = 1
seed: int = 42
transformer_impl: str = 'transformer_engine'
sequence_parallel: bool = False

apply_query_key_layer_scaling: bool = False # fp16
Expand Down
5 changes: 0 additions & 5 deletions swift/llm/megatron/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
def convert_hf_to_megatron(
hf_model,
extra_args: Dict[str, Any],
check_model_forward: bool = False,
save_torch_dtype: Optional[torch.dtype] = None,
) -> None:
from megatron.training.initialize import initialize_megatron
Expand All @@ -24,10 +23,6 @@ def convert_hf_to_megatron(
convert_module.convert_checkpoint_from_transformers_to_megatron(hf_model, mg_model, args)
if save_torch_dtype is not None:
mg_model.to(save_torch_dtype)
if check_model_forward and hasattr(convert_module, 'check_hf_mg_forward'):
if save_torch_dtype is not None:
hf_model.to(save_torch_dtype)
convert_module.check_hf_mg_forward(hf_model, mg_model, args)
convert_module.save_mgmodel(mg_model, args)


Expand Down
108 changes: 65 additions & 43 deletions swift/llm/megatron/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,63 +6,85 @@
MEGATRON_MODEL_MAPPING = {}


def register_megatron_model(model_type_list: List[str], convert_module: str, get_function: Optional[Callable] = None):
model_info = {'convert_module': convert_module}
def register_megatron_model(
model_type_list: List[str],
convert_module: str,
model_module: str, # GPTModel
config_cls, # transformer_config_cls
get_function: Optional[Callable] = None):
megatron_model_info = {
'convert_module': convert_module,
'model_module': model_module,
'config_cls': config_cls,
}
res_model_type_list = []
for model_type in model_type_list:
model_info = MODEL_MAPPING[model_type]
support_megatron = model_info.get('support_megatron', False)
if support_megatron:
res_model_type_list.append(model_type)
model_type_list = res_model_type_list

if get_function is not None:
model_info['get_function'] = get_function
megatron_model_info['get_function'] = get_function
for model_type in model_type_list:
MEGATRON_MODEL_MAPPING[model_type] = model_info
MEGATRON_MODEL_MAPPING[model_type] = megatron_model_info
return

def _register_model(get_function: Callable) -> Callable:
model_info['get_function'] = get_function
megatron_model_info['get_function'] = get_function
for model_type in model_type_list:
MEGATRON_MODEL_MAPPING[model_type] = model_info
MEGATRON_MODEL_MAPPING[model_type] = megatron_model_info
return get_function

return _register_model


qwen1half_model_type = [model_type for model_type in MODEL_MAPPING.keys() if model_type.startswith('qwen1half')]


@register_megatron_model(
[model_type for model_type in qwen1half_model_type if ('32b' not in model_type or '110b' not in model_type)],
'qwen.hf2mcore_qwen1_5_dense_mha', 'qwen1_5', 'QwenTransformerConfig')
@register_megatron_model([model_type for model_type in MODEL_MAPPING.keys() if model_type.startswith('qwen2')],
'qwen.hf2mcore_qwen2_dense_and_moe_gqa')
def get_qwen2_model(pre_process=True, post_process=True):
from megatron.training import get_args
from megatron.training.arguments import core_transformer_config_from_args
from megatron_patch.model.qwen2.transformer_config import Qwen2TransformerConfig
from megatron_patch.model.qwen2.layer_specs import (get_gpt_layer_local_spec,
get_gpt_layer_with_transformer_engine_spec)
from megatron_patch.model.qwen2.model import GPTModel

args = get_args()
config = core_transformer_config_from_args(args, Qwen2TransformerConfig)
use_te = args.transformer_impl == 'transformer_engine'

if use_te:
transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(args.num_experts, args.moe_grouped_gemm,
args.qk_layernorm)
else:
transformer_layer_spec = get_gpt_layer_local_spec(args.num_experts, args.moe_grouped_gemm, args.qk_layernorm)

model = GPTModel(
config=config,
transformer_layer_spec=transformer_layer_spec,
vocab_size=args.padded_vocab_size,
max_sequence_length=args.max_position_embeddings,
pre_process=pre_process,
post_process=post_process,
fp16_lm_cross_entropy=args.fp16_lm_cross_entropy,
parallel_output=True,
share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights,
position_embedding_type=args.position_embedding_type,
rotary_percent=args.rotary_percent,
rotary_base=args.rotary_base,
seq_len_interpolation_factor=args.rotary_seq_len_interpolation_factor)
return model
'qwen.hf2mcore_qwen2_dense_and_moe_gqa', 'qwen2', 'Qwen2TransformerConfig')
def get_model_provider(gpt_model_cls, transformer_config_cls, layer_spec_module):

def model_provider(pre_process=True, post_process=True):
from megatron.training import get_args
from megatron.training.arguments import core_transformer_config_from_args
args = get_args()
config = core_transformer_config_from_args(args, transformer_config_cls)
transformer_layer_spec = layer_spec_module.get_gpt_layer_with_transformer_engine_spec(
args.num_experts, args.moe_grouped_gemm, args.qk_layernorm)
model = gpt_model_cls(
config=config,
transformer_layer_spec=transformer_layer_spec,
vocab_size=args.padded_vocab_size,
max_sequence_length=args.max_position_embeddings,
pre_process=pre_process,
post_process=post_process,
fp16_lm_cross_entropy=args.fp16_lm_cross_entropy,
parallel_output=True,
share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights,
position_embedding_type=args.position_embedding_type,
rotary_percent=args.rotary_percent,
rotary_base=args.rotary_base,
seq_len_interpolation_factor=args.rotary_seq_len_interpolation_factor)
return model

return model_provider


def get_megatron_model_convert(model_type: str):
model_info = MEGATRON_MODEL_MAPPING[model_type]
model_provider = model_info['get_function']
convert_module = model_info['convert_module']
convert_module = importlib.import_module(convert_module)
model_module = model_info['model_module']
config_cls = model_info['config_cls']

gpt_model_cls = importlib.import_module(f'megatron_patch.model.{model_module}.model').GPTModel
transformer_config_cls = getattr(
importlib.import_module(f'megatron_patch.model.{model_module}.transformer_config'), config_cls)
layer_spec_module = importlib.import_module(f'megatron_patch.model.{model_module}.layer_specs')
model_provider = model_info['get_function'](gpt_model_cls, transformer_config_cls, layer_spec_module)
convert_module = importlib.import_module(f"toolkits.model_checkpoints_convertor.{model_info['convert_module']}")
return model_provider, convert_module
30 changes: 22 additions & 8 deletions swift/llm/megatron/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import shutil
import sys
from functools import partial, wraps
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Mapping, Optional

import torch
import torch.distributed as dist
Expand All @@ -29,7 +30,16 @@ def init_megatron_env() -> None:
'https://github.com/alibaba/Pai-Megatron-Patch', commit_hash='6fd5d050b240fd959f0ba69f1e9cd9a053e5a81d')
os.environ['PAI_MEGATRON_PATCH_PATH'] = megatron_patch_path
sys.path.append(os.environ['PAI_MEGATRON_PATCH_PATH'])
sys.path.append(os.path.join(megatron_patch_path, 'toolkits/model_checkpoints_convertor'))

# rename qwen1.5->qwen1_5 files
qwen1_5_folders = ['toolkits/model_checkpoints_convertor/qwen']
for folder in qwen1_5_folders:
dir_path = os.path.join(megatron_patch_path, folder)
for fname in os.listdir(dir_path):
old_path = os.path.join(dir_path, fname)
new_path = os.path.join(dir_path, fname.replace('qwen1.', 'qwen1_'))
if old_path != new_path:
shutil.move(old_path, new_path)


def patch_megatron(tokenizer):
Expand All @@ -53,14 +63,18 @@ def _initialize_distributed(*_args, **kwargs):

initialize._initialize_distributed = _initialize_distributed

_old_load_checkpoint = training.load_checkpoint
_old_load_state_dict = torch.nn.Module.load_state_dict

@wraps(_old_load_checkpoint)
def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', strict=False):
# default: strict=False
return _old_load_checkpoint(model, optimizer, opt_param_scheduler, load_arg=load_arg, strict=strict)
@wraps(_old_load_state_dict)
def _load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True, *args, **kwargs):
if strict:
keys = self.state_dict().keys() ^ state_dict.keys()
new_keys = [k for k in keys if not k.endswith('_extra_state')]
if keys and not new_keys:
strict = False
return _old_load_state_dict(self, state_dict, strict, *args, **kwargs)

training.load_checkpoint = load_checkpoint
torch.nn.Module.load_state_dict = _load_state_dict

_old_training_log = training.training_log

Expand Down
2 changes: 1 addition & 1 deletion swift/llm/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,10 @@ def llm_sft_megatron(args: SftArguments) -> Dict[str, Any]:

res = MegatronArguments.load_megatron_config(tokenizer.model_dir)
res.update(MegatronArguments.from_sft_args(args, train_dataset, val_dataset))
model_provider, _ = get_megatron_model_convert(args.model_type)
megatron_args = MegatronArguments(**res)
extra_args = megatron_args.parse_to_megatron()

model_provider, _ = get_megatron_model_convert(args.model_type)
train_valid_test_datasets_provider = partial(
_train_valid_test_datasets_provider, train_dataset=train_dataset, val_dataset=val_dataset, template=template)
train_valid_test_datasets_provider.is_distributed = True
Expand Down
1 change: 0 additions & 1 deletion swift/llm/utils/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -1522,7 +1522,6 @@ class ExportArguments(InferArguments):
hf_output_dir: Optional[str] = None
tp: int = 1
pp: int = 1
check_model_forward: bool = False

# The parameter has been defined in InferArguments.
# merge_lora, hub_token
Expand Down
Loading
Loading