From adcae3841ac638aa911e6a185d3f1617e58495f8 Mon Sep 17 00:00:00 2001 From: Ekaterina Aidova Date: Fri, 10 Jan 2025 19:08:32 +0400 Subject: [PATCH] fix infer task from model_name if model from sentence transformer (#2151) * fix infer task from model_name if model from sentence transformer * use library_name for infer task --- optimum/exporters/onnx/__main__.py | 2 +- optimum/exporters/tasks.py | 34 ++++++++++++++++++++-------- optimum/exporters/tflite/__main__.py | 9 ++++++-- optimum/exporters/tflite/convert.py | 2 +- 4 files changed, 34 insertions(+), 13 deletions(-) diff --git a/optimum/exporters/onnx/__main__.py b/optimum/exporters/onnx/__main__.py index 280c6fc6554..20bea423cb3 100644 --- a/optimum/exporters/onnx/__main__.py +++ b/optimum/exporters/onnx/__main__.py @@ -256,7 +256,7 @@ def main_export( if task == "auto": try: - task = TasksManager.infer_task_from_model(model_name_or_path) + task = TasksManager.infer_task_from_model(model_name_or_path, library_name=library_name) except KeyError as e: raise KeyError( f"The task could not be automatically inferred. Please provide the argument --task with the relevant task from {', '.join(TasksManager.get_all_tasks())}. Detailed error: {e}" diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index 5651537162a..3793a56068a 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -1782,6 +1782,7 @@ def _infer_task_from_model_name_or_path( revision: Optional[str] = None, cache_dir: str = HUGGINGFACE_HUB_CACHE, token: Optional[Union[bool, str]] = None, + library_name: Optional[str] = None, ) -> str: inferred_task_name = None @@ -1803,13 +1804,14 @@ def _infer_task_from_model_name_or_path( raise RuntimeError( f"Hugging Face Hub is not reachable and we cannot infer the task from a cached model. Make sure you are not offline, or otherwise please specify the `task` (or `--task` in command-line) argument ({', '.join(TasksManager.get_all_tasks())})." ) - library_name = cls.infer_library_from_model( - model_name_or_path, - subfolder=subfolder, - revision=revision, - cache_dir=cache_dir, - token=token, - ) + if library_name is None: + library_name = cls.infer_library_from_model( + model_name_or_path, + subfolder=subfolder, + revision=revision, + cache_dir=cache_dir, + token=token, + ) if library_name == "timm": inferred_task_name = "image-classification" @@ -1828,6 +1830,8 @@ def _infer_task_from_model_name_or_path( break if inferred_task_name is not None: break + elif library_name == "sentence_transformers": + inferred_task_name = "feature-extraction" elif library_name == "transformers": pipeline_tag = model_info.pipeline_tag transformers_info = model_info.transformersInfo @@ -1864,6 +1868,7 @@ def infer_task_from_model( revision: Optional[str] = None, cache_dir: str = HUGGINGFACE_HUB_CACHE, token: Optional[Union[bool, str]] = None, + library_name: Optional[str] = None, ) -> str: """ Infers the task from the model repo, model instance, or model class. @@ -1882,7 +1887,9 @@ def infer_task_from_model( token (`Optional[Union[bool,str]]`, defaults to `None`): The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated when running `huggingface-cli login` (stored in `huggingface_hub.constants.HF_TOKEN_PATH`). - + library_name (`Optional[str]`, defaults to `None`): + The library name of the model. Can be any of "transformers", "timm", "diffusers", "sentence_transformers". See `TasksManager.infer_library_from_model` for the priority should + none be provided. Returns: `str`: The task name automatically detected from the HF hub repo, model instance, or model class. """ @@ -1895,6 +1902,7 @@ def infer_task_from_model( revision=revision, cache_dir=cache_dir, token=token, + library_name=library_name, ) elif type(model) == type: inferred_task_name = cls._infer_task_from_model_or_model_class(model_class=model) @@ -2170,6 +2178,9 @@ def get_model_from_task( none be provided. model_kwargs (`Dict[str, Any]`, *optional*): Keyword arguments to pass to the model `.from_pretrained()` method. + library_name (`Optional[str]`, defaults to `None`): + The library name of the model. Can be any of "transformers", "timm", "diffusers", "sentence_transformers". See `TasksManager.infer_library_from_model` for the priority should + none be provided. Returns: The instance of the model. @@ -2189,7 +2200,12 @@ def get_model_from_task( original_task = task if task == "auto": task = TasksManager.infer_task_from_model( - model_name_or_path, subfolder=subfolder, revision=revision, cache_dir=cache_dir, token=token + model_name_or_path, + subfolder=subfolder, + revision=revision, + cache_dir=cache_dir, + token=token, + library_name=library_name, ) model_type = None diff --git a/optimum/exporters/tflite/__main__.py b/optimum/exporters/tflite/__main__.py index 0c4c7b994fa..4d8d4ee7b7e 100644 --- a/optimum/exporters/tflite/__main__.py +++ b/optimum/exporters/tflite/__main__.py @@ -46,7 +46,7 @@ def main(): task = args.task if task == "auto": try: - task = TasksManager.infer_task_from_model(args.model) + task = TasksManager.infer_task_from_model(args.model, library_name="transformers") except KeyError as e: raise KeyError( "The task could not be automatically inferred. Please provide the argument --task with the task " @@ -58,7 +58,12 @@ def main(): ) model = TasksManager.get_model_from_task( - task, args.model, framework="tf", cache_dir=args.cache_dir, trust_remote_code=args.trust_remote_code + task, + args.model, + framework="tf", + cache_dir=args.cache_dir, + trust_remote_code=args.trust_remote_code, + library_name="transformers", ) tflite_config_constructor = TasksManager.get_exporter_config_constructor( diff --git a/optimum/exporters/tflite/convert.py b/optimum/exporters/tflite/convert.py index c1a2010355a..fb0706cacd5 100644 --- a/optimum/exporters/tflite/convert.py +++ b/optimum/exporters/tflite/convert.py @@ -194,7 +194,7 @@ def prepare_converter_for_quantization( if task is None: from ...exporters import TasksManager - task = TasksManager.infer_task_from_model(model) + task = TasksManager.infer_task_from_model(model, library_name="transformers") preprocessor_kwargs = {} if isinstance(preprocessor, PreTrainedTokenizerBase):