Skip to content

Commit

Permalink
fix circular import and remove extra imports from config
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg committed Jun 19, 2024
1 parent 0ffd73b commit 986f1bb
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 12 deletions.
11 changes: 1 addition & 10 deletions llmfoundry/models/mpt/configuration_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,12 @@
is_flash_v2_installed,
)

# NOTE: All utils are imported directly even if unused so that
# HuggingFace can detect all the needed files to copy into its modules folder.
# Otherwise, certain modules are missing.
# isort: off
from llmfoundry.models.layers.norm import LPLayerNorm # type: ignore (see note)
from llmfoundry.models.layers.layer_builders import build_norm, build_fc, build_ffn # type: ignore (see note)
from llmfoundry.models.layers.dmoe import dMoE # type: ignore (see note)
from llmfoundry.layers_registry import norms # type: ignore (see note)
from llmfoundry.utils.registry_utils import construct_from_registry # type: ignore (see note)
from llmfoundry.models.utils.config_defaults import (
attn_config_defaults,
ffn_config_defaults,
init_config_defaults,
fc_type_defaults,
) # type: ignore (see note)
)


class MPTConfig(PretrainedConfig):
Expand Down
5 changes: 3 additions & 2 deletions llmfoundry/utils/huggingface_hub_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,8 +288,9 @@ def edit_files_for_hf_compatibility(
existing_relative_imports = get_all_relative_imports(
os.path.join(folder, entrypoint),
)
# Add in self so we don't create a circular import
existing_relative_imports.add(os.path.splitext(entrypoint)[0])
# Add in all entrypoints so we don't create a circular import
for sub_entrypoint in entrypoint_files:
existing_relative_imports.add(os.path.splitext(sub_entrypoint)[0])
missing_relative_imports = all_relative_imports - existing_relative_imports
add_relative_imports(
os.path.join(folder, entrypoint),
Expand Down

0 comments on commit 986f1bb

Please sign in to comment.