diff --git a/lightllm/models/qwen2_reward/layer_infer/post_layer_infer.py b/lightllm/models/qwen2_reward/layer_infer/post_layer_infer.py new file mode 100644 index 00000000..22ec8fd4 --- /dev/null +++ b/lightllm/models/qwen2_reward/layer_infer/post_layer_infer.py @@ -0,0 +1,22 @@ +import torch + +from lightllm.models.llama.infer_struct import LlamaInferStateInfo +from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer +from lightllm.models.qwen2_reward.layer_weights.pre_and_post_layer_weight import Qwen2RewardPreAndPostLayerWeight +from einops import rearrange + + +class Qwen2RewardPostLayerInfer(LlamaPostLayerInfer): + def token_forward( + self, input_embdings, infer_state: LlamaInferStateInfo, layer_weight: Qwen2RewardPreAndPostLayerWeight + ): + last_input, token_num = self._slice_get_last_input(input_embdings, infer_state) + + input_embdings = None + last_input = self._norm(last_input, infer_state, layer_weight) + + last_input = torch.addmm(layer_weight.score_up_bias, last_input, layer_weight.score_up_weight) + last_input = torch.nn.functional.relu(last_input) + score = torch.addmm(layer_weight.score_down_bias, last_input, layer_weight.score_down_weight) + + return score diff --git a/lightllm/models/qwen2_reward/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen2_reward/layer_weights/pre_and_post_layer_weight.py new file mode 100644 index 00000000..d0b4d195 --- /dev/null +++ b/lightllm/models/qwen2_reward/layer_weights/pre_and_post_layer_weight.py @@ -0,0 +1,50 @@ +import torch +import numpy as np +from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, COLMMWeight, NormWeight, MultiROWMMWeight + + +class Qwen2RewardPreAndPostLayerWeight(LlamaPreAndPostLayerWeight): + def __init__(self, tp_rank, world_size, data_type, network_config, mode): + super().__init__(tp_rank, world_size, data_type, network_config, mode) + return + + def load_hf_weights(self, weights): + vob_size = self.network_config_["vocab_size"] + split_indexes = np.linspace(0, vob_size, self.world_size_ + 1, dtype=np.int64) + split_start = split_indexes[self.tp_rank_] + split_end = split_indexes[self.tp_rank_ + 1] + if "model.embed_tokens.weight" in weights: + self.wte_weight_ = self._cuda(weights["model.embed_tokens.weight"][split_start:split_end, :]) + tie_word_embeddings = self.network_config_.get("tie_word_embeddings", False) + if tie_word_embeddings: + self.lm_head_weight_ = self.wte_weight_ + + if "model.norm.weight" in weights: + self.final_norm_weight_ = self._cuda(weights["model.norm.weight"]) + + if "score.0.weight" in weights: + self.score_up_weight = self._cuda(weights["score.0.weight"]).transpose(0, 1) + if "score.0.bias" in weights: + self.score_up_bias = self._cuda(weights["score.0.bias"]) + + if "score.2.weight" in weights: + self.score_down_weight = self._cuda(weights["score.2.weight"]).transpose(0, 1) + if "score.2.bias" in weights: + self.score_down_bias = self._cuda(weights["score.2.bias"]) + + return + + def verify_load(self): + errors = "weights load not ok" + weights = [ + self.wte_weight_, + self.final_norm_weight_, + self.score_up_weight, + self.score_up_bias, + self.score_down_weight, + self.score_down_bias, + ] + for i in range(len(weights)): + assert weights[i] is not None, "index:" + str(i) + " " + errors + return diff --git a/lightllm/models/qwen2_reward/model.py b/lightllm/models/qwen2_reward/model.py new file mode 100644 index 00000000..68cdf76a --- /dev/null +++ b/lightllm/models/qwen2_reward/model.py @@ -0,0 +1,12 @@ +from lightllm.models.qwen2_reward.layer_infer.post_layer_infer import Qwen2RewardPostLayerInfer +from lightllm.models.qwen2_reward.layer_weights.pre_and_post_layer_weight import Qwen2RewardPreAndPostLayerWeight +from lightllm.models.qwen2.model import Qwen2TpPartModel + + +class Qwen2RewardTpPartModel(Qwen2TpPartModel): + + pre_and_post_weight_class = Qwen2RewardPreAndPostLayerWeight + post_layer_infer_class = Qwen2RewardPostLayerInfer + + def __init__(self, kvargs): + super().__init__(kvargs) diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index c3968300..8c592c62 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -31,6 +31,7 @@ from lightllm.models.internvl.model import InternVLLlamaTpPartModel, InternVLPhi3TpPartModel from lightllm.models.internvl.model import InternVLInternlm2TpPartModel from lightllm.models.qwen2_vl.model import Qwen2VLTpPartModel +from lightllm.models.qwen2_reward.model import Qwen2RewardTpPartModel from lightllm.utils.infer_utils import set_random_seed from lightllm.utils.infer_utils import calculate_time, mark_start, mark_end from lightllm.utils.log_utils import init_logger @@ -183,7 +184,10 @@ def init_model(self, kvargs): self.model = LlavaTpPartModel(model_kvargs) self.is_multimodal = True elif self.model_type == "qwen2": - self.model = Qwen2TpPartModel(model_kvargs) + if model_cfg["architectures"][0] == "Qwen2ForRewardModel": + self.model = Qwen2RewardTpPartModel(model_kvargs) + else: + self.model = Qwen2TpPartModel(model_kvargs) elif self.model_type == "qwen2_vl": self.model = Qwen2VLTpPartModel(model_kvargs) self.is_multimodal = True diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_reward_model.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_reward_model.py index 18e7bb1a..fe789dec 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_reward_model.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_reward_model.py @@ -23,7 +23,7 @@ def forward(self, batch_id, is_prefill): kwargs, run_reqs = prepare_prefill_inputs(batch, self.radix_cache, self.is_multimodal) scores: torch.Tensor = self.model.forward(**kwargs) - scores = scores.unsqueeze(1).detach().cpu().numpy() + scores = scores.unsqueeze(1).detach().cpu().float().numpy() next_token_id = 1 next_token_logprob = 1.0