Skip to content

Commit

Permalink
[converter] track dynamic inputs for slices (#284)
Browse files Browse the repository at this point in the history
* track dynamic inputs for slices

* update op matrix

* more fixes
  • Loading branch information
peterjc123 authored Mar 26, 2024
1 parent 663e7ce commit e506202
Show file tree
Hide file tree
Showing 7 changed files with 150 additions and 15 deletions.
9 changes: 7 additions & 2 deletions docs/op_matrix.md
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,6 @@ Operators that are implemented in Python
Non-tracking operators that are ignored during translation
| Operator | Limitations |
|---------------------------|--------------|
| `aten::Int` | |
| `aten::ScalarImplicit` | |
| `aten::arange` | |
| `aten::detach` | |
| `aten::empty` | |
Expand All @@ -232,3 +230,10 @@ Non-tracking operators that are ignored during translation
| `aten::size` | |
| `aten::zeros` | |
| `aten::zeros_like` | |

## Constant Tracking Operators
Tracking operators that produce a dynamic constant
| Operator | Limitations |
|---------------------------|--------------|
| `aten::Int` | |
| `aten::ScalarImplicit` | |
14 changes: 12 additions & 2 deletions scripts/gen_op_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import re

from tinynn.converter.operators.torch import OPERATOR_CONVERTER_DICT
from tinynn.converter.operators.torch.base import NoTrackOperator, PrimOperatorConverter
from tinynn.converter.operators.torch.base import NoTrackOperator, PrimOperatorConverter, TrackConstantOperator

CURRENT_PATH = os.path.abspath(os.path.dirname(__file__))

Expand All @@ -12,6 +12,7 @@
quantized_ops = []
torchvision_ops = []
passthrough_ops = []
track_constant_ops = []

limitation_dict = {}

Expand All @@ -22,13 +23,15 @@ def main():


def collect_ops():
global prim_ops, aten_ops, quantized_ops, torchvision_ops, passthrough_ops, limitation_dict
global prim_ops, aten_ops, quantized_ops, torchvision_ops, passthrough_ops, track_constant_ops, limitation_dict

for k, v in OPERATOR_CONVERTER_DICT.items():
if issubclass(v, PrimOperatorConverter):
prim_ops.append(k)
elif issubclass(v, NoTrackOperator):
passthrough_ops.append(k)
elif issubclass(v, TrackConstantOperator):
track_constant_ops.append(k)
else:
if v.__module__ == 'tinynn.converter.operators.torch.aten':
aten_ops.append(k)
Expand All @@ -50,6 +53,7 @@ def collect_ops():
quantized_ops = sorted(quantized_ops)
torchvision_ops = sorted(torchvision_ops)
passthrough_ops = sorted(passthrough_ops)
track_constant_ops = sorted(track_constant_ops)


def print_operators(topic, ops, f, desc=None, eol=True):
Expand Down Expand Up @@ -89,6 +93,12 @@ def update_file():
passthrough_ops,
f,
'Non-tracking operators that are ignored during translation',
)
print_operators(
'Constant Tracking Operators',
track_constant_ops,
f,
'Tracking operators that produce a dynamic constant',
False,
)

Expand Down
45 changes: 45 additions & 0 deletions tests/converter_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3853,6 +3853,51 @@ def model(x):
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output)

def test_slice_dyn_input(self):
dummy_input = [torch.randn(10, 10, dtype=torch.float32), torch.tensor(2), torch.tensor(4)]

def model(x, y, z):
return x[y:z]

model_path = get_model_path()

converter = TFLiteConverter(model, dummy_input, model_path, nchw_transpose=False)
converter.convert()

dummy_output = model(*dummy_input)
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output)

def test_slice_dyn_start(self):
dummy_input = [torch.randn(10, 10, dtype=torch.float32), torch.tensor(2)]

def model(x, y):
return x[y:4]

model_path = get_model_path()

converter = TFLiteConverter(model, dummy_input, model_path, nchw_transpose=False)
converter.convert()

dummy_output = model(*dummy_input)
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output)

def test_slice_dyn_end(self):
dummy_input = [torch.randn(10, 10, dtype=torch.float32), torch.tensor(4)]

def model(x, y):
return x[2:y]

model_path = get_model_path()

converter = TFLiteConverter(model, dummy_input, model_path, nchw_transpose=False)
converter.convert()

dummy_output = model(*dummy_input)
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output)

def test_slice_no_end(self):
dummy_input = torch.randn(10, 10, dtype=torch.float32)

Expand Down
1 change: 1 addition & 0 deletions tinynn/converter/operators/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(self) -> None:
self.node_op_counter = 0
self.q_mapping = {}
self.transform_store = {}
self.constant_mapping = {}

def add_transform_store(self, tensor_name: str, transform_name: str, new_tensor_name: str):
self.transform_store.setdefault(tensor_name, {})
Expand Down
4 changes: 2 additions & 2 deletions tinynn/converter/operators/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@
"quantized::linear_relu_dynamic": QuantizedLinearReluDynamicOperator,
"quantized::elu": QuantizedEluOperator,
# non tracking
"aten::Int": NoTrackOperator,
"aten::Int": TrackConstantOperator,
"aten::zeros": NoTrackOperator,
"aten::detach": NoTrackOperator,
"aten::size": NoTrackOperator,
Expand All @@ -224,5 +224,5 @@
"aten::empty": NoTrackOperator,
"aten::new_zeros": NoTrackOperator,
"aten::new_ones": NoTrackOperator,
"aten::ScalarImplicit": NoTrackOperator,
"aten::ScalarImplicit": TrackConstantOperator,
}
82 changes: 73 additions & 9 deletions tinynn/converter/operators/torch/aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -1781,28 +1781,92 @@ def parse(self, node, attrs, args, graph_converter):
starts = np.zeros(input_tensor.tensor.ndim, dtype='int32')
starts[dim] = start

