- API Walkthrough
- Debugging & Reporting Errors
- Visualization
- Use Torch XLA Conversion Backend (Legacy)
This section walks through the end-to-end process of preparing a PyTorch model for on-device deployment.
We'll use the resnet18
model from the PyTorch torchvision package as an example. This model can be executed in PyTorch as below:
import torch
import torchvision
resnet18 = torchvision.models.resnet18(torchvision.models.ResNet18_Weights.IMAGENET1K_V1).eval()
sample_inputs = (torch.randn(1, 3, 224, 224),)
torch_output = resnet18(*sample_inputs)
ai_edge_torch.convert()
converts a PyTorch model to an on-device (Edge) model.
The conversion process also requires sample inputs for tracing and shape
inference, passed in as a tuple. As an example, if the PyTorch model receives 3
tensors as positional arguments, the convert
function receives 1 tuple with 3
entries.
-
Note 1: The source PyTorch model needs to be compliant with
torch.export
introduced in PyTorch 2.1.0 . -
Note 2:
convert
expects atorch.nn.Module
with aforward
function that receives tensors as arguments and returns tensors as outputs. If your model has a different interface, you need to provide a model wrapper, as demonstrated in the Providing a Wrapper section. -
Note 3:
convert
does not support passing keyword arguments to the model.
import ai_edge_torch
# Note that we are setting the model to evaluation mode prior to conversion.
edge_model = ai_edge_torch.convert(resnet18.eval(), sample_inputs)
Once the model is converted, it is ready for inference with the TFLite runtime. Prior to deployment on-device, the outputs from PyTorch and the edge model can be compared in Python as a smoke check for the converted model.
import numpy as np
edge_output = edge_model(*sample_inputs)
assert np.allclose(torch_output.detach().numpy(), edge_output, atol=1e-5)
The on-device prepared model provides an export
function which can be used to
serialize the model as a TFLite
Flatbuffers file (.tflite
) which can be used
for deployment.
edge_model.export('resnet.tflite')
A model serialized via export
or any TFLite Flatbuffers file can be imported
into ai_edge_torch
as follows:
imported_edge_model = ai_edge_torch.load('resnet.tflite')
# Once imported, you can run the model with an input.
imported_edge_model(*sample_inputs)
Sometimes, it is desirable to have multiple PyTorch modules converted into one edge model. This is often the case when a model comprises multiple components that share weights.
Signatures are a TF Lite feature to address this.
The API for multi-signature conversion with ai_edge_torch
is as follows:
inputs_1 = (...,)
inputs_2 = (...,)
edge_model = ai_edge_torch
.signature("input1", model, inputs_1)
.signature("input2", model, inputs_2)
.convert()
# Run each signature separately by providing the signature_name as a keyword argument.
edge_model(*inputs_1, signature_name="input1")
edge_model(*inputs_2, signature_name="input2")
Following is the code snippet to quantize a model with PT2E
quantization
using the ai_edge_torch
backend.
from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e
from torch._export import capture_pre_autograd_graph
from ai_edge_torch.quantize.pt2e_quantizer import get_symmetric_quantization_config
from ai_edge_torch.quantize.pt2e_quantizer import PT2EQuantizer
from ai_edge_torch.quantize.quant_config import QuantConfig
pt2e_quantizer = PT2EQuantizer().set_global(
get_symmetric_quantization_config(is_per_channel=True, is_dynamic=True)
)
pt2e_torch_model = capture_pre_autograd_graph(torch_model, sample_args)
pt2e_torch_model = prepare_pt2e(pt2e_torch_model, pt2e_quantizer)
# Run the prepared model with sample input data to ensure that internal observers are populated with correct values
pt2e_torch_model(*sample_args)
# Convert the prepared model to a quantized model
pt2e_torch_model = convert_pt2e(pt2e_torch_model, fold_quantize=False)
# Convert to an ai_edge_torch model
pt2e_drq_model = ai_edge_torch.convert(pt2e_torch_model, sample_args, quant_config=QuantConfig(pt2e_quantizer=pt2e_quantizer))
Following is the code snippet to quantize a model with TensorFlow Lite Quantization.
import tensorflow as tf
# Pass TfLite Converter quantization flags to _ai_edge_converter_flags parameter.
tfl_converter_flags = {'optimizations': [tf.lite.Optimize.DEFAULT]}
tfl_drq_model = ai_edge_torch.convert(
torch_model, sample_args, _ai_edge_converter_flags=tfl_converter_flags
)
ai_edge_torch.convert
expects an nn.Module
with a forward
function that
receives tensors as positional arguments and returns a tensor, or multiple
tensors in a Python list or tuple. If you have a model with a different
interface, you will need to provide a wrapper.
As an example, let's say MyModel
receives only kwargs
and returns a custom
object. Here is how the mentioned wrapper would look:
class MyModelWrapper(torch.nn.Module):
def __init__(self):
super().__init__()
self.m = MyModel()
def forward(self, tensor1, tensor2):
custom_output_object = self.m(arg1=tensor1, arg2=tensor2)
return custom_output_object.out_tensor1, custom_output_object.out_tensor2
The instance in evaluation mode, MyModelWrapper().eval()
, would be the right argument to pass to ai_edge_torch.convert
.
ai_edge_torch.to_channel_last_io
is a helper function facilitates the conversion of
PyTorch models (typically using NCHW channel first ordering) to TFLite models with
channel last (NHWC) input/output layouts. It achieves this by wrapping the original model
with layout transformation transposes, ensuring compatibility with target
deployment environments. This is particularly useful for deploying models,
such as image classifiers, to mobile environments that expect NHWC (channel last)
image data.
Here is an example of converting ResNet18 with NHWC image input:
import torch
import torchvision
import ai_edge_torch
# Use resnet18 with pre-trained weights.
resnet18 = torchvision.models.resnet18(torchvision.models.ResNet18_Weights.IMAGENET1K_V1)
# Transform the first input to NHWC.
nhwc_resnet18 = ai_edge_torch.to_channel_last_io(resnet18, args=[0])
# Convert the transformed model with NHWC input(s).
edge_model = ai_edge_torch.convert(nhwc_resnet18, (torch.randn(1, 224, 224, 3),))
edge_model.export("resnet18.tflite")
More examples of usage can be found here.
Failure of ai_edge_torch.convert(...)
can happen in a multiple different steps
with verbose and potentially hard to understand error messages.
The two high-level steps that users should be aware of are
-
torch.export to convert PyTorch model to an ExportedProgram
-
Lowering from ExportedProgram to an edge_model.
In case of a convert
failure, please use our find_culprits
tool to help
narrow down the issue and generate a minimal PyTorch program that reproduces the
failure (in some cases).
find_culprits
can be given the same parameters as convert
:
from ai_edge_torch.debug import find_culprits
culprits = find_culprits(model.eval(), args)
culprit = next(culprits)
culprit.print_code()
In this case print_code()
will provide all the logs from torch.export.export
followed by an error message confirming the error type.
ValueError: Your model is not exportable by torch.export.export. Please modify your model to be torch-exportable first.
The fix for these errors involves changing the model source to be compliant
with torch.export
and is not a bug in ai_edge_torch.convert
. Please refer
to PyTorch torch.export doc
for more information.
For errors after we have an ExportedProgram, find_culprits
can provide
a minimal reproduction code sample that can be attached to a GitHub issue.
Below is a code snippet that causes such a failure.
import torch
import torchaudio
import ai_edge_torch
model = torchaudio.models.ConvTasNet()
args = (torch.rand((1, 1, 256)),)
ai_edge_torch.convert(model.eval(), args)
To debug the error, call ai_edge_torch.debug.find_culprits
with the same arguments
provided to ai_edge_torch.convert(...)
to get a generator of culprits.
from ai_edge_torch.debug import find_culprits
culprits = find_culprits(model, args)
Next, print a Python code snippet that reproduces the error with.
culprit = next(culprits)
culprit.print_code()
Which prints the following to the console.
import torch
from torch import device
import ai_edge_torch
class CulpritGraphModule(torch.nn.Module):
def forward(self, arg0_1: "f32[512, 1, 16]", arg1_1: "f32[2, 512, 33]"):
# File: /opt/venv/lib/python3.10/site-packages/torchaudio/models/conv_tasnet.py:300 in forward, code: decoded = self.decoder(masked) # B*S, 1, L'
convolution: "f32[2, 1, 256]" = torch.ops.aten.convolution.default(arg1_1, arg0_1, None, [8], [8], [1], True, [0], 1); arg1_1 = arg0_1 = None
return (convolution,)
_args = (
torch.randn((512, 1, 16,), dtype=torch.float32),
torch.randn((2, 512, 33,), dtype=torch.float32),
)
_edge_model = ai_edge_torch.convert(CulpritGraphModule().eval(), _args) # conversion should fail
You can attach the code snippet to a GitHub issue, after:
- Confirming that the generated code snippet fails conversion with the same error as the original program.
- Removing any sensitive information before reporting the issue with the code snippets to us.
- Note that the culprit finder tool overwrites weights and inputs with random values in the generated code.
You can also find and print all culprits at once:
for culprit in find_culprits(model, args):
culprit.print_code()
Once the exported TFLite model is obtained, you can visualize the model structure with Model Explorer.
pip install ai-edge-model-explorer
model-explorer 'resnet.tflite'
AI Edge Torch has been switched to a modern conversion backend for better on-device compatibility and performance. However, if you encounter compatibility issues, you can optionally switch to the legacy Torch XLA backend:
# Install ai-edge-torch with torch-xla dependency
pip install ai-edge-torch-nightly[torch-xla]
# Enable torch-xla as the AI Edge Torch backend
export USE_TORCH_XLA=1
Torch XLA builds a shared library, _XLAC.so
that needs to link to the version of Python
it was built with (currently 3.10 or 3.11). In order to ensure that import _XLAC
can succeed,
update the LD_LIBRARY_PATH to the lib directory of your Python environment:
export LD_LIBRARY_PATH=<path to Python installation>/lib:$LD_LIBRARY_PATH