Skip to content

Commit

Permalink
[converter] add aten::index_put
Browse files Browse the repository at this point in the history
  • Loading branch information
peterjc123 committed Mar 22, 2024
1 parent 5be0192 commit 655dd32
Show file tree
Hide file tree
Showing 3 changed files with 208 additions and 0 deletions.
73 changes: 73 additions & 0 deletions tests/converter_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4197,6 +4197,79 @@ def model(x):
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output)

@unittest.skipIf(LooseVersion(tf.__version__) < LooseVersion('2.1.0'), 'scatter_nd is not supported')
def test_index_put(self):
dummy_input = torch.randn(3, dtype=torch.float32)

def model(x):
return torch.ones(10, dtype=torch.float32).index_put_((torch.tensor([1, 2, 3]),), x)

model_path = get_model_path()

converter = TFLiteConverter(model, dummy_input, model_path, nchw_transpose=False)
converter.convert()

dummy_output = model(dummy_input)
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output)

@unittest.skipIf(LooseVersion(tf.__version__) < LooseVersion('2.1.0'), 'scatter_nd is not supported')
def test_index_put_tensor(self):
dummy_input = torch.randn(3, dtype=torch.float32)
dummy_input_1 = torch.tensor([1, 2, 3])

dummy_input = (dummy_input, dummy_input_1)

def model(x, y):
return torch.ones(10, dtype=torch.float32).index_put_((y,), x)

model_path = get_model_path()

converter = TFLiteConverter(model, dummy_input, model_path, nchw_transpose=False)
converter.convert()

dummy_output = model(*dummy_input)
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output)

@unittest.skipIf(LooseVersion(tf.__version__) < LooseVersion('2.1.0'), 'scatter_nd is not supported')
def test_index_put_tensor_2d(self):
dummy_input = torch.randn(3, dtype=torch.float32)
dummy_input_1 = torch.tensor([1, 2, 3])

dummy_input = (dummy_input, dummy_input_1)

def model(x, y):
return torch.ones(10, 10, dtype=torch.float32).index_put_((y, y), x)

model_path = get_model_path()

converter = TFLiteConverter(model, dummy_input, model_path, nchw_transpose=False)
converter.convert()

dummy_output = model(*dummy_input)
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output)

@unittest.skipIf(LooseVersion(tf.__version__) < LooseVersion('2.1.0'), 'scatter_nd is not supported')
def test_index_put_tensor_2d_complex(self):
dummy_input = torch.randn(1, dtype=torch.float32)
dummy_input_1 = torch.tensor([1, 2, 3])

dummy_input = (dummy_input, dummy_input_1)

def model(x, y):
return torch.ones(10, 10, dtype=torch.float32).index_put_((y,), x)

model_path = get_model_path()

converter = TFLiteConverter(model, dummy_input, model_path, nchw_transpose=False)
converter.convert()

dummy_output = model(*dummy_input)
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output)

def test_var_std_no_dim(self):
dummy_input = torch.randn(1, 3, 224, 224, dtype=torch.float32)

Expand Down
2 changes: 2 additions & 0 deletions tinynn/converter/operators/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,8 @@
"aten::broadcast_tensors": ATenBroadcastTensorsOperator,
"aten::maximum": ATenMaximumOperator,
"aten::minimum": ATenMinimumOperator,
"aten::index_put": ATenIndexPutOperator,
"aten::index_put_": ATenIndexPutOperator,
# quantized
"aten::quantize_per_tensor": ATenQuantizePerTensorOperator,
"aten::fake_quantize_per_tensor_affine": ATenFakeQuantizePerTensorAffineOperator,
Expand Down
133 changes: 133 additions & 0 deletions tinynn/converter/operators/torch/aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -3054,6 +3054,139 @@ def parse(self, node, attrs, args, graph_converter):
)


class ATenIndexPutOperator(ATenIndexPutSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)

# torch.Tensor.index_put_ requires index tensor of type `torch.int64`
accumulate = self.input_tensors[3]
assert not accumulate, "aten::index_put_ with accumulate=True is not supported"

orig_type = self.input_tensors[1][0].dtype
self.input_tensors[1] = tuple([x.to(dtype=torch.int64) for x in self.input_tensors[1]])
self.run(node)

input_tensor = self.find_or_create_input(0, graph_converter)
output_tensor = self.to_tfl_tensors(self.output_names, self.output_tensors)[0]

self.input_tensors[1] = tuple([x.to(dtype=orig_type) for x in self.input_tensors[1]])

if graph_converter.has_nested_names(self.input_names[1]):
input_names = graph_converter.get_list_expanded_names(self.input_names[1])
indices_tensors = self.to_tfl_tensors(
input_names, self.input_tensors[1], graph_converter=graph_converter, non_existent_as_buffer=True
)
else:
indices_tensors = [self.find_or_create_input(1, graph_converter)]

dim = input_tensor.tensor.ndim

indices_shape = [x.tensor.size for x in indices_tensors]
max_len = max(indices_shape)
indices_shape_tensor = torch.tensor(indices_shape)
left_indices = (torch.arange(max_len).view(-1, 1).expand(-1, len(indices_shape)) % indices_shape_tensor).int()

