Skip to content

Commit

Permalink
Allow naming the outputs in the generated tflite file.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 666458444
  • Loading branch information
majiddadashi authored and copybara-github committed Aug 22, 2024
1 parent 23db233 commit c6da1a3
Show file tree
Hide file tree
Showing 7 changed files with 147 additions and 60 deletions.
38 changes: 2 additions & 36 deletions ai_edge_torch/_convert/signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import dataclasses
from typing import Any, Dict, List, Optional, Tuple, Union

from ai_edge_torch import lowertools
import torch
import torch.utils._pytree as pytree

Expand Down Expand Up @@ -53,47 +54,12 @@ def flat_arg_names(self) -> list[str]:
for i in range(args_spec.num_leaves):
names.append(f"args_{i}")

kwargs_names = self._flat_kwarg_names(
kwargs_names = lowertools.flat_dict_names(
kwargs_spec.children_specs, kwargs_spec.context
)
names.extend(kwargs_names)
return names

def _flat_kwarg_names(self, specs, context) -> List[str]:
flat_names = []
if context is None:
for i, spec in enumerate(specs):
if spec.children_specs:
flat_names.extend([
f"{i}_{name}"
for name in self._flat_kwarg_names(
spec.children_specs, spec.context
)
])
else:
flat_names.append(f"{i}")
else:
flat_ctx = self._flatten_list(context)
for prefix, spec in zip(flat_ctx, specs):
leaf_flat_names = self._flat_kwarg_names(
spec.children_specs, spec.context
)
if leaf_flat_names:
flat_names.extend([f"{prefix}_{name}" for name in leaf_flat_names])
else:
flat_names.append(prefix)

return flat_names

def _flatten_list(self, l: List) -> List:
flattened = []
for item in l:
if isinstance(item, list):
flattened.extend(self._flatten_list(item))
else:
flattened.append(item)
return flattened

@property
def flat_args(self) -> tuple[Any]:
args, kwargs = self._normalized_sample_args_kwargs
Expand Down
33 changes: 32 additions & 1 deletion ai_edge_torch/_convert/test/test_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def forward(self, arg0, arg1):
self.assertTrue(result)

def test_12_outputs_model(self):
"""Tests conversion of a model that returns multiple outputs."""
"""Tests conversion of a model that returns more than 10 outputs."""

class BasicAddModelWithMultipleOutputs(torch.nn.Module):
"""A model that returns multiple outputs."""
Expand Down Expand Up @@ -421,6 +421,37 @@ def forward(self, x, y, z):
SampleModel(), args, kwargs, flat_inputs
)

def test_convert_model_non_flat_output_dict(self):
"""Test converting a model with non-flat output structure."""

class SampleModel(torch.nn.Module):

def forward(self, x, y, z):
return {"x": x, "y": TestContainer1(data_1=y, data_2=[y, z])}

args = (torch.randn(10, 10), torch.randn(10, 10), torch.randn(10, 10))
kwargs = dict()
flat_inputs = {
"args_0": args[0].numpy(),
"args_1": args[1].numpy(),
"args_2": args[2].numpy(),
}

edge_model = ai_edge_torch.convert(SampleModel().eval(), args, kwargs)
edge_output = edge_model(**flat_inputs)
np.testing.assert_almost_equal(edge_output["x"], args[0])
np.testing.assert_almost_equal(edge_output["y_data_1"], args[1])
np.testing.assert_almost_equal(edge_output["y_data_2_0"], args[1])
np.testing.assert_almost_equal(edge_output["y_data_2_1"], args[2])

interpreter = tf.lite.Interpreter(model_content=edge_model._tflite_model)
runner = interpreter.get_signature_runner("serving_default")
output_details = runner.get_output_details()
self.assertIn("x", output_details.keys())
self.assertIn("y_data_1", output_details.keys())
self.assertIn("y_data_2_0", output_details.keys())
self.assertIn("y_data_2_1", output_details.keys())

def _compare_tflite_torch_args_kwargs(self, model, args, kwargs, flat_inputs):
model.eval()
edge_model = ai_edge_torch.convert(model, args, kwargs)
Expand Down
1 change: 1 addition & 0 deletions ai_edge_torch/lowertools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@
# ==============================================================================

