Skip to content

Commit

Permalink
updated docs
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Sep 19, 2023
1 parent 89d08c4 commit 625162d
Showing 1 changed file with 16 additions and 19 deletions.
35 changes: 16 additions & 19 deletions docs/source/onnxruntime/usage_guides/gpu.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -291,15 +291,17 @@ We recommend setting these two provider options when using the TensorRT executio
... )
```

TensorRT builds its engine depending on specified input shapes. Unfortunately, in the [current ONNX Runtime implementation](https://github.com/microsoft/onnxruntime/blob/613920d6c5f53a8e5e647c5f1dcdecb0a8beef31/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc#L1677-L1688) (references: [1](https://github.com/microsoft/onnxruntime/issues/13559), [2](https://github.com/microsoft/onnxruntime/issues/13851)), the engine is rebuilt every time an input has a shape smaller than the previously smallest encountered shape, and conversely if the input has a shape larger than the previously largest encountered shape. For example, if a model takes `(batch_size, input_ids)` as inputs, and the model takes successively the inputs:
TensorRT builds its engine depending on specified input shapes. One big issue is that building the engine can be time consuming, especially for large models. Therefore, as a workaround, one recommendation is to build the TensorRT engine with dynamic shapes. This allows to avoid rebuilding the engine for new small and large shapes, which is unwanted once the model is deployed for inference.

1. `input.shape: (4, 5) --> the engine is built (first input)`
2. `input.shape: (4, 10) --> engine rebuilt (10 larger than 5)`
3. `input.shape: (4, 7) --> no rebuild (5 <= 7 <= 10)`
4. `input.shape: (4, 12) --> engine rebuilt (10 <= 12)`
5. `input.shape: (4, 3) --> engine rebuilt (3 <= 5)`
To do so we use the provider's options `trt_profile_min_shapes`, `trt_profile_max_shapes` and `trt_profile_opt_shapes` to specify the minimum, maximum and optimal shapes for the engine. For example, for GPT2, we can use the following shapes:

One big issue is that building the engine can be time consuming, especially for large models. Therefore, as a workaround, one recommendation is to **first build the TensorRT engine with an input of small shape, and then with an input of large shape to have an engine valid for all shapes inbetween**. This allows to avoid rebuilding the engine for new small and large shapes, which is unwanted once the model is deployed for inference.
```python
provider_options = {
"trt_profile_min_shapes": "input_ids:1x1,attention_mask:1x1,position_ids:1x1",
"trt_profile_opt_shapes": "input_ids:1x1,attention_mask:1x1,position_ids:1x1",
"trt_profile_max_shapes": "input_ids:1x64,attention_mask:1x64,position_ids:1x64",
}
```

Passing the engine cache path in the provider options, the engine can therefore be built once for all and used fully for inference thereafter.

Expand All @@ -314,25 +316,20 @@ For example, for text generation, the engine can be built with:
>>> provider_options = {
... "trt_engine_cache_enable": True,
... "trt_engine_cache_path": "tmp/trt_cache_gpt2_example"
... "trt_profile_min_shapes": "input_ids:1x1,attention_mask:1x1,position_ids:1x1",
... "trt_profile_opt_shapes": "input_ids:1x1,attention_mask:1x1,position_ids:1x1",
... "trt_profile_max_shapes": "input_ids:1x64,attention_mask:1x64,position_ids:1x64",
... }

>>> ort_model = ORTModelForCausalLM.from_pretrained(
... "optimum/gpt2",
... "gpt2",
... export=True,
... use_cache=False,
... use_merged=False,
... use_io_binding=False,
... provider="TensorrtExecutionProvider",
... provider_options=provider_options,
... )
>>> tokenizer = AutoTokenizer.from_pretrained("optimum/gpt2")

>>> print("Building engine for a short sequence...") # doctest: +IGNORE_RESULT
>>> text = ["short"]
>>> encoded_input = tokenizer(text, return_tensors="pt").to("cuda")
>>> output = ort_model(**encoded_input)

>>> print("Building engine for a long sequence...") # doctest: +IGNORE_RESULT
>>> text = [" a very long input just for demo purpose, this is very long" * 10]
>>> encoded_input = tokenizer(text, return_tensors="pt").to("cuda")
>>> output = ort_model(**encoded_input)
```

The engine is stored as:
Expand Down

0 comments on commit 625162d

Please sign in to comment.