Skip to content

Commit

Permalink
Fix dynamic module import error (huggingface#21646)
Browse files Browse the repository at this point in the history
* fix dynamic module import error

---------

Co-authored-by: ydshieh <[email protected]>
  • Loading branch information
ydshieh and ydshieh authored Feb 17, 2023
1 parent 8a4c319 commit 7f1cdf1
Showing 1 changed file with 23 additions and 5 deletions.
28 changes: 23 additions & 5 deletions src/transformers/dynamic_module_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
import os
import re
import shutil
import subprocess
import sys
import tempfile
from pathlib import Path
from typing import Dict, Optional, Union

Expand Down Expand Up @@ -143,9 +145,25 @@ def get_class_in_module(class_name, module_path):
"""
Import a module on the cache directory for modules and extract a class from it.
"""
module_path = module_path.replace(os.path.sep, ".")
module = importlib.import_module(module_path)
return getattr(module, class_name)
with tempfile.TemporaryDirectory() as tmp_dir:
module_dir = Path(HF_MODULES_CACHE) / os.path.dirname(module_path)
module_file_name = module_path.split(os.path.sep)[-1] + ".py"

# Copy to a temporary directory. We need to do this in another process to avoid strange and flaky error
# `ModuleNotFoundError: No module named 'transformers_modules.[module_dir_name].modeling'`
shutil.copy(f"{module_dir}/{module_file_name}", tmp_dir)
# On Windows, we need this character `r` before the path argument of `os.remove`
cmd = f'import os; os.remove(r"{module_dir}{os.path.sep}{module_file_name}")'
subprocess.run(["python", "-c", cmd])

# copy back the file that we want to import
shutil.copyfile(f"{tmp_dir}/{module_file_name}", f"{module_dir}/{module_file_name}")

# import the module
module_path = module_path.replace(os.path.sep, ".")
module = importlib.import_module(module_path)

return getattr(module, class_name)


def get_cached_module_file(
Expand Down Expand Up @@ -212,7 +230,7 @@ def get_cached_module_file(
# Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file.
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
if os.path.isdir(pretrained_model_name_or_path):
submodule = "local"
submodule = pretrained_model_name_or_path.split(os.path.sep)[-1]
else:
submodule = pretrained_model_name_or_path.replace("/", os.path.sep)

Expand Down Expand Up @@ -240,7 +258,7 @@ def get_cached_module_file(
full_submodule = TRANSFORMERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule
create_dynamic_module(full_submodule)
submodule_path = Path(HF_MODULES_CACHE) / full_submodule
if submodule == "local":
if submodule == pretrained_model_name_or_path.split(os.path.sep)[-1]:
# We always copy local files (we could hash the file to see if there was a change, and give them the name of
# that hash, to only copy when there is a modification but it seems overkill for now).
# The only reason we do the copy is to avoid putting too many folders in sys.path.
Expand Down

0 comments on commit 7f1cdf1

Please sign in to comment.