from ._shim import *
from .common_utils import flat_dict_names
from .test_utils import *
53 changes: 53 additions & 0 deletions ai_edge_torch/lowertools/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,63 @@
# ==============================================================================

import logging
from typing import List

from ai_edge_torch._convert import signature as signature_module
import tensorflow as tf
import torch
import torch.utils._pytree as pytree


def _flatten_list(l: List) -> List:
flattened = []
for item in l:
if isinstance(item, list):
flattened.extend(_flatten_list(item))
else:
flattened.append(item)
return flattened


def flat_dict_names(
tree_spec: pytree.TreeSpec, context: pytree.Context
) -> List[str]:
"""Given a TreeSpec, this produces a list of names for the leaves.
The list of names embeddeds the structure of the tree_spec. A nesting level is
indicated by an `_` and elements in a list are indicated by `_<index>`.
TODO b/361601485: The flattening of names is not collision-free and needs to
be revised.
Args:
tree_spec: The TreeSpec to extract the names from.
context: The context used to check if the provided spec belongs to a
dictionary or a list.
Returns:
A list of flattened names.
"""
flat_names = []
if context is None:
for i, spec in enumerate(tree_spec):
if spec.children_specs:
flat_names.extend([
f"{i}_{name}"
for name in flat_dict_names(spec.children_specs, spec.context)
])
else:
flat_names.append(f"{i}")
else:
flat_ctx = _flatten_list(context)
for prefix, spec in zip(flat_ctx, tree_spec):
leaf_flat_names = flat_dict_names(spec.children_specs, spec.context)
if leaf_flat_names:
flat_names.extend([f"{prefix}_{name}" for name in leaf_flat_names])
else:
flat_names.append(prefix)

return flat_names


def _torch_to_tf_variable(torch_tensor: torch.Tensor):
Expand Down
27 changes: 22 additions & 5 deletions ai_edge_torch/lowertools/odml_torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class MergedBundle:
"""A bundle of MlirLowered that has been merged."""

bundles: list[odml_torch.export.MlirLowered]
exported_programs: list[torch.export.ExportedProgram]
deduped_tf_vars: list[tf.Variable]


