diff --git a/optimum/onnxruntime/modeling_ort.py b/optimum/onnxruntime/modeling_ort.py index 9b29afa566b..ce1d68536ac 100644 --- a/optimum/onnxruntime/modeling_ort.py +++ b/optimum/onnxruntime/modeling_ort.py @@ -510,13 +510,12 @@ def _from_pretrained( if file_name is None: if model_path.is_dir(): - onnx_files = list(model_path.glob("*.onnx")) + onnx_files = list((model_path / subfolder).glob("*.onnx")) else: repo_files, _ = TasksManager.get_model_files( model_id, revision=revision, cache_dir=cache_dir, token=token ) repo_files = map(Path, repo_files) - pattern = "*.onnx" if subfolder == "" else f"{subfolder}/*.onnx" onnx_files = [p for p in repo_files if p.match(pattern)] @@ -983,10 +982,9 @@ def _cached_file( token = use_auth_token model_path = Path(model_path) - # locates a file in a local folder and repo, downloads and cache it if necessary. if model_path.is_dir(): - model_cache_path = model_path / file_name + model_cache_path = model_path / subfolder / file_name preprocessors = maybe_load_preprocessors(model_path.as_posix()) else: model_cache_path = hf_hub_download( diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index 665f253c480..abf508a80c3 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -28,6 +28,7 @@ import requests import timm import torch +from huggingface_hub import HfApi from huggingface_hub.constants import default_cache_path from parameterized import parameterized from PIL import Image @@ -1263,6 +1264,19 @@ def test_trust_remote_code(self): torch.allclose(pt_logits, ort_logits, atol=1e-4), f" Maxdiff: {torch.abs(pt_logits - ort_logits).max()}" ) + @parameterized.expand(("", "onnx")) + def test_loading_with_config_in_root(self, subfolder): + # config.json file in the root directory and not in the subfolder + model_id = "sentence-transformers-testing/stsb-bert-tiny-onnx" + # hub model + ORTModelForFeatureExtraction.from_pretrained(model_id, subfolder=subfolder, export=subfolder == "") + # local model + api = HfApi() + with tempfile.TemporaryDirectory() as tmpdirname: + local_dir = Path(tmpdirname) / "model" + api.snapshot_download(repo_id=model_id, local_dir=local_dir) + ORTModelForFeatureExtraction.from_pretrained(local_dir, subfolder=subfolder, export=subfolder == "") + class ORTModelForQuestionAnsweringIntegrationTest(ORTModelTestMixin): SUPPORTED_ARCHITECTURES = [