if len(indices_tensors) < dim:
pad_shape = list(input_tensor.shape[len(indices_tensors) :])
pad_indices = torch.ones(pad_shape).nonzero().int()
left_len = len(indices_shape)
right_len = len(pad_shape)
left_size = left_indices.size(0)
right_size = pad_indices.size(0)
left_reshaped = left_indices.view(-1, 1, left_len).expand(-1, right_size, left_len).reshape(-1, left_len)
right_reshaped = pad_indices.view(1, -1, right_len).expand(left_size, -1, right_len).reshape(-1, right_len)
all_indices = torch.cat([left_reshaped, right_reshaped], 1).unbind(1)
else:
all_indices = left_indices.unbind(1)

new_indices = []
for i in range(dim):
if i < len(indices_tensors):
idx_tensor = indices_tensors[i]
actual_idx = np.take(idx_tensor.tensor, all_indices[i].numpy())
else:
actual_idx = all_indices[i].numpy()
if idx_tensor.buffer is None and i < len(indices_tensors):
actual_idx_t = self.create_transform_tensor(actual_idx)
fake_idx_t = self.create_attr_tensor(all_indices[i].numpy())
graph_converter.add_operator(tfl.GatherOperator([idx_tensor, fake_idx_t], [actual_idx_t], axis=0))

if str(actual_idx_t.dtype) != 'int32':
index_casted = self.create_transform_tensor(actual_idx_t.tensor.astype('int32'))
graph_converter.add_operator(
tfl.CastOperator(
[actual_idx_t],
[index_casted],
tfl.numpy_tflite_dtype_mappings[str(actual_idx_t.dtype)],
tfl.numpy_tflite_dtype_mappings[str(index_casted.dtype)],
)
)
actual_idx_t = index_casted
new_indices.append(actual_idx_t)
else:
new_indices.append(self.create_attr_tensor(actual_idx.astype(np.int32)))

index_arr = np.stack([x.tensor for x in new_indices], 1)
if all((x.buffer is not None for x in new_indices)):
index_tensor = self.create_attr_tensor(index_arr)
else:
index_tensor = self.create_transform_tensor(index_arr)
graph_converter.add_operator(tfl.PackOperator(new_indices, [index_tensor], dim, axis=1))

val_tensor = self.find_or_create_input(2, graph_converter)
actual_val = val_tensor
orig_val_shape = val_tensor.shape
target_val_shape = index_tensor.shape[:-1]
if orig_val_shape != target_val_shape:
if val_tensor.buffer is None:
new_shape = orig_val_shape
val_reshaped = val_tensor
if len(target_val_shape) > len(orig_val_shape):
new_shape = [1] * (len(target_val_shape) - len(orig_val_shape)) + list(orig_val_shape)
new_shape_arr = np.array(new_shape, dtype='int32')
new_shape_tensor = self.create_attr_tensor(new_shape_arr)
reshaped = self.create_transform_tensor(np.reshape(val_tensor.tensor, new_shape_arr))
val_reshaped = reshaped
reshape_op = tfl.ReshapeOperator([val_tensor, new_shape_tensor], [reshaped], new_shape_arr)
reshape_op.extra_hints['direction'] = 'up'
graph_converter.add_operator(reshape_op)

repeats = []
for x, y in zip(new_shape, target_val_shape):
if x != y:
repeats.append(y // x)
else:
repeats.append(1)

actual_val = self.create_transform_tensor(np.tile(val_reshaped.tensor, repeats))
repeat_tensor = self.create_attr_tensor(np.array(repeats, dtype='int32'))
graph_converter.add_operator(tfl.TileOperator([val_reshaped, repeat_tensor], [actual_val]))
else:
actual_val = self.create_attr_tensor(np.broadcast_to(val_tensor.tensor, target_val_shape))

shape_tensor = self.create_attr_tensor(np.array(input_tensor.shape, dtype='int32'))

if input_tensor.buffer is None or index_tensor.buffer is None:
old_val_tensor = self.create_transform_tensor(actual_val.tensor)
graph_converter.add_operator(tfl.GatherNdOperator([input_tensor, index_tensor], [old_val_tensor]))
else:
transformed_index = tuple(index_tensor.tensor[..., i] for i in range(index_tensor.shape[-1]))
old_val_tensor = self.create_attr_tensor(input_tensor.tensor[transformed_index])

if actual_val.buffer is None:
update_tensor = self.create_transform_tensor(actual_val.tensor - old_val_tensor.tensor)
graph_converter.add_operator(tfl.SubOperator([actual_val, old_val_tensor], [update_tensor]))
else:
update_tensor = self.create_attr_tensor(actual_val.tensor - old_val_tensor.tensor)

updated_tensor = self.create_transform_tensor(input_tensor.tensor)
graph_converter.add_operator(
tfl.ScatterNdOperator([index_tensor, update_tensor, shape_tensor], [updated_tensor])
)

graph_converter.add_operator(tfl.AddOperator([input_tensor, updated_tensor], [output_tensor]))


class ATenGeluOperator(ATenGeluSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
Expand Down

0 comments on commit 655dd32

Please sign in to comment.