diff --git a/tests/converter_op_test.py b/tests/converter_op_test.py index b10c81c4..600dd62d 100644 --- a/tests/converter_op_test.py +++ b/tests/converter_op_test.py @@ -785,6 +785,28 @@ def model(x): with self.assertRaisesRegex(AssertionError, r'.* (are not close!|exceeded the margin of error).*'): assert_close(dummy_output, tfl_output) + def test_same_prelu_for_different_channels(self): + class Model(nn.Module): + def __init__(self) -> None: + super().__init__() + self.prelu = nn.PReLU() + + def forward(self, x): + return self.prelu(x), self.prelu(x[:, 0:1]) + + model = Model() + model.eval() + + dummy_input = torch.rand(1, 3, 224, 224) + + model_path = get_model_path() + converter = TFLiteConverter(model, dummy_input, model_path, nchw_transpose=False) + converter.convert() + + dummy_output = model(dummy_input) + tfl_output = tfl_run_model(model_path, dummy_input, dummy_output) + assert_close(dummy_output, tfl_output) + def test_prelu(self): class Model(nn.Module): def __init__(self) -> None: diff --git a/tinynn/converter/operators/graph.py b/tinynn/converter/operators/graph.py index e8d81591..4c1fee63 100644 --- a/tinynn/converter/operators/graph.py +++ b/tinynn/converter/operators/graph.py @@ -36,6 +36,16 @@ def __init__(self) -> None: self.output_transpose = None self.node_op_counter = 0 self.q_mapping = {} + self.transform_store = {} + + def add_transform_store(self, tensor_name: str, transform_name: str, new_tensor_name: str): + self.transform_store.setdefault(tensor_name, {}) + self.transform_store[tensor_name][transform_name] = new_tensor_name + + def get_transform_store(self, tensor_name: str, transform_name: str) -> typing.Optional[tfl.Tensor]: + if tensor_name not in self.transform_store: + return None + return self.transform_store[tensor_name].get(transform_name, None) def add_iterable_pair( self, input_names: typing.List[str], output_names: typing.List[str], key: typing.Optional[str] = None diff --git a/tinynn/converter/operators/torch/aten.py b/tinynn/converter/operators/torch/aten.py index cbe38f08..b2b4e99c 100644 --- a/tinynn/converter/operators/torch/aten.py +++ b/tinynn/converter/operators/torch/aten.py @@ -1562,7 +1562,7 @@ def parse(self, node, attrs, args, graph_converter): alpha_tensor = self.find_or_create_input(1, graph_converter) shape_tensor = self.create_attr_tensor(np.array(new_shape, dtype='int32')) - update_name = True + update_name = None if weight_c == input_c: new_alpha = self.create_transform_tensor(np.reshape(alpha_tensor.tensor, new_shape)) graph_converter.add_operator(tfl.ReshapeOperator([alpha_tensor, shape_tensor], [new_alpha], new_shape)) @@ -1571,12 +1571,19 @@ def parse(self, node, attrs, args, graph_converter): if alpha_tensor.buffer is None: graph_converter.add_operator(tfl.TileOperator([alpha_tensor, shape_tensor], [new_alpha])) else: - update_name = False - new_alpha = new_alpha.tensor + store = graph_converter.get_transform_store(alpha_tensor.name, str(input_c)) + if store is None: + graph_converter.add_transform_store(alpha_tensor.name, str(input_c), new_alpha.name) + update_name = new_alpha.name + new_alpha = new_alpha.tensor + else: + update_name = store self.input_tensors[1] = new_alpha - if update_name: + if update_name is None: self.input_names[1] = new_alpha.name + else: + self.input_names[1] = update_name self.elementwise_binary(tfl.PreluOperator, graph_converter, False)