Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Support InternLM2 Reward models #11571

Merged
merged 3 commits into from
Dec 28, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,11 @@ of the whole prompt are extracted from the normalized hidden state corresponding
- Example HF Models
- :ref:`LoRA <lora-adapter>`
- :ref:`PP <distributed-serving>`
* - :code:`InternLM2ForRewardModel`
- InternLM2-based
- :code:`internlm/internlm2-1_8b-reward`, :code:`internlm/internlm2-7b-reward`, etc.
- ✅︎
- ✅︎
* - :code:`LlamaForCausalLM`
- Llama-based
- :code:`peiyi9979/math-shepherd-mistral-7b-prm`, etc.
Expand Down
2 changes: 2 additions & 0 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,8 @@ class _HfExamplesInfo:
"BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"),
"Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2"),
"GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"),
"InternLM2ForRewardModel": _HfExamplesInfo("internlm/internlm2-1_8b-reward",
trust_remote_code=True),
"JambaForSequenceClassification": _HfExamplesInfo("ai21labs/Jamba-tiny-reward-dev"), # noqa: E501
"LlamaModel": _HfExamplesInfo("llama", is_available_online=False),
"MistralModel": _HfExamplesInfo("intfloat/e5-mistral-7b-instruct"),
Expand Down
60 changes: 59 additions & 1 deletion vllm/model_executor/models/internlm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,16 @@
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.sequence import IntermediateTensors, PoolerOutput

from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
Expand Down Expand Up @@ -433,3 +435,59 @@ def load_weights(self, weights: Iterable[Tuple[str,
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params


class InternLM2ForRewardModel(InternLM2ForCausalLM):

def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
model_type=InternLM2Model,
Isotr0py marked this conversation as resolved.
Show resolved Hide resolved
):
super().__init__(vllm_config=vllm_config,
prefix=prefix,
model_type=model_type)

for attr in ("output", "logits_processor", "sampler"):
delattr(self, attr)

config = vllm_config.model_config.hf_config
self.v_head = RowParallelLinear(
config.hidden_size,
1,
bias=False,
input_is_parallel=False,
prefix=maybe_prefix(prefix, "v_head"),
)

pooler_config = vllm_config.model_config.pooler_config
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.ALL,
normalize=False,
softmax=False,
)

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors,
inputs_embeds)
logits, _ = self.v_head(hidden_states)
return logits

def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
1 change: 1 addition & 0 deletions vllm/model_executor/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@
"Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
"GlmForCausalLM": ("glm", "GlmForCausalLM"),
"GritLM": ("gritlm", "GritLM"),
"InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
"JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"), # noqa: E501
"LlamaModel": ("llama", "LlamaForCausalLM"),
**{
Expand Down
Loading