Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Jul 12, 2024
1 parent 48afefd commit 68fabd4
Show file tree
Hide file tree
Showing 3 changed files with 209 additions and 181 deletions.
13 changes: 9 additions & 4 deletions optimum/exporters/onnx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,10 +230,15 @@ def main_export(
original_task = task
task = TasksManager.map_from_synonym(task)

framework = TasksManager.determine_framework(model_name_or_path, subfolder=subfolder, framework=framework)
library_name = TasksManager.infer_library_from_model(
model_name_or_path, subfolder=subfolder, library_name=library_name
)
if framework is None:
framework = TasksManager.determine_framework(
model_name_or_path, subfolder=subfolder, revision=revision, cache_dir=cache_dir, token=token
)

if library_name is None:
library_name = TasksManager.infer_library_from_model(
model_name_or_path, subfolder=subfolder, revision=revision, cache_dir=cache_dir, token=token
)

torch_dtype = None
if framework == "pt":
Expand Down
5 changes: 3 additions & 2 deletions optimum/exporters/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -999,15 +999,16 @@ def onnx_export_from_model(
>>> onnx_export_from_model(model, output="gpt2_onnx/")
```
"""
library_name = TasksManager._infer_library_from_model(model)

TasksManager.standardize_model_attributes(model, library_name)
TasksManager.standardize_model_attributes(model)

if hasattr(model.config, "export_model_type"):
model_type = model.config.export_model_type.replace("_", "-")
else:
model_type = model.config.model_type.replace("_", "-")

library_name = TasksManager.infer_library_from_model(model)

custom_architecture = library_name == "transformers" and model_type not in TasksManager._SUPPORTED_MODEL_TYPE

if task is not None:
Expand Down
Loading

0 comments on commit 68fabd4

Please sign in to comment.