Skip to content

Commit

Permalink
add tutorial
Browse files Browse the repository at this point in the history
  • Loading branch information
youkaichao committed Sep 6, 2024
1 parent 53969aa commit 596f184
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,4 @@ If you'd like to contribute (which we highly appreciate), please read the `devel
advanced
faq
dev_doc
opt_tutorial
120 changes: 120 additions & 0 deletions docs/opt_tutorial.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
Optimization Tutorial
===========================================

This tutorial will guide you through the process of optimizing code with ``torch.compile``, using the ``depyf`` library.

The example code we want to optimize is as follows:

.. code-block:: python
import torch
class F(torch.nn.Module):
def forward(self, x, i):
return x + i
@torch.compile
def g(x):
x = F()(x, 5)
return x
for i in range(1000):
x = torch.tensor([i]) # create input tensor
y = g(x)
# do something with y
For illustration purposes, we make the computation in the function ``g`` trivial. In practice, the function ``g`` can be a complex function that does some real computation.

To optimize the code, we need to first get an understanding of what's going on in the code. We can use the ``depyf`` library to decompile the bytecode of the function ``g``, with just two more lines:

.. code-block:: python
import torch
class F(torch.nn.Module):
def forward(self, x, i):
return x + i
@torch.compile
def g(x):
x = F()(x, 5)
return x
import depyf
with depyf.prepare_debug("dump_src_dir/"):
for i in range(1000):
x = torch.tensor([i]) # create input tensor
y = g(x)
# do something with y
After running the code above, you will find a new directory ``dump_src_dir/`` in the current directory. The directory contains the decompiled source code of the function ``g``. Inside the ``full_code_for_g_0.py`` file, you can find:

.. code-block:: python
def __guard_0_for_g(L, G, **___kwargs_ignored):
__guard_hit = True
__guard_hit = __guard_hit and utils_device.CURRENT_DEVICE == None # _dynamo/output_graph.py:460 in init_ambient_guards
__guard_hit = __guard_hit and ___check_global_state()
__guard_hit = __guard_hit and check_tensor(L['x'], Tensor, DispatchKeySet(CPU, BackendSelect, ADInplaceOrView, AutogradCPU), torch.float32, device=None, requires_grad=False, size=[1], stride=[1])
__guard_hit = __guard_hit and hasattr(L['x'], '_dynamo_dynamic_indices') == False
__guard_hit = __guard_hit and ___check_obj_id(G['F'], 4576341520)
__guard_hit = __guard_hit and ___check_obj_id(G['__import_torch_dot_nn_dot_modules_dot_module'], 4413465488)
__guard_hit = __guard_hit and ___check_obj_id(G['__import_torch_dot_nn_dot_modules_dot_module'].torch, 4309172144)
__guard_hit = __guard_hit and ___check_obj_id(G['__import_torch_dot_nn_dot_modules_dot_module'].torch._C, 4314290416)
__guard_hit = __guard_hit and ___check_obj_id(G['__import_torch_dot_nn_dot_modules_dot_module'].torch._C._get_tracing_state, 4337294032)
__guard_hit = __guard_hit and ___check_type_id(G['__import_torch_dot_nn_dot_modules_dot_module']._global_forward_hooks, 4305934016)
__guard_hit = __guard_hit and not G['__import_torch_dot_nn_dot_modules_dot_module']._global_forward_hooks
__guard_hit = __guard_hit and ___check_type_id(G['__import_torch_dot_nn_dot_modules_dot_module']._global_backward_hooks, 4305934016)
__guard_hit = __guard_hit and not G['__import_torch_dot_nn_dot_modules_dot_module']._global_backward_hooks
__guard_hit = __guard_hit and ___check_type_id(G['__import_torch_dot_nn_dot_modules_dot_module']._global_forward_pre_hooks, 4305934016)
__guard_hit = __guard_hit and not G['__import_torch_dot_nn_dot_modules_dot_module']._global_forward_pre_hooks
__guard_hit = __guard_hit and ___check_type_id(G['__import_torch_dot_nn_dot_modules_dot_module']._global_backward_pre_hooks, 4305934016)
__guard_hit = __guard_hit and not G['__import_torch_dot_nn_dot_modules_dot_module']._global_backward_pre_hooks
return __guard_hit
This is the code ``torch.compile`` generates to check the input to see if the compiled function can be called. However, we can see it is way too conservative. It is checking a lot of things that will be constants during the whole execution, e.g. ``___check_obj_id(G['F'], 4576341520)`` wants to make sure ``F`` is still a class object. Technically, we can indeed change the class object during the execution, but it is not a common practice. And these checks are executed every time we call the function ``g``, which counts as overhead.

If we just want to use ``torch.compile`` to compile the code, but skip the checks, we can use ``TorchCompileWrapperWithCustomDispacther`` from ``depyf``:

.. code-block:: python
import torch
class F(torch.nn.Module):
def forward(self, x, i):
return x + i
def g(x):
x = F()(x, 5)
return x
import depyf
from depyf.optimization import TorchCompileWrapperWithCustomDispacther
class MyMod(TorchCompileWrapperWithCustomDispacther):
def __init__(self):
compiled_callable = torch.compile(g)
super().__init__(compiled_callable)
def forward(self, x):
return g(x)
def __call__(self, x):
if len(self.compiled_codes) == 1:
with self.dispatch_to_code(0):
return self.forward(x)
else:
return self.compiled_callable(x)
mod = MyMod()
for i in range(1000):
x = torch.tensor([i]) # create input tensor
y = mod(x)
# do something with y
Under the hood, it will hijack the bytecode compiled by ``torch.compile`` and directly call the compiled function without the checks. As we can see in the ``__call__`` method, if there is already one compiled code, it will directly call the compiled code. Otherwise, it will call the ``torch.compile`` function to compile the code. This will remove the Dynamo overhead.

This technique is used in `vLLM's TPU integration <https://github.com/vllm-project/vllm/pull/7898>_` to remove the overhead of the Dynamo checks, because TPU is very fast and the overhead of the checks is significant. With this technique, it helps to improve the throughput of the TPU by 4%.

This is just one example of how to optimize code with ``torch.compile``. You can also use the decompiled source code to understand the code better and optimize it in other ways.

0 comments on commit 596f184

Please sign in to comment.