Skip to content

Commit

Permalink
add aotautograd
Browse files Browse the repository at this point in the history
  • Loading branch information
youkaichao committed Nov 27, 2023
1 parent ff9ff9d commit 6b74f53
Showing 1 changed file with 26 additions and 1 deletion.
27 changes: 26 additions & 1 deletion docs/walk_through.rst
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ And we can check the correctness of two implementations against native PyTorch i
.. code-block:: python
input = torch.randn((5, 5, 5), requires_grad=True)
grad_output = torch.randn((5, 5, 5), requires_grad=True)
grad_output = torch.randn((5, 5, 5))
output1 = torch.cos(torch.cos(input))
(output1 * grad_output).sum().backward()
Expand Down Expand Up @@ -373,6 +373,31 @@ By varying the amount of ``saved_tensors``, we can save less tensors for backwar

That is basically how AOT Autograd works!

Note: if you are curious about how to get the joint graph of a function, here is the code:

.. code-block:: python
def run_autograd_ahead_of_time(function, inputs, grad_outputs):
def forward_and_backward(inputs, grad_outputs):
outputs = function(*inputs)
grad_inputs = torch.autograd.grad(outputs, inputs, grad_outputs)
return grad_inputs
from torch.fx.experimental.proxy_tensor import make_fx
wrapped_function = make_fx(forward_and_backward, tracing_mode="fake")
joint_graph = wrapped_function(inputs, grad_outputs)
print(joint_graph._graph.python_code(root_module="self", verbose=True).src)
def f(x):
x = torch.cos(x)
x = torch.cos(x)
return x
input = torch.randn((5, 5, 5), requires_grad=True)
grad_output = torch.randn((5, 5, 5))
run_autograd_ahead_of_time(f, [input], [grad_output])
This function will create some fake tensors from real inputs, and just use the metadata (shapes, devices, dtypes) to do the computation. Therefore, the component AOTAutograd is run ahead-of-time. That's why it gets the name: AOTAutograd is to run autograd engine ahead-of-time.

Backend: compile and optimize computation graph
--------------------------------------------------

Expand Down

0 comments on commit 6b74f53

Please sign in to comment.