Skip to content

Commit

Permalink
Add strict_export convert param
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 684588928
  • Loading branch information
chunnienc authored and copybara-github committed Oct 10, 2024
1 parent 18d7630 commit 2ebd8c7
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 6 deletions.
30 changes: 25 additions & 5 deletions ai_edge_torch/_convert/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import logging
import os
from typing import Any, Optional
from typing import Any, Literal, Optional, Union

from ai_edge_torch import fx_pass_base
from ai_edge_torch import lowertools
Expand Down Expand Up @@ -73,6 +73,7 @@ def _warn_training_modules(signatures: list[signature.Signature]):
def convert_signatures(
signatures: list[signature.Signature],
*,
strict_export: Union[Literal["auto"], bool] = True,
quant_config: Optional[qcfg.QuantConfig] = None,
_tfl_converter_flags: Optional[dict[str, Any]],
) -> model.TfLiteModel:
Expand All @@ -81,6 +82,11 @@ def convert_signatures(
Args:
signatures: The list of 'signature.Signature' objects containing PyTorch
modules to be converted.
strict_export: Experimental `strict` arg for torch.export.export. When
enabled, the export function will trace the program through TorchDynamo
and ensure the soundness of the exported graph. When
strict_export="auto", the function will try to export module in both
modes and use the first one succeeds for downstream conversion.
quant_config: User-defined quantization method and scheme of the model.
_tfl_converter_flags: A nested dictionary allowing setting flags for the
underlying tflite converter.
Expand All @@ -93,10 +99,24 @@ def convert_signatures(

_warn_training_modules(signatures)

exported_programs: torch.export.torch.export.ExportedProgram = [
torch.export.export(
sig.module, sig.flat_args, dynamic_shapes=sig.dynamic_shapes
)
def export(*args, **kwargs):
nonlocal strict_export
if strict_export == "auto":
try:
return torch.export.export(*args, **kwargs, strict=True)
except Exception:
logging.warning(
"torch.export.export(..., strict=True) failed. Retrying with"
" strict=False"
)
return torch.export.export(*args, **kwargs, strict=False)
elif not strict_export:
return torch.export.export(*args, **kwargs, strict=False)
else:
return torch.export.export(*args, **kwargs, strict=True)

exported_programs: torch.export.ExportedProgram = [
export(sig.module, sig.flat_args, dynamic_shapes=sig.dynamic_shapes)
for sig in signatures
]

Expand Down
16 changes: 15 additions & 1 deletion ai_edge_torch/_convert/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from __future__ import annotations

from typing import Any, Optional, Tuple, Union
from typing import Any, Literal, Optional, Tuple, Union

from ai_edge_torch import model
from ai_edge_torch._convert import conversion
Expand Down Expand Up @@ -102,6 +102,7 @@ def convert(
sample_args=None,
sample_kwargs=None,
*,
strict_export: Union[Literal["auto"], bool] = True,
quant_config: Optional[qcfg.QuantConfig] = None,
dynamic_shapes: Optional[Union[dict[str, Any], Tuple[Any, ...]]] = None,
_ai_edge_converter_flags: Optional[dict[str, Any]] = None,
Expand All @@ -123,6 +124,11 @@ def convert(
with prior to conversion.
sample_kwargs: Dict of str to tensor by which the torch module will be
traced with prior to conversion.
strict_export: Experimental `strict` arg for torch.export.export. When
enabled, the export function will trace the program through TorchDynamo
and ensure the soundness of the exported graph. When
strict_export="auto", the function will try to export module in both
modes and use the first one succeeds for downstream conversion.
quant_config: User-defined quantization method and scheme of the model.
dynamic_shapes: Optional dict or tuple that specify dynamic shape
specifications for each input in original order. See
Expand Down Expand Up @@ -162,6 +168,7 @@ def convert(
)
return conversion.convert_signatures(
self._signatures,
strict_export=strict_export,
quant_config=quant_config,
_tfl_converter_flags=_ai_edge_converter_flags,
)
Expand Down Expand Up @@ -205,6 +212,7 @@ def convert(
sample_args=None,
sample_kwargs=None,
*,
strict_export: Union[Literal["auto"], bool] = True,
quant_config: Optional[qcfg.QuantConfig] = None,
dynamic_shapes: Optional[Union[dict[str, Any], Tuple[Any, ...]]] = None,
_ai_edge_converter_flags: Optional[dict[str, Any]] = None,
Expand All @@ -217,6 +225,11 @@ def convert(
prior to conversion.
sample_kwargs: Dict of str to tensor by which the torch module will be
traced with prior to conversion.
strict_export: Experimental `strict` arg for torch.export.export. When
enabled, the export function will trace the program through TorchDynamo
and ensure the soundness of the exported graph. When strict_export="auto",
the function will try to export module in both modes and use the first one
succeeds for downstream conversion.
quant_config: User-defined quantization method and scheme of the model.
dynamic_shapes: Optional dict or tuple that specify dynamic shape
specifications for each input in original order. See
Expand All @@ -242,6 +255,7 @@ def convert(
module,
sample_args,
sample_kwargs,
strict_export=strict_export,
quant_config=quant_config,
dynamic_shapes=dynamic_shapes,
_ai_edge_converter_flags=_ai_edge_converter_flags,
Expand Down
9 changes: 9 additions & 0 deletions ai_edge_torch/fx_pass_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,15 @@ class CanonicalizePass(ExportedProgramPassBase):
}

def call(self, exported_program: torch.export.ExportedProgram):
for node in exported_program.graph.nodes:
if node.target == torch.ops.aten.view.default:
# Passes or torch.export may generate aten.view nodes not respecting the
# tensor memory format. Changes all the aten.view to torch.reshape
# for retracing. If the input memory format is already contiguous,
# retracing in run_decomposition below would decompose torch.reshape
# back to one aten.view.
node.target = lambda self, size: torch.reshape(self, size)

exported_program = exported_program.run_decompositions(
self._DUMMY_DECOMP_TABLE
)
Expand Down

0 comments on commit 2ebd8c7

Please sign in to comment.