Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix softmax parsing in pytorch and add test #1086

Merged
merged 5 commits into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions hls4ml/converters/pytorch/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,13 @@ def parse_activation_layer(operation, layer_name, input_names, input_shapes, nod
layer['activation'] = 'ThresholdedReLU'
if layer['activ_param'] < 0:
raise Exception('negative threshold values not supported')

if hasattr(node, 'dim'):
if hasattr(class_object, 'dim'):
layer['axis'] = class_object.dim
if layer['class_name'] == 'Softmax' and layer['axis'] is None:
layer['axis'] = -1
if 'IOType' in config:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have done these types of checks before, but now I think we should parse everything into the IR, and then backbends should check if the feature is supported. We can have optimizers for that. Check out feature_check.py in Vitis backend. We could add similar checks for all backends and move away from the parser (in a future PR)

if layer['class_name'] == 'Softmax' and config['IOType'] == 'io_stream' and layer['axis'] != -1:
raise Exception('dim needs to be -1 for io_stream')
else:
if layer['class_name'] in ['ReLU', 'Sigmoid', 'Tanh']:
layer['class_name'] = 'Activation'
Expand All @@ -80,6 +84,11 @@ def parse_activation_layer(operation, layer_name, input_names, input_shapes, nod
layer['activation'] = 'ThresholdedReLU'
if 'dim' in node.kwargs:
layer['axis'] = node.kwargs['dim']
if layer['class_name'] == 'Softmax' and layer['axis'] is None:
layer['axis'] = -1
if 'IOType' in config:
if layer['class_name'] == 'Softmax' and config['IOType'] == 'io_stream' and layer['axis'] != -1:
raise Exception('dim needs to be -1 for io_stream')

output_shape = input_shapes[0]
return layer, output_shape
Expand Down
6 changes: 5 additions & 1 deletion hls4ml/model/optimizer/passes/convert_to_channels_last.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,11 @@ def transform(self, model, node):
node.add_output_variable(shape, dims)

# Have to transpose back before flattening to get correct order of elements in the flattened tensor
if isinstance(node, Reshape) and len(node.attributes['target_shape']) == 1:
if (
isinstance(node, Reshape)
and len(node.attributes['target_shape']) == 1
and not model.config.config['HLSConfig']['Model']['ChannelsLastConversion'] == "internal"
):
previous_node = node.get_input_node(node.inputs[0])
input = previous_node.name
outshape = previous_node.get_output_variable().shape
Expand Down
10 changes: 10 additions & 0 deletions test/pytest/test_pytorch_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def test_linear(backend, io_type):
@pytest.mark.parametrize(
"activation_function",
[
nn.Softmax(dim=-1),
nn.ReLU(),
nn.Tanh(),
nn.LeakyReLU(negative_slope=1.0),
Expand Down Expand Up @@ -119,6 +120,14 @@ def forward(self, x):
return nn.functional.relu(x)


class SoftmaxModel(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return nn.functional.softmax(x, dim=-1)


class TanHModel(nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -162,6 +171,7 @@ def forward(self, x):
@pytest.mark.parametrize(
"activation_function",
[
SoftmaxModel(),
ReLuModel(),
TanHModel(),
LeakyReLuModel(),
Expand Down
Loading