diff --git a/ai_edge_torch/convert/conversion.py b/ai_edge_torch/convert/conversion.py index a6e6495c..ada2e0b9 100644 --- a/ai_edge_torch/convert/conversion.py +++ b/ai_edge_torch/convert/conversion.py @@ -88,16 +88,14 @@ def convert_signatures( _warn_training_modules(signatures) exported_programs: torch.export.ExportedProgram = [ - torch.export.export( - sig.module, sig.sample_args, dynamic_shapes=sig.dynamic_shapes - ) + torch.export.export(sig.module, sig.flat_args, dynamic_shapes=sig.dynamic_shapes) for sig in signatures ] # Apply default fx passes exported_programs = list(map(_run_convert_passes, exported_programs)) shlo_bundles: list[stablehlo.StableHLOModelBundle] = [ - cutils.exported_program_to_stablehlo_bundle(exported, sig.sample_args) + cutils.exported_program_to_stablehlo_bundle(exported, sig.flat_args) for exported, sig in zip(exported_programs, signatures) ] diff --git a/ai_edge_torch/convert/conversion_utils.py b/ai_edge_torch/convert/conversion_utils.py index 73e81030..e26e26d5 100644 --- a/ai_edge_torch/convert/conversion_utils.py +++ b/ai_edge_torch/convert/conversion_utils.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================== +import collections import copy from dataclasses import dataclass import gc @@ -22,6 +23,7 @@ from typing import Any, Dict, Optional, Tuple, Union import torch +import torch.utils._pytree as pytree from torch_xla import stablehlo from ai_edge_torch.generative.quantize.ai_edge_quantizer_glue import translate_recipe # NOQA @@ -47,8 +49,59 @@ class Signature: name: str module: torch.nn.Module sample_args: tuple[torch.Tensor] + sample_kwargs: dict[str, torch.Tensor] dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None + @property + def _normalized_sample_args_kwargs(self): + args, kwargs = self.sample_args, self.sample_kwargs + if args is not None: + if not isinstance(args, tuple): + # TODO(b/352584188): Check value types + raise ValueError("sample_args must be a tuple of torch tensors.") + if kwargs is not None: + if not isinstance(kwargs, dict) or not all( + isinstance(key, str) for key in kwargs.keys() + ): + # TODO(b/352584188): Check value types + raise ValueError("sample_kwargs must be a dict of string to tensor.") + + args = args if args is not None else tuple() + kwargs = kwargs if kwargs is not None else {} + return args, kwargs + + @property + def flat_arg_names(self) -> list[str]: + spec = pytree.tree_flatten(self._normalized_sample_args_kwargs)[1] + args_spec, kwargs_spec = spec.children_specs + + names = [] + for i in range(args_spec.num_leaves): + names.append(f"args_{i}") + + dict_context = ( + kwargs_spec.context + if kwargs_spec.type is not collections.defaultdict + # ignore mismatch of `default_factory` for defaultdict + else kwargs_spec.context[1] + ) + + for name, value_spec in zip(dict_context, kwargs_spec.children_specs): + if value_spec.num_leaves == 1: + names.append(name) + else: + # value_spec.num_leaves may be greater than 1 when the value is a (nested) + # tuple of tensors. We haven't decided how we should support flattenable + # tensor containers as inputs. + # TODO(b/352584188): Decide the behavior of tensor container as input (flatten or reject) + for i in range(value_spec.num_leaves): + names.append(f"{name}_{i}") + return names + + @property + def flat_args(self) -> tuple[torch.Tensor]: + return tuple(pytree.tree_flatten(self._normalized_sample_args_kwargs)[0]) + def exported_program_to_stablehlo_bundle( exported_program: torch.export.ExportedProgram, sample_args: tuple[torch.Tensor] @@ -189,7 +242,9 @@ def _make_tf_function( def _make_tf_signature( meta: stablehlo.StableHLOFunctionMeta, + signature: Signature, ) -> list[tf.TensorSpec]: + input_names = signature.flat_arg_names input_pos_to_spec = { loc.position: spec for loc, spec in itertools.chain( @@ -197,9 +252,11 @@ def _make_tf_signature( ) if loc.type_ == stablehlo.VariableType.INPUT_ARG } + assert len(input_pos_to_spec) == len(input_names) + primitive_type_to_tf_type = {"int": "int32", "float": "float32"} ret: list[tf.TensorSpec] = [] - for i in range(len(input_pos_to_spec)): + for i, name in enumerate(input_names): spec = input_pos_to_spec[i] shape = _get_shape_with_dynamic(spec) ret.append( @@ -208,7 +265,7 @@ def _make_tf_signature( dtype=primitive_type_to_tf_type[spec.dtype] if spec.dtype in primitive_type_to_tf_type else spec.dtype, - name=f"args_{i}", + name=name, ) ) return ret @@ -276,7 +333,8 @@ def convert_stablehlo_to_tflite( tf.Variable(v, trainable=False) for v in bundle.additional_constants ] tf_signatures: list[list[tf.TensorSpec]] = list( - _make_tf_signature(func.meta) for func in bundle.stablehlo_funcs + _make_tf_signature(func.meta, sig) + for func, sig in zip(bundle.stablehlo_funcs, signatures) ) tf_functions = _make_tf_function(shlo_graph_module, bundle) diff --git a/ai_edge_torch/convert/converter.py b/ai_edge_torch/convert/converter.py index c3787c17..f71e143d 100644 --- a/ai_edge_torch/convert/converter.py +++ b/ai_edge_torch/convert/converter.py @@ -34,17 +34,23 @@ def signature( self, name: str, module: torch.nn.Module, - sample_args: tuple[cutils.TracingArg], + sample_args=None, + sample_kwargs=None, + *, dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, ) -> Converter: """Alias to `add_signature`""" - return self.add_signature(name, module, sample_args, dynamic_shapes) + return self.add_signature( + name, module, sample_args, sample_kwargs, dynamic_shapes=dynamic_shapes + ) def add_signature( self, name: str, module: torch.nn.Module, - sample_args: tuple[cutils.TracingArg], + sample_args=None, + sample_kwargs=None, + *, dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, ) -> Converter: """Allows adding a new named torch model along with sample args to the conversion. @@ -52,7 +58,8 @@ def add_signature( Args: name: The name of the signature included in the converted edge model. module: The torch module to be converted. - sample_args: Tuple of args by which the torch module will be traced prior to conversion. + sample_args: Tuple of tensors by which the torch module will be traced with prior to conversion. + sample_kwargs: Dict of str to tensor by which the torch module will be traced with prior to conversion. dynamic_shapes: Optional dict or tuple that specify dynamic shape specifications for each input in original order. See https://pytorch.org/docs/stable/export.html#expressing-dynamism for more details. @@ -63,13 +70,21 @@ def add_signature( if name in [sig.name for sig in self._signatures]: raise ValueError(f"A signature with the provided name ({name}) is already added.") - self._signatures.append(cutils.Signature(name, module, sample_args, dynamic_shapes)) + if sample_args is None and sample_kwargs is None: + raise ValueError("sample_args or sample_kwargs must be provided.") + + self._signatures.append( + cutils.Signature( + name, module, sample_args, sample_kwargs, dynamic_shapes=dynamic_shapes + ) + ) return self def convert( self, module: torch.nn.Module = None, - sample_args: tuple[cutils.TracingArg] = None, + sample_args=None, + sample_kwargs=None, *, quant_config: Optional[qcfg.QuantConfig] = None, dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, @@ -88,7 +103,8 @@ def convert( Args: name: The name of the signature included in the converted edge model. module: The torch module to be converted. - sample_args: Tuple of args by which the torch module will be traced prior to conversion. + sample_args: Tuple of tensors by which the torch module will be traced with prior to conversion. + sample_kwargs: Dict of str to tensor by which the torch module will be traced with prior to conversion. quant_config: User-defined quantization method and scheme of the model. dynamic_shapes: Optional dict or tuple that specify dynamic shape specifications for each input in original order. See https://pytorch.org/docs/stable/export.html#expressing-dynamism for more details. @@ -100,12 +116,20 @@ def convert( ValueError: If the arguments are not provided as expected. See the example in this functions's comment. """ if module is not None: - if sample_args is not None: # both module and args provided + if ( + sample_args is not None or sample_kwargs is not None + ): # both module and args provided self.add_signature( - cutils.DEFAULT_SIGNATURE_NAME, module, sample_args, dynamic_shapes + cutils.DEFAULT_SIGNATURE_NAME, + module, + sample_args, + sample_kwargs, + dynamic_shapes=dynamic_shapes, + ) + else: # module is provided but not args + raise ValueError( + "sample_args or sample_kwargs must be provided if a module is specified." ) - else: # module is provided but not sample_args - raise ValueError("sample_args needs to be provided if a module is specified.") return conversion.convert_signatures( self._signatures, @@ -117,7 +141,8 @@ def convert( def signature( name: str, module: torch.nn.Module, - sample_args: tuple[cutils.TracingArg], + sample_args=None, + sample_kwargs=None, dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, ) -> Converter: """Initiates a Converter object with the provided signature. @@ -125,7 +150,8 @@ def signature( Args: name: The name of the signature included in the converted edge model. module: The torch module to be converted. - sample_args: Tuple of args by which the torch module will be traced prior to conversion. + sample_args: Tuple of tensors by which the torch module will be traced with prior to conversion. + sample_kwargs: Dict of str to tensor by which the torch module will be traced with prior to conversion. dynamic_shapes: Optional dict or tuple that specify dynamic shape specifications for each input in original order. See https://pytorch.org/docs/stable/export.html#expressing-dynamism for more details. @@ -134,12 +160,15 @@ def signature( edge_model = converter.convert() """ - return Converter().signature(name, module, sample_args, dynamic_shapes) + return Converter().signature( + name, module, sample_args, sample_kwargs, dynamic_shapes=dynamic_shapes + ) def convert( module: torch.nn.Module = None, - sample_args: tuple[cutils.TracingArg] = None, + sample_args=None, + sample_kwargs=None, *, quant_config: Optional[qcfg.QuantConfig] = None, dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, @@ -149,7 +178,8 @@ def convert( Args: module: The torch module to be converted. - sample_args: Tuple of args by which the torch module will be traced prior to conversion. + sample_args: Tuple of tensors by which the torch module will be traced with prior to conversion. + sample_kwargs: Dict of str to tensor by which the torch module will be traced with prior to conversion. quant_config: User-defined quantization method and scheme of the model. dynamic_shapes: Optional dict or tuple that specify dynamic shape specifications for each input in original order. See https://pytorch.org/docs/stable/export.html#expressing-dynamism for more details. @@ -165,6 +195,7 @@ def convert( return Converter().convert( module, sample_args, + sample_kwargs, quant_config=quant_config, dynamic_shapes=dynamic_shapes, _ai_edge_converter_flags=_ai_edge_converter_flags, diff --git a/ai_edge_torch/convert/test/test_convert.py b/ai_edge_torch/convert/test/test_convert.py index 266e4562..c3c551e9 100644 --- a/ai_edge_torch/convert/test/test_convert.py +++ b/ai_edge_torch/convert/test/test_convert.py @@ -267,6 +267,45 @@ def forward(self, x, y): model_coverage.compare_tflite_torch(edge_model, model, validate_input) ) + def test_convert_model_with_kwargs(self): + """ + Test converting a simple model with sample_kwargs. + """ + + class SampleModel(torch.nn.Module): + + def forward(self, x, y): + return x + y + + kwargs_gen = lambda: dict(x=torch.randn(10, 10), y=torch.randn(10, 10)) + + model = SampleModel().eval() + edge_model = ai_edge_torch.convert(model, sample_kwargs=kwargs_gen()) + + self.assertTrue( + model_coverage.compare_tflite_torch(edge_model, model, kwargs=kwargs_gen) + ) + + def test_convert_model_with_args_kwargs(self): + """ + Test converting a simple model with both sample_args and sample_kwargs. + """ + + class SampleModel(torch.nn.Module): + + def forward(self, x, y): + return x + y + + args_gen = lambda: (torch.randn(10, 10),) + kwargs_gen = lambda: dict(y=torch.randn(10, 10)) + + model = SampleModel().eval() + edge_model = ai_edge_torch.convert(model, args_gen(), kwargs_gen()) + + self.assertTrue( + model_coverage.compare_tflite_torch(edge_model, model, args_gen, kwargs_gen) + ) + if __name__ == "__main__": unittest.main() diff --git a/ai_edge_torch/model.py b/ai_edge_torch/model.py index 27632887..198af876 100644 --- a/ai_edge_torch/model.py +++ b/ai_edge_torch/model.py @@ -33,7 +33,10 @@ class Model(abc.ABC): @abc.abstractmethod def __call__( - self, *args: npt.ArrayLike, signature_name: str = cutils.DEFAULT_SIGNATURE_NAME + self, + *args: npt.ArrayLike, + signature_name: str = cutils.DEFAULT_SIGNATURE_NAME, + **kwargs, ) -> npt.ArrayLike | tuple[npt.ArrayLike]: raise NotImplementedError() @@ -62,12 +65,16 @@ def __init__(self, tflite_model): self._tflite_model = tflite_model def __call__( - self, *args: npt.ArrayLike, signature_name: str = cutils.DEFAULT_SIGNATURE_NAME + self, + *args: npt.ArrayLike, + signature_name: str = cutils.DEFAULT_SIGNATURE_NAME, + **kwargs, ) -> npt.ArrayLike | tuple[npt.ArrayLike]: """Runs inference on the edge model using the provided arguments. Args: *args: The arguments to be passed to the model for inference. + **kwargs: The arguments with specific names to be passed to the model for inference. signature_name: The name of the signature to be used for inference. The default signature is used if not provided. """ @@ -90,13 +97,14 @@ def __call__( else: raise exception - if len(signature_list[signature_name]['inputs']) != len(args): + if len(signature_list[signature_name]['inputs']) != len(args) + len(kwargs): raise ValueError( f"The model requires {len(signature_list[signature_name]['inputs'])} arguments but {len(args)} was provided." ) # Gather the input dictionary based on the signature. inputs = {f'args_{idx}': args[idx] for idx in range(len(args))} + inputs = {**inputs, **kwargs} outputs = runner(**inputs) return ( diff --git a/ai_edge_torch/testing/model_coverage/model_coverage.py b/ai_edge_torch/testing/model_coverage/model_coverage.py index 00edbaa2..2d4a6b6f 100644 --- a/ai_edge_torch/testing/model_coverage/model_coverage.py +++ b/ai_edge_torch/testing/model_coverage/model_coverage.py @@ -60,7 +60,8 @@ def _torch_tensors_to_np(*argv): def compare_tflite_torch( edge_model: Model, torch_eval_func: Callable, - input_data=None, + args=None, + kwargs=None, *, num_valid_inputs: int = 1, signature_name: str = None, @@ -71,8 +72,9 @@ def compare_tflite_torch( Args: edge_model: Serialized ai_edge_torch.model.Model object. torch_eval_func: Callable function to evaluate torch model. - input_data: torch.tensor array or a callable to generate a torch.tensor array + args: torch.tensor array or a callable to generate a torch.tensor array with random data, to pass into models during inference. (default None). + kwargs: dict of str to torch.tensor, or a callable to generate such. num_valid_inputs: Defines the number of times the random inputs will be generated (if a callable is provided for input_data). signature_name: If provided, specifies the name for the signature of the edge_model to run. Calls the default signature if not provided. @@ -86,29 +88,33 @@ def compare_tflite_torch( # The supplied model_def.forward_args() will be executed num_valid_inputs # times to generate num_valid_inputs random inputs. torch_inputs = [ - input_data() if callable(input_data) else input_data + ( + (args() if callable(args) else args) or tuple(), + (kwargs() if callable(kwargs) else kwargs) or {}, + ) for _ in range(num_valid_inputs) ] - torch_outputs = [torch_eval_func(*xs) for xs in torch_inputs] - np_inputs = [_torch_tensors_to_np(xs) for xs in torch_inputs] + torch_outputs = [torch_eval_func(*args, **kwargs) for args, kwargs in torch_inputs] + np_inputs = [ + (_torch_tensors_to_np(args), _torch_tensors_to_np(kwargs)) + for args, kwargs in torch_inputs + ] np_outputs = [_torch_tensors_to_np(_flatten(ys)) for ys in torch_outputs] # Define inline utility function used throughout the function. def equal_fn(actual, expected): return np.allclose(actual, expected, atol=atol, rtol=rtol) - def get_actual_fn(input): + def get_edge_output(inputs): + args, kwargs = inputs if signature_name is None: - return _flatten(edge_model(*input)) + return _flatten(edge_model(*args, **kwargs)) else: - return _flatten(edge_model(*input, signature_name=signature_name)) - - def get_expected_fn(input=None, idx=0): - return np_outputs[idx] + return _flatten(edge_model(*args, **kwargs, signature_name=signature_name)) for idx, np_input in enumerate(np_inputs): - output = get_actual_fn(np_input) - golden_output = get_expected_fn(np_input, idx) + output = get_edge_output(np_input) + golden_output = np_outputs[idx] is_output_len_eq = len(golden_output) == len(output)