Skip to content

Commit

Permalink
precommit
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg committed Sep 30, 2023
1 parent d5e0683 commit 53ba514
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 9 deletions.
15 changes: 9 additions & 6 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from composer.loggers.remote_uploader_downloader import RemoteUploaderDownloader
from composer.models import HuggingFaceModel
from composer.utils import dist, format_name_with_dist_and_time, parse_uri
from transformers import PreTrainedTokenizerBase, PreTrainedModel
from transformers import PreTrainedModel, PreTrainedTokenizerBase

from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM
from llmfoundry.utils.huggingface_hub_utils import \
Expand Down Expand Up @@ -148,9 +148,8 @@ def _save_checkpoint(self, state: State, logger: Logger):
str) # pyright doesn't know about enter_result

log.debug('Gathering state dict')
from torch.distributed.fsdp import \
FullyShardedDataParallel as FSDP

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

if state.is_model_ddp:
original_model = state.model.module.model
original_tokenizer = state.model.module.tokenizer
Expand All @@ -164,7 +163,10 @@ def _save_checkpoint(self, state: State, logger: Logger):
assert isinstance(original_model, PreTrainedModel)
assert isinstance(original_tokenizer, PreTrainedTokenizerBase)

state_dict_context = fsdp_state_dict_type_context(original_model, state_dict_type='full') if ((not state.is_model_ddp) and isinstance(state.model.model, FSDP)) else contextlib.nullcontext()
state_dict_context = fsdp_state_dict_type_context(
original_model, state_dict_type='full') if (
(not state.is_model_ddp) and isinstance(
state.model.model, FSDP)) else contextlib.nullcontext()
with state_dict_context:
state_dict = original_model.state_dict()

Expand Down Expand Up @@ -192,7 +194,8 @@ def _save_checkpoint(self, state: State, logger: Logger):
log.debug('Saving Hugging Face checkpoint to disk')
new_model_instance.save_pretrained(temp_save_dir)
if original_tokenizer is not None:
assert isinstance(original_tokenizer, PreTrainedTokenizerBase)
assert isinstance(original_tokenizer,
PreTrainedTokenizerBase)
original_tokenizer.save_pretrained(temp_save_dir)

# Only need to edit files for MPT because it has custom code
Expand Down
8 changes: 5 additions & 3 deletions tests/test_hf_conversion_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
sys.path.append(repo_dir)
import shutil
from argparse import Namespace
from typing import cast, Optional
from typing import Optional, cast

import pytest
import torch
Expand Down Expand Up @@ -427,8 +427,10 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path,
trust_remote_code=True,
)

check_hf_model_equivalence(trainer.state.model.model.to(precision) if fsdp_state_dict_type is not None else trainer.state.model.module.model.to(precision),
loaded_model)
check_hf_model_equivalence(
trainer.state.model.model.to(precision) if fsdp_state_dict_type
is not None else trainer.state.model.module.model.to(precision),
loaded_model)
check_hf_tokenizer_equivalence(tokenizer, loaded_tokenizer)

delete_transformers_cache()
Expand Down

0 comments on commit 53ba514

Please sign in to comment.