From 567ee57674fa0335b87b68c503626da8cf423c94 Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Thu, 3 Oct 2024 19:59:25 -0700 Subject: [PATCH] training ir torchao migration (#1006) Summary: Migrate capture_pre_autograd_graph to export_for_training. We still need to keep capture_pre_autograd_graph call because torch/ao's CI tests uses earlier version of pytorch that does not have export_for_training. See https://github.com/pytorch/ao/blob/main/.github/workflows/regression_test.yml Differential Revision: D63859678 --- test/dtypes/test_uint4.py | 16 +++++++++++----- test/integration/test_integration.py | 6 ++++-- torchao/utils.py | 2 +- 3 files changed, 16 insertions(+), 8 deletions(-) diff --git a/test/dtypes/test_uint4.py b/test/dtypes/test_uint4.py index aa9415e51..98fb523d3 100644 --- a/test/dtypes/test_uint4.py +++ b/test/dtypes/test_uint4.py @@ -7,7 +7,6 @@ from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer -from torch._export import capture_pre_autograd_graph from torch.testing._internal.common_quantization import ( NodeSpec as ns, QuantizationTestCase, @@ -25,6 +24,7 @@ QuantizationAnnotation, ) import copy +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 def _apply_weight_only_uint4_quant(model): @@ -203,10 +203,16 @@ def forward(self, x): # program capture m = copy.deepcopy(m_eager) - m = capture_pre_autograd_graph( - m, - example_inputs, - ) + if TORCH_VERSION_AT_LEAST_2_5: + m = torch.export.texport_for_training( + m, + example_inputs, + ).module() + else: + m = torch._export.capture_pre_autograd_graph( + m, + example_inputs, + ).module() m = prepare_pt2e(m, quantizer) # Calibrate diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 5f81858ba..be8f2f954 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -1484,11 +1484,13 @@ def forward(self, x): # make sure it compiles example_inputs = (x,) - from torch._export import capture_pre_autograd_graph # TODO: export changes numerics right now, this is because of functionalization according to Zhengxu # we can re-enable this after non-functional IR is enabled in export # model = torch.export.export(model, example_inputs).module() - model = capture_pre_autograd_graph(model, example_inputs) + if TORCH_VERSION_AT_LEAST_2_5: + model = torch.export.export_for_training(model, example_inputs).module() + else: + model = torch._export.capture_pre_autograd_graph(model, example_inputs) after_export = model(x) self.assertTrue(torch.equal(after_export, ref)) if api is _int8da_int8w_api: diff --git a/torchao/utils.py b/torchao/utils.py index 4b5409e65..a0302cabe 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -180,7 +180,7 @@ def _the_op_that_needs_to_be_preserved(...) # after this, `_the_op_that_needs_to_be_preserved` will be preserved as # torch.ops.my_namespace.the_op_that_needs_to_be_preserved operator after - # torch.export.export / torch._export.capture_pre_autograd_graph + # torch.export.export / torch._export.export_for_training """ from torch._inductor.decomposition import register_decomposition