Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[converter] fix shared prelu #271

Merged
merged 1 commit into from
Dec 5, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions tests/converter_op_test.py
Original file line number Diff line number Diff line change
@@ -785,6 +785,28 @@ def model(x):
with self.assertRaisesRegex(AssertionError, r'.* (are not close!|exceeded the margin of error).*'):
assert_close(dummy_output, tfl_output)

def test_same_prelu_for_different_channels(self):
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.prelu = nn.PReLU()

def forward(self, x):
return self.prelu(x), self.prelu(x[:, 0:1])

model = Model()
model.eval()

dummy_input = torch.rand(1, 3, 224, 224)

model_path = get_model_path()
converter = TFLiteConverter(model, dummy_input, model_path, nchw_transpose=False)
converter.convert()

dummy_output = model(dummy_input)
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output)

def test_prelu(self):
class Model(nn.Module):
def __init__(self) -> None:
10 changes: 10 additions & 0 deletions tinynn/converter/operators/graph.py
Original file line number Diff line number Diff line change
@@ -36,6 +36,16 @@ def __init__(self) -> None:
self.output_transpose = None
self.node_op_counter = 0
self.q_mapping = {}
self.transform_store = {}

def add_transform_store(self, tensor_name: str, transform_name: str, new_tensor_name: str):
self.transform_store.setdefault(tensor_name, {})
self.transform_store[tensor_name][transform_name] = new_tensor_name

def get_transform_store(self, tensor_name: str, transform_name: str) -> typing.Optional[tfl.Tensor]:
if tensor_name not in self.transform_store:
return None
return self.transform_store[tensor_name].get(transform_name, None)

def add_iterable_pair(
self, input_names: typing.List[str], output_names: typing.List[str], key: typing.Optional[str] = None
15 changes: 11 additions & 4 deletions tinynn/converter/operators/torch/aten.py
Original file line number Diff line number Diff line change
@@ -1562,7 +1562,7 @@ def parse(self, node, attrs, args, graph_converter):
alpha_tensor = self.find_or_create_input(1, graph_converter)
shape_tensor = self.create_attr_tensor(np.array(new_shape, dtype='int32'))

update_name = True
update_name = None
if weight_c == input_c:
new_alpha = self.create_transform_tensor(np.reshape(alpha_tensor.tensor, new_shape))
graph_converter.add_operator(tfl.ReshapeOperator([alpha_tensor, shape_tensor], [new_alpha], new_shape))
@@ -1571,12 +1571,19 @@ def parse(self, node, attrs, args, graph_converter):
if alpha_tensor.buffer is None:
graph_converter.add_operator(tfl.TileOperator([alpha_tensor, shape_tensor], [new_alpha]))
else:
update_name = False
new_alpha = new_alpha.tensor
store = graph_converter.get_transform_store(alpha_tensor.name, str(input_c))
if store is None:
graph_converter.add_transform_store(alpha_tensor.name, str(input_c), new_alpha.name)
update_name = new_alpha.name
new_alpha = new_alpha.tensor
else:
update_name = store

self.input_tensors[1] = new_alpha
if update_name:
if update_name is None:
self.input_names[1] = new_alpha.name
else:
self.input_names[1] = update_name

self.elementwise_binary(tfl.PreluOperator, graph_converter, False)