From 2df8a8baba429689e9e1d204e067ed78882bd363 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Fri, 27 Dec 2024 23:59:19 +0800 Subject: [PATCH 1/2] init internlm2 reward model Signed-off-by: Isotr0py <2037008807@qq.com> --- docs/source/models/supported_models.md | 5 +++ tests/models/registry.py | 2 + vllm/model_executor/models/internlm2.py | 60 ++++++++++++++++++++++++- vllm/model_executor/models/registry.py | 1 + 4 files changed, 67 insertions(+), 1 deletion(-) diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index 545a2ccaa5634..485a212be4796 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -445,6 +445,11 @@ of the whole prompt are extracted from the normalized hidden state corresponding - Example HF Models - :ref:`LoRA ` - :ref:`PP ` + * - :code:`InternLM2ForRewardModel` + - InternLM2-based + - :code:`internlm/internlm2-1_8b-reward`, :code:`internlm/internlm2-7b-reward`, etc. + - ✅︎ + - ✅︎ * - :code:`LlamaForCausalLM` - Llama-based - :code:`peiyi9979/math-shepherd-mistral-7b-prm`, etc. diff --git a/tests/models/registry.py b/tests/models/registry.py index 819ef957a07f3..94bf8932cba19 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -138,6 +138,8 @@ class _HfExamplesInfo: "BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"), "Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2"), "GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"), + "InternLM2ForRewardModel": _HfExamplesInfo("internlm/internlm2-1_8b-reward", + trust_remote_code=True), "JambaForSequenceClassification": _HfExamplesInfo("ai21labs/Jamba-tiny-reward-dev"), # noqa: E501 "LlamaModel": _HfExamplesInfo("llama", is_available_online=False), "MistralModel": _HfExamplesInfo("intfloat/e5-mistral-7b-instruct"), diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index 41b9f110d771f..a840dcc7149c6 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -18,14 +18,16 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors +from vllm.sequence import IntermediateTensors, PoolerOutput from .interfaces import SupportsLoRA, SupportsPP from .utils import (is_pp_missing_parameter, @@ -433,3 +435,59 @@ def load_weights(self, weights: Iterable[Tuple[str, weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params + + +class InternLM2ForRewardModel(InternLM2ForCausalLM): + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + model_type=InternLM2Model, + ): + super().__init__(vllm_config=vllm_config, + prefix=prefix, + model_type=model_type) + + for attr in ("output", "logits_processor", "sampler"): + delattr(self, attr) + + config = vllm_config.model_config.hf_config + self.v_head = RowParallelLinear( + config.hidden_size, + 1, + bias=False, + input_is_parallel=False, + prefix=maybe_prefix(prefix, "v_head"), + ) + + pooler_config = vllm_config.model_config.pooler_config + self._pooler = Pooler.from_config_with_defaults( + pooler_config, + pooling_type=PoolingType.ALL, + normalize=False, + softmax=False, + ) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors, + inputs_embeds) + logits, _ = self.v_head(hidden_states) + return logits + + def pooler( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Optional[PoolerOutput]: + return self._pooler(hidden_states, pooling_metadata) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index b32a3421d5841..13b20ffbce97a 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -112,6 +112,7 @@ "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"), "GlmForCausalLM": ("glm", "GlmForCausalLM"), "GritLM": ("gritlm", "GritLM"), + "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"), "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"), # noqa: E501 "LlamaModel": ("llama", "LlamaForCausalLM"), **{ From 3af45c29d5ca31523049210984b46e5e50cab137 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Sat, 28 Dec 2024 12:16:39 +0800 Subject: [PATCH 2/2] Update vllm/model_executor/models/internlm2.py Co-authored-by: Cyrus Leung --- vllm/model_executor/models/internlm2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index a840dcc7149c6..28c23edd4c8e8 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -444,7 +444,7 @@ def __init__( *, vllm_config: VllmConfig, prefix: str = "", - model_type=InternLM2Model, + model_type: Type[InternLM2Model] = InternLM2Model, ): super().__init__(vllm_config=vllm_config, prefix=prefix,