Graph tracer for PyTorch.
It is always hard to get the connections between nodes within a computation graph without calling torch.jit.trace
. However, if we call the JIT ops, some of nodes in the graph may be removed, which makes it difficult to restore the graph in Python. By calling our tracer, you can get the connections between the nodes while the model is still runnable with Python. What's more, it can be used as a code generator that organizes the codes of your model into a single script.
import torch
import torchvision
from tinynn.graph.tracer import model_tracer, trace
with model_tracer():
# Prapare the model
model = torchvision.models.alexnet()
model.eval()
# Provide a viable input for the model
dummy_input = torch.rand((1, 3, 224, 224))
# After tracing the model, we will get a TraceGraph object
graph = trace(model, dummy_input)
# We can use it to generate the code for the original model
graph.generate_code(f'my_alexnet.py', f'my_alexnet.pth', 'Alexnet')
According to the official PyTorch tutorial, you'll have to do the following things to get a quantized model running on devices.
- Inserting the QuantStub / DeQuantStub nodes after / before the input / output nodes.
- Rewrite the unsupported OPs (e.g. torch.add(x, y) -> FloatFunctional.add(x, y))
- Fuse the modules (e.g. Conv + BatchNorm + ReLU -> ConvBnReLU)
- Convert the model to the quantized version
- JIT Tracing and seralize it to a TorchScript model file
- Running the inference through PyTorch Mobile
Step 1-3 require a lot of manual work, which are pretty cumbersome and error-prone. Therefore, we have come up with the idea to write an automatic quantization preparation tool.
- The model to be traced could either be instantiated inside or outside the with-block.
- You may trace multiple models in one with-block.
- It is supported to have runtime-defined constants. If the size is too large, these constants will be transformed to parameters.
- Like
torch.jit.trace
, if you trace such models with control-flow ops, you may silently get incorrect results on subsequent invocations of the model. - Only the flow of PyTorch tensors is tracked. Other variables (e.g. numpy arrays, numbers and strings) won't be tracked and will be treated as constants.
- Only those parts of the tensor property are tracked. For example, if you call
.data
or.shape
of a tensor, the returned values will be joined in the computation graph. Below are the properties of a tensor that will be tracked.
.data
.shape
.device
.dtype
- It is not supported to call numel on a torch.Size object generated by calling
.size()
or.shape
on a tensor. Please usetorch.prod
instead.