From e5062025f5dcbc5df9d6c83d7e850d37a61fd03f Mon Sep 17 00:00:00 2001 From: peterjc123 Date: Tue, 26 Mar 2024 13:58:56 +0800 Subject: [PATCH] [converter] track dynamic inputs for slices (#284) * track dynamic inputs for slices * update op matrix * more fixes --- docs/op_matrix.md | 9 ++- scripts/gen_op_docs.py | 14 +++- tests/converter_op_test.py | 45 +++++++++++ tinynn/converter/operators/graph.py | 1 + tinynn/converter/operators/torch/__init__.py | 4 +- tinynn/converter/operators/torch/aten.py | 82 +++++++++++++++++--- tinynn/converter/operators/torch/base.py | 10 +++ 7 files changed, 150 insertions(+), 15 deletions(-) diff --git a/docs/op_matrix.md b/docs/op_matrix.md index 6dec3cd5..543b9283 100644 --- a/docs/op_matrix.md +++ b/docs/op_matrix.md @@ -220,8 +220,6 @@ Operators that are implemented in Python Non-tracking operators that are ignored during translation | Operator | Limitations | |---------------------------|--------------| -| `aten::Int` | | -| `aten::ScalarImplicit` | | | `aten::arange` | | | `aten::detach` | | | `aten::empty` | | @@ -232,3 +230,10 @@ Non-tracking operators that are ignored during translation | `aten::size` | | | `aten::zeros` | | | `aten::zeros_like` | | + +## Constant Tracking Operators +Tracking operators that produce a dynamic constant +| Operator | Limitations | +|---------------------------|--------------| +| `aten::Int` | | +| `aten::ScalarImplicit` | | diff --git a/scripts/gen_op_docs.py b/scripts/gen_op_docs.py index 32396e13..bb3ae310 100644 --- a/scripts/gen_op_docs.py +++ b/scripts/gen_op_docs.py @@ -3,7 +3,7 @@ import re from tinynn.converter.operators.torch import OPERATOR_CONVERTER_DICT -from tinynn.converter.operators.torch.base import NoTrackOperator, PrimOperatorConverter +from tinynn.converter.operators.torch.base import NoTrackOperator, PrimOperatorConverter, TrackConstantOperator CURRENT_PATH = os.path.abspath(os.path.dirname(__file__)) @@ -12,6 +12,7 @@ quantized_ops = [] torchvision_ops = [] passthrough_ops = [] +track_constant_ops = [] limitation_dict = {} @@ -22,13 +23,15 @@ def main(): def collect_ops(): - global prim_ops, aten_ops, quantized_ops, torchvision_ops, passthrough_ops, limitation_dict + global prim_ops, aten_ops, quantized_ops, torchvision_ops, passthrough_ops, track_constant_ops, limitation_dict for k, v in OPERATOR_CONVERTER_DICT.items(): if issubclass(v, PrimOperatorConverter): prim_ops.append(k) elif issubclass(v, NoTrackOperator): passthrough_ops.append(k) + elif issubclass(v, TrackConstantOperator): + track_constant_ops.append(k) else: if v.__module__ == 'tinynn.converter.operators.torch.aten': aten_ops.append(k) @@ -50,6 +53,7 @@ def collect_ops(): quantized_ops = sorted(quantized_ops) torchvision_ops = sorted(torchvision_ops) passthrough_ops = sorted(passthrough_ops) + track_constant_ops = sorted(track_constant_ops) def print_operators(topic, ops, f, desc=None, eol=True): @@ -89,6 +93,12 @@ def update_file(): passthrough_ops, f, 'Non-tracking operators that are ignored during translation', + ) + print_operators( + 'Constant Tracking Operators', + track_constant_ops, + f, + 'Tracking operators that produce a dynamic constant', False, ) diff --git a/tests/converter_op_test.py b/tests/converter_op_test.py index ebe454dc..425733a9 100644 --- a/tests/converter_op_test.py +++ b/tests/converter_op_test.py @@ -3853,6 +3853,51 @@ def model(x): tfl_output = tfl_run_model(model_path, dummy_input, dummy_output) assert_close(dummy_output, tfl_output) + def test_slice_dyn_input(self): + dummy_input = [torch.randn(10, 10, dtype=torch.float32), torch.tensor(2), torch.tensor(4)] + + def model(x, y, z): + return x[y:z] + + model_path = get_model_path() + + converter = TFLiteConverter(model, dummy_input, model_path, nchw_transpose=False) + converter.convert() + + dummy_output = model(*dummy_input) + tfl_output = tfl_run_model(model_path, dummy_input, dummy_output) + assert_close(dummy_output, tfl_output) + + def test_slice_dyn_start(self): + dummy_input = [torch.randn(10, 10, dtype=torch.float32), torch.tensor(2)] + + def model(x, y): + return x[y:4] + + model_path = get_model_path() + + converter = TFLiteConverter(model, dummy_input, model_path, nchw_transpose=False) + converter.convert() + + dummy_output = model(*dummy_input) + tfl_output = tfl_run_model(model_path, dummy_input, dummy_output) + assert_close(dummy_output, tfl_output) + + def test_slice_dyn_end(self): + dummy_input = [torch.randn(10, 10, dtype=torch.float32), torch.tensor(4)] + + def model(x, y): + return x[2:y] + + model_path = get_model_path() + + converter = TFLiteConverter(model, dummy_input, model_path, nchw_transpose=False) + converter.convert() + + dummy_output = model(*dummy_input) + tfl_output = tfl_run_model(model_path, dummy_input, dummy_output) + assert_close(dummy_output, tfl_output) + def test_slice_no_end(self): dummy_input = torch.randn(10, 10, dtype=torch.float32) diff --git a/tinynn/converter/operators/graph.py b/tinynn/converter/operators/graph.py index 4c1fee63..9a0efebf 100644 --- a/tinynn/converter/operators/graph.py +++ b/tinynn/converter/operators/graph.py @@ -37,6 +37,7 @@ def __init__(self) -> None: self.node_op_counter = 0 self.q_mapping = {} self.transform_store = {} + self.constant_mapping = {} def add_transform_store(self, tensor_name: str, transform_name: str, new_tensor_name: str): self.transform_store.setdefault(tensor_name, {}) diff --git a/tinynn/converter/operators/torch/__init__.py b/tinynn/converter/operators/torch/__init__.py index 90d45429..c1ff35da 100644 --- a/tinynn/converter/operators/torch/__init__.py +++ b/tinynn/converter/operators/torch/__init__.py @@ -213,7 +213,7 @@ "quantized::linear_relu_dynamic": QuantizedLinearReluDynamicOperator, "quantized::elu": QuantizedEluOperator, # non tracking - "aten::Int": NoTrackOperator, + "aten::Int": TrackConstantOperator, "aten::zeros": NoTrackOperator, "aten::detach": NoTrackOperator, "aten::size": NoTrackOperator, @@ -224,5 +224,5 @@ "aten::empty": NoTrackOperator, "aten::new_zeros": NoTrackOperator, "aten::new_ones": NoTrackOperator, - "aten::ScalarImplicit": NoTrackOperator, + "aten::ScalarImplicit": TrackConstantOperator, } diff --git a/tinynn/converter/operators/torch/aten.py b/tinynn/converter/operators/torch/aten.py index d96fbc63..9c4ff969 100644 --- a/tinynn/converter/operators/torch/aten.py +++ b/tinynn/converter/operators/torch/aten.py @@ -1781,17 +1781,84 @@ def parse(self, node, attrs, args, graph_converter): starts = np.zeros(input_tensor.tensor.ndim, dtype='int32') starts[dim] = start - start_tensor = self.create_attr_tensor(starts) + if self.input_names[2] in graph_converter.constant_mapping: + start_t = graph_converter.constant_mapping[self.input_names[2]] + new_shape_arr = np.array((1,), dtype='int32') + new_shape_tensor = self.create_attr_tensor(new_shape_arr) + start_reshaped = self.create_transform_tensor(np.reshape(start_t.tensor, new_shape_arr)) + graph_converter.add_operator( + tfl.ReshapeOperator([start_t, new_shape_tensor], [start_reshaped], new_shape_arr) + ) + + start_casted = self.create_transform_tensor(start_reshaped.tensor.astype('int32')) + graph_converter.add_operator( + tfl.CastOperator( + [start_reshaped], + [start_casted], + tfl.numpy_tflite_dtype_mappings[str(start_reshaped.dtype)], + tfl.numpy_tflite_dtype_mappings[str(start_casted.dtype)], + ) + ) + + start_tensor = self.create_transform_tensor(starts) + starts_left = starts[:dim] + starts_right = starts[dim + 1 :] + starts_tensors = [] + if len(starts_left) > 0: + starts_tensors.append(self.create_attr_tensor(starts_left)) + starts_tensors.append(start_casted) + if len(starts_right) > 0: + starts_tensors.append(self.create_attr_tensor(starts_right)) + if len(starts_tensors) > 1: + graph_converter.add_operator(tfl.ConcatenationOperator(starts_tensors, [start_tensor], 0)) + else: + start_tensor = starts_tensors[0] + else: + start_tensor = self.create_attr_tensor(starts) - if step != 1: - # if True: - ends = np.array(input_tensor.tensor.shape, dtype='int32') + ends = np.array(input_tensor.tensor.shape, dtype='int32') + if step != 1 or start_tensor.buffer is None or self.input_names[3] in graph_converter.constant_mapping: ends[dim] = end + else: + ends[dim] = end - start + + if self.input_names[3] in graph_converter.constant_mapping: + end_t = graph_converter.constant_mapping[self.input_names[3]] + new_shape_arr = np.array((1,), dtype='int32') + new_shape_tensor = self.create_attr_tensor(new_shape_arr) + end_reshaped = self.create_transform_tensor(np.reshape(end_t.tensor, new_shape_arr)) + graph_converter.add_operator(tfl.ReshapeOperator([end_t, new_shape_tensor], [end_reshaped], new_shape_arr)) + end_casted = self.create_transform_tensor(end_reshaped.tensor.astype('int32')) + graph_converter.add_operator( + tfl.CastOperator( + [end_reshaped], + [end_casted], + tfl.numpy_tflite_dtype_mappings[str(end_reshaped.dtype)], + tfl.numpy_tflite_dtype_mappings[str(end_casted.dtype)], + ) + ) + + end_tensor = self.create_transform_tensor(ends) + ends_left = ends[:dim] + ends_right = ends[dim + 1 :] + ends_tensors = [] + if len(ends_left) > 0: + ends_tensors.append(self.create_attr_tensor(ends_left)) + ends_tensors.append(end_casted) + if len(ends_right) > 0: + ends_tensors.append(self.create_attr_tensor(ends_right)) + if len(ends_tensors) > 1: + graph_converter.add_operator(tfl.ConcatenationOperator(ends_tensors, [end_tensor], 0)) + else: + end_tensor = ends_tensors[0] + else: + end_tensor = self.create_attr_tensor(ends) + + if step != 1 or start_tensor.buffer is None or end_tensor.buffer is None: strides = np.ones(input_tensor.tensor.ndim, dtype='int32') strides[dim] = step - end_tensor = self.create_attr_tensor(ends) stride_tensor = self.create_attr_tensor(strides) inputs = [input_tensor, start_tensor, end_tensor, stride_tensor] @@ -1799,10 +1866,7 @@ def parse(self, node, attrs, args, graph_converter): graph_converter.add_operator(tfl.StridedSliceOperator(inputs, outputs)) else: - sizes = np.array(input_tensor.tensor.shape, dtype='int32') - sizes[dim] = end - start - - size_tensor = self.create_attr_tensor(sizes) + size_tensor = end_tensor inputs = [input_tensor, start_tensor, size_tensor] outputs = self.to_tfl_tensors(self.output_names, self.output_tensors) diff --git a/tinynn/converter/operators/torch/base.py b/tinynn/converter/operators/torch/base.py index 53847b45..3aa36cfb 100644 --- a/tinynn/converter/operators/torch/base.py +++ b/tinynn/converter/operators/torch/base.py @@ -757,6 +757,16 @@ def parse(self, node, attrs, args, graph_converter): graph_converter.q_mapping[self.output_names[0]] = t +class TrackConstantOperator(OperatorConverter): + def parse(self, node, attrs, args, graph_converter): + super().parse(node, attrs, args, graph_converter) + + self.run(node) + + t = self.find_or_create_input(0, graph_converter) + graph_converter.constant_mapping[self.output_names[0]] = t + + class PrimOperatorConverter(OperatorConverter): # prim::* ops needs custom implementation def run(self, node):