From 655dd32dcc58e156c5f390d2a2ab651190046295 Mon Sep 17 00:00:00 2001 From: peterjc123 Date: Fri, 22 Mar 2024 07:09:58 +0000 Subject: [PATCH] [converter] add aten::index_put --- tests/converter_op_test.py | 73 ++++++++++ tinynn/converter/operators/torch/__init__.py | 2 + tinynn/converter/operators/torch/aten.py | 133 +++++++++++++++++++ 3 files changed, 208 insertions(+) diff --git a/tests/converter_op_test.py b/tests/converter_op_test.py index 600dd62d..ebe454dc 100644 --- a/tests/converter_op_test.py +++ b/tests/converter_op_test.py @@ -4197,6 +4197,79 @@ def model(x): tfl_output = tfl_run_model(model_path, dummy_input, dummy_output) assert_close(dummy_output, tfl_output) + @unittest.skipIf(LooseVersion(tf.__version__) < LooseVersion('2.1.0'), 'scatter_nd is not supported') + def test_index_put(self): + dummy_input = torch.randn(3, dtype=torch.float32) + + def model(x): + return torch.ones(10, dtype=torch.float32).index_put_((torch.tensor([1, 2, 3]),), x) + + model_path = get_model_path() + + converter = TFLiteConverter(model, dummy_input, model_path, nchw_transpose=False) + converter.convert() + + dummy_output = model(dummy_input) + tfl_output = tfl_run_model(model_path, dummy_input, dummy_output) + assert_close(dummy_output, tfl_output) + + @unittest.skipIf(LooseVersion(tf.__version__) < LooseVersion('2.1.0'), 'scatter_nd is not supported') + def test_index_put_tensor(self): + dummy_input = torch.randn(3, dtype=torch.float32) + dummy_input_1 = torch.tensor([1, 2, 3]) + + dummy_input = (dummy_input, dummy_input_1) + + def model(x, y): + return torch.ones(10, dtype=torch.float32).index_put_((y,), x) + + model_path = get_model_path() + + converter = TFLiteConverter(model, dummy_input, model_path, nchw_transpose=False) + converter.convert() + + dummy_output = model(*dummy_input) + tfl_output = tfl_run_model(model_path, dummy_input, dummy_output) + assert_close(dummy_output, tfl_output) + + @unittest.skipIf(LooseVersion(tf.__version__) < LooseVersion('2.1.0'), 'scatter_nd is not supported') + def test_index_put_tensor_2d(self): + dummy_input = torch.randn(3, dtype=torch.float32) + dummy_input_1 = torch.tensor([1, 2, 3]) + + dummy_input = (dummy_input, dummy_input_1) + + def model(x, y): + return torch.ones(10, 10, dtype=torch.float32).index_put_((y, y), x) + + model_path = get_model_path() + + converter = TFLiteConverter(model, dummy_input, model_path, nchw_transpose=False) + converter.convert() + + dummy_output = model(*dummy_input) + tfl_output = tfl_run_model(model_path, dummy_input, dummy_output) + assert_close(dummy_output, tfl_output) + + @unittest.skipIf(LooseVersion(tf.__version__) < LooseVersion('2.1.0'), 'scatter_nd is not supported') + def test_index_put_tensor_2d_complex(self): + dummy_input = torch.randn(1, dtype=torch.float32) + dummy_input_1 = torch.tensor([1, 2, 3]) + + dummy_input = (dummy_input, dummy_input_1) + + def model(x, y): + return torch.ones(10, 10, dtype=torch.float32).index_put_((y,), x) + + model_path = get_model_path() + + converter = TFLiteConverter(model, dummy_input, model_path, nchw_transpose=False) + converter.convert() + + dummy_output = model(*dummy_input) + tfl_output = tfl_run_model(model_path, dummy_input, dummy_output) + assert_close(dummy_output, tfl_output) + def test_var_std_no_dim(self): dummy_input = torch.randn(1, 3, 224, 224, dtype=torch.float32) diff --git a/tinynn/converter/operators/torch/__init__.py b/tinynn/converter/operators/torch/__init__.py index 509f8d14..90d45429 100644 --- a/tinynn/converter/operators/torch/__init__.py +++ b/tinynn/converter/operators/torch/__init__.py @@ -181,6 +181,8 @@ "aten::broadcast_tensors": ATenBroadcastTensorsOperator, "aten::maximum": ATenMaximumOperator, "aten::minimum": ATenMinimumOperator, + "aten::index_put": ATenIndexPutOperator, + "aten::index_put_": ATenIndexPutOperator, # quantized "aten::quantize_per_tensor": ATenQuantizePerTensorOperator, "aten::fake_quantize_per_tensor_affine": ATenFakeQuantizePerTensorAffineOperator, diff --git a/tinynn/converter/operators/torch/aten.py b/tinynn/converter/operators/torch/aten.py index b2b4e99c..685c2d7a 100644 --- a/tinynn/converter/operators/torch/aten.py +++ b/tinynn/converter/operators/torch/aten.py @@ -3054,6 +3054,139 @@ def parse(self, node, attrs, args, graph_converter): ) +class ATenIndexPutOperator(ATenIndexPutSchema): + def parse(self, node, attrs, args, graph_converter): + super().parse(node, attrs, args, graph_converter) + + # torch.Tensor.index_put_ requires index tensor of type `torch.int64` + accumulate = self.input_tensors[3] + assert not accumulate, "aten::index_put_ with accumulate=True is not supported" + + orig_type = self.input_tensors[1][0].dtype + self.input_tensors[1] = tuple([x.to(dtype=torch.int64) for x in self.input_tensors[1]]) + self.run(node) + + input_tensor = self.find_or_create_input(0, graph_converter) + output_tensor = self.to_tfl_tensors(self.output_names, self.output_tensors)[0] + + self.input_tensors[1] = tuple([x.to(dtype=orig_type) for x in self.input_tensors[1]]) + + if graph_converter.has_nested_names(self.input_names[1]): + input_names = graph_converter.get_list_expanded_names(self.input_names[1]) + indices_tensors = self.to_tfl_tensors( + input_names, self.input_tensors[1], graph_converter=graph_converter, non_existent_as_buffer=True + ) + else: + indices_tensors = [self.find_or_create_input(1, graph_converter)] + + dim = input_tensor.tensor.ndim + + indices_shape = [x.tensor.size for x in indices_tensors] + max_len = max(indices_shape) + indices_shape_tensor = torch.tensor(indices_shape) + left_indices = (torch.arange(max_len).view(-1, 1).expand(-1, len(indices_shape)) % indices_shape_tensor).int() + + if len(indices_tensors) < dim: + pad_shape = list(input_tensor.shape[len(indices_tensors) :]) + pad_indices = torch.ones(pad_shape).nonzero().int() + left_len = len(indices_shape) + right_len = len(pad_shape) + left_size = left_indices.size(0) + right_size = pad_indices.size(0) + left_reshaped = left_indices.view(-1, 1, left_len).expand(-1, right_size, left_len).reshape(-1, left_len) + right_reshaped = pad_indices.view(1, -1, right_len).expand(left_size, -1, right_len).reshape(-1, right_len) + all_indices = torch.cat([left_reshaped, right_reshaped], 1).unbind(1) + else: + all_indices = left_indices.unbind(1) + + new_indices = [] + for i in range(dim): + if i < len(indices_tensors): + idx_tensor = indices_tensors[i] + actual_idx = np.take(idx_tensor.tensor, all_indices[i].numpy()) + else: + actual_idx = all_indices[i].numpy() + if idx_tensor.buffer is None and i < len(indices_tensors): + actual_idx_t = self.create_transform_tensor(actual_idx) + fake_idx_t = self.create_attr_tensor(all_indices[i].numpy()) + graph_converter.add_operator(tfl.GatherOperator([idx_tensor, fake_idx_t], [actual_idx_t], axis=0)) + + if str(actual_idx_t.dtype) != 'int32': + index_casted = self.create_transform_tensor(actual_idx_t.tensor.astype('int32')) + graph_converter.add_operator( + tfl.CastOperator( + [actual_idx_t], + [index_casted], + tfl.numpy_tflite_dtype_mappings[str(actual_idx_t.dtype)], + tfl.numpy_tflite_dtype_mappings[str(index_casted.dtype)], + ) + ) + actual_idx_t = index_casted + new_indices.append(actual_idx_t) + else: + new_indices.append(self.create_attr_tensor(actual_idx.astype(np.int32))) + + index_arr = np.stack([x.tensor for x in new_indices], 1) + if all((x.buffer is not None for x in new_indices)): + index_tensor = self.create_attr_tensor(index_arr) + else: + index_tensor = self.create_transform_tensor(index_arr) + graph_converter.add_operator(tfl.PackOperator(new_indices, [index_tensor], dim, axis=1)) + + val_tensor = self.find_or_create_input(2, graph_converter) + actual_val = val_tensor + orig_val_shape = val_tensor.shape + target_val_shape = index_tensor.shape[:-1] + if orig_val_shape != target_val_shape: + if val_tensor.buffer is None: + new_shape = orig_val_shape + val_reshaped = val_tensor + if len(target_val_shape) > len(orig_val_shape): + new_shape = [1] * (len(target_val_shape) - len(orig_val_shape)) + list(orig_val_shape) + new_shape_arr = np.array(new_shape, dtype='int32') + new_shape_tensor = self.create_attr_tensor(new_shape_arr) + reshaped = self.create_transform_tensor(np.reshape(val_tensor.tensor, new_shape_arr)) + val_reshaped = reshaped + reshape_op = tfl.ReshapeOperator([val_tensor, new_shape_tensor], [reshaped], new_shape_arr) + reshape_op.extra_hints['direction'] = 'up' + graph_converter.add_operator(reshape_op) + + repeats = [] + for x, y in zip(new_shape, target_val_shape): + if x != y: + repeats.append(y // x) + else: + repeats.append(1) + + actual_val = self.create_transform_tensor(np.tile(val_reshaped.tensor, repeats)) + repeat_tensor = self.create_attr_tensor(np.array(repeats, dtype='int32')) + graph_converter.add_operator(tfl.TileOperator([val_reshaped, repeat_tensor], [actual_val])) + else: + actual_val = self.create_attr_tensor(np.broadcast_to(val_tensor.tensor, target_val_shape)) + + shape_tensor = self.create_attr_tensor(np.array(input_tensor.shape, dtype='int32')) + + if input_tensor.buffer is None or index_tensor.buffer is None: + old_val_tensor = self.create_transform_tensor(actual_val.tensor) + graph_converter.add_operator(tfl.GatherNdOperator([input_tensor, index_tensor], [old_val_tensor])) + else: + transformed_index = tuple(index_tensor.tensor[..., i] for i in range(index_tensor.shape[-1])) + old_val_tensor = self.create_attr_tensor(input_tensor.tensor[transformed_index]) + + if actual_val.buffer is None: + update_tensor = self.create_transform_tensor(actual_val.tensor - old_val_tensor.tensor) + graph_converter.add_operator(tfl.SubOperator([actual_val, old_val_tensor], [update_tensor])) + else: + update_tensor = self.create_attr_tensor(actual_val.tensor - old_val_tensor.tensor) + + updated_tensor = self.create_transform_tensor(input_tensor.tensor) + graph_converter.add_operator( + tfl.ScatterNdOperator([index_tensor, update_tensor, shape_tensor], [updated_tensor]) + ) + + graph_converter.add_operator(tfl.AddOperator([input_tensor, updated_tensor], [output_tensor])) + + class ATenGeluOperator(ATenGeluSchema): def parse(self, node, attrs, args, graph_converter): super().parse(node, attrs, args, graph_converter)