Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Oct 8, 2024
1 parent f369b08 commit 7dd8f58
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions optimum/onnxruntime/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,6 @@ def _from_pretrained(
force_download: bool = False,
local_files_only: bool = False,
revision: Optional[str] = None,
trust_remote_code: bool = False,
cache_dir: str = HUGGINGFACE_HUB_CACHE,
token: Optional[Union[bool, str]] = None,
unet_file_name: str = ONNX_WEIGHTS_NAME,
Expand Down Expand Up @@ -258,10 +257,10 @@ def _from_pretrained(
if kwargs.get(model, None) is not None:
# this allows passing a model directly to from_pretrained
sessions[f"{model}_session"] = kwargs.pop(model)
elif path.is_file():
sessions[f"{model}_session"] = ORTModel.load_model(path, provider, session_options, provider_options)
else:
sessions[f"{model}_session"] = None
sessions[f"{model}_session"] = (
ORTModel.load_model(path, provider, session_options, provider_options) if path.is_file() else None
)

submodels = {}
for submodel in {"scheduler", "tokenizer", "tokenizer_2", "feature_extractor"}:
Expand All @@ -274,9 +273,9 @@ def _from_pretrained(
load_method = getattr(class_obj, "from_pretrained")
# Check if the module is in a subdirectory
if (model_save_path / submodel).is_dir():
submodels[submodel] = load_method(model_save_path / submodel, trust_remote_code=trust_remote_code)
submodels[submodel] = load_method(model_save_path / submodel)
else:
submodels[submodel] = load_method(model_save_path, trust_remote_code=trust_remote_code)
submodels[submodel] = load_method(model_save_path)

return cls(
**sessions,
Expand Down

0 comments on commit 7dd8f58

Please sign in to comment.