Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Jul 12, 2024
1 parent 68fabd4 commit 14037cd
Showing 1 changed file with 7 additions and 35 deletions.
42 changes: 7 additions & 35 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,12 +176,9 @@ def get_transformers_tasks_to_model_mapping(tasks_to_model_loader, framework="pt

tasks_to_model_mapping[task_name] = {}
for model_loader in model_loaders:
tasks_to_model_mapping[task_name][model_loader] = {}
model_loader_class = getattr(auto_modeling_module, model_loader, None)
if model_loader_class is not None:
tasks_to_model_mapping[task_name][model_loader].update(
model_loader_class._model_mapping._model_mapping
)
tasks_to_model_mapping[task_name].update(model_loader_class._model_mapping._model_mapping)

return tasks_to_model_mapping

Expand Down Expand Up @@ -1519,9 +1516,7 @@ def determine_framework(
subfolder: str = "",
revision: Optional[str] = None,
cache_dir: str = HUGGINGFACE_HUB_CACHE,
use_auth_token: Optional[Union[bool, str]] = None,
token: Optional[Union[bool, str]] = None,
framework: Optional[str] = None,
) -> str:
"""
Determines the framework to use for the export.
Expand All @@ -1543,29 +1538,14 @@ def determine_framework(
Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id.
cache_dir (`Optional[str]`, *optional*):
Path to a directory in which a downloaded pretrained model weights have been cached if the standard cache should not be used.
use_auth_token (`Optional[Union[bool,str]]`, defaults to `None`):
Deprecated. Please use the `token` argument instead.
token (`Optional[Union[bool,str]]`, defaults to `None`):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `huggingface-cli login` (stored in `huggingface_hub.constants.HF_TOKEN_PATH`).
framework (`Optional[str]`, *optional*):
The framework to use for the export. See above for priority if none provided.
Returns:
`str`: The framework to use for the export.
"""
if framework is not None:
return framework

if use_auth_token is not None:
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.",
FutureWarning,
)
if token is not None:
raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.")
token = use_auth_token

all_files, request_exception = TasksManager.get_model_files(
model_name_or_path, subfolder=subfolder, cache_dir=cache_dir, token=token, revision=revision
Expand Down Expand Up @@ -1669,7 +1649,10 @@ def _infer_task_from_model_or_model_class(
for task_name, model_mapping in tasks_to_model_mapping.items():
for model_type, model_class_name in model_mapping.items():
if target_class_name == model_class_name:
return task_name
inferred_task_name = task_name
break
if inferred_task_name is not None:
break

if inferred_task_name is None:
raise ValueError(
Expand Down Expand Up @@ -1758,7 +1741,6 @@ def infer_task_from_model(
subfolder: str = "",
revision: Optional[str] = None,
cache_dir: str = HUGGINGFACE_HUB_CACHE,
use_auth_token: Optional[Union[bool, str]] = None,
token: Optional[Union[bool, str]] = None,
) -> str:
"""
Expand All @@ -1775,8 +1757,6 @@ def infer_task_from_model(
Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id.
cache_dir (`Optional[str]`, *optional*):
Path to a directory in which a downloaded pretrained model weights have been cached if the standard cache should not be used.
use_auth_token (`Optional[Union[bool,str]]`, defaults to `None`):
Deprecated. Please use the `token` argument instead.
token (`Optional[Union[bool,str]]`, defaults to `None`):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `huggingface-cli login` (stored in `huggingface_hub.constants.HF_TOKEN_PATH`).
Expand All @@ -1786,15 +1766,6 @@ def infer_task_from_model(
"""
inferred_task_name = None

if use_auth_token is not None:
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.",
FutureWarning,
)
if token is not None:
raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.")
token = use_auth_token

if isinstance(model, str):
inferred_task_name = cls._infer_task_from_model_name_or_path(
model, subfolder=subfolder, revision=revision, cache_dir=cache_dir, token=token
Expand All @@ -1811,8 +1782,9 @@ def infer_task_from_model(

return inferred_task_name

@staticmethod
@classmethod
def _infer_library_from_model_or_model_class(
cls,
model: Optional[Union["PreTrainedModel", "TFPreTrainedModel", "DiffusionPipeline"]] = None,
model_class: Optional[Type[Union["PreTrainedModel", "TFPreTrainedModel", "DiffusionPipeline"]]] = None,
):
Expand Down

0 comments on commit 14037cd

Please sign in to comment.