From 6d40dd11d9b9324c89e81c67f6ac113f640294d0 Mon Sep 17 00:00:00 2001 From: Jintao Date: Tue, 23 Apr 2024 17:44:43 +0800 Subject: [PATCH] fix bugs (#776) --- README.md | 5 +- README_CN.md | 5 +- ROADMAP.md | 56 ------------------- ...11\344\270\216\346\213\223\345\261\225.md" | 8 +-- docs/source_en/LLM/Customization.md | 2 +- .../lora_ddp_ds/sft.sh | 2 +- .../lora_mp_ddp/sft.sh | 2 +- swift/llm/utils/argument.py | 4 +- swift/llm/utils/dataset.py | 34 +++++++++-- swift/llm/utils/model.py | 1 + swift/llm/utils/preprocess.py | 24 ++++++-- swift/trainers/trainers.py | 35 +----------- tests/llm/data/alpaca.jsonl | 2 +- 13 files changed, 70 insertions(+), 110 deletions(-) delete mode 100644 ROADMAP.md diff --git a/README.md b/README.md index eea780265..ea8c3e53b 100644 --- a/README.md +++ b/README.md @@ -292,6 +292,7 @@ swift sft \ ``` #### Deepspeed Training +Deepspeed supports training of quantized GPTQ and AWQ models. ZeRO2: ```shell @@ -432,6 +433,7 @@ CUDA_VISIBLE_DEVICES=0 swift deploy \ ``` ### Supported Models +The complete list of supported models and datasets can be found at [Supported Models and Datasets List](https://idealab.alibaba-inc.com/docs/source/LLM/Supported-Models-and-Datasets.md). #### LLMs @@ -470,7 +472,8 @@ CUDA_VISIBLE_DEVICES=0 swift deploy \ | c4ai-command-r | [c4ai](https://cohere.com/command) | Multilingual | 35B-104B | chat model | | WizardLM2 | [WizardLM2 series models](https://github.com/nlpxucan/WizardLM) | English | 7B-8x22B
including quantized versions | chat model
MoE model | | Atom | [Atom](https://github.com/LlamaFamily/Llama-Chinese) | Chinese | 7B| base model
chat model| -| Chinese-LLaMA-Alpaca-2 | [Chinese-LLaMA-Alpaca-2](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2) | Chinese | 1.3B-13B| base model
chat model
long text model| +| Chinese-LLaMA-Alpaca-2 | [Chinese-LLaMA-Alpaca-2](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2) | Chinese | 1.3B-13B| base model
chat model
long text model | +| ModelScope-Agent | [ModelScope Agent series models](https://github.com/modelscope/modelscope-agent) | Chinese | 7B-14B| agent model | #### MLLMs diff --git a/README_CN.md b/README_CN.md index d25465a54..a9387db3b 100644 --- a/README_CN.md +++ b/README_CN.md @@ -290,6 +290,7 @@ swift sft \ ``` #### Deepspeed训练 +Deepspeed支持对GPTQ和AWQ量化模型进行训练. ZeRO2: ```shell @@ -429,6 +430,7 @@ CUDA_VISIBLE_DEVICES=0 swift deploy \ ``` ### 支持的模型 +完整的支持模型和数据集可以查看[支持的模型和数据集列表](docs/source/LLM/支持的模型和数据集.md). #### 大语言模型 @@ -467,7 +469,8 @@ CUDA_VISIBLE_DEVICES=0 swift deploy \ | c4ai-command-r | [c4ai](https://cohere.com/command) | 多语种 | 35B-104B | chat模型 | | WizardLM2 | [WizardLM2系列模型](https://github.com/nlpxucan/WizardLM) | 多语种 | 7B-8x22B
包含量化版本 | chat模型
MoE模型 | | Atom | [Atom](https://github.com/LlamaFamily/Llama-Chinese) | 中文 | 7B| base模型
chat模型| -| Chinese-LLaMA-Alpaca-2 | [Chinese-LLaMA-Alpaca-2](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2) | 中文 | 1.3B-13B| base模型
chat模型
长文本模型| +| Chinese-LLaMA-Alpaca-2 | [Chinese-LLaMA-Alpaca-2](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2) | 中文 | 1.3B-13B| base模型
chat模型
长文本模型 | +| ModelScope-Agent | [ModelScope Agent系列](https://github.com/modelscope/modelscope-agent) | 中文 | 7B-14B| agent模型 | #### 多模态大模型 diff --git a/ROADMAP.md b/ROADMAP.md deleted file mode 100644 index 65297907e..000000000 --- a/ROADMAP.md +++ /dev/null @@ -1,56 +0,0 @@ -# SWIFT V1.7 Features - -The development of SWIFT V1.7 is between Feb/1/2024 and Feb/29/2024 ideally. - -## Data processing - -- Support dataset processing to improve dataset performance. @[tastelikefeet](https://github.com/modelscope/swift/issues?q=is%3Apr+author%3Atastelikefeet) -- Support more dataset, including pretraining dataset @[Jintao-Huang](https://github.com/modelscope/swift/issues?q=is%3Apr+author%3AJintao-Huang) - -## Model Evaluation - -- Support model evaluation @[Jintao-Huang](https://github.com/modelscope/swift/issues?q=is%3Apr+author%3AJintao-Huang) -- Do evaluation of important tuners and techniques @[tastelikefeet](https://github.com/modelscope/swift/issues?q=is%3Apr+author%3Atastelikefeet) - -## Exporting - -- Support model exporting to GPTQ&AWQ(Maybe CPP) @[Jintao-Huang](https://github.com/modelscope/swift/issues?q=is%3Apr+author%3AJintao-Huang) - -## Deployment - -- Support the deployment of vanilla PyTorch code @[Jintao-Huang](https://github.com/modelscope/swift/issues?q=is%3Apr+author%3AJintao-Huang) -- Support VLLM+AWQ/Vanilla AWQ @[Jintao-Huang](https://github.com/modelscope/swift/issues?q=is%3Apr+author%3AJintao-Huang) - -## WEB-UI - -- Support DPO @[slin000111](https://github.com/slin000111) -- Support deployment(instead of inference) @[slin000111](https://github.com/slin000111) -- Support VL/Audio models @[slin000111](https://github.com/slin000111) -- Support task management @[slin000111](https://github.com/slin000111) -- Support an alternative port of tensorboard @[slin000111](https://github.com/slin000111) - -## Training - -- Support FSDP @[Jintao-Huang](https://github.com/modelscope/swift/issues?q=is%3Apr+author%3AJintao-Huang) -- RRHF dataset improvement @[tastelikefeet](https://github.com/modelscope/swift/issues?q=is%3Apr+author%3Atastelikefeet) - -## SD - -- LCM Training @[slin000111](https://github.com/slin000111) - -## Multi Modal - -- Support more models and datasets @[Jintao-Huang](https://github.com/modelscope/swift/issues?q=is%3Apr+author%3AJintao-Huang) -- Documentation @[Jintao-Huang](https://github.com/modelscope/swift/issues?q=is%3Apr+author%3AJintao-Huang) - -## To be Assigned - -- *Support ShengTeng GPU when training and inference* -- *Support windows for web-ui* -- *Support More LLM Models*: - - codefuse-ai/CodeFuse-DeepSeek-33B - - codefuse-ai/CodeFuse-13B - - 01ai/Yi-34B-Chat-4bits - - 01ai/Yi-34B-Chat-8bits - - 01ai/Yi-6B-Chat-4bits - - 01ai/Yi-6B-Chat-8bits diff --git "a/docs/source/LLM/\350\207\252\345\256\232\344\271\211\344\270\216\346\213\223\345\261\225.md" "b/docs/source/LLM/\350\207\252\345\256\232\344\271\211\344\270\216\346\213\223\345\261\225.md" index 681515972..2c4a6b793 100644 --- "a/docs/source/LLM/\350\207\252\345\256\232\344\271\211\344\270\216\346\213\223\345\261\225.md" +++ "b/docs/source/LLM/\350\207\252\345\256\232\344\271\211\344\270\216\346\213\223\345\261\225.md" @@ -26,11 +26,11 @@ 2. `--custom_val_dataset_path`: 默认值为`[]`, 表示不使用自定义验证数据集. 如果你指定了`custom_train_dataset_path`, 则自定义数据集的验证集将按照命令行参数`dataset_test_ratio`进行切割. -脚本支持的文件格式包含`csv`, `json`, `jsonl`格式. 你需要将传入的文件符合以下数据集格式. csv格式的文件只支持指令微调, 即没有history的情况. json, jsonl格式的文件支持system, history. +脚本支持的文件格式包含`csv`, `json`, `jsonl`格式. 你需要将传入的文件符合以下数据集格式. 以下格式都支持system. `json`, `jsonl`格式的文件支持多轮对话 (`csv`不支持). **格式1:** -Pre-Training +预训练: ```csv response @@ -45,7 +45,7 @@ AAAAA {"response": "AAAAA"} ``` -Single-Round Dialogue +单轮对话: ```csv query,response @@ -60,7 +60,7 @@ AAAAA,BBBBB {"query": "AAAAA", "response": "BBBBB"} ``` -Multi-Round Dialogue +多轮对话: ```jsonl {"query": "55555", "response": "66666"} diff --git a/docs/source_en/LLM/Customization.md b/docs/source_en/LLM/Customization.md index 1cdf6a8cc..6482da73d 100644 --- a/docs/source_en/LLM/Customization.md +++ b/docs/source_en/LLM/Customization.md @@ -26,7 +26,7 @@ The corresponding example sh script can be found [here](https://github.com/model 2. `--custom_val_dataset_path`: The default value is `[]`, indicating not to use a custom validation dataset. If you specify `custom_train_dataset_path`, then the validation set of the custom dataset will be split according to the command line argument `dataset_test_ratio`. -The script supports file formats including `csv`, `json`, and `jsonl`. You need to ensure the passed in files conform to the following dataset formats. csv files only support instruction tuning, i.e. the case without history. json and jsonl files support system and history. +The supported file formats for the script include `csv`, `json`, and `jsonl`. You need to ensure that the incoming files conform to the following dataset formats. Both `json` and `jsonl` formats support multi-turn dialogues (`csv` does not support this). **Format 1:** diff --git a/examples/pytorch/llm/scripts/openbuddy_mistral_7b_chat/lora_ddp_ds/sft.sh b/examples/pytorch/llm/scripts/openbuddy_mistral_7b_chat/lora_ddp_ds/sft.sh index 4c6af7048..3c838d963 100644 --- a/examples/pytorch/llm/scripts/openbuddy_mistral_7b_chat/lora_ddp_ds/sft.sh +++ b/examples/pytorch/llm/scripts/openbuddy_mistral_7b_chat/lora_ddp_ds/sft.sh @@ -8,7 +8,7 @@ torchrun \ --nproc_per_node=$nproc_per_node \ --master_port 29500 \ llm_sft.py \ - --model_id_or_path OpenBuddy/openbuddy-mistral-7b-v13.1 \ + --model_id_or_path OpenBuddy/openbuddy-mistral-7b-v17.1-32k \ --model_revision master \ --sft_type lora \ --tuner_backend peft \ diff --git a/examples/pytorch/llm/scripts/openbuddy_mistral_7b_chat/lora_mp_ddp/sft.sh b/examples/pytorch/llm/scripts/openbuddy_mistral_7b_chat/lora_mp_ddp/sft.sh index 3ed5e6055..eb3dcaa1c 100644 --- a/examples/pytorch/llm/scripts/openbuddy_mistral_7b_chat/lora_mp_ddp/sft.sh +++ b/examples/pytorch/llm/scripts/openbuddy_mistral_7b_chat/lora_mp_ddp/sft.sh @@ -8,7 +8,7 @@ torchrun \ --nproc_per_node=$nproc_per_node \ --master_port 29500 \ llm_sft.py \ - --model_id_or_path OpenBuddy/openbuddy-mistral-7b-v13.1 \ + --model_id_or_path OpenBuddy/openbuddy-mistral-7b-v17.1-32k \ --model_revision master \ --sft_type lora \ --tuner_backend peft \ diff --git a/swift/llm/utils/argument.py b/swift/llm/utils/argument.py index 29d684071..807cd7186 100644 --- a/swift/llm/utils/argument.py +++ b/swift/llm/utils/argument.py @@ -309,7 +309,7 @@ class SftArguments(ArgumentsBase): }) output_dir: str = 'output' add_output_dir_suffix: Optional[bool] = None - ddp_backend: Literal['nccl', 'gloo', 'mpi', 'ccl'] = None + ddp_backend: Optional[Literal['nccl', 'gloo', 'mpi', 'ccl']] = None ddp_find_unused_parameters: Optional[bool] = None ddp_broadcast_buffers: Optional[bool] = None @@ -658,6 +658,8 @@ def __post_init__(self) -> None: else: torch.cuda.set_device(local_rank) self.seed += rank # Avoid the same dropout + if self.ddp_backend is None: + self.ddp_backend = 'nccl' if self.ddp_backend == 'gloo' and self.quantization_bit != 0: raise ValueError('not supported, please use `nccl`') diff --git a/swift/llm/utils/dataset.py b/swift/llm/utils/dataset.py index 413fd9a6e..4ac6757b2 100644 --- a/swift/llm/utils/dataset.py +++ b/swift/llm/utils/dataset.py @@ -67,6 +67,8 @@ class DatasetName: open_orca_gpt4 = 'open-orca-gpt4' sharegpt_gpt4 = 'sharegpt-gpt4' sharegpt_gpt4_mini = 'sharegpt-gpt4-mini' + deepctrl_sft_zh = 'deepctrl-sft-zh' + deepctrl_sft_en = 'deepctrl-sft-en' # agent ms_agent = 'ms-agent' ms_agent_for_agentfabric_default = 'ms-agent-for-agentfabric-default' @@ -549,7 +551,7 @@ def _preprocess_aishell1_dataset(dataset: HfDataset) -> HfDataset: def _repair_agent_conversations(conversations: str, - use_mini: bool) -> Dict[str, str]: + use_mini: bool) -> List[Dict[str, str]]: if use_mini: pattern = r'\d\. {"plugin_name": "(.+?)"' else: @@ -562,13 +564,14 @@ def _repair_agent_conversations(conversations: str, find_list = re.findall(pattern, conversations[:idx]) if len(set(find_list)) <= 1: return - conversations = ast.literal_eval(conversations) + if isinstance(conversations, str): + conversations = ast.literal_eval(conversations) if len(conversations) == 1: return return conversations -def _repair_ms_bench(conversations: str) -> Dict[str, str]: +def _repair_ms_bench(conversations: str) -> List[Dict[str, str]]: if isinstance(conversations, str): conversations = ast.literal_eval(conversations) default_system = 'You are a helpful assistant.' @@ -684,6 +687,22 @@ def map_row(row): get_dataset_from_repo, tags=['chat', 'agent', 'multi-round']) +register_dataset( + DatasetName.deepctrl_sft_zh, + 'AI-ModelScope/deepctrl-sft-data', [['default', 'train']], + None, + SmartPreprocessor(), + get_dataset_from_repo, + tags=['chat', 'general', 'sft', 'multi-round']) + +register_dataset( + DatasetName.deepctrl_sft_en, + 'AI-ModelScope/deepctrl-sft-data', [['en', 'train']], + None, + SmartPreprocessor(), + get_dataset_from_repo, + tags=['chat', 'general', 'sft', 'multi-round']) + advertise_gen_prompt = """Task: Generating advertisements based on keywords. Keywords: {query} Advertisements:""" @@ -1066,7 +1085,8 @@ def _preprocess_sharegpt(dataset: HfDataset) -> HfDataset: response = [] history: List[History] = [] for d in tqdm(dataset): - conversation = ast.literal_eval(d['conversation']) + if isinstance(d['conversation'], str): + conversation = ast.literal_eval(d['conversation']) query.append(conversation[-1]['human']) response.append(conversation[-1]['assistant']) h = [] @@ -1316,9 +1336,11 @@ def _preprocess_leetcode_python(dataset: HfDataset) -> HfDataset: ] -def _repair_conversations_agent_instruct(s: str) -> str: +def _repair_conversations_agent_instruct(s: str) -> List[Dict[str, Any]]: s = s.replace('}\n {', '},\n {') - return ast.literal_eval(s) + if isinstance(s, str): + s = ast.literal_eval(s) + return s register_dataset( diff --git a/swift/llm/utils/model.py b/swift/llm/utils/model.py index 347e06e8f..dec0c8c09 100644 --- a/swift/llm/utils/model.py +++ b/swift/llm/utils/model.py @@ -2377,6 +2377,7 @@ def _git_clone_github(github_url: str, command = f'git -C {git_cache_dir} clone {github_url} {local_repo_name}' logger.info(f'Run the command: `{command}`') os.system(command) + logger.info(f'local_repo_path: {local_repo_path}') return local_repo_path diff --git a/swift/llm/utils/preprocess.py b/swift/llm/utils/preprocess.py index d59af1892..41dcc606b 100644 --- a/swift/llm/utils/preprocess.py +++ b/swift/llm/utils/preprocess.py @@ -39,9 +39,16 @@ def __init__(self, def __call__(self, dataset: HfDataset) -> HfDataset: query: List[str] = [] response = [] - for d in tqdm(dataset): - inst, inp, output = d['instruction'], d.get('input', - None), d['output'] + system = None + history = None + for i, d in tqdm(enumerate(dataset)): + inst, inp = d['instruction'], d.get('input', None) + h, output = d.pop('history', None), d['output'] + sys = d.pop('system', None) + if history is None and h is not None: + history = [None for _ in range(i - 1)] + if system is None and sys is not None: + system = [None for _ in range(i - 1)] if output is None: continue if inp is None or len(inp) == 0: @@ -52,7 +59,16 @@ def __call__(self, dataset: HfDataset) -> HfDataset: q = f'{inst}\n{inp}' query.append(q) response.append(output) - dataset = HfDataset.from_dict({'query': query, 'response': response}) + if history is not None: + history.append(h) + if system is not None: + system.append(sys) + d_dict = {'query': query, 'response': response} + if history is not None: + d_dict['history'] = history + if system is not None: + d_dict['system'] = system + dataset = HfDataset.from_dict(d_dict) return dataset diff --git a/swift/trainers/trainers.py b/swift/trainers/trainers.py index 4ef06733b..d73b772fe 100644 --- a/swift/trainers/trainers.py +++ b/swift/trainers/trainers.py @@ -268,40 +268,9 @@ def compute_loss(self, model, inputs, return_outputs=None): def get_train_dataloader(self): - def __iter__(self): - self._num_yielded = 0 - if self._iterator is None: - self._iterator = self.__original_iter__() - return self - - def __next__(self): - if self._num_yielded >= len(self): - raise StopIteration - self._num_yielded += 1 - try: - return next(self._iterator) - except StopIteration: - self._iterator = self.__original_iter__() - return next(self._iterator) - if not use_torchacc(): - origin_loader = super().get_train_dataloader() - grad_acc_steps = self.args.gradient_accumulation_steps - if grad_acc_steps is None or grad_acc_steps <= 1: - return origin_loader - - length = len(origin_loader) // grad_acc_steps * grad_acc_steps - origin_loader_type = type(origin_loader) - loader = type( - origin_loader_type.__name__, (origin_loader_type, ), { - '__len__': lambda _: length, - '__iter__': __iter__, - '__next__': __next__ - })( - origin_loader.dataset) - loader.__dict__.update(origin_loader.__dict__) - loader.__original_iter__ = origin_loader.__iter__ - return loader + return super().get_train_dataloader() + else: if trainer.is_datasets_available(): import datasets diff --git a/tests/llm/data/alpaca.jsonl b/tests/llm/data/alpaca.jsonl index 661be4216..89802b51b 100644 --- a/tests/llm/data/alpaca.jsonl +++ b/tests/llm/data/alpaca.jsonl @@ -1,3 +1,3 @@ -{"instruction": "11111", "input": "22222", "output": "33333"} +{"instruction": "11111", "input": "22222", "output": "33333", "history": [["aaaaa", "bbbbb"]], "system": "system123"} {"instruction": "aaaaa", "output": "ccccc"} {"instruction": "AAAAA", "input": "BBBBB", "output": "CCCCC"}