Skip to content

Commit

Permalink
fixed load peft weight. #47
Browse files Browse the repository at this point in the history
  • Loading branch information
shibing624 committed Jun 20, 2023
1 parent 9e7c163 commit adf0335
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion textgen/chatglm/chatglm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def train_model(
# The two files above have a different name depending on how they were saved, but are actually the same.
if os.path.exists(checkpoint_name):
logger.info(f"Restarting from {checkpoint_name}")
adapters_weights = torch.load(checkpoint_name)
adapters_weights = torch.load(checkpoint_name, map_location='cpu')
set_peft_model_state_dict(self.model, adapters_weights)
else:
logger.warning(f"Checkpoint {checkpoint_name} not found")
Expand Down
2 changes: 1 addition & 1 deletion textgen/gpt/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ def train_model(
# The two files above have a different name depending on how they were saved, but are actually the same.
if os.path.exists(checkpoint_name):
logger.info(f"Restarting from {checkpoint_name}")
adapters_weights = torch.load(checkpoint_name)
adapters_weights = torch.load(checkpoint_name, map_location='cpu')
set_peft_model_state_dict(self.model, adapters_weights)
else:
logger.warning(f"Checkpoint {checkpoint_name} not found")
Expand Down

0 comments on commit adf0335

Please sign in to comment.