Skip to content

Commit

Permalink
training ir torchao migration (pytorch#1006)
Browse files Browse the repository at this point in the history
Summary:

Migrate capture_pre_autograd_graph to export_for_training.

We still need to keep capture_pre_autograd_graph call because torch/ao's CI tests uses earlier version of pytorch that does not have export_for_training.

See https://github.com/pytorch/ao/blob/main/.github/workflows/regression_test.yml

Differential Revision: D63859678
  • Loading branch information
yushangdi authored and facebook-github-bot committed Oct 4, 2024
1 parent 9ce7ebb commit ed15f3a
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 8 deletions.
27 changes: 22 additions & 5 deletions test/dtypes/test_uint4.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e
from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer

from torch._export import capture_pre_autograd_graph
from torch.testing._internal.common_quantization import (
NodeSpec as ns,
QuantizationTestCase,
Expand All @@ -26,6 +25,18 @@
)
import copy

from packaging import version
torch_version = torch.__version__

has_export_for_training = False

if version.parse(torch_version) > version.parse('2.5.0rc1'):
from torch.export import export_for_training
has_export_for_training = True
else:
# capture_pre_autograd_graph is deprecated, it's
# left here to work with previous versions of pytorch
from torch._export import capture_pre_autograd_graph

def _apply_weight_only_uint4_quant(model):
def fn(mod):
Expand Down Expand Up @@ -203,10 +214,16 @@ def forward(self, x):

# program capture
m = copy.deepcopy(m_eager)
m = capture_pre_autograd_graph(
m,
example_inputs,
)
if has_export_for_training:
m = export_for_training(
m,
example_inputs,
).module()
else:
m = capture_pre_autograd_graph(
m,
example_inputs,
).module()

m = prepare_pt2e(m, quantizer)
# Calibrate
Expand Down
18 changes: 16 additions & 2 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,18 @@
benchmark_model
)

has_export_for_training = False

from packaging import version
torch_version = torch.__version__
if version.parse(torch_version) > version.parse('2.5.0rc1'):
from torch.export import export_for_training
has_export_for_training = True
else:
# capture_pre_autograd_graph is deprecated, it's
# left here to work with previous versions of pytorch
from torch._export import capture_pre_autograd_graph

logger = logging.getLogger("INFO")

torch.manual_seed(0)
Expand Down Expand Up @@ -1484,11 +1496,13 @@ def forward(self, x):

# make sure it compiles
example_inputs = (x,)
from torch._export import capture_pre_autograd_graph
# TODO: export changes numerics right now, this is because of functionalization according to Zhengxu
# we can re-enable this after non-functional IR is enabled in export
# model = torch.export.export(model, example_inputs).module()
model = capture_pre_autograd_graph(model, example_inputs)
if has_export_for_training:
model = export_for_training(model, example_inputs).module()
else:
model = capture_pre_autograd_graph(model, example_inputs)
after_export = model(x)
self.assertTrue(torch.equal(after_export, ref))
if api is _int8da_int8w_api:
Expand Down
2 changes: 1 addition & 1 deletion torchao/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def _the_op_that_needs_to_be_preserved(...)
# after this, `_the_op_that_needs_to_be_preserved` will be preserved as
# torch.ops.my_namespace.the_op_that_needs_to_be_preserved operator after
# torch.export.export / torch._export.capture_pre_autograd_graph
# torch.export.export / torch._export.export_for_training
"""
from torch._inductor.decomposition import register_decomposition
Expand Down

0 comments on commit ed15f3a

Please sign in to comment.