start_tensor = self.create_attr_tensor(starts)
if self.input_names[2] in graph_converter.constant_mapping:
start_t = graph_converter.constant_mapping[self.input_names[2]]
new_shape_arr = np.array((1,), dtype='int32')
new_shape_tensor = self.create_attr_tensor(new_shape_arr)
start_reshaped = self.create_transform_tensor(np.reshape(start_t.tensor, new_shape_arr))
graph_converter.add_operator(
tfl.ReshapeOperator([start_t, new_shape_tensor], [start_reshaped], new_shape_arr)
)

start_casted = self.create_transform_tensor(start_reshaped.tensor.astype('int32'))
graph_converter.add_operator(
tfl.CastOperator(
[start_reshaped],
[start_casted],
tfl.numpy_tflite_dtype_mappings[str(start_reshaped.dtype)],
tfl.numpy_tflite_dtype_mappings[str(start_casted.dtype)],
)
)

start_tensor = self.create_transform_tensor(starts)
starts_left = starts[:dim]
starts_right = starts[dim + 1 :]
starts_tensors = []
if len(starts_left) > 0:
starts_tensors.append(self.create_attr_tensor(starts_left))
starts_tensors.append(start_casted)
if len(starts_right) > 0:
starts_tensors.append(self.create_attr_tensor(starts_right))
if len(starts_tensors) > 1:
graph_converter.add_operator(tfl.ConcatenationOperator(starts_tensors, [start_tensor], 0))
else:
start_tensor = starts_tensors[0]
else:
start_tensor = self.create_attr_tensor(starts)

if step != 1:
# if True:
ends = np.array(input_tensor.tensor.shape, dtype='int32')
ends = np.array(input_tensor.tensor.shape, dtype='int32')
if step != 1 or start_tensor.buffer is None or self.input_names[3] in graph_converter.constant_mapping:
ends[dim] = end
else:
ends[dim] = end - start

if self.input_names[3] in graph_converter.constant_mapping:
end_t = graph_converter.constant_mapping[self.input_names[3]]
new_shape_arr = np.array((1,), dtype='int32')
new_shape_tensor = self.create_attr_tensor(new_shape_arr)
end_reshaped = self.create_transform_tensor(np.reshape(end_t.tensor, new_shape_arr))
graph_converter.add_operator(tfl.ReshapeOperator([end_t, new_shape_tensor], [end_reshaped], new_shape_arr))

end_casted = self.create_transform_tensor(end_reshaped.tensor.astype('int32'))
graph_converter.add_operator(
tfl.CastOperator(
[end_reshaped],
[end_casted],
tfl.numpy_tflite_dtype_mappings[str(end_reshaped.dtype)],
tfl.numpy_tflite_dtype_mappings[str(end_casted.dtype)],
)
)

end_tensor = self.create_transform_tensor(ends)
ends_left = ends[:dim]
ends_right = ends[dim + 1 :]
ends_tensors = []
if len(ends_left) > 0:
ends_tensors.append(self.create_attr_tensor(ends_left))
ends_tensors.append(end_casted)
if len(ends_right) > 0:
ends_tensors.append(self.create_attr_tensor(ends_right))
if len(ends_tensors) > 1:
graph_converter.add_operator(tfl.ConcatenationOperator(ends_tensors, [end_tensor], 0))
else:
end_tensor = ends_tensors[0]
else:
end_tensor = self.create_attr_tensor(ends)

if step != 1 or start_tensor.buffer is None or end_tensor.buffer is None:
strides = np.ones(input_tensor.tensor.ndim, dtype='int32')
strides[dim] = step

end_tensor = self.create_attr_tensor(ends)
stride_tensor = self.create_attr_tensor(strides)

inputs = [input_tensor, start_tensor, end_tensor, stride_tensor]
outputs = self.to_tfl_tensors(self.output_names, self.output_tensors)

graph_converter.add_operator(tfl.StridedSliceOperator(inputs, outputs))
else:
sizes = np.array(input_tensor.tensor.shape, dtype='int32')
sizes[dim] = end - start

size_tensor = self.create_attr_tensor(sizes)
size_tensor = end_tensor
inputs = [input_tensor, start_tensor, size_tensor]
outputs = self.to_tfl_tensors(self.output_names, self.output_tensors)

Expand Down
10 changes: 10 additions & 0 deletions tinynn/converter/operators/torch/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,6 +757,16 @@ def parse(self, node, attrs, args, graph_converter):
graph_converter.q_mapping[self.output_names[0]] = t


class TrackConstantOperator(OperatorConverter):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)

self.run(node)

t = self.find_or_create_input(0, graph_converter)
graph_converter.constant_mapping[self.output_names[0]] = t


class PrimOperatorConverter(OperatorConverter):
# prim::* ops needs custom implementation
def run(self, node):
Expand Down

0 comments on commit e506202

Please sign in to comment.