Skip to content

Commit

Permalink
Format source with new rules (#118)
Browse files Browse the repository at this point in the history
* update scripts and configs

* fmt
  • Loading branch information
chunnienc authored Aug 1, 2024
1 parent 4c40530 commit 6bb40ca
Show file tree
Hide file tree
Showing 95 changed files with 1,529 additions and 900 deletions.
9 changes: 9 additions & 0 deletions .isort.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
[settings]
profile=google
multi_line_output=7
line_length=200
skip=.downloads,venv,bazel
known_third_party=ai_edge_torch
known_internal=tensorflow.python.platform,tensorflow.compiler.tf2xla.python
default_section=THIRDPARTY
sections=FUTURE,STDLIB,LOCALFOLDER,THIRDPARTY,INTERNAL
20 changes: 12 additions & 8 deletions ai_edge_torch/convert/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,6 @@
import os
from typing import Optional

import torch
from torch.export import ExportedProgram
from torch_xla import stablehlo

from ai_edge_torch import model
from ai_edge_torch.convert import conversion_utils as cutils
from ai_edge_torch.convert.fx_passes import BuildAtenCompositePass
Expand All @@ -32,6 +28,9 @@
from ai_edge_torch.convert.fx_passes import run_passes
from ai_edge_torch.generative.fx_passes import run_generative_passes
from ai_edge_torch.quantize import quant_config as qcfg
import torch
from torch.export import ExportedProgram
from torch_xla import stablehlo

os.environ["EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM"] = "1"

Expand Down Expand Up @@ -61,8 +60,9 @@ def _warn_training_modules(signatures: list[cutils.Signature]):
continue

message = (
"Your model {sig_name}is converted in training mode. "
"Please set the module in evaluation mode with `module.eval()` for better on-device performance and compatibility."
"Your model {sig_name}is converted in training mode. Please set the"
" module in evaluation mode with `module.eval()` for better on-device"
" performance and compatibility."
)
if len(signatures) == 1 and sig.name == cutils.DEFAULT_SIGNATURE_NAME:
# User does not specify any signature names explicitly.
Expand All @@ -88,7 +88,9 @@ def convert_signatures(
_warn_training_modules(signatures)

exported_programs: torch.export.ExportedProgram = [
torch.export.export(sig.module, sig.flat_args, dynamic_shapes=sig.dynamic_shapes)
torch.export.export(
sig.module, sig.flat_args, dynamic_shapes=sig.dynamic_shapes
)
for sig in signatures
]

Expand All @@ -100,7 +102,9 @@ def convert_signatures(
]

merged_shlo_graph_module: stablehlo.StableHLOGraphModule = (
cutils.merge_stablehlo_bundles(shlo_bundles, signatures, exported_programs)
cutils.merge_stablehlo_bundles(
shlo_bundles, signatures, exported_programs
)
)
del exported_programs
del shlo_bundles
Expand Down
58 changes: 38 additions & 20 deletions ai_edge_torch/convert/conversion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@
import tempfile
from typing import Any, Dict, List, Optional, Tuple, Union

from ai_edge_torch.generative.quantize.ai_edge_quantizer_glue import translate_recipe # NOQA
from ai_edge_torch.quantize import quant_config as qcfg
import torch
import torch.utils._pytree as pytree
from torch_xla import stablehlo

from ai_edge_torch.generative.quantize.ai_edge_quantizer_glue import translate_recipe # NOQA
from ai_edge_torch.quantize import quant_config as qcfg

try:
import tensorflow as tf

from tensorflow.compiler.tf2xla.python import xla as tfxla

from tensorflow.lite.python import conversion_metadata_schema_py_generated as conversion_metadata_fb # isort:skip
Expand Down Expand Up @@ -90,18 +90,20 @@ def _flat_kwarg_names(self, specs, context) -> List[str]:
if context is None:
for i, spec in enumerate(specs):
if spec.children_specs:
flat_names.extend(
[
f"{i}_{name}"
for name in self._flat_kwarg_names(spec.children_specs, spec.context)
]
)
flat_names.extend([
f"{i}_{name}"
for name in self._flat_kwarg_names(
spec.children_specs, spec.context
)
])
else:
flat_names.append(f"{i}")
else:
flat_ctx = self._flatten_list(context)
for prefix, spec in zip(flat_ctx, specs):
leaf_flat_names = self._flat_kwarg_names(spec.children_specs, spec.context)
leaf_flat_names = self._flat_kwarg_names(
spec.children_specs, spec.context
)
if leaf_flat_names:
flat_names.extend([f"{prefix}_{name}" for name in leaf_flat_names])
else:
Expand All @@ -125,7 +127,8 @@ def flat_args(self) -> tuple[Any]:


def exported_program_to_stablehlo_bundle(
exported_program: torch.export.ExportedProgram, sample_args: tuple[torch.Tensor]
exported_program: torch.export.ExportedProgram,
sample_args: tuple[torch.Tensor],
) -> stablehlo.StableHLOModelBundle:
# Setting export_weights to False here so that pytorch/xla avoids copying the weights
# to a numpy array which would lead to memory bloat. This means that the state_dict
Expand All @@ -146,15 +149,18 @@ def _torch_to_tf_tensor(torch_tensor: torch.Tensor):
dlpack_capsule = torch.utils.dlpack.to_dlpack(torch_tensor)
tf_tensor = tf.experimental.dlpack.from_dlpack(dlpack_capsule)
except Exception:
logging.info("Can not use dlpack to convert torch tensors. Falling back to numpy.")
logging.info(
"Can not use dlpack to convert torch tensors. Falling back to numpy."
)
nparray = torch_tensor.cpu().detach().numpy()
tf_tensor = tf.convert_to_tensor(nparray)

return tf_tensor


def _get_states(
exported_programs: list[torch.export.ExportedProgram], signatures: list[Signature]
exported_programs: list[torch.export.ExportedProgram],
signatures: list[Signature],
):
for exported_program, signature in zip(exported_programs, signatures):
args, _ = exported_program.example_inputs
Expand All @@ -166,7 +172,8 @@ def _get_states(
# Only interested in Tensors that are part of the state (and not user input).
if (
not isinstance(tensor, torch.Tensor)
or input_spec.kind == torch.export.graph_signature.InputKind.USER_INPUT
or input_spec.kind
== torch.export.graph_signature.InputKind.USER_INPUT
):
continue
yield signature, tensor, input_spec
Expand All @@ -192,9 +199,13 @@ def _gather_state_dict(
deduped_tensor_map[unique_id] = _torch_to_tf_tensor(tensor)

state_dict = {}
for signature, tensor, input_spec in _get_states(exported_programs, signatures):
for signature, tensor, input_spec in _get_states(
exported_programs, signatures
):
unique_id = _tensor_unique_id(tensor)
state_dict[signature.name + "_" + input_spec.target] = deduped_tensor_map[unique_id]
state_dict[signature.name + "_" + input_spec.target] = deduped_tensor_map[
unique_id
]

return state_dict

Expand Down Expand Up @@ -236,7 +247,9 @@ def _wrap_as_tf_func(
):
def inner(*args):
type_info = [sig.dtype for sig in func.meta.output_signature]
shape_info = [_get_shape_with_dynamic(sig) for sig in func.meta.output_signature]
shape_info = [
_get_shape_with_dynamic(sig) for sig in func.meta.output_signature
]
call_args = stablehlo._extract_call_parameters(args, func.meta, bundle)
return tfxla.call_module(
tuple(call_args),
Expand Down Expand Up @@ -369,7 +382,9 @@ def convert_stablehlo_to_tflite(
)
)

tf_module._variables = list(bundle.state_dict.values()) + bundle.additional_constants
tf_module._variables = (
list(bundle.state_dict.values()) + bundle.additional_constants
)
del bundle
gc.collect()

Expand All @@ -385,7 +400,8 @@ def convert_stablehlo_to_tflite(
tf_module,
temp_dir_path,
signatures={
sig.name: tf_concrete_funcs[idx] for idx, sig in enumerate(signatures)
sig.name: tf_concrete_funcs[idx]
for idx, sig in enumerate(signatures)
},
)
# Clean up intermediate memory early.
Expand Down Expand Up @@ -416,6 +432,8 @@ def convert_stablehlo_to_tflite(
and quant_config._quantizer_mode
== quant_config._QuantizerMode.AI_EDGE_QUANTIZER
):
tflite_model = translate_recipe.quantize_model(tflite_model, translated_recipe)
tflite_model = translate_recipe.quantize_model(
tflite_model, translated_recipe
)

return tflite_model
16 changes: 11 additions & 5 deletions ai_edge_torch/convert/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,11 @@

from typing import Any, Dict, Optional, Tuple, Union

import torch

from ai_edge_torch import model
from ai_edge_torch.convert import conversion
from ai_edge_torch.convert import conversion_utils as cutils
from ai_edge_torch.quantize import quant_config as qcfg
import torch


class Converter:
Expand Down Expand Up @@ -68,14 +67,20 @@ def add_signature(
"""

if name in [sig.name for sig in self._signatures]:
raise ValueError(f"A signature with the provided name ({name}) is already added.")
raise ValueError(
f"A signature with the provided name ({name}) is already added."
)

if sample_args is None and sample_kwargs is None:
raise ValueError("sample_args or sample_kwargs must be provided.")

self._signatures.append(
cutils.Signature(
name, module, sample_args, sample_kwargs, dynamic_shapes=dynamic_shapes
name,
module,
sample_args,
sample_kwargs,
dynamic_shapes=dynamic_shapes,
)
)
return self
Expand Down Expand Up @@ -128,7 +133,8 @@ def convert(
)
else: # module is provided but not args
raise ValueError(
"sample_args or sample_kwargs must be provided if a module is specified."
"sample_args or sample_kwargs must be provided if a module is"
" specified."
)

return conversion.convert_signatures(
Expand Down
7 changes: 3 additions & 4 deletions ai_edge_torch/convert/fx_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,6 @@

from typing import Sequence, Union

from torch.export import ExportedProgram
from torch.fx.passes.infra.pass_manager import pass_result_wrapper
import torch.utils._pytree as pytree

from ai_edge_torch.convert.fx_passes._pass_base import ExportedProgramPassBase
from ai_edge_torch.convert.fx_passes._pass_base import ExportedProgramPassResult # NOQA
from ai_edge_torch.convert.fx_passes._pass_base import FxPassBase
Expand All @@ -28,6 +24,9 @@
from ai_edge_torch.convert.fx_passes.canonicalize_pass import CanonicalizePass
from ai_edge_torch.convert.fx_passes.inject_mlir_debuginfo_pass import InjectMlirDebuginfoPass # NOQA
from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import OptimizeLayoutTransposesPass # NOQA
from torch.export import ExportedProgram
from torch.fx.passes.infra.pass_manager import pass_result_wrapper
import torch.utils._pytree as pytree


# TODO(cnchan): make a PassManager class.
Expand Down
8 changes: 6 additions & 2 deletions ai_edge_torch/convert/fx_passes/_pass_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,18 @@ def __new__(cls, exported_program, modified):

class ExportedProgramPassBase(abc.ABC):

def __call__(self, exported_program: ExportedProgram) -> ExportedProgramPassResult:
def __call__(
self, exported_program: ExportedProgram
) -> ExportedProgramPassResult:
self.requires(exported_program)
res = self.call(exported_program)
self.ensures(exported_program)
return res

@abc.abstractmethod
def call(self, exported_program: ExportedProgram) -> ExportedProgramPassResult:
def call(
self, exported_program: ExportedProgram
) -> ExportedProgramPassResult:
pass

def requires(self, exported_program: ExportedProgram) -> None:
Expand Down
Loading

0 comments on commit 6bb40ca

Please sign in to comment.