diff --git a/lightllm/models/internlm2_reward/layer_infer/post_layer_infer.py b/lightllm/models/internlm2_reward/layer_infer/post_layer_infer.py index 2760d838..42061c78 100644 --- a/lightllm/models/internlm2_reward/layer_infer/post_layer_infer.py +++ b/lightllm/models/internlm2_reward/layer_infer/post_layer_infer.py @@ -14,8 +14,6 @@ def token_forward(self, input_embdings, infer_state: LlamaInferStateInfo, layer_ input_embdings = None last_input = self._norm(last_input, infer_state, layer_weight) - last_input = rearrange(last_input, "batch embed_dim -> embed_dim batch").contiguous().reshape(-1, token_num) - - score = torch.mm(layer_weight.lm_head_weight_, last_input) + score = torch.mm(last_input, layer_weight.lm_head_weight_) return score diff --git a/lightllm/models/internlm2_reward/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/internlm2_reward/layer_weights/pre_and_post_layer_weight.py index 8484a558..235eed71 100644 --- a/lightllm/models/internlm2_reward/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/internlm2_reward/layer_weights/pre_and_post_layer_weight.py @@ -16,7 +16,7 @@ def load_hf_weights(self, weights): if "model.tok_embeddings.weight" in weights: self.wte_weight_ = self._cuda(weights["model.tok_embeddings.weight"][split_start:split_end, :]) if "v_head.weight" in weights: - self.lm_head_weight_ = self._cuda(weights["v_head.weight"]) + self.lm_head_weight_ = self._cuda(weights["v_head.weight"]).transpose(0, 1) if "model.norm.weight" in weights: self.final_norm_weight_ = self._cuda(weights["model.norm.weight"]) diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_reward_model.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_reward_model.py index 1695d7cc..18e7bb1a 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_reward_model.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_reward_model.py @@ -23,7 +23,7 @@ def forward(self, batch_id, is_prefill): kwargs, run_reqs = prepare_prefill_inputs(batch, self.radix_cache, self.is_multimodal) scores: torch.Tensor = self.model.forward(**kwargs) - scores = scores[0].detach().cpu().numpy() + scores = scores.unsqueeze(1).detach().cpu().numpy() next_token_id = 1 next_token_logprob = 1.0