Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support ppo #2783

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ You can contact us and communicate with us by adding our group:
- 🍊 **Lightweight Training**: Supports lightweight fine-tuning methods like LoRA, QLoRA, DoRA, LoRA+, ReFT, RS-LoRA, LLaMAPro, Adapter, GaLore, Q-Galore, LISA, UnSloth, Liger-Kernel.
- **Distributed Training**: Supports distributed data parallel (DDP), device_map simple model parallelism, DeepSpeed ZeRO2/ZeRO3, FSDP, and other distributed training techniques.
- **Quantization Training**: Supports training quantized models like BNB, AWQ, GPTQ, AQLM, HQQ, EETQ.
- **RLHF Training**: Supports human alignment training methods such as DPO, CPO, SimPO, ORPO, KTO, RM for both pure text and multi-modal large models.
- **RLHF Training**: Supports human alignment training methods such as DPO, CPO, SimPO, ORPO, KTO, RM, PPO for both pure text and multi-modal large models.
- 🍓 **Multi-Modal Training**: Supports training on different modalities like images, videos, and audio, for tasks like VQA, captioning, OCR, and grounding.
- **Interface Training**: Provides capabilities for training, inference, evaluation, quantization through an interface, completing the whole large model pipeline.
- **Plugin and Extension**: Supports custom model and dataset extensions, as well as customization of components like loss, metric, trainer, loss-scale, callback, optimizer.
Expand All @@ -83,7 +83,7 @@ You can contact us and communicate with us by adding our group:
- 🎉 2024.08.12: The SWIFT paper has been published on arXiv, and you can read it [here](https://arxiv.org/abs/2408.05517).
- 🔥 2024.08.05: Support for using [evalscope](https://github.com/modelscope/evalscope/) as a backend for evaluating large models and multimodal models.
- 🔥 2024.07.29: Support for using [vllm](https://github.com/vllm-project/vllm) and [lmdeploy](https://github.com/InternLM/lmdeploy) to accelerate inference for large models and multimodal models. When performing infer/deploy/eval, you can specify `--infer_backend vllm/lmdeploy`.
- 🔥 2024.07.24: Support for human preference alignment training for multimodal large models, including DPO/ORPO/SimPO/CPO/KTO/RM.
- 🔥 2024.07.24: Support for human preference alignment training for multimodal large models, including DPO/ORPO/SimPO/CPO/KTO/RM/PPO.
- 🔥 2024.02.01: Support for Agent training! The training algorithm is derived from [this paper](https://arxiv.org/pdf/2309.00986.pdf).


Expand Down
4 changes: 2 additions & 2 deletions README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
- 🍊 **轻量训练**:支持了LoRA、QLoRA、DoRA、LoRA+、ReFT、RS-LoRA、LLaMAPro、Adapter、GaLore、Q-Galore、LISA、UnSloth、Liger-Kernel等轻量微调方式。
- **分布式训练**:支持分布式数据并行(DDP)、device_map简易模型并行、DeepSpeed ZeRO2 ZeRO3、FSDP等分布式训练技术。
- **量化训练**:支持对BNB、AWQ、GPTQ、AQLM、HQQ、EETQ量化模型进行训练。
- **RLHF训练**:支持纯文本大模型和多模态大模型的DPO、CPO、SimPO、ORPO、KTO、RM等人类对齐训练方法
- **RLHF训练**:支持纯文本大模型和多模态大模型的DPO、CPO、SimPO、ORPO、KTO、RM、PPO等人类对齐训练方法
- 🍓 **多模态训练**:支持对图像、视频和语音不同模态模型进行训练,支持VQA、Caption、OCR、Grounding任务的训练。
- **界面训练**:以界面的方式提供训练、推理、评测、量化的能力,完成大模型的全链路。
- **插件化与拓展**:支持自定义模型和数据集拓展,支持对loss、metric、trainer、loss-scale、callback、optimizer等组件进行自定义。
Expand All @@ -78,7 +78,7 @@
- 🎉 2024.08.12: SWIFT论文已经发布到arXiv上,可以点击[这里](https://arxiv.org/abs/2408.05517)阅读。
- 🔥 2024.08.05: 支持使用[evalscope](https://github.com/modelscope/evalscope/)作为后端进行大模型和多模态模型的评测。
- 🔥 2024.07.29: 支持使用[vllm](https://github.com/vllm-project/vllm), [lmdeploy](https://github.com/InternLM/lmdeploy)对大模型和多模态大模型进行推理加速,在infer/deploy/eval时额外指定`--infer_backend vllm/lmdeploy`即可。
- 🔥 2024.07.24: 支持对多模态大模型进行人类偏好对齐训练,包括DPO/ORPO/SimPO/CPO/KTO/RM。
- 🔥 2024.07.24: 支持对多模态大模型进行人类偏好对齐训练,包括DPO/ORPO/SimPO/CPO/KTO/RM/PPO
- 🔥 2024.02.01: 支持Agent训练!训练算法源自这篇[论文](https://arxiv.org/pdf/2309.00986.pdf)。

## 🛠️ 安装
Expand Down
2 changes: 1 addition & 1 deletion docs/source/Customization/自定义数据集.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ query-response格式:

### RLHF

#### DPO/ORPO/CPO/SimPO/RM
#### DPO/ORPO/CPO/SimPO/RM/PPO

```jsonl
{"messages": [{"role": "system", "content": "你是个有用无害的助手"}, {"role": "user", "content": "告诉我明天的天气"}, {"role": "assistant", "content": "明天天气晴朗"}], "rejected_response": "我不知道"}
Expand Down
2 changes: 1 addition & 1 deletion docs/source/GetStarted/快速开始.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ ms-swift是魔搭社区提供的大模型与多模态大模型训练部署框架
- 🍊 轻量训练:支持了LoRA、QLoRA、DoRA、LoRA+、ReFT、RS-LoRA、LLaMAPro、Adapter、GaLore、Q-Galore、LISA、UnSloth、Liger-Kernel等轻量微调方式。
- 分布式训练:支持分布式数据并行(DDP)、device_map简易模型并行、DeepSpeed ZeRO2 ZeRO3、FSDP等分布式训练技术。
- 量化训练:支持对BNB、AWQ、GPTQ、AQLM、HQQ、EETQ量化模型进行训练。
- RLHF训练:支持纯文本大模型和多模态大模型的DPO、CPO、SimPO、ORPO、KTO、RM等人类对齐训练方法
- RLHF训练:支持纯文本大模型和多模态大模型的DPO、CPO、SimPO、ORPO、KTO、RM、PPO等人类对齐训练方法
- 🍓 多模态训练:支持对图像、视频和语音不同模态模型进行训练,支持VQA、Caption、OCR、Grounding任务的训练。
- 界面训练:以界面的方式提供训练、推理、评测、量化的能力,完成大模型的全链路。
- 插件化与拓展:支持自定义模型和数据集拓展,支持对loss、metric、trainer、loss-scale、callback、optimizer等组件进行自定义。
Expand Down
7 changes: 3 additions & 4 deletions docs/source/Instruction/ReleaseNote3.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@

## 待完成

1. RM/PPO能力3.0版本尚不支持,请使用2.6.1版本
2. 自定义数据集评测3.0版本尚不支持,请使用2.6.1版本
3. Megatron预训练能力3.0版本尚不支持,请使用2.6.1版本
4. 文档和README,尤其是英文部分暂时未更新完整
1. 自定义数据集评测3.0版本尚不支持,请使用2.6.1版本
2. Megatron预训练能力3.0版本尚不支持,请使用2.6.1版本
3. 文档和README暂时未更新完整
2 changes: 1 addition & 1 deletion docs/source/Instruction/命令行参数.md
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ Vera使用`target_modules`, `target_regex`, `modules_to_save`三个参数.
### RLHF参数
RLHF参数继承于[训练参数](#训练参数)

- 🔥rlhf_type: 对齐算法类型,支持`dpo`, `orpo`, `simpo`, `kto`, `cpo`
- 🔥rlhf_type: 对齐算法类型,支持`dpo`, `orpo`, `simpo`, `kto`, `cpo`, `rm`, `ppo`
- ref_model: DPO等算法中的原始对比模型
- ref_model_type: 同model_type
- ref_model_revision: 同model_revision
Expand Down
2 changes: 1 addition & 1 deletion docs/source_en/Customization/Custom-dataset.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ The following provides the recommended dataset format for ms-swift, where the sy

### RLHF

#### DPO/ORPO/CPO/SimPO/RM
#### DPO/ORPO/CPO/SimPO/RM/PPO

```jsonl
{"messages": [{"role": "system", "content": "You are a useful and harmless assistant"}, {"role": "user", "content": "Tell me tomorrow's weather"}, {"role": "assistant", "content": "Tomorrow's weather will be sunny"}], "rejected_response": "I don't know"}
Expand Down
2 changes: 1 addition & 1 deletion docs/source_en/GetStarted/Quick-start.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ ms-swift is a comprehensive training and deployment framework for large language
- 🍊 Lightweight Training: Supports lightweight fine-tuning methods like LoRA, QLoRA, DoRA, LoRA+, ReFT, RS-LoRA, LLaMAPro, Adapter, GaLore, Q-Galore, LISA, UnSloth, Liger-Kernel, and more.
- Distributed Training: Supports distributed data parallel (DDP), simple model parallelism via device_map, DeepSpeed ZeRO2 ZeRO3, FSDP, and other distributed training technologies.
- Quantization Training: Provides training for quantized models like BNB, AWQ, GPTQ, AQLM, HQQ, EETQ.
- RLHF Training: Supports human alignment training methods like DPO, CPO, SimPO, ORPO, KTO, RM for both text-based and multimodal large models.
- RLHF Training: Supports human alignment training methods like DPO, CPO, SimPO, ORPO, KTO, RM, PPO for both text-based and multimodal large models.
- 🍓 Multimodal Training: Capable of training models for different modalities such as images, videos, and audios; supports tasks like VQA (Visual Question Answering), Captioning, OCR (Optical Character Recognition), and Grounding.
- Interface-driven Training: Offers training, inference, evaluation, and quantization capabilities through an interface, enabling a complete workflow for large models.
- Plugins and Extensions: Allows customization and extension of models and datasets, and supports customizations for components like loss, metric, trainer, loss-scale, callback, optimizer, etc.
Expand Down
2 changes: 1 addition & 1 deletion docs/source_en/Instruction/Command-line-parameters.md
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ Training arguments include the [base arguments](#base-arguments), [Seq2SeqTraine

RLHF arguments inherit from the [training arguments](#training-arguments).

- 🔥rlhf_type: Alignment algorithm type, supports `dpo`, `orpo`, `simpo`, `kto`, `cpo`.
- 🔥rlhf_type: Alignment algorithm type, supports `dpo`, `orpo`, `simpo`, `kto`, `cpo`, `rm`, `ppo`.
- ref_model: Original comparison model in algorithms like DPO.
- ref_model_type: Same as model_type.
- ref_model_revision: Same as model_revision.
Expand Down
7 changes: 3 additions & 4 deletions docs/source_en/Instruction/ReleaseNote3.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ The parameters marked as compatible in version 2.0 have been entirely removed.

## Pending Tasks

1. RM/PPO capabilities are not supported in version 3.0. Please use version 2.6.1.
2. Custom dataset evaluation is not supported in version 3.0. Please use version 2.6.1.
3. Megatron pre-training capabilities are not supported in version 3.0. Please use version 2.6.1.
4. Documentation and README, especially the English portions, are temporarily incomplete and will be updated.
1. Custom dataset evaluation is not supported in version 3.0. Please use version 2.6.1.
2. Megatron pre-training capabilities are not supported in version 3.0. Please use version 2.6.1.
3. Documentation and README are temporarily incomplete and will be updated.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"\n",
"Are you ready? Let's begin the journey...\n",
"\n",
"中文版:https://modelscope.cn/notebook/share/ipynb/4340fdeb/self-cognition-sft.ipynb"
"中文版:https://modelscope.cn/notebook/share/ipynb/313f6116/self-cognition-sft.ipynb"
]
},
{
Expand Down
3 changes: 2 additions & 1 deletion swift/llm/argument/base_args/base_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ def _init_custom_register(self) -> None:
folder, fname = os.path.split(path)
sys.path.append(folder)
__import__(fname.rstrip('.py'))
logger.info(f'Successfully registered `{self.custom_register_path}`')
if self.custom_register_path:
logger.info(f'Successfully registered `{self.custom_register_path}`')

def _init_adapters(self):
if isinstance(self.adapters, str):
Expand Down
23 changes: 21 additions & 2 deletions swift/llm/argument/rlhf_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class RLHFArguments(TrainArguments):
desirable_weight (float): Weight for desirable outcomes in KTO. Default is 1.0.
undesirable_weight (float): Weight for undesirable outcomes in KTO. Default is 1.0.
"""
rlhf_type: Literal['dpo', 'orpo', 'simpo', 'kto', 'cpo', 'rm'] = 'dpo'
rlhf_type: Literal['dpo', 'orpo', 'simpo', 'kto', 'cpo', 'rm', 'ppo'] = 'dpo'
ref_model: Optional[str] = None
ref_model_type: Optional[str] = field(
default=None, metadata={'help': f'model_type choices: {list(MODEL_MAPPING.keys())}'})
Expand All @@ -42,19 +42,38 @@ class RLHFArguments(TrainArguments):
# KTO
desirable_weight: float = 1.0
undesirable_weight: float = 1.0
# PPO
reward_model: Optional[str] = None
reward_model_type: Optional[str] = field(
default=None, metadata={'help': f'model_type choices: {list(MODEL_MAPPING.keys())}'})
reward_model_revision: Optional[str] = None
local_rollout_forward_batch_size: int = 64
kl_coef: float = 0.05
cliprange: float = 0.2
vf_coef: float = 0.1
cliprange_value: float = 0.2
gamma: float = 1.0
lam: float = 0.95
num_sample_generations: int = 10

def __post_init__(self):
self._init_simpo()
self._init_ppo()
self._set_default()
super().__post_init__()

if self.rlhf_type not in ['cpo', 'orpo', 'rm'] and (self.train_type == 'full' or self.rlhf_type == 'ppo'):
if self.rlhf_type in ['dpo', 'kto'] and self.train_type == 'full' or self.rlhf_type == 'ppo':
self.ref_model = self.ref_model or self.model
self.ref_model_type = self.ref_model_type or self.model_type
self.ref_model_revision = self.ref_model_revision or self.model_revision
elif self.ref_model is not None:
raise ValueError('CPO/ORPO or LoRA training does not require a ref_model to be passed in.')

def _init_ppo(self):
self.response_length = self.max_new_tokens
self.num_ppo_epochs = self.num_train_epochs
# TODO: streaming, MLLM

def _init_simpo(self):
if self.rlhf_type != 'simpo':
return
Expand Down
2 changes: 1 addition & 1 deletion swift/llm/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@
from .model_arch import MODEL_ARCH_MAPPING, ModelArch, ModelKeys, MultiModelKeys, get_model_arch, register_model_arch
from .register import (MODEL_MAPPING, Model, ModelGroup, ModelMeta, fix_do_sample_warning, get_default_device_map,
get_default_torch_dtype, get_model_info_meta, get_model_name, get_model_tokenizer,
get_model_tokenizer_multimodal, get_model_tokenizer_with_flash_attn, get_model_with_value_head,
get_model_tokenizer_multimodal, get_model_tokenizer_with_flash_attn,
load_by_unsloth, register_model)
from .utils import HfConfigFactory, ModelInfo, git_clone_github, safe_snapshot_download
4 changes: 3 additions & 1 deletion swift/llm/model/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,11 @@ class LLMModelType:
mamba = 'mamba'
polylm = 'polylm'
aya = 'aya'

# bert
modern_bert = 'modern_bert'
bert = 'bert'
# reward model
reward_model = 'reward_model'


class MLLMModelType:
Expand Down
2 changes: 1 addition & 1 deletion swift/llm/model/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from . import (baai, baichuan, bert, codefuse, deepseek, gemma, glm, internlm, llama, llava, llm, mamba, microsoft,
minicpm, mistral, mllm, mplug, openbuddy, qwen, telechat, yi)
minicpm, mistral, mllm, mplug, openbuddy, qwen, telechat, yi, reward_model)
33 changes: 33 additions & 0 deletions swift/llm/model/model/reward_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from transformers import AutoConfig
from transformers import AutoModel
from swift.utils import get_logger
from ..constant import LLMModelType
from ..register import Model, ModelGroup, ModelMeta, get_model_tokenizer_from_local, register_model

logger = get_logger()


def get_model_tokenizer_reward_model(model_dir, *args, **kwargs):
model_config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
if 'AutoModel' in (getattr(model_config, 'auto_map', None) or {}):
kwargs['automodel_class'] = AutoModel
return get_model_tokenizer_from_local(model_dir, *args, **kwargs)


register_model(
ModelMeta(
LLMModelType.reward_model, [
ModelGroup([
Model('Qwen/Qwen2.5-Math-RM-72B', 'Qwen/Qwen2.5-Math-RM-72B'),
Model('Qwen/Qwen2-Math-RM-72B', 'Qwen/Qwen2-Math-RM-72B'),
]),
ModelGroup([
Model('Shanghai_AI_Laboratory/internlm2-1_8b-reward', 'internlm/internlm2-1_8b-reward'),
Model('Shanghai_AI_Laboratory/internlm2-7b-reward', 'internlm/internlm2-7b-reward'),
Model('Shanghai_AI_Laboratory/internlm2-20b-reward', 'internlm/internlm2-20b-reward'),
]),
],
None,
get_model_tokenizer_reward_model,
tags=['reward_model']))
56 changes: 0 additions & 56 deletions swift/llm/model/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,62 +196,6 @@ def get_model_tokenizer_from_local(model_dir: str,
return model, tokenizer


def get_model_with_value_head(model) -> 'AutoModelForCausalLMWithValueHead':
from trl import AutoModelForCausalLMWithValueHead
lm_head_namings = ['lm_head', 'embed_out']
if not any(hasattr(model, attribute) for attribute in lm_head_namings):
setattr(model, 'lm_head', None) # avoid ValueError

model = AutoModelForCausalLMWithValueHead.from_pretrained(model)

def patch_valuehead_model(model):
attr_list = [
'get_input_embeddings', 'vis_processor', 'extract_feature', 'get_rope_index', 'model', 'vision_tower',
'img2emb', '_encode_image', '_merge_input_ids_with_image_features', 'prepare_inputs_embeds',
'build_conversation_input_ids', 'config', 'get_slice_image_placeholder', 'transform', 'get_vllm_embedding',
'forward_image', 'dtype', 'base_model_prefix', 'device', 'visual'
]
for attr in attr_list:
if hasattr(model.pretrained_model, attr) and not hasattr(model, attr):
setattr(model, attr, getattr(model.pretrained_model, attr))

# PPO compatible
if not hasattr(model, 'score'):
setattr(model, 'score', model.v_head)
if model.base_model_prefix == '' and hasattr(model.pretrained_model, 'language_model'):
model.base_model_prefix = model.pretrained_model.language_model.base_model_prefix

base_model_prefix = model.pretrained_model.base_model_prefix
if hasattr(model.pretrained_model, base_model_prefix):
setattr(model, base_model_prefix, getattr(model.pretrained_model, base_model_prefix))

patch_valuehead_model(model)

# try to load local vhead weights
vhead_params = None
try:
from safetensors import safe_open
vhead_file = os.path.join(model.pretrained_model.model_dir, 'value_head.safetensors')
with safe_open(vhead_file, framework='pt', device='cpu') as f:
vhead_params = {key: f.get_tensor(key) for key in f.keys()}
except Exception:
pass

try:
vhead_file = os.path.join(model.pretrained_model.model_dir, 'value_head.bin')
vhead_params = torch.load(vhead_file, map_location='cpu')
except Exception:
pass

if vhead_params is not None:
model.load_state_dict(vhead_params, strict=False)
logger.info(f'Loading value head weights from {vhead_file}')
else:
logger.info('The local value head weight file was not detected.'
'Ignore it if this is during the reward modeling phase,')
return model


def get_model_tokenizer_with_flash_attn(model_dir: str,
model_info: ModelInfo,
model_kwargs: Dict[str, Any],
Expand Down
11 changes: 0 additions & 11 deletions swift/llm/train/rlhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,6 @@ def _prepare_template(self) -> None:
# Avoid padding labels during the model's forward pass in multimodal models.
self.template.loss_scale = 'last_round'

@classmethod
def prepare_model(cls, args, model, *_args, **kwargs):
model = super().prepare_model(args, model, *_args, **kwargs)
if args.rlhf_type == 'rm':
from trl import AutoModelForCausalLMWithValueHead
lm_head_namings = ['lm_head', 'embed_out']
if not any(hasattr(model, attribute) for attribute in lm_head_namings):
model.lm_head = None # avoid error
model = AutoModelForCausalLMWithValueHead.from_pretrained(model)
patch_getattr(AutoModelForCausalLMWithValueHead, 'pretrained_model')
return model

def _get_dataset(self):
args = self.args
Expand Down
Loading