Skip to content

Commit

Permalink
Support kwargs in ai_edge_torch.convert to named inputs (#86)
Browse files Browse the repository at this point in the history
* init

* init

* fix

* fix

* Update conversion_utils.py

* Update conversion_utils.py
  • Loading branch information
chunnienc authored Jul 11, 2024
1 parent 3b81ea0 commit 9fa0401
Show file tree
Hide file tree
Showing 6 changed files with 179 additions and 39 deletions.
6 changes: 2 additions & 4 deletions ai_edge_torch/convert/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,16 +88,14 @@ def convert_signatures(
_warn_training_modules(signatures)

exported_programs: torch.export.ExportedProgram = [
torch.export.export(
sig.module, sig.sample_args, dynamic_shapes=sig.dynamic_shapes
)
torch.export.export(sig.module, sig.flat_args, dynamic_shapes=sig.dynamic_shapes)
for sig in signatures
]

# Apply default fx passes
exported_programs = list(map(_run_convert_passes, exported_programs))
shlo_bundles: list[stablehlo.StableHLOModelBundle] = [
cutils.exported_program_to_stablehlo_bundle(exported, sig.sample_args)
cutils.exported_program_to_stablehlo_bundle(exported, sig.flat_args)
for exported, sig in zip(exported_programs, signatures)
]

Expand Down
64 changes: 61 additions & 3 deletions ai_edge_torch/convert/conversion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
# ==============================================================================

import collections
import copy
from dataclasses import dataclass
import gc
Expand All @@ -22,6 +23,7 @@
from typing import Any, Dict, Optional, Tuple, Union

import torch
import torch.utils._pytree as pytree
from torch_xla import stablehlo

from ai_edge_torch.generative.quantize.ai_edge_quantizer_glue import translate_recipe # NOQA
Expand All @@ -47,8 +49,59 @@ class Signature:
name: str
module: torch.nn.Module
sample_args: tuple[torch.Tensor]
sample_kwargs: dict[str, torch.Tensor]
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None

@property
def _normalized_sample_args_kwargs(self):
args, kwargs = self.sample_args, self.sample_kwargs
if args is not None:
if not isinstance(args, tuple):
# TODO(b/352584188): Check value types
raise ValueError("sample_args must be a tuple of torch tensors.")
if kwargs is not None:
if not isinstance(kwargs, dict) or not all(
isinstance(key, str) for key in kwargs.keys()
):
# TODO(b/352584188): Check value types
raise ValueError("sample_kwargs must be a dict of string to tensor.")

args = args if args is not None else tuple()
kwargs = kwargs if kwargs is not None else {}
return args, kwargs

@property
def flat_arg_names(self) -> list[str]:
spec = pytree.tree_flatten(self._normalized_sample_args_kwargs)[1]
args_spec, kwargs_spec = spec.children_specs

names = []
for i in range(args_spec.num_leaves):
names.append(f"args_{i}")

dict_context = (
kwargs_spec.context
if kwargs_spec.type is not collections.defaultdict
# ignore mismatch of `default_factory` for defaultdict
else kwargs_spec.context[1]
)

for name, value_spec in zip(dict_context, kwargs_spec.children_specs):
if value_spec.num_leaves == 1:
names.append(name)
else:
# value_spec.num_leaves may be greater than 1 when the value is a (nested)
# tuple of tensors. We haven't decided how we should support flattenable
# tensor containers as inputs.
# TODO(b/352584188): Decide the behavior of tensor container as input (flatten or reject)
for i in range(value_spec.num_leaves):
names.append(f"{name}_{i}")
return names

@property
def flat_args(self) -> tuple[torch.Tensor]:
return tuple(pytree.tree_flatten(self._normalized_sample_args_kwargs)[0])


def exported_program_to_stablehlo_bundle(
exported_program: torch.export.ExportedProgram, sample_args: tuple[torch.Tensor]
Expand Down Expand Up @@ -189,17 +242,21 @@ def _make_tf_function(

def _make_tf_signature(
meta: stablehlo.StableHLOFunctionMeta,
signature: Signature,
) -> list[tf.TensorSpec]:
input_names = signature.flat_arg_names
input_pos_to_spec = {
loc.position: spec
for loc, spec in itertools.chain(
zip(meta.input_locations, meta.input_signature), meta.unused_inputs
)
if loc.type_ == stablehlo.VariableType.INPUT_ARG
}
assert len(input_pos_to_spec) == len(input_names)

primitive_type_to_tf_type = {"int": "int32", "float": "float32"}
ret: list[tf.TensorSpec] = []
for i in range(len(input_pos_to_spec)):
for i, name in enumerate(input_names):
spec = input_pos_to_spec[i]
shape = _get_shape_with_dynamic(spec)
ret.append(
Expand All @@ -208,7 +265,7 @@ def _make_tf_signature(
dtype=primitive_type_to_tf_type[spec.dtype]
if spec.dtype in primitive_type_to_tf_type
else spec.dtype,
name=f"args_{i}",
name=name,
)
)
return ret
Expand Down Expand Up @@ -276,7 +333,8 @@ def convert_stablehlo_to_tflite(
tf.Variable(v, trainable=False) for v in bundle.additional_constants
]
tf_signatures: list[list[tf.TensorSpec]] = list(
_make_tf_signature(func.meta) for func in bundle.stablehlo_funcs
_make_tf_signature(func.meta, sig)
for func, sig in zip(bundle.stablehlo_funcs, signatures)
)

tf_functions = _make_tf_function(shlo_graph_module, bundle)
Expand Down
63 changes: 47 additions & 16 deletions ai_edge_torch/convert/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,25 +34,32 @@ def signature(
self,
name: str,
module: torch.nn.Module,
sample_args: tuple[cutils.TracingArg],
sample_args=None,
sample_kwargs=None,
*,
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
) -> Converter:
"""Alias to `add_signature`"""
return self.add_signature(name, module, sample_args, dynamic_shapes)
return self.add_signature(
name, module, sample_args, sample_kwargs, dynamic_shapes=dynamic_shapes
)

def add_signature(
self,
name: str,
module: torch.nn.Module,
sample_args: tuple[cutils.TracingArg],
sample_args=None,
sample_kwargs=None,
*,
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
) -> Converter:
"""Allows adding a new named torch model along with sample args to the conversion.
Args:
name: The name of the signature included in the converted edge model.
module: The torch module to be converted.
sample_args: Tuple of args by which the torch module will be traced prior to conversion.
sample_args: Tuple of tensors by which the torch module will be traced with prior to conversion.
sample_kwargs: Dict of str to tensor by which the torch module will be traced with prior to conversion.
dynamic_shapes: Optional dict or tuple that specify dynamic shape specifications for each input in original order.
See https://pytorch.org/docs/stable/export.html#expressing-dynamism for more details.
Expand All @@ -63,13 +70,21 @@ def add_signature(
if name in [sig.name for sig in self._signatures]:
raise ValueError(f"A signature with the provided name ({name}) is already added.")

self._signatures.append(cutils.Signature(name, module, sample_args, dynamic_shapes))
if sample_args is None and sample_kwargs is None:
raise ValueError("sample_args or sample_kwargs must be provided.")

self._signatures.append(
cutils.Signature(
name, module, sample_args, sample_kwargs, dynamic_shapes=dynamic_shapes
)
)
return self

def convert(
self,
module: torch.nn.Module = None,
sample_args: tuple[cutils.TracingArg] = None,
sample_args=None,
sample_kwargs=None,
*,
quant_config: Optional[qcfg.QuantConfig] = None,
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
Expand All @@ -88,7 +103,8 @@ def convert(
Args:
name: The name of the signature included in the converted edge model.
module: The torch module to be converted.
sample_args: Tuple of args by which the torch module will be traced prior to conversion.
sample_args: Tuple of tensors by which the torch module will be traced with prior to conversion.
sample_kwargs: Dict of str to tensor by which the torch module will be traced with prior to conversion.
quant_config: User-defined quantization method and scheme of the model.
dynamic_shapes: Optional dict or tuple that specify dynamic shape specifications for each input in original order.
See https://pytorch.org/docs/stable/export.html#expressing-dynamism for more details.
Expand All @@ -100,12 +116,20 @@ def convert(
ValueError: If the arguments are not provided as expected. See the example in this functions's comment.
"""
if module is not None:
if sample_args is not None: # both module and args provided
if (
sample_args is not None or sample_kwargs is not None
): # both module and args provided
self.add_signature(
cutils.DEFAULT_SIGNATURE_NAME, module, sample_args, dynamic_shapes
cutils.DEFAULT_SIGNATURE_NAME,
module,
sample_args,
sample_kwargs,
dynamic_shapes=dynamic_shapes,
)
else: # module is provided but not args
raise ValueError(
"sample_args or sample_kwargs must be provided if a module is specified."
)
else: # module is provided but not sample_args
raise ValueError("sample_args needs to be provided if a module is specified.")

return conversion.convert_signatures(
self._signatures,
Expand All @@ -117,15 +141,17 @@ def convert(
def signature(
name: str,
module: torch.nn.Module,
sample_args: tuple[cutils.TracingArg],
sample_args=None,
sample_kwargs=None,
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
) -> Converter:
"""Initiates a Converter object with the provided signature.
Args:
name: The name of the signature included in the converted edge model.
module: The torch module to be converted.
sample_args: Tuple of args by which the torch module will be traced prior to conversion.
sample_args: Tuple of tensors by which the torch module will be traced with prior to conversion.
sample_kwargs: Dict of str to tensor by which the torch module will be traced with prior to conversion.
dynamic_shapes: Optional dict or tuple that specify dynamic shape specifications for each input in original order.
See https://pytorch.org/docs/stable/export.html#expressing-dynamism for more details.
Expand All @@ -134,12 +160,15 @@ def signature(
edge_model = converter.convert()
"""
return Converter().signature(name, module, sample_args, dynamic_shapes)
return Converter().signature(
name, module, sample_args, sample_kwargs, dynamic_shapes=dynamic_shapes
)


def convert(
module: torch.nn.Module = None,
sample_args: tuple[cutils.TracingArg] = None,
sample_args=None,
sample_kwargs=None,
*,
quant_config: Optional[qcfg.QuantConfig] = None,
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
Expand All @@ -149,7 +178,8 @@ def convert(
Args:
module: The torch module to be converted.
sample_args: Tuple of args by which the torch module will be traced prior to conversion.
sample_args: Tuple of tensors by which the torch module will be traced with prior to conversion.
sample_kwargs: Dict of str to tensor by which the torch module will be traced with prior to conversion.
quant_config: User-defined quantization method and scheme of the model.
dynamic_shapes: Optional dict or tuple that specify dynamic shape specifications for each input in original order.
See https://pytorch.org/docs/stable/export.html#expressing-dynamism for more details.
Expand All @@ -165,6 +195,7 @@ def convert(
return Converter().convert(
module,
sample_args,
sample_kwargs,
quant_config=quant_config,
dynamic_shapes=dynamic_shapes,
_ai_edge_converter_flags=_ai_edge_converter_flags,
Expand Down
39 changes: 39 additions & 0 deletions ai_edge_torch/convert/test/test_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,45 @@ def forward(self, x, y):
model_coverage.compare_tflite_torch(edge_model, model, validate_input)
)

def test_convert_model_with_kwargs(self):
"""
Test converting a simple model with sample_kwargs.
"""

class SampleModel(torch.nn.Module):

def forward(self, x, y):
return x + y

kwargs_gen = lambda: dict(x=torch.randn(10, 10), y=torch.randn(10, 10))

model = SampleModel().eval()
edge_model = ai_edge_torch.convert(model, sample_kwargs=kwargs_gen())

self.assertTrue(
model_coverage.compare_tflite_torch(edge_model, model, kwargs=kwargs_gen)
)

def test_convert_model_with_args_kwargs(self):
"""
Test converting a simple model with both sample_args and sample_kwargs.
"""

class SampleModel(torch.nn.Module):

def forward(self, x, y):
return x + y

args_gen = lambda: (torch.randn(10, 10),)
kwargs_gen = lambda: dict(y=torch.randn(10, 10))

model = SampleModel().eval()
edge_model = ai_edge_torch.convert(model, args_gen(), kwargs_gen())

self.assertTrue(
model_coverage.compare_tflite_torch(edge_model, model, args_gen, kwargs_gen)
)


if __name__ == "__main__":
unittest.main()
14 changes: 11 additions & 3 deletions ai_edge_torch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@ class Model(abc.ABC):

@abc.abstractmethod
def __call__(
self, *args: npt.ArrayLike, signature_name: str = cutils.DEFAULT_SIGNATURE_NAME
self,
*args: npt.ArrayLike,
signature_name: str = cutils.DEFAULT_SIGNATURE_NAME,
**kwargs,
) -> npt.ArrayLike | tuple[npt.ArrayLike]:
raise NotImplementedError()

Expand Down Expand Up @@ -62,12 +65,16 @@ def __init__(self, tflite_model):
self._tflite_model = tflite_model

def __call__(
self, *args: npt.ArrayLike, signature_name: str = cutils.DEFAULT_SIGNATURE_NAME
self,
*args: npt.ArrayLike,
signature_name: str = cutils.DEFAULT_SIGNATURE_NAME,
**kwargs,
) -> npt.ArrayLike | tuple[npt.ArrayLike]:
"""Runs inference on the edge model using the provided arguments.
Args:
*args: The arguments to be passed to the model for inference.
**kwargs: The arguments with specific names to be passed to the model for inference.
signature_name: The name of the signature to be used for inference.
The default signature is used if not provided.
"""
Expand All @@ -90,13 +97,14 @@ def __call__(
else:
raise exception

if len(signature_list[signature_name]['inputs']) != len(args):
if len(signature_list[signature_name]['inputs']) != len(args) + len(kwargs):
raise ValueError(
f"The model requires {len(signature_list[signature_name]['inputs'])} arguments but {len(args)} was provided."
)

# Gather the input dictionary based on the signature.
inputs = {f'args_{idx}': args[idx] for idx in range(len(args))}
inputs = {**inputs, **kwargs}
outputs = runner(**inputs)

return (
Expand Down
Loading

0 comments on commit 9fa0401

Please sign in to comment.