From 14037cd7a07fad2e73f29d542670934c8de40e0d Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Fri, 12 Jul 2024 14:27:03 +0200 Subject: [PATCH] fix --- optimum/exporters/tasks.py | 42 +++++++------------------------------- 1 file changed, 7 insertions(+), 35 deletions(-) diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index d3b3b69003f..eefc37a3716 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -176,12 +176,9 @@ def get_transformers_tasks_to_model_mapping(tasks_to_model_loader, framework="pt tasks_to_model_mapping[task_name] = {} for model_loader in model_loaders: - tasks_to_model_mapping[task_name][model_loader] = {} model_loader_class = getattr(auto_modeling_module, model_loader, None) if model_loader_class is not None: - tasks_to_model_mapping[task_name][model_loader].update( - model_loader_class._model_mapping._model_mapping - ) + tasks_to_model_mapping[task_name].update(model_loader_class._model_mapping._model_mapping) return tasks_to_model_mapping @@ -1519,9 +1516,7 @@ def determine_framework( subfolder: str = "", revision: Optional[str] = None, cache_dir: str = HUGGINGFACE_HUB_CACHE, - use_auth_token: Optional[Union[bool, str]] = None, token: Optional[Union[bool, str]] = None, - framework: Optional[str] = None, ) -> str: """ Determines the framework to use for the export. @@ -1543,29 +1538,14 @@ def determine_framework( Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id. cache_dir (`Optional[str]`, *optional*): Path to a directory in which a downloaded pretrained model weights have been cached if the standard cache should not be used. - use_auth_token (`Optional[Union[bool,str]]`, defaults to `None`): - Deprecated. Please use the `token` argument instead. token (`Optional[Union[bool,str]]`, defaults to `None`): The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated when running `huggingface-cli login` (stored in `huggingface_hub.constants.HF_TOKEN_PATH`). - framework (`Optional[str]`, *optional*): - The framework to use for the export. See above for priority if none provided. Returns: `str`: The framework to use for the export. """ - if framework is not None: - return framework - - if use_auth_token is not None: - warnings.warn( - "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", - FutureWarning, - ) - if token is not None: - raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") - token = use_auth_token all_files, request_exception = TasksManager.get_model_files( model_name_or_path, subfolder=subfolder, cache_dir=cache_dir, token=token, revision=revision @@ -1669,7 +1649,10 @@ def _infer_task_from_model_or_model_class( for task_name, model_mapping in tasks_to_model_mapping.items(): for model_type, model_class_name in model_mapping.items(): if target_class_name == model_class_name: - return task_name + inferred_task_name = task_name + break + if inferred_task_name is not None: + break if inferred_task_name is None: raise ValueError( @@ -1758,7 +1741,6 @@ def infer_task_from_model( subfolder: str = "", revision: Optional[str] = None, cache_dir: str = HUGGINGFACE_HUB_CACHE, - use_auth_token: Optional[Union[bool, str]] = None, token: Optional[Union[bool, str]] = None, ) -> str: """ @@ -1775,8 +1757,6 @@ def infer_task_from_model( Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id. cache_dir (`Optional[str]`, *optional*): Path to a directory in which a downloaded pretrained model weights have been cached if the standard cache should not be used. - use_auth_token (`Optional[Union[bool,str]]`, defaults to `None`): - Deprecated. Please use the `token` argument instead. token (`Optional[Union[bool,str]]`, defaults to `None`): The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated when running `huggingface-cli login` (stored in `huggingface_hub.constants.HF_TOKEN_PATH`). @@ -1786,15 +1766,6 @@ def infer_task_from_model( """ inferred_task_name = None - if use_auth_token is not None: - warnings.warn( - "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", - FutureWarning, - ) - if token is not None: - raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") - token = use_auth_token - if isinstance(model, str): inferred_task_name = cls._infer_task_from_model_name_or_path( model, subfolder=subfolder, revision=revision, cache_dir=cache_dir, token=token @@ -1811,8 +1782,9 @@ def infer_task_from_model( return inferred_task_name - @staticmethod + @classmethod def _infer_library_from_model_or_model_class( + cls, model: Optional[Union["PreTrainedModel", "TFPreTrainedModel", "DiffusionPipeline"]] = None, model_class: Optional[Type[Union["PreTrainedModel", "TFPreTrainedModel", "DiffusionPipeline"]]] = None, ):