Skip to content

Commit

Permalink
Merge branch 'main' into trasformers-446
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix committed Oct 29, 2024
2 parents 0cec49b + 4a39ae0 commit f2f8f41
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 9 deletions.
68 changes: 61 additions & 7 deletions docs/source/onnxruntime/usage_guides/models.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@ to run accelerated inference without rewriting your APIs.

### Transformers models

Once your model was [exported to the ONNX format](https://huggingface.co/docs/optimum/exporters/onnx/usage_guides/export_a_model), you can load it by replacing the `AutoModelForXxx` class with the corresponding `ORTModelForXxx`.
Once your model was [exported to the ONNX format](https://huggingface.co/docs/optimum/exporters/onnx/usage_guides/export_a_model), you can load it by replacing `AutoModelForXxx` with the corresponding `ORTModelForXxx` class.

```diff
from transformers import AutoTokenizer, pipeline
- from transformers import AutoModelForQuestionAnswering
+ from optimum.onnxruntime import ORTModelForQuestionAnswering
- from transformers import AutoModelForCausalLM
+ from optimum.onnxruntime import ORTModelForCausalLM

- model = AutoModelForQuestionAnswering.from_pretrained("meta-llama/Llama-3.2-1B) # PyTorch checkpoint
+ model = ORTModelForQuestionAnswering.from_pretrained("onnx-community/Llama-3.2-1B", subfolder="onnx") # ONNX checkpoint
- model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B) # PyTorch checkpoint
+ model = ORTModelForCausalLM.from_pretrained("onnx-community/Llama-3.2-1B", subfolder="onnx") # ONNX checkpoint
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")

pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
Expand All @@ -29,7 +29,7 @@ More information for all the supported `ORTModelForXxx` in our [documentation](h

### Diffusers models

Once your model was [exported to the ONNX format](https://huggingface.co/docs/optimum/exporters/onnx/usage_guides/export_a_model), you can load it by replacing the `DiffusionPipeline` class with the corresponding `ORTDiffusionPipeline`.
Once your model was [exported to the ONNX format](https://huggingface.co/docs/optimum/exporters/onnx/usage_guides/export_a_model), you can load it by replacing `DiffusionPipeline` with the corresponding `ORTDiffusionPipeline` class.


```diff
Expand All @@ -43,6 +43,60 @@ Once your model was [exported to the ONNX format](https://huggingface.co/docs/op
image = pipeline(prompt).images[0]
```


### Sentence Transformers models

Once your model was [exported to the ONNX format](https://huggingface.co/docs/optimum/exporters/onnx/usage_guides/export_a_model), you can load it by replacing `AutoModel` with the corresponding `ORTModelForFeatureExtraction` class.

```diff
from transformers import AutoTokenizer
- from transformers import AutoModel
+ from optimum.onnxruntime import ORTModelForFeatureExtraction

tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
- model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
+ model = ORTModelForFeatureExtraction.from_pretrained("optimum/all-MiniLM-L6-v2")
inputs = tokenizer("This is an example sentence", return_tensors="pt")
outputs = model(**inputs)
```

You can also load your ONNX model directly using the [`sentence_transformers.SentenceTransformer`](https://sbert.net/docs/sentence_transformer/usage/efficiency.html#onnx) class, just make sure to have `sentence-transformers>=3.2` installed. If the model wasn't already converted to ONNX, it will be converted automatically on-the-fly.

```diff
from sentence_transformers import SentenceTransformer

model_id = "sentence-transformers/all-MiniLM-L6-v2"
- model = SentenceTransformer(model_id)
+ model = SentenceTransformer(model_id, backend="onnx")

sentences = ["This is an example sentence", "Each sentence is converted"]
embeddings = model.encode(sentences)
```


### Timm models

Once your model was [exported to the ONNX format](https://huggingface.co/docs/optimum/exporters/onnx/usage_guides/export_a_model), you can load it by replacing the `create_model` with the corresponding `ORTModelForImageClassification` class.


```diff
import requests
from PIL import Image
- from timm import create_model
from timm.data import resolve_data_config, create_transform
+ from optimum.onnxruntime import ORTModelForImageClassification

- model = create_model("timm/mobilenetv3_large_100.ra_in1k", pretrained=True)
+ model = ORTModelForImageClassification.from_pretrained("optimum/mobilenetv3_large_100.ra_in1k")
transform = create_transform(**resolve_data_config(model.config.pretrained_cfg, model=model))
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png"
image = Image.open(requests.get(url, stream=True).raw)
inputs = transform(image).unsqueeze(0)
outputs = model(inputs)
```



## Converting your model to ONNX on-the-fly

In case your model wasn't already [converted to ONNX](https://huggingface.co/docs/optimum/exporters/onnx/usage_guides/export_a_model), [`~optimum.onnxruntime.ORTModel`] includes a method to convert your model to ONNX on-the-fly.
Expand Down Expand Up @@ -74,4 +128,4 @@ You can also call `push_to_hub` directly on your model to upload it to the [Hub]

# Push the onnx model to HF Hub
>>> model.push_to_hub(output_dir, repository_id="my-onnx-repo") # doctest: +SKIP
```
```
19 changes: 18 additions & 1 deletion optimum/onnxruntime/modeling_ort.py
Original file line number Diff line number Diff line change
Expand Up @@ -931,7 +931,6 @@ def _prepare_onnx_inputs(
self, use_torch: bool, **inputs: Union[torch.Tensor, np.ndarray]
) -> Dict[str, np.ndarray]:
onnx_inputs = {}

# converts pytorch inputs into numpy inputs for onnx
for input_name in self.input_names.keys():
onnx_inputs[input_name] = inputs.pop(input_name)
Expand Down Expand Up @@ -1086,6 +1085,9 @@ def forward(
use_torch = isinstance(input_ids, torch.Tensor)
self.raise_on_numpy_input_io_binding(use_torch)

if token_type_ids is None and "token_type_ids" in self.input_names:
token_type_ids = torch.zeros_like(input_ids) if use_torch else np.zeros_like(input_ids)

if self.device.type == "cuda" and self.use_io_binding:
io_binding, output_shapes, output_buffers = self.prepare_io_binding(
input_ids,
Expand Down Expand Up @@ -1241,6 +1243,9 @@ def forward(
use_torch = isinstance(input_ids, torch.Tensor)
self.raise_on_numpy_input_io_binding(use_torch)

if token_type_ids is None and "token_type_ids" in self.input_names:
token_type_ids = torch.zeros_like(input_ids) if use_torch else np.zeros_like(input_ids)

if self.device.type == "cuda" and self.use_io_binding:
io_binding, output_shapes, output_buffers = self.prepare_io_binding(
input_ids,
Expand Down Expand Up @@ -1330,6 +1335,9 @@ def forward(
use_torch = isinstance(input_ids, torch.Tensor)
self.raise_on_numpy_input_io_binding(use_torch)

if token_type_ids is None and "token_type_ids" in self.input_names:
token_type_ids = torch.zeros_like(input_ids) if use_torch else np.zeros_like(input_ids)

if self.device.type == "cuda" and self.use_io_binding:
io_binding, output_shapes, output_buffers = self.prepare_io_binding(
input_ids,
Expand Down Expand Up @@ -1437,6 +1445,9 @@ def forward(
use_torch = isinstance(input_ids, torch.Tensor)
self.raise_on_numpy_input_io_binding(use_torch)

if token_type_ids is None and "token_type_ids" in self.input_names:
token_type_ids = torch.zeros_like(input_ids) if use_torch else np.zeros_like(input_ids)

if self.device.type == "cuda" and self.use_io_binding:
io_binding, output_shapes, output_buffers = self.prepare_io_binding(
input_ids,
Expand Down Expand Up @@ -1527,6 +1538,9 @@ def forward(
use_torch = isinstance(input_ids, torch.Tensor)
self.raise_on_numpy_input_io_binding(use_torch)

if token_type_ids is None and "token_type_ids" in self.input_names:
token_type_ids = torch.zeros_like(input_ids) if use_torch else np.zeros_like(input_ids)

if self.device.type == "cuda" and self.use_io_binding:
io_binding, output_shapes, output_buffers = self.prepare_io_binding(
input_ids,
Expand Down Expand Up @@ -1610,6 +1624,9 @@ def forward(
use_torch = isinstance(input_ids, torch.Tensor)
self.raise_on_numpy_input_io_binding(use_torch)

if token_type_ids is None and "token_type_ids" in self.input_names:
token_type_ids = torch.zeros_like(input_ids) if use_torch else np.zeros_like(input_ids)

if self.device.type == "cuda" and self.use_io_binding:
io_binding, output_shapes, output_buffers = self.prepare_io_binding(
input_ids,
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@
"nncf": "optimum-intel[nncf]>=1.18.0",
"neural-compressor": "optimum-intel[neural-compressor]>=1.18.0",
"ipex": "optimum-intel[ipex]>=1.18.0",
"habana": ["optimum-habana", "transformers>=4.43.0,<4.44.0"],
"habana": ["optimum-habana", "transformers>=4.45.0,<4.46.0"],
"neuron": ["optimum-neuron[neuron]>=0.0.20", "transformers>=4.36.2,<4.42.0"],
"neuronx": ["optimum-neuron[neuronx]>=0.0.20", "transformers>=4.36.2,<4.42.0"],
"graphcore": "optimum-graphcore",
Expand Down
12 changes: 12 additions & 0 deletions tests/onnxruntime/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2193,6 +2193,18 @@ def test_compare_to_io_binding(self, model_arch):

gc.collect()

def test_default_token_type_ids(self):
model_id = MODEL_NAMES["bert"]
model = ORTModelForFeatureExtraction.from_pretrained(model_id, export=True)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokens = tokenizer("this is a simple input", return_tensors="np")
self.assertTrue("token_type_ids" in model.input_names)
token_type_ids = tokens.pop("token_type_ids")
outs = model(token_type_ids=token_type_ids, **tokens)
outs_without_token_type_ids = model(**tokens)
self.assertTrue(np.allclose(outs.last_hidden_state, outs_without_token_type_ids.last_hidden_state))
gc.collect()


class ORTModelForMultipleChoiceIntegrationTest(ORTModelTestMixin):
# Multiple Choice tests are conducted on different models due to mismatch size in model's classifier
Expand Down

0 comments on commit f2f8f41

Please sign in to comment.