Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Jul 9, 2024
1 parent ff7ea35 commit 1fad1bd
Showing 1 changed file with 9 additions and 8 deletions.
17 changes: 9 additions & 8 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1589,7 +1589,7 @@ def _infer_task_from_model_or_model_class(
raise ValueError("Either a model or a model class must be provided, but none were given here.")

target_name = model.__class__.__name__ if model is not None else model_class.__name__
task_name = None
inferred_task_name = None
iterable = ()
for _, model_loader in cls._LIBRARY_TO_MODEL_LOADERS_TO_TASKS_MAP.items():
iterable += (model_loader.items(),)
Expand All @@ -1601,7 +1601,7 @@ def _infer_task_from_model_or_model_class(
for auto_cls_name, task in itertools.chain.from_iterable(iterable):
if any((target_name.startswith("Auto"), target_name.startswith("TFAuto"))):
if target_name == auto_cls_name:
task_name = task
inferred_task_name = task
break
continue

Expand All @@ -1614,20 +1614,20 @@ def _infer_task_from_model_or_model_class(
continue
model_mapping = auto_cls._model_mapping._model_mapping
if target_name in model_mapping.values():
task_name = task
inferred_task_name = task
break

for task_name, model_type, pipeline_class in DIFFUSION_PIPELINES_MAPPING:
if target_name == pipeline_class:
task_name = task_name
if target_name == pipeline_class.__name__:
inferred_task_name = task_name
break

if task_name is None:
if inferred_task_name is None:
raise ValueError(
"The task name could not be automatically inferred. If using the command-line, please provide the argument --task task-name. Example: `--task text-classification`."
)

return task_name
return inferred_task_name

@classmethod
def _infer_task_from_model_name_or_path(
Expand Down Expand Up @@ -1755,8 +1755,10 @@ def infer_task_from_model(
token=token,
)
elif issubclass(model, object):
# checks if it's a model class
task_name = cls._infer_task_from_model_or_model_class(model_class=model)
elif isinstance(model, object):
# checks if it's a model instance
task_name = cls._infer_task_from_model_or_model_class(model=model)

if task_name is None:
Expand Down Expand Up @@ -1903,7 +1905,6 @@ def standardize_model_attributes(

if library_name == "diffusers":
for task_name, model_type, pipeline_class in DIFFUSION_PIPELINES_MAPPING:
print(task_name, model_type, pipeline_class)
if isinstance(model, pipeline_class):
# `model_type` is a class attribute in Transformers, let's avoid modifying it.
model.config.export_model_type = model_type
Expand Down

0 comments on commit 1fad1bd

Please sign in to comment.