Skip to content

Torch-TensorRT v1.4.0

Compare
Choose a tag to compare
@narendasan narendasan released this 03 Jun 04:05
· 2 commits to release/1.4 since this release

PyTorch 2.0, CUDA 11.8, TensorRT 8.6, Support for the new torch.compile API, compatibility mode for FX frontend

Torch-TensorRT 1.4.0 targets PyTorch 2.0, CUDA 11.8, TensorRT 8.5. This release introduces a number of beta features to set the stage for working with PyTorch and TensorRT in the 2.0 ecosystem. Primarily, this includes a new torch.compile backend targeting Torch-TensorRT. It also adds a compatibility layer that allows users of the TorchScript frontend for Torch-TensorRT to seamlessly try FX and Dynamo.

torch.compile` Backend for Torch-TensorRT

One of the most prominent new features in PyTorch 2.0 is the torch.compile workflow, which enables users to accelerate code easily by specifying a backend of their choice. Torch-TensorRT 1.4.0 introduces a new backend for torch.compile as a beta feature, including a convenience frontend to perform accelerated inference. This frontend can be accessed in one of two ways:

import torch_tensorrt
torch_tensorrt.dynamo.compile(model, inputs, ...)
​
##### OR #####torch_tensorrt.compile(model, ir="dynamo_compile", inputs=inputs, ...)

For more examples, see the provided sample scripts, which can be found here

This compilation method has a couple key considerations:

  1. It can handle models with data-dependent control flow
  2. It automatically falls back to Torch if the TRT Engine Build fails for any reason
  3. It uses the Torch FX aten library of converters to accelerate models
  4. Recompilation can be caused by changing the batch size of the input, or providing an input which enters a new control flow branch
  5. Compiled models cannot be saved across Python sessions (yet)

    The feature is currently in beta, and we expect updates, changes, and improvements to the above in the future.

fx_ts_compat Frontend

As the ecosystem transitions from TorchScript to Dynamo, users of Torch-TensorRT may want start to experiment with this stack. As such we have introduced a new frontend for Torch-TensorRT which exposes the same APIs as the TorchScript frontend but will use the FX/Dynamo compiler stack. You can try this frontend by using the ir="fx_ts_compat" setting

torch_tensorrt.compile(..., ir="fx_ts_compat")

What's Changed

New Contributors

Full Changelog: v1.3.0...v1.4.0