Skip to content

Commit

Permalink
remove the need for the config to be in the subfolder
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix committed Oct 9, 2024
1 parent 049b00f commit f17cf1d
Showing 1 changed file with 12 additions and 14 deletions.
26 changes: 12 additions & 14 deletions optimum/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from .exporters import TasksManager
from .utils import CONFIG_NAME
from .utils.file_utils import find_files_matching_pattern


if TYPE_CHECKING:
Expand Down Expand Up @@ -380,27 +381,24 @@ def from_pretrained(
)
model_id, revision = model_id.split("@")

config_folder = (
subfolder if find_files_matching_pattern(model_id, cls.config_name)[0].parent == subfolder else ""
)

library_name = TasksManager.infer_library_from_model(
model_id, subfolder=subfolder, revision=revision, cache_dir=cache_dir, token=token
model_id, subfolder=config_folder, revision=revision, cache_dir=cache_dir, token=token
)

if library_name == "timm":
config = PretrainedConfig.from_pretrained(
model_id, subfolder=subfolder, revision=revision, cache_dir=cache_dir, token=token
model_id, subfolder=config_folder, revision=revision, cache_dir=cache_dir, token=token
)

if config is None:
if os.path.isdir(os.path.join(model_id, subfolder)) and cls.config_name == CONFIG_NAME:
if CONFIG_NAME in os.listdir(os.path.join(model_id, subfolder)):
config = AutoConfig.from_pretrained(
os.path.join(model_id, subfolder), trust_remote_code=trust_remote_code
)
elif CONFIG_NAME in os.listdir(model_id):
if os.path.isdir(os.path.join(model_id, config_folder)) and cls.config_name == CONFIG_NAME:
if CONFIG_NAME in os.listdir(os.path.join(model_id, config_folder)):
config = AutoConfig.from_pretrained(
os.path.join(model_id, CONFIG_NAME), trust_remote_code=trust_remote_code
)
logger.info(
f"config.json not found in the specified subfolder {subfolder}. Using the top level config.json."
os.path.join(model_id, config_folder), trust_remote_code=trust_remote_code
)
else:
raise OSError(f"config.json not found in {model_id} local folder")
Expand All @@ -411,7 +409,7 @@ def from_pretrained(
cache_dir=cache_dir,
token=token,
force_download=force_download,
subfolder=subfolder,
subfolder=config_folder,
trust_remote_code=trust_remote_code,
)
elif isinstance(config, (str, os.PathLike)):
Expand All @@ -421,7 +419,7 @@ def from_pretrained(
cache_dir=cache_dir,
token=token,
force_download=force_download,
subfolder=subfolder,
subfolder=config_folder,
trust_remote_code=trust_remote_code,
)

Expand Down

0 comments on commit f17cf1d

Please sign in to comment.