Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[converter] track dynamic inputs for slices #284

Merged
merged 3 commits into from
Mar 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions docs/op_matrix.md
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,6 @@ Operators that are implemented in Python
Non-tracking operators that are ignored during translation
| Operator | Limitations |
|---------------------------|--------------|
| `aten::Int` | |
| `aten::ScalarImplicit` | |
| `aten::arange` | |
| `aten::detach` | |
| `aten::empty` | |
Expand All @@ -232,3 +230,10 @@ Non-tracking operators that are ignored during translation
| `aten::size` | |
| `aten::zeros` | |
| `aten::zeros_like` | |

## Constant Tracking Operators
Tracking operators that produce a dynamic constant
| Operator | Limitations |
|---------------------------|--------------|
| `aten::Int` | |
| `aten::ScalarImplicit` | |
14 changes: 12 additions & 2 deletions scripts/gen_op_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import re

from tinynn.converter.operators.torch import OPERATOR_CONVERTER_DICT
from tinynn.converter.operators.torch.base import NoTrackOperator, PrimOperatorConverter
from tinynn.converter.operators.torch.base import NoTrackOperator, PrimOperatorConverter, TrackConstantOperator

CURRENT_PATH = os.path.abspath(os.path.dirname(__file__))

Expand All @@ -12,6 +12,7 @@
quantized_ops = []
torchvision_ops = []
passthrough_ops = []
track_constant_ops = []

limitation_dict = {}

Expand All @@ -22,13 +23,15 @@ def main():


def collect_ops():
global prim_ops, aten_ops, quantized_ops, torchvision_ops, passthrough_ops, limitation_dict
global prim_ops, aten_ops, quantized_ops, torchvision_ops, passthrough_ops, track_constant_ops, limitation_dict

for k, v in OPERATOR_CONVERTER_DICT.items():
if issubclass(v, PrimOperatorConverter):
prim_ops.append(k)
elif issubclass(v, NoTrackOperator):
passthrough_ops.append(k)
elif issubclass(v, TrackConstantOperator):
track_constant_ops.append(k)
else:
if v.__module__ == 'tinynn.converter.operators.torch.aten':
aten_ops.append(k)
Expand All @@ -50,6 +53,7 @@ def collect_ops():
quantized_ops = sorted(quantized_ops)
torchvision_ops = sorted(torchvision_ops)
passthrough_ops = sorted(passthrough_ops)
track_constant_ops = sorted(track_constant_ops)


def print_operators(topic, ops, f, desc=None, eol=True):
Expand Down Expand Up @@ -89,6 +93,12 @@ def update_file():
passthrough_ops,
f,
'Non-tracking operators that are ignored during translation',
)
print_operators(
'Constant Tracking Operators',
track_constant_ops,
f,
'Tracking operators that produce a dynamic constant',
False,
)

Expand Down
45 changes: 45 additions & 0 deletions tests/converter_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3853,6 +3853,51 @@ def model(x):
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output)

def test_slice_dyn_input(self):
dummy_input = [torch.randn(10, 10, dtype=torch.float32), torch.tensor(2), torch.tensor(4)]

def model(x, y, z):
return x[y:z]

model_path = get_model_path()

converter = TFLiteConverter(model, dummy_input, model_path, nchw_transpose=False)
converter.convert()

dummy_output = model(*dummy_input)
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output)

def test_slice_dyn_start(self):
dummy_input = [torch.randn(10, 10, dtype=torch.float32), torch.tensor(2)]

def model(x, y):
return x[y:4]

model_path = get_model_path()

converter = TFLiteConverter(model, dummy_input, model_path, nchw_transpose=False)
converter.convert()

dummy_output = model(*dummy_input)
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output)

def test_slice_dyn_end(self):
dummy_input = [torch.randn(10, 10, dtype=torch.float32), torch.tensor(4)]

def model(x, y):
return x[2:y]

model_path = get_model_path()

converter = TFLiteConverter(model, dummy_input, model_path, nchw_transpose=False)
converter.convert()

dummy_output = model(*dummy_input)
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output)

def test_slice_no_end(self):
dummy_input = torch.randn(10, 10, dtype=torch.float32)

Expand Down
1 change: 1 addition & 0 deletions tinynn/converter/operators/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(self) -> None:
self.node_op_counter = 0
self.q_mapping = {}
self.transform_store = {}
self.constant_mapping = {}

def add_transform_store(self, tensor_name: str, transform_name: str, new_tensor_name: str):
self.transform_store.setdefault(tensor_name, {})
Expand Down
4 changes: 2 additions & 2 deletions tinynn/converter/operators/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@
"quantized::linear_relu_dynamic": QuantizedLinearReluDynamicOperator,
"quantized::elu": QuantizedEluOperator,
# non tracking
"aten::Int": NoTrackOperator,
"aten::Int": TrackConstantOperator,
"aten::zeros": NoTrackOperator,
"aten::detach": NoTrackOperator,
"aten::size": NoTrackOperator,
Expand All @@ -224,5 +224,5 @@
"aten::empty": NoTrackOperator,
"aten::new_zeros": NoTrackOperator,
"aten::new_ones": NoTrackOperator,
"aten::ScalarImplicit": NoTrackOperator,
"aten::ScalarImplicit": TrackConstantOperator,
}
82 changes: 73 additions & 9 deletions tinynn/converter/operators/torch/aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -1781,28 +1781,92 @@ def parse(self, node, attrs, args, graph_converter):
starts = np.zeros(input_tensor.tensor.ndim, dtype='int32')
starts[dim] = start

