diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index df768a8e59e..f8d2b3a19f1 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -1589,7 +1589,7 @@ def _infer_task_from_model_or_model_class( raise ValueError("Either a model or a model class must be provided, but none were given here.") target_name = model.__class__.__name__ if model is not None else model_class.__name__ - task_name = None + inferred_task_name = None iterable = () for _, model_loader in cls._LIBRARY_TO_MODEL_LOADERS_TO_TASKS_MAP.items(): iterable += (model_loader.items(),) @@ -1601,7 +1601,7 @@ def _infer_task_from_model_or_model_class( for auto_cls_name, task in itertools.chain.from_iterable(iterable): if any((target_name.startswith("Auto"), target_name.startswith("TFAuto"))): if target_name == auto_cls_name: - task_name = task + inferred_task_name = task break continue @@ -1614,20 +1614,20 @@ def _infer_task_from_model_or_model_class( continue model_mapping = auto_cls._model_mapping._model_mapping if target_name in model_mapping.values(): - task_name = task + inferred_task_name = task break for task_name, model_type, pipeline_class in DIFFUSION_PIPELINES_MAPPING: - if target_name == pipeline_class: - task_name = task_name + if target_name == pipeline_class.__name__: + inferred_task_name = task_name break - if task_name is None: + if inferred_task_name is None: raise ValueError( "The task name could not be automatically inferred. If using the command-line, please provide the argument --task task-name. Example: `--task text-classification`." ) - return task_name + return inferred_task_name @classmethod def _infer_task_from_model_name_or_path( @@ -1755,8 +1755,10 @@ def infer_task_from_model( token=token, ) elif issubclass(model, object): + # checks if it's a model class task_name = cls._infer_task_from_model_or_model_class(model_class=model) elif isinstance(model, object): + # checks if it's a model instance task_name = cls._infer_task_from_model_or_model_class(model=model) if task_name is None: @@ -1903,7 +1905,6 @@ def standardize_model_attributes( if library_name == "diffusers": for task_name, model_type, pipeline_class in DIFFUSION_PIPELINES_MAPPING: - print(task_name, model_type, pipeline_class) if isinstance(model, pipeline_class): # `model_type` is a class attribute in Transformers, let's avoid modifying it. model.config.export_model_type = model_type