From f17cf1d53dedf7fbf849605bcf00775753e961f2 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Wed, 9 Oct 2024 12:42:33 +0200 Subject: [PATCH] remove the need for the config to be in the subfolder --- optimum/modeling_base.py | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/optimum/modeling_base.py b/optimum/modeling_base.py index 29521b7c0c6..2147660bf99 100644 --- a/optimum/modeling_base.py +++ b/optimum/modeling_base.py @@ -28,6 +28,7 @@ from .exporters import TasksManager from .utils import CONFIG_NAME +from .utils.file_utils import find_files_matching_pattern if TYPE_CHECKING: @@ -380,27 +381,24 @@ def from_pretrained( ) model_id, revision = model_id.split("@") + config_folder = ( + subfolder if find_files_matching_pattern(model_id, cls.config_name)[0].parent == subfolder else "" + ) + library_name = TasksManager.infer_library_from_model( - model_id, subfolder=subfolder, revision=revision, cache_dir=cache_dir, token=token + model_id, subfolder=config_folder, revision=revision, cache_dir=cache_dir, token=token ) if library_name == "timm": config = PretrainedConfig.from_pretrained( - model_id, subfolder=subfolder, revision=revision, cache_dir=cache_dir, token=token + model_id, subfolder=config_folder, revision=revision, cache_dir=cache_dir, token=token ) if config is None: - if os.path.isdir(os.path.join(model_id, subfolder)) and cls.config_name == CONFIG_NAME: - if CONFIG_NAME in os.listdir(os.path.join(model_id, subfolder)): - config = AutoConfig.from_pretrained( - os.path.join(model_id, subfolder), trust_remote_code=trust_remote_code - ) - elif CONFIG_NAME in os.listdir(model_id): + if os.path.isdir(os.path.join(model_id, config_folder)) and cls.config_name == CONFIG_NAME: + if CONFIG_NAME in os.listdir(os.path.join(model_id, config_folder)): config = AutoConfig.from_pretrained( - os.path.join(model_id, CONFIG_NAME), trust_remote_code=trust_remote_code - ) - logger.info( - f"config.json not found in the specified subfolder {subfolder}. Using the top level config.json." + os.path.join(model_id, config_folder), trust_remote_code=trust_remote_code ) else: raise OSError(f"config.json not found in {model_id} local folder") @@ -411,7 +409,7 @@ def from_pretrained( cache_dir=cache_dir, token=token, force_download=force_download, - subfolder=subfolder, + subfolder=config_folder, trust_remote_code=trust_remote_code, ) elif isinstance(config, (str, os.PathLike)): @@ -421,7 +419,7 @@ def from_pretrained( cache_dir=cache_dir, token=token, force_download=force_download, - subfolder=subfolder, + subfolder=config_folder, trust_remote_code=trust_remote_code, )