-
Notifications
You must be signed in to change notification settings - Fork 211
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
90 additions
and
2 deletions.
There are no files selected for viewing
22 changes: 22 additions & 0 deletions
22
lightllm/models/qwen2_reward/layer_infer/post_layer_infer.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import torch | ||
|
||
from lightllm.models.llama.infer_struct import LlamaInferStateInfo | ||
from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer | ||
from lightllm.models.qwen2_reward.layer_weights.pre_and_post_layer_weight import Qwen2RewardPreAndPostLayerWeight | ||
from einops import rearrange | ||
|
||
|
||
class Qwen2RewardPostLayerInfer(LlamaPostLayerInfer): | ||
def token_forward( | ||
self, input_embdings, infer_state: LlamaInferStateInfo, layer_weight: Qwen2RewardPreAndPostLayerWeight | ||
): | ||
last_input, token_num = self._slice_get_last_input(input_embdings, infer_state) | ||
|
||
input_embdings = None | ||
last_input = self._norm(last_input, infer_state, layer_weight) | ||
|
||
last_input = torch.addmm(layer_weight.score_up_bias, last_input, layer_weight.score_up_weight) | ||
last_input = torch.nn.functional.relu(last_input) | ||
score = torch.addmm(layer_weight.score_down_bias, last_input, layer_weight.score_down_weight) | ||
|
||
return score |
50 changes: 50 additions & 0 deletions
50
lightllm/models/qwen2_reward/layer_weights/pre_and_post_layer_weight.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import torch | ||
import numpy as np | ||
from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight | ||
from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, COLMMWeight, NormWeight, MultiROWMMWeight | ||
|
||
|
||
class Qwen2RewardPreAndPostLayerWeight(LlamaPreAndPostLayerWeight): | ||
def __init__(self, tp_rank, world_size, data_type, network_config, mode): | ||
super().__init__(tp_rank, world_size, data_type, network_config, mode) | ||
return | ||
|
||
def load_hf_weights(self, weights): | ||
vob_size = self.network_config_["vocab_size"] | ||
split_indexes = np.linspace(0, vob_size, self.world_size_ + 1, dtype=np.int64) | ||
split_start = split_indexes[self.tp_rank_] | ||
split_end = split_indexes[self.tp_rank_ + 1] | ||
if "model.embed_tokens.weight" in weights: | ||
self.wte_weight_ = self._cuda(weights["model.embed_tokens.weight"][split_start:split_end, :]) | ||
tie_word_embeddings = self.network_config_.get("tie_word_embeddings", False) | ||
if tie_word_embeddings: | ||
self.lm_head_weight_ = self.wte_weight_ | ||
|
||
if "model.norm.weight" in weights: | ||
self.final_norm_weight_ = self._cuda(weights["model.norm.weight"]) | ||
|
||
if "score.0.weight" in weights: | ||
self.score_up_weight = self._cuda(weights["score.0.weight"]).transpose(0, 1) | ||
if "score.0.bias" in weights: | ||
self.score_up_bias = self._cuda(weights["score.0.bias"]) | ||
|
||
if "score.2.weight" in weights: | ||
self.score_down_weight = self._cuda(weights["score.2.weight"]).transpose(0, 1) | ||
if "score.2.bias" in weights: | ||
self.score_down_bias = self._cuda(weights["score.2.bias"]) | ||
|
||
return | ||
|
||
def verify_load(self): | ||
errors = "weights load not ok" | ||
weights = [ | ||
self.wte_weight_, | ||
self.final_norm_weight_, | ||
self.score_up_weight, | ||
self.score_up_bias, | ||
self.score_down_weight, | ||
self.score_down_bias, | ||
] | ||
for i in range(len(weights)): | ||
assert weights[i] is not None, "index:" + str(i) + " " + errors | ||
return |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
from lightllm.models.qwen2_reward.layer_infer.post_layer_infer import Qwen2RewardPostLayerInfer | ||
from lightllm.models.qwen2_reward.layer_weights.pre_and_post_layer_weight import Qwen2RewardPreAndPostLayerWeight | ||
from lightllm.models.qwen2.model import Qwen2TpPartModel | ||
|
||
|
||
class Qwen2RewardTpPartModel(Qwen2TpPartModel): | ||
|
||
pre_and_post_weight_class = Qwen2RewardPreAndPostLayerWeight | ||
post_layer_infer_class = Qwen2RewardPostLayerInfer | ||
|
||
def __init__(self, kvargs): | ||
super().__init__(kvargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters