Skip to content

Commit

Permalink
final
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Jul 10, 2024
1 parent 7b937be commit 48afefd
Showing 1 changed file with 9 additions and 11 deletions.
20 changes: 9 additions & 11 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1631,9 +1631,6 @@ def _infer_task_from_model_or_model_class(
target_class_name = model.__class__.__name__ if model is not None else model_class.__name__
target_class_module = model.__class__.__module__ if model is not None else model_class.__module__

print(target_class_name)
print(target_class_module)

if target_class_name.startswith("AutoModel"):
# transfromers models (auto)
for task_name, model_loader_class_names in cls._TRANSFORMERS_TASKS_TO_MODEL_LOADERS.items():
Expand Down Expand Up @@ -1667,14 +1664,15 @@ def _infer_task_from_model_or_model_class(
break
elif target_class_module.startswith("transformers"):
# transformers models
for task_name, model_loader_class_names in (
cls._TRANSFORMERS_TASKS_TO_MODEL_MAPPINGS.items()
+ cls._TRANSFORMERS_TASKS_TO_TF_MODEL_MAPPINGS.items()
):
if isinstance(model_loader_class_names, str):
model_loader_class_names = (model_loader_class_names,)
for model_loader_class_name in model_loader_class_names:
for model_type, model_class_name in model_loader_class_name.items():
if target_class_name.startswith("TF"):
task_name_to_model_mappings = cls._TRANSFORMERS_TASKS_TO_TF_MODEL_MAPPINGS
else:
task_name_to_model_mappings = cls._TRANSFORMERS_TASKS_TO_MODEL_MAPPINGS
for task_name, model_mappings in task_name_to_model_mappings.items():
if isinstance(model_mappings, dict):
model_mappings = (model_mappings,)
for model_mapping in model_mappings:
for model_type, model_class_name in model_mapping.items():
if target_class_name == model_class_name:
inferred_task_name = task_name
break
Expand Down

0 comments on commit 48afefd

Please sign in to comment.