Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into longjie/add_automat…
Browse files Browse the repository at this point in the history
…ic_model_parallel_via_fx
  • Loading branch information
zhenglongjiepheonix committed Jul 23, 2024
2 parents 8d2cabb + 5ece6e8 commit c9c7571
Showing 1 changed file with 10 additions and 7 deletions.
17 changes: 10 additions & 7 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,9 @@ def get_transformers_tasks_to_model_mapping(tasks_to_model_loader, framework="pt
for model_loader in model_loaders:
model_loader_class = getattr(auto_modeling_module, model_loader, None)
if model_loader_class is not None:
# we can just update the model_type to model_class mapping since we only need one either way
# we can just update the model_type to model_class mapping since
# we can only have one task->model_type->model_class either way
# e.g. we merge automatic-speech-recognition's SpeechSeq2Seq and CTC models
tasks_to_model_mapping[task_name].update(model_loader_class._model_mapping._model_mapping)

return tasks_to_model_mapping
Expand Down Expand Up @@ -1767,6 +1769,7 @@ def _infer_library_from_model_or_model_class(
model: Optional[Union["PreTrainedModel", "TFPreTrainedModel", "DiffusionPipeline"]] = None,
model_class: Optional[Type[Union["PreTrainedModel", "TFPreTrainedModel", "DiffusionPipeline"]]] = None,
):
inferred_library_name = None
if model is not None and model_class is not None:
raise ValueError("Either a model or a model class must be provided, but both were given here.")
if model is None and model_class is None:
Expand All @@ -1775,20 +1778,20 @@ def _infer_library_from_model_or_model_class(
target_class_module = model.__class__.__module__ if model is not None else model_class.__module__

if target_class_module.startswith("sentence_transformers"):
library_name = "sentence_transformers"
inferred_library_name = "sentence_transformers"
elif target_class_module.startswith("transformers"):
library_name = "transformers"
inferred_library_name = "transformers"
elif target_class_module.startswith("diffusers"):
library_name = "diffusers"
inferred_library_name = "diffusers"
elif target_class_module.startswith("timm"):
library_name = "timm"
inferred_library_name = "timm"

if library_name is None:
if inferred_library_name is None:
raise ValueError(
"The library name could not be automatically inferred. If using the command-line, please provide the argument --library {transformers,diffusers,timm,sentence_transformers}. Example: `--library diffusers`."
)

return library_name
return inferred_library_name

@classmethod
def _infer_library_from_model_name_or_path(
Expand Down

0 comments on commit c9c7571

Please sign in to comment.