From a2d7517071a3cb5d5082b4d963317567e75cee1f Mon Sep 17 00:00:00 2001 From: peterjc123 Date: Tue, 5 Dec 2023 06:40:28 +0000 Subject: [PATCH 1/2] [converter] support aten::{maximum, minimum} --- tests/converter_op_test.py | 14 ++++++++--- tinynn/converter/operators/torch/__init__.py | 2 ++ tinynn/converter/operators/torch/aten.py | 26 ++++++++++++++++++-- tinynn/converter/operators/torch/base.py | 1 + 4 files changed, 38 insertions(+), 5 deletions(-) diff --git a/tests/converter_op_test.py b/tests/converter_op_test.py index f89a7ea7..b10c81c4 100644 --- a/tests/converter_op_test.py +++ b/tests/converter_op_test.py @@ -86,14 +86,14 @@ def test_sign(self): def model(x): return torch.sign(x) - + model_path = get_model_path() converter = TFLiteConverter(model, dummy_input, model_path, nchw_transpose=False) converter.convert() dummy_output = model(dummy_input) tfl_output = tfl_run_model(model_path, dummy_input, dummy_output) - assert_close(dummy_output, tfl_output) + assert_close(dummy_output, tfl_output) def test_masked_fill(self): class TestModel(nn.Module): @@ -188,6 +188,8 @@ def test_binary_elementwise_same_dtype(self): (torch, 'eq'), (torch, 'ne'), (torch, 'rsub'), + (torch, 'maximum'), + (torch, 'minimum'), ] funcs = [getattr(ns, attr) for ns, attr in func_names if hasattr(ns, attr)] @@ -224,6 +226,8 @@ def test_binary_elementwise_constant_same_dtype(self): (torch, 'eq'), (torch, 'ne'), (torch, 'rsub'), + (torch, 'maximum'), + (torch, 'minimum'), ] funcs = [getattr(ns, attr) for ns, attr in func_names if hasattr(ns, attr)] @@ -259,6 +263,8 @@ def test_binary_elementwise_different_dtype(self): (torch, 'eq'), (torch, 'ne'), (torch, 'rsub'), + (torch, 'maximum'), + (torch, 'minimum'), ] funcs = [getattr(ns, attr) for ns, attr in func_names if hasattr(ns, attr)] @@ -295,6 +301,8 @@ def test_binary_elementwise_constant_different_dtype(self): (torch, 'eq'), (torch, 'ne'), (torch, 'rsub'), + (torch, 'maximum'), + (torch, 'minimum'), ] funcs = [getattr(ns, attr) for ns, attr in func_names if hasattr(ns, attr)] @@ -5294,7 +5302,7 @@ def model(x): dummy_output = model(dummy_input) tfl_output = tfl_run_model(model_path, dummy_input, dummy_output) assert_close(dummy_output, tfl_output, atol=256.0, rtol=256.0) - + @unittest.skipIf(not hasattr(torch, 'norm'), "Norm is not supported") def test_norm_p1(self): dummy_input = torch.randn(10, 10, dtype=torch.float32) diff --git a/tinynn/converter/operators/torch/__init__.py b/tinynn/converter/operators/torch/__init__.py index 558dcea8..509f8d14 100644 --- a/tinynn/converter/operators/torch/__init__.py +++ b/tinynn/converter/operators/torch/__init__.py @@ -179,6 +179,8 @@ "aten::baddbmm": ATenBaddbmmOperator, "aten::linalg_vector_norm": ATenLinalgVectorNormOperator, "aten::broadcast_tensors": ATenBroadcastTensorsOperator, + "aten::maximum": ATenMaximumOperator, + "aten::minimum": ATenMinimumOperator, # quantized "aten::quantize_per_tensor": ATenQuantizePerTensorOperator, "aten::fake_quantize_per_tensor_affine": ATenFakeQuantizePerTensorAffineOperator, diff --git a/tinynn/converter/operators/torch/aten.py b/tinynn/converter/operators/torch/aten.py index 59dc4199..e4416c18 100644 --- a/tinynn/converter/operators/torch/aten.py +++ b/tinynn/converter/operators/torch/aten.py @@ -15,7 +15,7 @@ class AtenSignOperator(ATenSignSchema): def parse(self, node, attrs, args, graph_converter): - + super().parse(node, attrs, args, graph_converter) self.run(node) @@ -3453,6 +3453,28 @@ def parse_common(self, graph_converter, input_idx=0, mask_idx=1, other_idx=2, ou graph_converter.add_operator(op) +class ATenMaximumOperator(ATenMaximumSchema): + def parse(self, node, attrs, args, graph_converter): + super().parse(node, attrs, args, graph_converter) + + self.run(node) + if type(self.input_tensors[1]) != torch.Tensor: + self.input_tensors[1] = self.torch_tensor_from_scalar(self.input_tensors[0], self.input_tensors[1]) + + self.elementwise_binary(tfl.MaximumOperator, graph_converter, True) + + +class ATenMinimumOperator(ATenMinimumSchema): + def parse(self, node, attrs, args, graph_converter): + super().parse(node, attrs, args, graph_converter) + + self.run(node) + if type(self.input_tensors[1]) != torch.Tensor: + self.input_tensors[1] = self.torch_tensor_from_scalar(self.input_tensors[0], self.input_tensors[1]) + + self.elementwise_binary(tfl.MinimumOperator, graph_converter, True) + + class ATenGtOperator(ATenGtSchema): def parse(self, node, attrs, args, graph_converter): super().parse(node, attrs, args, graph_converter) @@ -3792,7 +3814,7 @@ def parse(self, node, attrs, args, graph_converter): self.run(node) self.parse_common(node, attrs, args, graph_converter) - + class ATenLinalgVectorNormOperator(ATenLinalgVectorNormSchema): def parse(self, node, attrs, args, graph_converter): super().parse(node, attrs, args, graph_converter) diff --git a/tinynn/converter/operators/torch/base.py b/tinynn/converter/operators/torch/base.py index 53847b45..8ed94f08 100644 --- a/tinynn/converter/operators/torch/base.py +++ b/tinynn/converter/operators/torch/base.py @@ -545,6 +545,7 @@ def get_minimum_constant(self, ref_tensor): def handle_reduce(self, converter_class, input_args, graph_converter, transpose_opt, *args, **kwargs): input_tensor = self.find_or_create_input(0, graph_converter) + print(input_args) if 'dim' in input_args and 'keepdim' in input_args: dims, keep_dim = self.input_tensors[1:3] From 353a0baeff32347f3f8751ee450a0d37359db4a9 Mon Sep 17 00:00:00 2001 From: peterjc123 Date: Tue, 5 Dec 2023 06:43:16 +0000 Subject: [PATCH 2/2] minor fixes --- docs/op_matrix.md | 2 ++ tinynn/converter/operators/torch/base.py | 1 - 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/op_matrix.md b/docs/op_matrix.md index f0f105fe..0b0a9173 100644 --- a/docs/op_matrix.md +++ b/docs/op_matrix.md @@ -111,9 +111,11 @@ Operators that are implemented in Python | `aten::matmul` | | | `aten::max` | | | `aten::max_pool2d` | Only dilation == 1 is supported | +| `aten::maximum` | | | `aten::mean` | | | `aten::meshgrid` | aten::meshgrid for dynamic tensors is not supported | | `aten::min` | | +| `aten::minimum` | | | `aten::mish` | | | `aten::mm` | | | `aten::mul` | | diff --git a/tinynn/converter/operators/torch/base.py b/tinynn/converter/operators/torch/base.py index 8ed94f08..53847b45 100644 --- a/tinynn/converter/operators/torch/base.py +++ b/tinynn/converter/operators/torch/base.py @@ -545,7 +545,6 @@ def get_minimum_constant(self, ref_tensor): def handle_reduce(self, converter_class, input_args, graph_converter, transpose_opt, *args, **kwargs): input_tensor = self.find_or_create_input(0, graph_converter) - print(input_args) if 'dim' in input_args and 'keepdim' in input_args: dims, keep_dim = self.input_tensors[1:3]