Skip to content

Commit

Permalink
Fix tanh activiation in pytorch parser (#1055)
Browse files Browse the repository at this point in the history
* fix tanh activiation in pytorch parser

* simplify fix but making the activation attribute lower case
  • Loading branch information
JanFSchulte authored Sep 11, 2024
1 parent c8c95a7 commit d63033b
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 9 deletions.
8 changes: 3 additions & 5 deletions hls4ml/converters/pytorch/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,12 @@ def parse_activation_layer(operation, layer_name, input_names, input_shapes, nod
layer = {}

layer['class_name'] = operation
layer['activation'] = layer['class_name']
layer['activation'] = layer['class_name'].lower()
layer['name'] = layer_name
layer['inputs'] = input_names

# if layer['class_name'] != 'Activation':
# layer['activation'] = layer['class_name']
if node.op == 'call_module':
if layer['class_name'] == 'ReLU' or layer['class_name'] == 'Sigmoid':
if layer['class_name'] in ['ReLU', 'Sigmoid', 'Tanh']:
layer['class_name'] = 'Activation'
if layer['class_name'] == 'LeakyReLU':
layer['activ_param'] = class_object.negative_slope
Expand All @@ -68,7 +66,7 @@ def parse_activation_layer(operation, layer_name, input_names, input_shapes, nod
if hasattr(node, 'dim'):
layer['axis'] = class_object.dim
else:
if layer['class_name'] == 'ReLU' or layer['class_name'] == 'Sigmoid':
if layer['class_name'] in ['ReLU', 'Sigmoid', 'Tanh']:
layer['class_name'] = 'Activation'
if layer['class_name'] == 'LeakyReLU':
layer['activ_param'] = node.kwargs['negative_slope']
Expand Down
1 change: 1 addition & 0 deletions hls4ml/converters/pytorch_to_hls.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def decorator(function):
# map names of operations between toch.nn and torch.nn.functionals
layer_name_map = {
'relu': 'ReLU',
'tanh': 'Tanh',
'leaky_relu': 'LeakyReLU',
'elu': 'ELU',
'prelu': 'PReLU',
Expand Down
20 changes: 16 additions & 4 deletions test/pytest/test_pytorch_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def test_linear(backend, io_type):
"activation_function",
[
nn.ReLU(),
nn.Tanh(),
nn.LeakyReLU(negative_slope=1.0),
nn.ELU(alpha=1.0),
nn.PReLU(init=0.25),
Expand Down Expand Up @@ -102,7 +103,7 @@ def test_activations(activation_function, backend, io_type):

assert nNodes - 1 == len(hls_model.get_layers())

if activation_function.__class__.__name__ == 'ReLU' or activation_function.__class__.__name__ == 'Sigmoid':
if activation_function.__class__.__name__ in ['ReLU', 'Sigmoid', 'Tanh']:
assert list(hls_model.get_layers())[2].attributes['class_name'] == 'Activation'
elif activation_function.__class__.__name__ == 'Threshold':
assert list(hls_model.get_layers())[2].attributes['class_name'] == 'ThresholdedReLU'
Expand All @@ -118,6 +119,14 @@ def forward(self, x):
return nn.functional.relu(x)


class TanHModel(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return nn.functional.tanh(x)


class LeakyReLuModel(nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -154,6 +163,7 @@ def forward(self, x):
"activation_function",
[
ReLuModel(),
TanHModel(),
LeakyReLuModel(),
EluModel(),
SigmoidModel(),
Expand All @@ -172,7 +182,7 @@ def test_activation_functionals(activation_function, backend, io_type):

config = config_from_pytorch_model(model, (1,))
fn_name = activation_function.__class__.__name__
output_dir = str(test_root_path / f'hls4mlprj_pytorch_api_activations_functional_relu_{backend}_{io_type}_{fn_name}')
output_dir = str(test_root_path / f'hls4mlprj_pytorch_api_activations_functional_{fn_name}_{backend}_{io_type}')
hls_model = convert_from_pytorch_model(model, hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type)
hls_model.compile()

Expand Down Expand Up @@ -268,7 +278,7 @@ def test_conv1d(padds, backend, io_type):
act_index = 2
assert list(hls_model.get_layers())[conv_index].attributes['name'] == convNode.name
assert list(hls_model.get_layers())[conv_index].attributes['class_name'] == 'Conv1D'
assert list(hls_model.get_layers())[act_index].attributes['activation'] == class_object_relu.__class__.__name__
assert list(hls_model.get_layers())[act_index].attributes['activation'] == class_object_relu.__class__.__name__.lower()
if io_type == "io_stream" and (backend == "Vivado" or backend == "Vitis") and padds == 1:
assert list(hls_model.get_layers())[conv_index].attributes["in_width"] == size_in + 2
else:
Expand Down Expand Up @@ -412,7 +422,9 @@ def test_conv2d(padds, backend, io_type):
act_index = 2
assert list(hls_model.get_layers())[conv_index].attributes['name'] == convNode.name
assert list(hls_model.get_layers())[conv_index].attributes['class_name'] == 'Conv2D'
assert list(hls_model.get_layers())[act_index].attributes['activation'] == class_object_relu.__class__.__name__
assert (
list(hls_model.get_layers())[act_index].attributes['activation'] == class_object_relu.__class__.__name__.lower()
)
assert list(hls_model.get_layers())[conv_index].attributes["in_width"] == size_in_width
assert list(hls_model.get_layers())[conv_index].attributes["in_height"] == size_in_height
assert list(hls_model.get_layers())[conv_index].attributes['filt_width'] == class_object_conv.kernel_size[1]
Expand Down

0 comments on commit d63033b

Please sign in to comment.