start_tensor = self.create_attr_tensor(starts)
if self.input_names[2] in graph_converter.constant_mapping:
start_t = graph_converter.constant_mapping[self.input_names[2]]
new_shape_arr = np.array((1,), dtype='int32')
new_shape_tensor = self.create_attr_tensor(new_shape_arr)
start_reshaped = self.create_transform_tensor(np.reshape(start_t.tensor, new_shape_arr))
graph_converter.add_operator(
tfl.ReshapeOperator([start_t, new_shape_tensor], [start_reshaped], new_shape_arr)
)

start_casted = self.create_transform_tensor(start_reshaped.tensor.astype('int32'))
graph_converter.add_operator(
tfl.CastOperator(
[start_reshaped],
[start_casted],
tfl.numpy_tflite_dtype_mappings[str(start_reshaped.dtype)],
tfl.numpy_tflite_dtype_mappings[str(start_casted.dtype)],
)
)

start_tensor = self.create_transform_tensor(starts)
starts_left = starts[:dim]
starts_right = starts[dim + 1 :]
starts_tensors = []
if len(starts_left) > 0:
starts_tensors.append(self.create_attr_tensor(starts_left))
starts_tensors.append(start_casted)
if len(starts_right) > 0:
starts_tensors.append(self.create_attr_tensor(starts_right))
if len(starts_tensors) > 1:
graph_converter.add_operator(tfl.ConcatenationOperator(starts_tensors, [start_tensor], 0))
else:
start_tensor = starts_tensors[0]
else:
start_tensor = self.create_attr_tensor(starts)

if step != 1:
# if True:
ends = np.array(input_tensor.tensor.shape, dtype='int32')
ends = np.array(input_tensor.tensor.shape, dtype='int32')
if step != 1 or start_tensor.buffer is None or self.input_names[3] in graph_converter.constant_mapping:
ends[dim] = end
else:
ends[dim] = end - start

if self.input_names[3] in graph_converter.constant_mapping:
end_t = graph_converter.constant_mapping[self.input_names[3]]
new_shape_arr = np.array((1,), dtype='int32')
new_shape_tensor = self.create_attr_tensor(new_shape_arr)
end_reshaped = self.create_transform_tensor(np.reshape(end_t.tensor, new_shape_arr))
graph_converter.add_operator(tfl.ReshapeOperator([end_t, new_shape_tensor], [end_reshaped], new_shape_arr))

end_casted = self.create_transform_tensor(end_reshaped.tensor.astype('int32'))
graph_converter.add_operator(
tfl.CastOperator(
[end_reshaped],
[end_casted],
tfl.numpy_tflite_dtype_mappings[str(end_reshaped.dtype)],
tfl.numpy_tflite_dtype_mappings[str(end_casted.dtype)],
)
)

end_tensor = self.create_transform_tensor(ends)
ends_left = ends[:dim]
ends_right = ends[dim + 1 :]
ends_tensors = []
if len(ends_left) > 0:
ends_tensors.append(self.create_attr_tensor(ends_left))
ends_tensors.append(end_casted)
if len(ends_right) > 0:
ends_tensors.append(self.create_attr_tensor(ends_right))
if len(ends_tensors) > 1:
graph_converter.add_operator(tfl.ConcatenationOperator(ends_tensors, [end_tensor], 0))
else:
end_tensor = ends_tensors[0]
else:
end_tensor = self.create_attr_tensor(ends)

if step != 1 or start_tensor.buffer is None or end_tensor.buffer is None:
strides = np.ones(input_tensor.tensor.ndim, dtype='int32')
strides[dim] = step

end_tensor = self.create_attr_tensor(ends)
stride_tensor = self.create_attr_tensor(strides)

inputs = [input_tensor, start_tensor, end_tensor, stride_tensor]
outputs = self.to_tfl_tensors(self.output_names, self.output_tensors)

graph_converter.add_operator(tfl.StridedSliceOperator(inputs, outputs))
else:
sizes = np.array(input_tensor.tensor.shape, dtype='int32')
sizes[dim] = end - start

size_tensor = self.create_attr_tensor(sizes)
size_tensor = end_tensor
inputs = [input_tensor, start_tensor, size_tensor]
outputs = self.to_tfl_tensors(self.output_names, self.output_tensors)

Expand Down
10 changes: 10 additions & 0 deletions tinynn/converter/operators/torch/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,6 +757,16 @@ def parse(self, node, attrs, args, graph_converter):
graph_converter.q_mapping[self.output_names[0]] = t


class TrackConstantOperator(OperatorConverter):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)

self.run(node)

t = self.find_or_create_input(0, graph_converter)
graph_converter.constant_mapping[self.output_names[0]] = t


class PrimOperatorConverter(OperatorConverter):
# prim::* ops needs custom implementation
def run(self, node):
Expand Down
Loading