Skip to content

Commit

Permalink
yo
Browse files Browse the repository at this point in the history
  • Loading branch information
snarayan21 committed Jul 29, 2024
1 parent 4bb1dec commit 8805071
Showing 1 changed file with 6 additions and 8 deletions.
14 changes: 6 additions & 8 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,10 +218,11 @@ def __init__(

self.mlflow_logging_config = mlflow_logging_config
if 'metadata' in self.mlflow_logging_config:
self.pretrained_model_name = self.mlflow_logging_config['metadata'].get(
'pretrained_model_name',
None,
)
self.pretrained_model_name = self.mlflow_logging_config[
'metadata'].get(
'pretrained_model_name',
None,
)
else:
self.pretrained_model_name = None

Expand Down Expand Up @@ -540,14 +541,11 @@ def tensor_hook(
if self.pretrained_model_name is not None:
new_model_instance.name_or_path = self.pretrained_model_name
if self.using_peft:
new_model_instance.base_model.name_or_path = self.pretrained_model_name
for k in new_model_instance.peft_config.keys():
new_model_instance.base_model.name_or_path = self.pretrained_model_name
new_model_instance.peft_config[
k
].base_model_name_or_path = self.pretrained_model_name
print("PEFT CONFIG IS:")
for k,v in new_model_instance.peft_config.items():
print("key:", k, "value:", v)

log.debug('Saving Hugging Face checkpoint to disk')
# This context manager casts the TE extra state in io.BytesIO format to tensor format
Expand Down

0 comments on commit 8805071

Please sign in to comment.