Skip to content

Commit

Permalink
[BugFix] Fix reward model support
Browse files Browse the repository at this point in the history
  • Loading branch information
sufubao committed Oct 28, 2024
1 parent 721f037 commit 85a1840
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@ def token_forward(self, input_embdings, infer_state: LlamaInferStateInfo, layer_

input_embdings = None
last_input = self._norm(last_input, infer_state, layer_weight)
last_input = rearrange(last_input, "batch embed_dim -> embed_dim batch").contiguous().reshape(-1, token_num)

score = torch.mm(layer_weight.lm_head_weight_, last_input)
score = torch.mm(last_input, layer_weight.lm_head_weight_)

return score
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def load_hf_weights(self, weights):
if "model.tok_embeddings.weight" in weights:
self.wte_weight_ = self._cuda(weights["model.tok_embeddings.weight"][split_start:split_end, :])
if "v_head.weight" in weights:
self.lm_head_weight_ = self._cuda(weights["v_head.weight"])
self.lm_head_weight_ = self._cuda(weights["v_head.weight"]).transpose(0, 1)
if "model.norm.weight" in weights:
self.final_norm_weight_ = self._cuda(weights["model.norm.weight"])

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def forward(self, batch_id, is_prefill):
kwargs, run_reqs = prepare_prefill_inputs(batch, self.radix_cache, self.is_multimodal)

scores: torch.Tensor = self.model.forward(**kwargs)
scores = scores[0].detach().cpu().numpy()
scores = scores.unsqueeze(1).detach().cpu().numpy()

next_token_id = 1
next_token_logprob = 1.0
Expand Down

0 comments on commit 85a1840

Please sign in to comment.