Skip to content

Commit

Permalink
Add Onnx Config for ImageGPT (huggingface#19868)
Browse files Browse the repository at this point in the history
* add Onnx Config for ImageGPT

* add generate_dummy_inputs for onnx config

* add TYPE_CHECKING clause

* Update doc for generate_dummy_inputs

Co-authored-by: Sylvain Gugger <[email protected]>

Co-authored-by: Sylvain Gugger <[email protected]>
  • Loading branch information
RaghavPrabhakar66 and sgugger authored Oct 28, 2022
1 parent 9b1dcba commit 0d4c45c
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 2 deletions.
1 change: 1 addition & 0 deletions docs/source/en/serialization.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ Ready-made configurations include the following architectures:
- GPT-J
- GroupViT
- I-BERT
- ImageGPT
- LayoutLM
- LayoutLMv3
- LeViT
Expand Down
6 changes: 4 additions & 2 deletions src/transformers/models/imagegpt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available


_import_structure = {"configuration_imagegpt": ["IMAGEGPT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ImageGPTConfig"]}
_import_structure = {
"configuration_imagegpt": ["IMAGEGPT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ImageGPTConfig", "ImageGPTOnnxConfig"]
}

try:
if not is_vision_available():
Expand All @@ -48,7 +50,7 @@


if TYPE_CHECKING:
from .configuration_imagegpt import IMAGEGPT_PRETRAINED_CONFIG_ARCHIVE_MAP, ImageGPTConfig
from .configuration_imagegpt import IMAGEGPT_PRETRAINED_CONFIG_ARCHIVE_MAP, ImageGPTConfig, ImageGPTOnnxConfig

try:
if not is_vision_available():
Expand Down
60 changes: 60 additions & 0 deletions src/transformers/models/imagegpt/configuration_imagegpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,17 @@
# limitations under the License.
""" OpenAI ImageGPT configuration"""

from collections import OrderedDict
from typing import TYPE_CHECKING, Any, Mapping, Optional

from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig
from ...utils import logging


if TYPE_CHECKING:
from ... import FeatureExtractionMixin, TensorType

logger = logging.get_logger(__name__)

IMAGEGPT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
Expand Down Expand Up @@ -140,3 +147,56 @@ def __init__(
self.tie_word_embeddings = tie_word_embeddings

super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)


class ImageGPTOnnxConfig(OnnxConfig):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict(
[
("input_ids", {0: "batch", 1: "sequence"}),
]
)

def generate_dummy_inputs(
self,
preprocessor: "FeatureExtractionMixin",
batch_size: int = 1,
seq_length: int = -1,
is_pair: bool = False,
framework: Optional["TensorType"] = None,
num_channels: int = 3,
image_width: int = 32,
image_height: int = 32,
) -> Mapping[str, Any]:
"""
Generate inputs to provide to the ONNX exporter for the specific framework
Args:
preprocessor ([`PreTrainedTokenizerBase`] or [`FeatureExtractionMixin`]):
The preprocessor associated with this model configuration.
batch_size (`int`, *optional*, defaults to -1):
The batch size to export the model for (-1 means dynamic axis).
num_choices (`int`, *optional*, defaults to -1):
The number of candidate answers provided for multiple choice task (-1 means dynamic axis).
seq_length (`int`, *optional*, defaults to -1):
The sequence length to export the model for (-1 means dynamic axis).
is_pair (`bool`, *optional*, defaults to `False`):
Indicate if the input is a pair (sentence 1, sentence 2)
framework (`TensorType`, *optional*, defaults to `None`):
The framework (PyTorch or TensorFlow) that the tokenizer will generate tensors for.
num_channels (`int`, *optional*, defaults to 3):
The number of channels of the generated images.
image_width (`int`, *optional*, defaults to 40):
The width of the generated images.
image_height (`int`, *optional*, defaults to 40):
The height of the generated images.
Returns:
Mapping[str, Tensor] holding the kwargs to provide to the model's forward function
"""

input_image = self._generate_dummy_images(batch_size, num_channels, image_height, image_width)
inputs = dict(preprocessor(input_image, framework))

return inputs
3 changes: 3 additions & 0 deletions src/transformers/onnx/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,9 @@ class FeaturesManager:
"question-answering",
onnx_config_cls="models.ibert.IBertOnnxConfig",
),
"imagegpt": supported_features_mapping(
"default", "image-classification", onnx_config_cls="models.imagegpt.ImageGPTOnnxConfig"
),
"layoutlm": supported_features_mapping(
"default",
"masked-lm",
Expand Down
1 change: 1 addition & 0 deletions tests/onnx/test_onnx_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ def test_values_override(self):
("detr", "facebook/detr-resnet-50"),
("distilbert", "distilbert-base-cased"),
("electra", "google/electra-base-generator"),
("imagegpt", "openai/imagegpt-small"),
("resnet", "microsoft/resnet-50"),
("roberta", "roberta-base"),
("roformer", "junnyu/roformer_chinese_base"),
Expand Down

0 comments on commit 0d4c45c

Please sign in to comment.