From 6b74f536a787d0759e83b483c13a66e7146e6b2a Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 27 Nov 2023 21:09:31 +0800 Subject: [PATCH] add aotautograd --- docs/walk_through.rst | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/docs/walk_through.rst b/docs/walk_through.rst index 15aee474..d4920b90 100644 --- a/docs/walk_through.rst +++ b/docs/walk_through.rst @@ -309,7 +309,7 @@ And we can check the correctness of two implementations against native PyTorch i .. code-block:: python input = torch.randn((5, 5, 5), requires_grad=True) - grad_output = torch.randn((5, 5, 5), requires_grad=True) + grad_output = torch.randn((5, 5, 5)) output1 = torch.cos(torch.cos(input)) (output1 * grad_output).sum().backward() @@ -373,6 +373,31 @@ By varying the amount of ``saved_tensors``, we can save less tensors for backwar That is basically how AOT Autograd works! +Note: if you are curious about how to get the joint graph of a function, here is the code: + +.. code-block:: python + + def run_autograd_ahead_of_time(function, inputs, grad_outputs): + def forward_and_backward(inputs, grad_outputs): + outputs = function(*inputs) + grad_inputs = torch.autograd.grad(outputs, inputs, grad_outputs) + return grad_inputs + from torch.fx.experimental.proxy_tensor import make_fx + wrapped_function = make_fx(forward_and_backward, tracing_mode="fake") + joint_graph = wrapped_function(inputs, grad_outputs) + print(joint_graph._graph.python_code(root_module="self", verbose=True).src) + + def f(x): + x = torch.cos(x) + x = torch.cos(x) + return x + + input = torch.randn((5, 5, 5), requires_grad=True) + grad_output = torch.randn((5, 5, 5)) + run_autograd_ahead_of_time(f, [input], [grad_output]) + +This function will create some fake tensors from real inputs, and just use the metadata (shapes, devices, dtypes) to do the computation. Therefore, the component AOTAutograd is run ahead-of-time. That's why it gets the name: AOTAutograd is to run autograd engine ahead-of-time. + Backend: compile and optimize computation graph --------------------------------------------------