Skip to content

Commit

Permalink
update meta_weights
Browse files Browse the repository at this point in the history
  • Loading branch information
baishihao committed Nov 21, 2024
1 parent ded7eea commit 5fc5194
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 4 deletions.
2 changes: 0 additions & 2 deletions lightllm/common/basemodel/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,6 @@ def __init__(self, kvargs):
self.quant_type = kvargs.get("quant_type", None)
self.quant_cfg_path = kvargs.get("quant_cfg", None)
self.mem_fraction = kvargs.get("mem_fraction", 0.9)
self.disable_qk_absorb = kvargs.get("disable_qk_absorb", False)
self.disable_vo_absorb = kvargs.get("disable_vo_absorb", False)

self._init_datatype()
self._init_config()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def load_hf_weights(self, weights):
self.weight = weight[:, start:end]
if self.bias_name in weights:
bias = weights[self.bias_name].to(self.data_type_)
self.bias = bias.cuda(self.tp_rank_) / self.world_size_
self.bias = (bias / self.world_size_).cuda(self.tp_rank_)
if weight is None:
return
self._post_load_weights()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(self, weight_name, data_type, bias_name=None):

def load_hf_weights(self, weights):
if self.weight_name in weights:
self.weight = weights[self.weight_name].to(self.data_type_).cuda(self.tp_rank_) + 1
self.weight = (weights[self.weight_name].to(self.data_type_) + 1).cuda(self.tp_rank_)


class TpNormWeight(NormWeight):
Expand Down

0 comments on commit 5fc5194

Please sign in to comment.