Skip to content

Commit

Permalink
[model] support Qwen2.5-RM (#621)
Browse files Browse the repository at this point in the history
  • Loading branch information
sufubao authored Nov 28, 2024
1 parent b8ef073 commit 89ef6b5
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 2 deletions.
22 changes: 22 additions & 0 deletions lightllm/models/qwen2_reward/layer_infer/post_layer_infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import torch

from lightllm.models.llama.infer_struct import LlamaInferStateInfo
from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer
from lightllm.models.qwen2_reward.layer_weights.pre_and_post_layer_weight import Qwen2RewardPreAndPostLayerWeight
from einops import rearrange


class Qwen2RewardPostLayerInfer(LlamaPostLayerInfer):
def token_forward(
self, input_embdings, infer_state: LlamaInferStateInfo, layer_weight: Qwen2RewardPreAndPostLayerWeight
):
last_input, token_num = self._slice_get_last_input(input_embdings, infer_state)

input_embdings = None
last_input = self._norm(last_input, infer_state, layer_weight)

last_input = torch.addmm(layer_weight.score_up_bias, last_input, layer_weight.score_up_weight)
last_input = torch.nn.functional.relu(last_input)
score = torch.addmm(layer_weight.score_down_bias, last_input, layer_weight.score_down_weight)

return score
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import torch
import numpy as np
from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight
from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, COLMMWeight, NormWeight, MultiROWMMWeight


class Qwen2RewardPreAndPostLayerWeight(LlamaPreAndPostLayerWeight):
def __init__(self, tp_rank, world_size, data_type, network_config, mode):
super().__init__(tp_rank, world_size, data_type, network_config, mode)
return

def load_hf_weights(self, weights):
vob_size = self.network_config_["vocab_size"]
split_indexes = np.linspace(0, vob_size, self.world_size_ + 1, dtype=np.int64)
split_start = split_indexes[self.tp_rank_]
split_end = split_indexes[self.tp_rank_ + 1]
if "model.embed_tokens.weight" in weights:
self.wte_weight_ = self._cuda(weights["model.embed_tokens.weight"][split_start:split_end, :])
tie_word_embeddings = self.network_config_.get("tie_word_embeddings", False)
if tie_word_embeddings:
self.lm_head_weight_ = self.wte_weight_

if "model.norm.weight" in weights:
self.final_norm_weight_ = self._cuda(weights["model.norm.weight"])

if "score.0.weight" in weights:
self.score_up_weight = self._cuda(weights["score.0.weight"]).transpose(0, 1)
if "score.0.bias" in weights:
self.score_up_bias = self._cuda(weights["score.0.bias"])

if "score.2.weight" in weights:
self.score_down_weight = self._cuda(weights["score.2.weight"]).transpose(0, 1)
if "score.2.bias" in weights:
self.score_down_bias = self._cuda(weights["score.2.bias"])

return

def verify_load(self):
errors = "weights load not ok"
weights = [
self.wte_weight_,
self.final_norm_weight_,
self.score_up_weight,
self.score_up_bias,
self.score_down_weight,
self.score_down_bias,
]
for i in range(len(weights)):
assert weights[i] is not None, "index:" + str(i) + " " + errors
return
12 changes: 12 additions & 0 deletions lightllm/models/qwen2_reward/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from lightllm.models.qwen2_reward.layer_infer.post_layer_infer import Qwen2RewardPostLayerInfer
from lightllm.models.qwen2_reward.layer_weights.pre_and_post_layer_weight import Qwen2RewardPreAndPostLayerWeight
from lightllm.models.qwen2.model import Qwen2TpPartModel


class Qwen2RewardTpPartModel(Qwen2TpPartModel):

pre_and_post_weight_class = Qwen2RewardPreAndPostLayerWeight
post_layer_infer_class = Qwen2RewardPostLayerInfer

def __init__(self, kvargs):
super().__init__(kvargs)
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from lightllm.models.internvl.model import InternVLLlamaTpPartModel, InternVLPhi3TpPartModel
from lightllm.models.internvl.model import InternVLInternlm2TpPartModel
from lightllm.models.qwen2_vl.model import Qwen2VLTpPartModel
from lightllm.models.qwen2_reward.model import Qwen2RewardTpPartModel
from lightllm.utils.infer_utils import set_random_seed
from lightllm.utils.infer_utils import calculate_time, mark_start, mark_end
from lightllm.utils.log_utils import init_logger
Expand Down Expand Up @@ -183,7 +184,10 @@ def init_model(self, kvargs):
self.model = LlavaTpPartModel(model_kvargs)
self.is_multimodal = True
elif self.model_type == "qwen2":
self.model = Qwen2TpPartModel(model_kvargs)
if model_cfg["architectures"][0] == "Qwen2ForRewardModel":
self.model = Qwen2RewardTpPartModel(model_kvargs)
else:
self.model = Qwen2TpPartModel(model_kvargs)
elif self.model_type == "qwen2_vl":
self.model = Qwen2VLTpPartModel(model_kvargs)
self.is_multimodal = True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def forward(self, batch_id, is_prefill):
kwargs, run_reqs = prepare_prefill_inputs(batch, self.radix_cache, self.is_multimodal)

scores: torch.Tensor = self.model.forward(**kwargs)
scores = scores.unsqueeze(1).detach().cpu().numpy()
scores = scores.unsqueeze(1).detach().cpu().float().numpy()

next_token_id = 1
next_token_logprob = 1.0
Expand Down

0 comments on commit 89ef6b5

Please sign in to comment.