diff --git a/README.md b/README.md index a572beaed..a3ed32bb4 100644 --- a/README.md +++ b/README.md @@ -67,7 +67,7 @@ You can contact us and communicate with us by adding our group: - 🍊 **Lightweight Training**: Supports lightweight fine-tuning methods like LoRA, QLoRA, DoRA, LoRA+, ReFT, RS-LoRA, LLaMAPro, Adapter, GaLore, Q-Galore, LISA, UnSloth, Liger-Kernel. - **Distributed Training**: Supports distributed data parallel (DDP), device_map simple model parallelism, DeepSpeed ZeRO2/ZeRO3, FSDP, and other distributed training techniques. - **Quantization Training**: Supports training quantized models like BNB, AWQ, GPTQ, AQLM, HQQ, EETQ. -- **RLHF Training**: Supports human alignment training methods such as DPO, CPO, SimPO, ORPO, KTO, RM for both pure text and multi-modal large models. +- **RLHF Training**: Supports human alignment training methods such as DPO, CPO, SimPO, ORPO, KTO, RM, PPO for both pure text and multi-modal large models. - 🍓 **Multi-Modal Training**: Supports training on different modalities like images, videos, and audio, for tasks like VQA, captioning, OCR, and grounding. - **Interface Training**: Provides capabilities for training, inference, evaluation, quantization through an interface, completing the whole large model pipeline. - **Plugin and Extension**: Supports custom model and dataset extensions, as well as customization of components like loss, metric, trainer, loss-scale, callback, optimizer. @@ -83,7 +83,7 @@ You can contact us and communicate with us by adding our group: - 🎉 2024.08.12: The SWIFT paper has been published on arXiv, and you can read it [here](https://arxiv.org/abs/2408.05517). - 🔥 2024.08.05: Support for using [evalscope](https://github.com/modelscope/evalscope/) as a backend for evaluating large models and multimodal models. - 🔥 2024.07.29: Support for using [vllm](https://github.com/vllm-project/vllm) and [lmdeploy](https://github.com/InternLM/lmdeploy) to accelerate inference for large models and multimodal models. When performing infer/deploy/eval, you can specify `--infer_backend vllm/lmdeploy`. -- 🔥 2024.07.24: Support for human preference alignment training for multimodal large models, including DPO/ORPO/SimPO/CPO/KTO/RM. +- 🔥 2024.07.24: Support for human preference alignment training for multimodal large models, including DPO/ORPO/SimPO/CPO/KTO/RM/PPO. - 🔥 2024.02.01: Support for Agent training! The training algorithm is derived from [this paper](https://arxiv.org/pdf/2309.00986.pdf). diff --git a/README_CN.md b/README_CN.md index c3796e0a5..fdfe9d5ed 100644 --- a/README_CN.md +++ b/README_CN.md @@ -64,7 +64,7 @@ - 🍊 **轻量训练**:支持了LoRA、QLoRA、DoRA、LoRA+、ReFT、RS-LoRA、LLaMAPro、Adapter、GaLore、Q-Galore、LISA、UnSloth、Liger-Kernel等轻量微调方式。 - **分布式训练**:支持分布式数据并行(DDP)、device_map简易模型并行、DeepSpeed ZeRO2 ZeRO3、FSDP等分布式训练技术。 - **量化训练**:支持对BNB、AWQ、GPTQ、AQLM、HQQ、EETQ量化模型进行训练。 -- **RLHF训练**:支持纯文本大模型和多模态大模型的DPO、CPO、SimPO、ORPO、KTO、RM等人类对齐训练方法。 +- **RLHF训练**:支持纯文本大模型和多模态大模型的DPO、CPO、SimPO、ORPO、KTO、RM、PPO等人类对齐训练方法。 - 🍓 **多模态训练**:支持对图像、视频和语音不同模态模型进行训练,支持VQA、Caption、OCR、Grounding任务的训练。 - **界面训练**:以界面的方式提供训练、推理、评测、量化的能力,完成大模型的全链路。 - **插件化与拓展**:支持自定义模型和数据集拓展,支持对loss、metric、trainer、loss-scale、callback、optimizer等组件进行自定义。 @@ -78,7 +78,7 @@ - 🎉 2024.08.12: SWIFT论文已经发布到arXiv上,可以点击[这里](https://arxiv.org/abs/2408.05517)阅读。 - 🔥 2024.08.05: 支持使用[evalscope](https://github.com/modelscope/evalscope/)作为后端进行大模型和多模态模型的评测。 - 🔥 2024.07.29: 支持使用[vllm](https://github.com/vllm-project/vllm), [lmdeploy](https://github.com/InternLM/lmdeploy)对大模型和多模态大模型进行推理加速,在infer/deploy/eval时额外指定`--infer_backend vllm/lmdeploy`即可。 -- 🔥 2024.07.24: 支持对多模态大模型进行人类偏好对齐训练,包括DPO/ORPO/SimPO/CPO/KTO/RM。 +- 🔥 2024.07.24: 支持对多模态大模型进行人类偏好对齐训练,包括DPO/ORPO/SimPO/CPO/KTO/RM/PPO。 - 🔥 2024.02.01: 支持Agent训练!训练算法源自这篇[论文](https://arxiv.org/pdf/2309.00986.pdf)。 ## 🛠️ 安装 diff --git "a/docs/source/Customization/\350\207\252\345\256\232\344\271\211\346\225\260\346\215\256\351\233\206.md" "b/docs/source/Customization/\350\207\252\345\256\232\344\271\211\346\225\260\346\215\256\351\233\206.md" index 16b54234b..df19d9939 100644 --- "a/docs/source/Customization/\350\207\252\345\256\232\344\271\211\346\225\260\346\215\256\351\233\206.md" +++ "b/docs/source/Customization/\350\207\252\345\256\232\344\271\211\346\225\260\346\215\256\351\233\206.md" @@ -53,7 +53,7 @@ query-response格式: ### RLHF -#### DPO/ORPO/CPO/SimPO/RM +#### DPO/ORPO/CPO/SimPO/RM/PPO ```jsonl {"messages": [{"role": "system", "content": "你是个有用无害的助手"}, {"role": "user", "content": "告诉我明天的天气"}, {"role": "assistant", "content": "明天天气晴朗"}], "rejected_response": "我不知道"} diff --git "a/docs/source/GetStarted/\345\277\253\351\200\237\345\274\200\345\247\213.md" "b/docs/source/GetStarted/\345\277\253\351\200\237\345\274\200\345\247\213.md" index c69597316..9715f248c 100644 --- "a/docs/source/GetStarted/\345\277\253\351\200\237\345\274\200\345\247\213.md" +++ "b/docs/source/GetStarted/\345\277\253\351\200\237\345\274\200\345\247\213.md" @@ -8,7 +8,7 @@ ms-swift是魔搭社区提供的大模型与多模态大模型训练部署框架 - 🍊 轻量训练:支持了LoRA、QLoRA、DoRA、LoRA+、ReFT、RS-LoRA、LLaMAPro、Adapter、GaLore、Q-Galore、LISA、UnSloth、Liger-Kernel等轻量微调方式。 - 分布式训练:支持分布式数据并行(DDP)、device_map简易模型并行、DeepSpeed ZeRO2 ZeRO3、FSDP等分布式训练技术。 - 量化训练:支持对BNB、AWQ、GPTQ、AQLM、HQQ、EETQ量化模型进行训练。 -- RLHF训练:支持纯文本大模型和多模态大模型的DPO、CPO、SimPO、ORPO、KTO、RM等人类对齐训练方法。 +- RLHF训练:支持纯文本大模型和多模态大模型的DPO、CPO、SimPO、ORPO、KTO、RM、PPO等人类对齐训练方法。 - 🍓 多模态训练:支持对图像、视频和语音不同模态模型进行训练,支持VQA、Caption、OCR、Grounding任务的训练。 - 界面训练:以界面的方式提供训练、推理、评测、量化的能力,完成大模型的全链路。 - 插件化与拓展:支持自定义模型和数据集拓展,支持对loss、metric、trainer、loss-scale、callback、optimizer等组件进行自定义。 diff --git a/docs/source/Instruction/ReleaseNote3.0.md b/docs/source/Instruction/ReleaseNote3.0.md index feda24658..85a8f7e8a 100644 --- a/docs/source/Instruction/ReleaseNote3.0.md +++ b/docs/source/Instruction/ReleaseNote3.0.md @@ -81,7 +81,6 @@ ## 待完成 -1. RM/PPO能力3.0版本尚不支持,请使用2.6.1版本 -2. 自定义数据集评测3.0版本尚不支持,请使用2.6.1版本 -3. Megatron预训练能力3.0版本尚不支持,请使用2.6.1版本 -4. 文档和README,尤其是英文部分暂时未更新完整 +1. 自定义数据集评测3.0版本尚不支持,请使用2.6.1版本 +2. Megatron预训练能力3.0版本尚不支持,请使用2.6.1版本 +3. 文档和README暂时未更新完整 diff --git "a/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" "b/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" index db8dd4201..1c652751f 100644 --- "a/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" +++ "b/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" @@ -304,7 +304,7 @@ Vera使用`target_modules`, `target_regex`, `modules_to_save`三个参数. ### RLHF参数 RLHF参数继承于[训练参数](#训练参数) -- 🔥rlhf_type: 对齐算法类型,支持`dpo`, `orpo`, `simpo`, `kto`, `cpo` +- 🔥rlhf_type: 对齐算法类型,支持`dpo`, `orpo`, `simpo`, `kto`, `cpo`, `rm`, `ppo` - ref_model: DPO等算法中的原始对比模型 - ref_model_type: 同model_type - ref_model_revision: 同model_revision diff --git a/docs/source_en/Customization/Custom-dataset.md b/docs/source_en/Customization/Custom-dataset.md index f7b883f46..3bb38cfe3 100644 --- a/docs/source_en/Customization/Custom-dataset.md +++ b/docs/source_en/Customization/Custom-dataset.md @@ -52,7 +52,7 @@ The following provides the recommended dataset format for ms-swift, where the sy ### RLHF -#### DPO/ORPO/CPO/SimPO/RM +#### DPO/ORPO/CPO/SimPO/RM/PPO ```jsonl {"messages": [{"role": "system", "content": "You are a useful and harmless assistant"}, {"role": "user", "content": "Tell me tomorrow's weather"}, {"role": "assistant", "content": "Tomorrow's weather will be sunny"}], "rejected_response": "I don't know"} diff --git a/docs/source_en/GetStarted/Quick-start.md b/docs/source_en/GetStarted/Quick-start.md index c410e4484..38e69b32e 100644 --- a/docs/source_en/GetStarted/Quick-start.md +++ b/docs/source_en/GetStarted/Quick-start.md @@ -8,7 +8,7 @@ ms-swift is a comprehensive training and deployment framework for large language - 🍊 Lightweight Training: Supports lightweight fine-tuning methods like LoRA, QLoRA, DoRA, LoRA+, ReFT, RS-LoRA, LLaMAPro, Adapter, GaLore, Q-Galore, LISA, UnSloth, Liger-Kernel, and more. - Distributed Training: Supports distributed data parallel (DDP), simple model parallelism via device_map, DeepSpeed ZeRO2 ZeRO3, FSDP, and other distributed training technologies. - Quantization Training: Provides training for quantized models like BNB, AWQ, GPTQ, AQLM, HQQ, EETQ. -- RLHF Training: Supports human alignment training methods like DPO, CPO, SimPO, ORPO, KTO, RM for both text-based and multimodal large models. +- RLHF Training: Supports human alignment training methods like DPO, CPO, SimPO, ORPO, KTO, RM, PPO for both text-based and multimodal large models. - 🍓 Multimodal Training: Capable of training models for different modalities such as images, videos, and audios; supports tasks like VQA (Visual Question Answering), Captioning, OCR (Optical Character Recognition), and Grounding. - Interface-driven Training: Offers training, inference, evaluation, and quantization capabilities through an interface, enabling a complete workflow for large models. - Plugins and Extensions: Allows customization and extension of models and datasets, and supports customizations for components like loss, metric, trainer, loss-scale, callback, optimizer, etc. diff --git a/docs/source_en/Instruction/Command-line-parameters.md b/docs/source_en/Instruction/Command-line-parameters.md index 28be6ae39..d163b10a6 100644 --- a/docs/source_en/Instruction/Command-line-parameters.md +++ b/docs/source_en/Instruction/Command-line-parameters.md @@ -308,7 +308,7 @@ Training arguments include the [base arguments](#base-arguments), [Seq2SeqTraine RLHF arguments inherit from the [training arguments](#training-arguments). -- 🔥rlhf_type: Alignment algorithm type, supports `dpo`, `orpo`, `simpo`, `kto`, `cpo`. +- 🔥rlhf_type: Alignment algorithm type, supports `dpo`, `orpo`, `simpo`, `kto`, `cpo`, `rm`, `ppo`. - ref_model: Original comparison model in algorithms like DPO. - ref_model_type: Same as model_type. - ref_model_revision: Same as model_revision. diff --git a/docs/source_en/Instruction/ReleaseNote3.0.md b/docs/source_en/Instruction/ReleaseNote3.0.md index c6c1c9cec..f46728886 100644 --- a/docs/source_en/Instruction/ReleaseNote3.0.md +++ b/docs/source_en/Instruction/ReleaseNote3.0.md @@ -94,7 +94,6 @@ The parameters marked as compatible in version 2.0 have been entirely removed. ## Pending Tasks -1. RM/PPO capabilities are not supported in version 3.0. Please use version 2.6.1. -2. Custom dataset evaluation is not supported in version 3.0. Please use version 2.6.1. -3. Megatron pre-training capabilities are not supported in version 3.0. Please use version 2.6.1. -4. Documentation and README, especially the English portions, are temporarily incomplete and will be updated. +1. Custom dataset evaluation is not supported in version 3.0. Please use version 2.6.1. +2. Megatron pre-training capabilities are not supported in version 3.0. Please use version 2.6.1. +3. Documentation and README are temporarily incomplete and will be updated. diff --git a/examples/notebook/qwen2.5-self-cognition/self-cognition-sft.ipynb b/examples/notebook/qwen2.5-self-cognition/self-cognition-sft.ipynb index 1ecd96cad..fee5144b2 100644 --- a/examples/notebook/qwen2.5-self-cognition/self-cognition-sft.ipynb +++ b/examples/notebook/qwen2.5-self-cognition/self-cognition-sft.ipynb @@ -10,7 +10,7 @@ "\n", "Are you ready? Let's begin the journey...\n", "\n", - "中文版:https://modelscope.cn/notebook/share/ipynb/4340fdeb/self-cognition-sft.ipynb" + "中文版:https://modelscope.cn/notebook/share/ipynb/313f6116/self-cognition-sft.ipynb" ] }, { diff --git a/swift/llm/argument/base_args/base_args.py b/swift/llm/argument/base_args/base_args.py index 293ca0f7c..49b1d3897 100644 --- a/swift/llm/argument/base_args/base_args.py +++ b/swift/llm/argument/base_args/base_args.py @@ -111,7 +111,8 @@ def _init_custom_register(self) -> None: folder, fname = os.path.split(path) sys.path.append(folder) __import__(fname.rstrip('.py')) - logger.info(f'Successfully registered `{self.custom_register_path}`') + if self.custom_register_path: + logger.info(f'Successfully registered `{self.custom_register_path}`') def _init_adapters(self): if isinstance(self.adapters, str): diff --git a/swift/llm/argument/rlhf_args.py b/swift/llm/argument/rlhf_args.py index 8d54f9eb7..48fba539f 100644 --- a/swift/llm/argument/rlhf_args.py +++ b/swift/llm/argument/rlhf_args.py @@ -25,7 +25,7 @@ class RLHFArguments(TrainArguments): desirable_weight (float): Weight for desirable outcomes in KTO. Default is 1.0. undesirable_weight (float): Weight for undesirable outcomes in KTO. Default is 1.0. """ - rlhf_type: Literal['dpo', 'orpo', 'simpo', 'kto', 'cpo', 'rm'] = 'dpo' + rlhf_type: Literal['dpo', 'orpo', 'simpo', 'kto', 'cpo', 'rm', 'ppo'] = 'dpo' ref_model: Optional[str] = None ref_model_type: Optional[str] = field( default=None, metadata={'help': f'model_type choices: {list(MODEL_MAPPING.keys())}'}) @@ -42,19 +42,38 @@ class RLHFArguments(TrainArguments): # KTO desirable_weight: float = 1.0 undesirable_weight: float = 1.0 + # PPO + reward_model: Optional[str] = None + reward_model_type: Optional[str] = field( + default=None, metadata={'help': f'model_type choices: {list(MODEL_MAPPING.keys())}'}) + reward_model_revision: Optional[str] = None + local_rollout_forward_batch_size: int = 64 + kl_coef: float = 0.05 + cliprange: float = 0.2 + vf_coef: float = 0.1 + cliprange_value: float = 0.2 + gamma: float = 1.0 + lam: float = 0.95 + num_sample_generations: int = 10 def __post_init__(self): self._init_simpo() + self._init_ppo() self._set_default() super().__post_init__() - if self.rlhf_type not in ['cpo', 'orpo', 'rm'] and (self.train_type == 'full' or self.rlhf_type == 'ppo'): + if self.rlhf_type in ['dpo', 'kto'] and self.train_type == 'full' or self.rlhf_type == 'ppo': self.ref_model = self.ref_model or self.model self.ref_model_type = self.ref_model_type or self.model_type self.ref_model_revision = self.ref_model_revision or self.model_revision elif self.ref_model is not None: raise ValueError('CPO/ORPO or LoRA training does not require a ref_model to be passed in.') + def _init_ppo(self): + self.response_length = self.max_new_tokens + self.num_ppo_epochs = self.num_train_epochs + # TODO: streaming, MLLM + def _init_simpo(self): if self.rlhf_type != 'simpo': return diff --git a/swift/llm/model/__init__.py b/swift/llm/model/__init__.py index 754d71520..939db750a 100644 --- a/swift/llm/model/__init__.py +++ b/swift/llm/model/__init__.py @@ -4,6 +4,6 @@ from .model_arch import MODEL_ARCH_MAPPING, ModelArch, ModelKeys, MultiModelKeys, get_model_arch, register_model_arch from .register import (MODEL_MAPPING, Model, ModelGroup, ModelMeta, fix_do_sample_warning, get_default_device_map, get_default_torch_dtype, get_model_info_meta, get_model_name, get_model_tokenizer, - get_model_tokenizer_multimodal, get_model_tokenizer_with_flash_attn, get_model_with_value_head, + get_model_tokenizer_multimodal, get_model_tokenizer_with_flash_attn, load_by_unsloth, register_model) from .utils import HfConfigFactory, ModelInfo, git_clone_github, safe_snapshot_download diff --git a/swift/llm/model/constant.py b/swift/llm/model/constant.py index a87f901c7..82a89bd03 100644 --- a/swift/llm/model/constant.py +++ b/swift/llm/model/constant.py @@ -93,9 +93,11 @@ class LLMModelType: mamba = 'mamba' polylm = 'polylm' aya = 'aya' - + # bert modern_bert = 'modern_bert' bert = 'bert' + # reward model + reward_model = 'reward_model' class MLLMModelType: diff --git a/swift/llm/model/model/__init__.py b/swift/llm/model/model/__init__.py index a972ec64e..82ebf432f 100644 --- a/swift/llm/model/model/__init__.py +++ b/swift/llm/model/model/__init__.py @@ -1,2 +1,2 @@ from . import (baai, baichuan, bert, codefuse, deepseek, gemma, glm, internlm, llama, llava, llm, mamba, microsoft, - minicpm, mistral, mllm, mplug, openbuddy, qwen, telechat, yi) + minicpm, mistral, mllm, mplug, openbuddy, qwen, telechat, yi, reward_model) diff --git a/swift/llm/model/model/reward_model.py b/swift/llm/model/model/reward_model.py new file mode 100644 index 000000000..63b0bb0c9 --- /dev/null +++ b/swift/llm/model/model/reward_model.py @@ -0,0 +1,33 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from transformers import AutoConfig +from transformers import AutoModel +from swift.utils import get_logger +from ..constant import LLMModelType +from ..register import Model, ModelGroup, ModelMeta, get_model_tokenizer_from_local, register_model + +logger = get_logger() + + +def get_model_tokenizer_reward_model(model_dir, *args, **kwargs): + model_config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True) + if 'AutoModel' in (getattr(model_config, 'auto_map', None) or {}): + kwargs['automodel_class'] = AutoModel + return get_model_tokenizer_from_local(model_dir, *args, **kwargs) + + +register_model( + ModelMeta( + LLMModelType.reward_model, [ + ModelGroup([ + Model('Qwen/Qwen2.5-Math-RM-72B', 'Qwen/Qwen2.5-Math-RM-72B'), + Model('Qwen/Qwen2-Math-RM-72B', 'Qwen/Qwen2-Math-RM-72B'), + ]), + ModelGroup([ + Model('Shanghai_AI_Laboratory/internlm2-1_8b-reward', 'internlm/internlm2-1_8b-reward'), + Model('Shanghai_AI_Laboratory/internlm2-7b-reward', 'internlm/internlm2-7b-reward'), + Model('Shanghai_AI_Laboratory/internlm2-20b-reward', 'internlm/internlm2-20b-reward'), + ]), + ], + None, + get_model_tokenizer_reward_model, + tags=['reward_model'])) diff --git a/swift/llm/model/register.py b/swift/llm/model/register.py index a98406eb8..8cf83a554 100644 --- a/swift/llm/model/register.py +++ b/swift/llm/model/register.py @@ -196,62 +196,6 @@ def get_model_tokenizer_from_local(model_dir: str, return model, tokenizer -def get_model_with_value_head(model) -> 'AutoModelForCausalLMWithValueHead': - from trl import AutoModelForCausalLMWithValueHead - lm_head_namings = ['lm_head', 'embed_out'] - if not any(hasattr(model, attribute) for attribute in lm_head_namings): - setattr(model, 'lm_head', None) # avoid ValueError - - model = AutoModelForCausalLMWithValueHead.from_pretrained(model) - - def patch_valuehead_model(model): - attr_list = [ - 'get_input_embeddings', 'vis_processor', 'extract_feature', 'get_rope_index', 'model', 'vision_tower', - 'img2emb', '_encode_image', '_merge_input_ids_with_image_features', 'prepare_inputs_embeds', - 'build_conversation_input_ids', 'config', 'get_slice_image_placeholder', 'transform', 'get_vllm_embedding', - 'forward_image', 'dtype', 'base_model_prefix', 'device', 'visual' - ] - for attr in attr_list: - if hasattr(model.pretrained_model, attr) and not hasattr(model, attr): - setattr(model, attr, getattr(model.pretrained_model, attr)) - - # PPO compatible - if not hasattr(model, 'score'): - setattr(model, 'score', model.v_head) - if model.base_model_prefix == '' and hasattr(model.pretrained_model, 'language_model'): - model.base_model_prefix = model.pretrained_model.language_model.base_model_prefix - - base_model_prefix = model.pretrained_model.base_model_prefix - if hasattr(model.pretrained_model, base_model_prefix): - setattr(model, base_model_prefix, getattr(model.pretrained_model, base_model_prefix)) - - patch_valuehead_model(model) - - # try to load local vhead weights - vhead_params = None - try: - from safetensors import safe_open - vhead_file = os.path.join(model.pretrained_model.model_dir, 'value_head.safetensors') - with safe_open(vhead_file, framework='pt', device='cpu') as f: - vhead_params = {key: f.get_tensor(key) for key in f.keys()} - except Exception: - pass - - try: - vhead_file = os.path.join(model.pretrained_model.model_dir, 'value_head.bin') - vhead_params = torch.load(vhead_file, map_location='cpu') - except Exception: - pass - - if vhead_params is not None: - model.load_state_dict(vhead_params, strict=False) - logger.info(f'Loading value head weights from {vhead_file}') - else: - logger.info('The local value head weight file was not detected.' - 'Ignore it if this is during the reward modeling phase,') - return model - - def get_model_tokenizer_with_flash_attn(model_dir: str, model_info: ModelInfo, model_kwargs: Dict[str, Any], diff --git a/swift/llm/train/rlhf.py b/swift/llm/train/rlhf.py index 906a3e116..3ec858b1b 100644 --- a/swift/llm/train/rlhf.py +++ b/swift/llm/train/rlhf.py @@ -32,17 +32,6 @@ def _prepare_template(self) -> None: # Avoid padding labels during the model's forward pass in multimodal models. self.template.loss_scale = 'last_round' - @classmethod - def prepare_model(cls, args, model, *_args, **kwargs): - model = super().prepare_model(args, model, *_args, **kwargs) - if args.rlhf_type == 'rm': - from trl import AutoModelForCausalLMWithValueHead - lm_head_namings = ['lm_head', 'embed_out'] - if not any(hasattr(model, attribute) for attribute in lm_head_namings): - model.lm_head = None # avoid error - model = AutoModelForCausalLMWithValueHead.from_pretrained(model) - patch_getattr(AutoModelForCausalLMWithValueHead, 'pretrained_model') - return model def _get_dataset(self): args = self.args