Expand Down Expand Up @@ -74,19 +75,31 @@ def _extract_call_args(
return call_args


def _wrap_as_tf_func(bundle, tf_state_dict):
def _wrap_as_tf_func(
bundle: export.MlirLowered,
tf_state_dict: Dict[str, tf.Variable],
exported_program: torch.export.ExportedProgram,
):
def inner(*args):
t_outs = [torch_dtype_to_tf(sig.dtype) for sig in bundle.output_signature]
s_outs = [_get_shape_with_dynamic(sig) for sig in bundle.output_signature]
call_args = _extract_call_args(bundle, args, tf_state_dict)
return tfxla.call_module(
call_module_return = tfxla.call_module(
tuple(call_args),
version=5,
Tout=t_outs, # dtype information
Sout=s_outs, # Shape information
function_list=[],
module=bundle.module_bytecode,
)
spec = exported_program.call_spec.out_spec

# The module returning a flat array.
if not spec.context:
return call_module_return

flat_names = common_utils.flat_dict_names(spec.children_specs, spec.context)
return {name: value for name, value in zip(flat_names, call_module_return)}

return inner

Expand Down Expand Up @@ -128,8 +141,10 @@ def merged_bundle_to_tfl_model(
for bundle, sig in zip(merged_bundle.bundles, signatures)
]
tf_functions = [
_wrap_as_tf_func(bundle, tf_state_dict)
for bundle in merged_bundle.bundles
_wrap_as_tf_func(bundle, tf_state_dict, ep)
for bundle, ep in zip(
merged_bundle.bundles, merged_bundle.exported_programs
)
]

tf_module = tf.Module()
Expand Down Expand Up @@ -202,7 +217,9 @@ def merge_mlir_bundles(
)

merged_bundle = MergedBundle(
bundles=bundles.copy(), deduped_tf_vars=deduped_vars
bundles=bundles.copy(),
exported_programs=exported_programs,
deduped_tf_vars=deduped_vars,
)
for bundle, signature in zip(merged_bundle.bundles, signatures):
bundle.state_dict = state_dict
Expand Down
37 changes: 24 additions & 13 deletions ai_edge_torch/lowertools/torch_xla_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,17 @@
class MergedBundle:

bundle: stablehlo.StableHLOModelBundle
exported_programs: list[torch.export.ExportedProgram]
deduped_tf_vars: list[tf.Variable]


def exported_program_to_mlir(
exported_program: torch.export.ExportedProgram,
sample_args: tuple[torch.Tensor],
) -> stablehlo.StableHLOModelBundle:
# Setting export_weights to False here so that pytorch/xla avoids copying the weights
# to a numpy array which would lead to memory bloat. This means that the state_dict
# in the returned bundle is going to be empty.
# Setting export_weights to False here so that pytorch/xla avoids copying the
# weights to a numpy array which would lead to memory bloat. This means that
# the state_dict in the returned bundle is going to be empty.
return stablehlo.exported_program_to_stablehlo(
exported_program,
stablehlo.StableHLOExportOptions(
Expand Down Expand Up @@ -96,7 +97,9 @@ def merge_mlir_bundles(
bundle.additional_constants
)
return MergedBundle(
bundle=new_shlo_model_bundle, deduped_tf_vars=deduped_tf_vars
bundle=new_shlo_model_bundle,
exported_programs=exported_programs,
deduped_tf_vars=deduped_tf_vars,
)


Expand All @@ -108,31 +111,34 @@ def _get_shape_with_dynamic(signature: stablehlo.VariableSignature):


def _wrap_as_tf_func(
func: stablehlo.StableHLOFunc, bundle: stablehlo.StableHLOModelBundle
func: stablehlo.StableHLOFunc,
bundle: stablehlo.StableHLOModelBundle,
exported_program: torch.export.ExportedProgram,
):
def inner(*args):
type_info = [sig.dtype for sig in func.meta.output_signature]
shape_info = [
_get_shape_with_dynamic(sig) for sig in func.meta.output_signature
]
call_args = stablehlo._extract_call_parameters(args, func.meta, bundle)
return tfxla.call_module(
call_module_return = tfxla.call_module(
tuple(call_args),
version=5,
Tout=type_info,
Sout=shape_info,
function_list=[],
module=func.bytecode,
)
spec = exported_program.call_spec.out_spec

return inner
# The module returning a flat array.
if not spec.context:
return call_module_return

flat_names = common_utils.flat_dict_names(spec.children_specs, spec.context)
return {name: value for name, value in zip(flat_names, call_module_return)}

def _make_tf_function(
bundle: stablehlo.StableHLOModelBundle = None,
):
bundle = bundle if bundle is None else bundle
return [_wrap_as_tf_func(func, bundle) for func in bundle.stablehlo_funcs]
return inner


def _make_tf_signature(
Expand Down Expand Up @@ -205,7 +211,12 @@ def merged_bundle_to_tfl_model(
for func, sig in zip(shlo_bundle.stablehlo_funcs, signatures)
)

tf_functions = _make_tf_function(shlo_bundle)
tf_functions = [
_wrap_as_tf_func(func, shlo_bundle, ep)
for func, ep in zip(
shlo_bundle.stablehlo_funcs, merged_bundle.exported_programs
)
]

tf_module.f = []
for tf_sig, func in zip(tf_signatures, tf_functions):
Expand Down
18 changes: 13 additions & 5 deletions ai_edge_torch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from __future__ import annotations

import abc
import re

import numpy.typing as npt
import tensorflow as tf
Expand Down Expand Up @@ -115,11 +116,18 @@ def __call__(
inputs = {**inputs, **kwargs}
outputs = runner(**inputs)

return (
outputs['output_0']
if len(outputs) == 1
else [outputs[f'output_{idx}'] for idx in range(len(outputs))]
)
# When attempting to run a model, check if all the output tensors are named
# output_<number>. If so, assume the pytorch model returned a tuple and not
# a dictionary.
output_heuristic = lambda key: bool(re.search(r'output_\d+', key))
if all(output_heuristic(key) for key in outputs.keys()):
return (
outputs['output_0']
if len(outputs) == 1
else [outputs[f'output_{idx}'] for idx in range(len(outputs))]
)

return outputs

def export(self, path: str) -> None:
"""Serializes the edge model to disk.
Expand Down

0 comments on commit c6da1a3

Please sign in to comment.