Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix tanh activiation in pytorch parser #1055

Merged
merged 3 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions hls4ml/converters/pytorch/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def parse_linear_layer(operation, layer_name, input_names, input_shapes, node, c
@pytorch_handler(*activation_layers)
def parse_activation_layer(operation, layer_name, input_names, input_shapes, node, class_object, data_reader, config):
layer = {}

layer['class_name'] = operation
layer['activation'] = layer['class_name']
layer['name'] = layer_name
Expand All @@ -50,7 +49,9 @@ def parse_activation_layer(operation, layer_name, input_names, input_shapes, nod
# if layer['class_name'] != 'Activation':
# layer['activation'] = layer['class_name']
if node.op == 'call_module':
if layer['class_name'] == 'ReLU' or layer['class_name'] == 'Sigmoid':
if layer['class_name'] in ['ReLU', 'Sigmoid', 'Tanh']:
if layer['class_name'] == 'Tanh':
layer['activation'] = 'tanh'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you have to do this for tanh but not for the others?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have a layer['activation'] = layer['class_name'] a little bit above in the code, but for tanh we have to make it a lower case t. I guess that could be done just with an layer['activation'] = layer['class_name'].lower(), but I'm not sure if it's a general rule that the activation attribute is supposed to be lower case.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is safe. I see the templates will lower() it here so the only way to break if it is not lowered is for Quartus which does this for god knows what reason. I think we generally expect it to be lowercase, so it should be fine to move it up to a line 45.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it's exactly because of this stuff in Quartus why I had to add this clause in the first place. But just making the attribute lower case works, so I changed it to that.

layer['class_name'] = 'Activation'
if layer['class_name'] == 'LeakyReLU':
layer['activ_param'] = class_object.negative_slope
Expand All @@ -68,7 +69,9 @@ def parse_activation_layer(operation, layer_name, input_names, input_shapes, nod
if hasattr(node, 'dim'):
layer['axis'] = class_object.dim
else:
if layer['class_name'] == 'ReLU' or layer['class_name'] == 'Sigmoid':
if layer['class_name'] in ['ReLU', 'Sigmoid', 'Tanh']:
if layer['class_name'] == 'Tanh':
layer['activation'] = 'tanh'
layer['class_name'] = 'Activation'
if layer['class_name'] == 'LeakyReLU':
layer['activ_param'] = node.kwargs['negative_slope']
Expand Down
1 change: 1 addition & 0 deletions hls4ml/converters/pytorch_to_hls.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def decorator(function):
# map names of operations between toch.nn and torch.nn.functionals
layer_name_map = {
'relu': 'ReLU',
'tanh': 'Tanh',
'leaky_relu': 'LeakyReLU',
'elu': 'ELU',
'prelu': 'PReLU',
Expand Down
14 changes: 12 additions & 2 deletions test/pytest/test_pytorch_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def test_linear(backend, io_type):
"activation_function",
[
nn.ReLU(),
nn.Tanh(),
nn.LeakyReLU(negative_slope=1.0),
nn.ELU(alpha=1.0),
nn.PReLU(init=0.25),
Expand Down Expand Up @@ -102,7 +103,7 @@ def test_activations(activation_function, backend, io_type):

assert nNodes - 1 == len(hls_model.get_layers())

if activation_function.__class__.__name__ == 'ReLU' or activation_function.__class__.__name__ == 'Sigmoid':
if activation_function.__class__.__name__ in ['ReLU', 'Sigmoid', 'Tanh']:
assert list(hls_model.get_layers())[2].attributes['class_name'] == 'Activation'
elif activation_function.__class__.__name__ == 'Threshold':
assert list(hls_model.get_layers())[2].attributes['class_name'] == 'ThresholdedReLU'
Expand All @@ -118,6 +119,14 @@ def forward(self, x):
return nn.functional.relu(x)


class TanHModel(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return nn.functional.tanh(x)


class LeakyReLuModel(nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -154,6 +163,7 @@ def forward(self, x):
"activation_function",
[
ReLuModel(),
TanHModel(),
LeakyReLuModel(),
EluModel(),
SigmoidModel(),
Expand All @@ -172,7 +182,7 @@ def test_activation_functionals(activation_function, backend, io_type):

config = config_from_pytorch_model(model, (1,))
fn_name = activation_function.__class__.__name__
output_dir = str(test_root_path / f'hls4mlprj_pytorch_api_activations_functional_relu_{backend}_{io_type}_{fn_name}')
output_dir = str(test_root_path / f'hls4mlprj_pytorch_api_activations_functional_{fn_name}_{backend}_{io_type}')
hls_model = convert_from_pytorch_model(model, hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type)
hls_model.compile()

Expand Down
Loading