Skip to content

Commit

Permalink
Add timm models export (#1587)
Browse files Browse the repository at this point in the history
* add timm export

* fix errors

* fix errors

* updated tests

* updated tests

* fix style

* fix tests

* uadd models for tests

* uadd models for tests

* update docs

* fix test

* fix test

---------

Co-authored-by: Félix Marty <[email protected]>
  • Loading branch information
mht-sharma and fxmarty authored Dec 13, 2023
1 parent afe2e3c commit 7f4d7ee
Show file tree
Hide file tree
Showing 10 changed files with 221 additions and 124 deletions.
67 changes: 62 additions & 5 deletions docs/source/exporters/onnx/overview.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,68 @@ Supported architectures from [🤗 Diffusers](https://huggingface.co/docs/diffus
- Stable Diffusion

Supported architectures from [🤗 Timm](https://huggingface.co/docs/timm/index):
- Resnext50-32x4d
- Resnext50d-32x4d
- Resnext101-32x4d
- Resnext101-32x8d
- Resnext101-64x4d
- Adversarial Inception v3
- AdvProp (EfficientNet)
- Big Transfer (BiT)
- CSP-DarkNet
- CSP-ResNet
- CSP-ResNeXt
- DenseNet
- Deep Layer Aggregation
- Dual Path Network (DPN)
- ECA-ResNet
- EfficientNet
- EfficientNet (Knapsack Pruned)
- Ensemble Adversarial Inception ResNet v2
- FBNet
- (Gluon) Inception v3
- (Gluon) ResNet
- (Gluon) ResNeXt
- (Gluon) SENet
- (Gluon) SE-ResNeXt
- (Gluon) Xception
- HRNet
- Instagram ResNeXt WSL
- Inception ResNet v2
- Inception v3
- Inception v4
- (Legacy) SE-ResNet
- (Legacy) SE-ResNeXt
- (Legacy) SENet
- MixNet
- MnasNet
- MobileNet v2
- MobileNet v3
- NASNet
- Noisy Student (EfficientNet)
- PNASNet
- RegNetX
- RegNetY
- Res2Net
- Res2NeXt
- ResNeSt
- ResNet
- ResNet-D
- ResNeXt
- RexNet
- SE-ResNet
- SelecSLS
- SE-ResNeXt
- SK-ResNet
- SK-ResNeXt
- SPNASNet
- SSL ResNet
- SWSL ResNet
- SWSL ResNeXt
- (Tensorflow) EfficientNet
- (Tensorflow) EfficientNet CondConv
- (Tensorflow) EfficientNet Lite
- (Tensorflow) Inception v3
- (Tensorflow) MixNet
- (Tensorflow) MobileNet v3
- TResNet
- Wide ResNet
- Xception

Supported architectures from [Sentence Transformers](https://github.com/UKPLab/sentence-transformers):
- All Transformer and CLIP-based models.
7 changes: 5 additions & 2 deletions optimum/exporters/onnx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def _get_submodels_and_onnx_configs(
fn_get_submodels: Optional[Callable] = None,
preprocessors: Optional[List[Any]] = None,
legacy: bool = False,
library_name: str = "transformers",
model_kwargs: Optional[Dict] = None,
):
is_stable_diffusion = "stable-diffusion" in task
Expand All @@ -84,7 +85,7 @@ def _get_submodels_and_onnx_configs(
)
else:
onnx_config_constructor = TasksManager.get_exporter_config_constructor(
model=model, exporter="onnx", task=task
model=model, exporter="onnx", task=task, library_name=library_name
)
onnx_config = onnx_config_constructor(
model.config,
Expand Down Expand Up @@ -425,7 +426,8 @@ def main_export(
if (
not custom_architecture
and not is_stable_diffusion
and task + "-with-past" in TasksManager.get_supported_tasks_for_model_type(model_type, "onnx")
and task + "-with-past"
in TasksManager.get_supported_tasks_for_model_type(model_type, "onnx", library_name=library_name)
):
if original_task == "auto": # Make -with-past the default if --task was not explicitely specified
task = task + "-with-past"
Expand Down Expand Up @@ -467,6 +469,7 @@ def main_export(
preprocessors=preprocessors,
_variant=_variant,
legacy=legacy,
library_name=library_name,
model_kwargs=model_kwargs,
)

Expand Down
2 changes: 1 addition & 1 deletion optimum/exporters/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,6 @@ def _run_validation(
if input_shapes is None:
input_shapes = {} # will use the defaults from DEFAULT_DUMMY_SHAPES
reference_model_inputs = config.generate_dummy_inputs(framework=framework, **input_shapes)
reference_model_inputs = config.rename_ambiguous_inputs(reference_model_inputs)

# Create ONNX Runtime session
session_options = SessionOptions()
Expand Down Expand Up @@ -322,6 +321,7 @@ def _run_validation(

# Some models may modify in place the inputs, hence the copy.
copy_reference_model_inputs = copy.deepcopy(reference_model_inputs)
copy_reference_model_inputs = config.rename_ambiguous_inputs(copy_reference_model_inputs)

with config.patch_model_for_export(reference_model, model_kwargs=model_kwargs):
if is_torch_available() and isinstance(reference_model, nn.Module):
Expand Down
28 changes: 21 additions & 7 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,8 +782,26 @@ class DonutSwinOnnxConfig(ViTOnnxConfig):
pass


class TimmResNextOnnxConfig(ViTOnnxConfig):
class TimmDefaultOnnxConfig(ViTOnnxConfig):
ATOL_FOR_VALIDATION = 1e-3
DEFAULT_ONNX_OPSET = 12

def __init__(
self,
config: "PretrainedConfig",
task: str = "feature-extraction",
preprocessors: Optional[List[Any]] = None,
int_dtype: str = "int64",
float_dtype: str = "fp32",
legacy: bool = False,
):
super().__init__(config, task, preprocessors, int_dtype, float_dtype, legacy)

pretrained_cfg = self._config
if hasattr(self._config, "pretrained_cfg"):
pretrained_cfg = self._config.pretrained_cfg

self._normalized_config = self.NORMALIZED_CONFIG_CLASS(pretrained_cfg)

def rename_ambiguous_inputs(self, inputs):
# The input name in the model signature is `x, hence the export input name is updated.
Expand All @@ -792,13 +810,9 @@ def rename_ambiguous_inputs(self, inputs):

return model_inputs


class TimmResNext50d_32x4dOnnxConfig(TimmResNextOnnxConfig):
ATOL_FOR_VALIDATION = 1e-3

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
return {"pixel_values": {0: "batch_size"}}
def torch_to_onnx_input_map(self) -> Dict[str, str]:
return {"x": "pixel_values"}


class SentenceTransformersTransformerOnnxConfig(TextEncoderOnnxConfig):
Expand Down
53 changes: 36 additions & 17 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,10 @@ class TasksManager:
"visual-question-answering",
)

_MODEL_TYPE_FOR_DEFAULT_CONFIG = {
"timm": "default-timm-config",
}

# TODO: some models here support text-generation export but are not supported in ORTModelForCausalLM
# Set of model topologies we support associated to the tasks supported by each topology and the factory
_SUPPORTED_MODEL_TYPE = {
Expand Down Expand Up @@ -836,12 +840,7 @@ class TasksManager:
"resnet": supported_tasks_mapping(
"feature-extraction", "image-classification", onnx="ResNetOnnxConfig", tflite="ResNetTFLiteConfig"
),
"resnext26ts": supported_tasks_mapping("image-classification", onnx="TimmResNextOnnxConfig"),
"resnext50-32x4d": supported_tasks_mapping("image-classification", onnx="TimmResNextOnnxConfig"),
"resnext50d-32x4d": supported_tasks_mapping("image-classification", onnx="TimmResNext50d_32x4dOnnxConfig"),
"resnext101-32x4d": supported_tasks_mapping("image-classification", onnx="TimmResNextOnnxConfig"),
"resnext101-32x8d": supported_tasks_mapping("image-classification", onnx="TimmResNextOnnxConfig"),
"resnext101-64x4d": supported_tasks_mapping("image-classification", onnx="TimmResNextOnnxConfig"),
"default-timm-config": supported_tasks_mapping("image-classification", onnx="TimmDefaultOnnxConfig"),
"roberta": supported_tasks_mapping(
"feature-extraction",
"fill-mask",
Expand Down Expand Up @@ -1108,7 +1107,7 @@ def decorator(config_cls: Type) -> Type:

@staticmethod
def get_supported_tasks_for_model_type(
model_type: str, exporter: str, model_name: Optional[str] = None
model_type: str, exporter: str, model_name: Optional[str] = None, library_name: str = "transformers"
) -> TaskNameToExportConfigDict:
"""
Retrieves the `task -> exporter backend config constructors` map from the model type.
Expand All @@ -1120,27 +1119,37 @@ def get_supported_tasks_for_model_type(
The name of the exporter.
model_name (`Optional[str]`, defaults to `None`):
The name attribute of the model object, only used for the exception message.
library_name (defaults to `transformers`):
The library name of the model.
Returns:
`TaskNameToExportConfigDict`: The dictionary mapping each task to a corresponding `ExportConfig`
constructor.
"""
model_type = model_type.lower().replace("_", "-")
model_type_and_model_name = f"{model_type} ({model_name})" if model_name else model_type

default_model_type = None
if library_name in TasksManager._MODEL_TYPE_FOR_DEFAULT_CONFIG:
default_model_type = TasksManager._MODEL_TYPE_FOR_DEFAULT_CONFIG[library_name]

if model_type not in TasksManager._SUPPORTED_MODEL_TYPE:
raise KeyError(
f"{model_type_and_model_name} is not supported yet. "
f"Only {list(TasksManager._SUPPORTED_MODEL_TYPE.keys())} are supported. "
f"If you want to support {model_type} please propose a PR or open up an issue."
)
elif exporter not in TasksManager._SUPPORTED_MODEL_TYPE[model_type]:
if default_model_type is not None:
model_type = default_model_type
else:
raise KeyError(
f"{model_type_and_model_name} is not supported yet for {library_name}. "
f"Only {list(TasksManager._SUPPORTED_MODEL_TYPE.keys())} are supported. "
f"If you want to support {model_type} please propose a PR or open up an issue."
)
if exporter not in TasksManager._SUPPORTED_MODEL_TYPE[model_type]:
raise KeyError(
f"{model_type_and_model_name} is not supported yet with the {exporter} backend. "
f"Only {list(TasksManager._SUPPORTED_MODEL_TYPE[model_type].keys())} are supported. "
f"If you want to support {exporter} please propose a PR or open up an issue."
)
else:
return TasksManager._SUPPORTED_MODEL_TYPE[model_type][exporter]

return TasksManager._SUPPORTED_MODEL_TYPE[model_type][exporter]

@staticmethod
def get_supported_model_type_for_task(task: str, exporter: str) -> List[str]:
Expand Down Expand Up @@ -1598,7 +1607,7 @@ def infer_library_from_model(

model_config = PretrainedConfig.from_json_file(config_path)

if hasattr(model_config, "pretrained_cfg"):
if hasattr(model_config, "pretrained_cfg") or hasattr(model_config, "architecture"):
library_name = "timm"
elif hasattr(model_config, "_diffusers_version"):
library_name = "diffusers"
Expand Down Expand Up @@ -1662,6 +1671,9 @@ def standardize_model_attributes(

model_config = PretrainedConfig.from_json_file(config_path)

if hasattr(model_config, "pretrained_cfg"):
model_config.pretrained_cfg = PretrainedConfig.from_dict(model_config.pretrained_cfg)

# Set config as in transformers
setattr(model, "config", model_config)

Expand Down Expand Up @@ -1781,6 +1793,7 @@ def get_model_from_task(

if library_name == "timm":
model = model_class(f"hf_hub:{model_name_or_path}", pretrained=True, exportable=True)
model = model.to(torch_dtype).to(device)
elif library_name == "sentence_transformers":
cache_folder = model_kwargs.pop("cache_folder", None)
use_auth_token = model_kwargs.pop("use_auth_token", None)
Expand Down Expand Up @@ -1831,6 +1844,7 @@ def get_exporter_config_constructor(
model_type: Optional[str] = None,
model_name: Optional[str] = None,
exporter_config_kwargs: Optional[Dict[str, Any]] = None,
library_name: str = "transformers",
) -> ExportConfigConstructor:
"""
Gets the `ExportConfigConstructor` for a model (or alternatively for a model type) and task combination.
Expand Down Expand Up @@ -1864,7 +1878,9 @@ def get_exporter_config_constructor(
model_type = model_type.replace("_", "-")
model_name = getattr(model, "name", model_name)

model_tasks = TasksManager.get_supported_tasks_for_model_type(model_type, exporter, model_name=model_name)
model_tasks = TasksManager.get_supported_tasks_for_model_type(
model_type, exporter, model_name=model_name, library_name=library_name
)

if task not in model_tasks:
synonyms = TasksManager.synonyms_for_task(task)
Expand All @@ -1878,6 +1894,9 @@ def get_exporter_config_constructor(
f" Supported tasks are: {', '.join(model_tasks.keys())}."
)

if model_type not in TasksManager._SUPPORTED_MODEL_TYPE:
model_type = TasksManager._MODEL_TYPE_FOR_DEFAULT_CONFIG[library_name]

exporter_config_constructor = TasksManager._SUPPORTED_MODEL_TYPE[model_type][exporter][task]
if exporter_config_kwargs is not None:
exporter_config_constructor = partial(exporter_config_constructor, **exporter_config_kwargs)
Expand Down
14 changes: 10 additions & 4 deletions optimum/utils/input_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,16 +604,22 @@ def __init__(
**kwargs,
):
self.task = task

# Some vision models can take any input sizes, in this case we use the values provided as parameters.
if normalized_config.has_attribute("image_size"):
self.image_size = normalized_config.image_size
else:
self.image_size = (height, width)
if normalized_config.has_attribute("num_channels"):
self.num_channels = normalized_config.num_channels
else:
self.num_channels = num_channels

if normalized_config.has_attribute("image_size"):
self.image_size = normalized_config.image_size
elif normalized_config.has_attribute("input_size"):
input_size = normalized_config.input_size
self.num_channels = input_size[0]
self.image_size = input_size[1:]
else:
self.image_size = (height, width)

if not isinstance(self.image_size, (tuple, list)):
self.image_size = (self.image_size, self.image_size)
self.batch_size = batch_size
Expand Down
1 change: 1 addition & 0 deletions optimum/utils/normalized_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ class NormalizedSeq2SeqConfig(NormalizedTextConfig):
class NormalizedVisionConfig(NormalizedConfig):
IMAGE_SIZE = "image_size"
NUM_CHANNELS = "num_channels"
INPUT_SIZE = "input_size"


class NormalizedTextAndVisionConfig(NormalizedTextConfig, NormalizedVisionConfig):
Expand Down
43 changes: 37 additions & 6 deletions tests/exporters/exporters_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,12 +288,43 @@
}

PYTORCH_TIMM_MODEL = {
"resnext26ts": "timm/resnext26ts.ra2_in1k",
"resnext50-32x4d": "timm/resnext50_32x4d.tv2_in1k",
"resnext50d-32x4d": "timm/resnext50d_32x4d.bt_in1k",
"resnext101-32x4d": "timm/resnext101_32x4d.gluon_in1k",
"resnext101-32x8d": "timm/resnext101_32x8d.tv_in1k",
"resnext101-64x4d": "timm/resnext101_64x4d.c1_in1k",
"default-timm-config": {
"timm/inception_v3.tf_adv_in1k": ["image-classification"],
"timm/tf_efficientnet_b0.in1k": ["image-classification"],
"timm/resnetv2_50x1_bit.goog_distilled_in1k": ["image-classification"],
"timm/cspdarknet53.ra_in1k": ["image-classification"],
"timm/cspresnet50.ra_in1k": ["image-classification"],
"timm/cspresnext50.ra_in1k": ["image-classification"],
"timm/densenet121.ra_in1k": ["image-classification"],
"timm/dla102.in1k": ["image-classification"],
"timm/dpn107.mx_in1k": ["image-classification"],
"timm/ecaresnet101d.miil_in1k": ["image-classification"],
"timm/efficientnet_b1_pruned.in1k": ["image-classification"],
"timm/inception_resnet_v2.tf_ens_adv_in1k": ["image-classification"],
"timm/fbnetc_100.rmsp_in1k": ["image-classification"],
"timm/xception41.tf_in1k": ["image-classification"],
"timm/senet154.gluon_in1k": ["image-classification"],
"timm/seresnext26d_32x4d.bt_in1k": ["image-classification"],
"timm/hrnet_w18.ms_aug_in1k": ["image-classification"],
"timm/inception_v3.gluon_in1k": ["image-classification"],
"timm/inception_v4.tf_in1k": ["image-classification"],
"timm/mixnet_s.ft_in1k": ["image-classification"],
"timm/mnasnet_100.rmsp_in1k": ["image-classification"],
"timm/mobilenetv2_100.ra_in1k": ["image-classification"],
"timm/mobilenetv3_small_050.lamb_in1k": ["image-classification"],
"timm/nasnetalarge.tf_in1k": ["image-classification"],
"timm/tf_efficientnet_b0.ns_jft_in1k": ["image-classification"],
"timm/pnasnet5large.tf_in1k": ["image-classification"],
"timm/regnetx_002.pycls_in1k": ["image-classification"],
"timm/regnety_002.pycls_in1k": ["image-classification"],
"timm/res2net101_26w_4s.in1k": ["image-classification"],
"timm/res2next50.in1k": ["image-classification"],
"timm/resnest101e.in1k": ["image-classification"],
"timm/spnasnet_100.rmsp_in1k": ["image-classification"],
"timm/resnet18.fb_swsl_ig1b_ft_in1k": ["image-classification"],
"timm/wide_resnet101_2.tv_in1k": ["image-classification"],
"timm/tresnet_l.miil_in1k": ["image-classification"],
}
}

PYTORCH_SENTENCE_TRANSFORMERS_MODEL = {
Expand Down
Loading

0 comments on commit 7f4d7ee

Please sign in to comment.