Skip to content

Commit

Permalink
Add View to layer name map for pytorch parser (#1039)
Browse files Browse the repository at this point in the history
* add view to layer name map in pytoch converter

* trigger pre-commit

* add test for view in pytorch

* Use unique output directory for pytorch 'view' tests

---------

Co-authored-by: Vladimir Loncar <[email protected]>
  • Loading branch information
JanFSchulte and vloncar authored Jul 22, 2024
1 parent 5c0c4e6 commit 7982c87
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 0 deletions.
1 change: 1 addition & 0 deletions hls4ml/converters/pytorch_to_hls.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def decorator(function):
'avg_pool1d': 'AvgPool1d',
'avg_pool2d': 'AvgPool2d',
'flatten': 'Flatten',
'view': 'View',
}


Expand Down
62 changes: 62 additions & 0 deletions test/pytest/test_pytorch_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,3 +810,65 @@ def forward(self, x):
hls_prediction = hls_model.predict(hls_input).flatten()

np.testing.assert_allclose(hls_prediction, pytorch_prediction, rtol=0, atol=5e-2)


@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus'])
@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream'])
def test_view(backend, io_type):

class TestModel(nn.Module):
def __init__(self, n_in, n_out, size_in):
super().__init__()
self.view_mult = n_out * size_in

self.conv1 = nn.Conv1d(
n_in,
n_out,
kernel_size=3,
padding=1,
bias=False,
)

def forward(self, x):
z = self.conv1(x)
z = z.view(-1, self.view_mult)
return z

n_in = 2
n_out = 4
size_in = 128
n_batch = 100

model = TestModel(n_in, n_out, size_in)
model = model.to(memory_format=torch.channels_last)
model.eval()

X_input = np.random.rand(n_batch, n_in, size_in)
pytorch_prediction = model(torch.Tensor(X_input)).detach().numpy()

# X_input is channels last
X_input = np.ascontiguousarray(X_input.transpose(0, 2, 1))
config = config_from_pytorch_model(model, inputs_channel_last=True, transpose_outputs=False)

output_dir = str(test_root_path / f'hls4mlprj_pytorch_view_{backend}_{io_type}')
hls_model = convert_from_pytorch_model(
model,
(None, n_in, size_in),
hls_config=config,
output_dir=output_dir,
backend=backend,
io_type=io_type,
)

hls_model.compile()

# reshape hls prediction to channels last, then transpose, then reshape
# to match .view
hls_prediction = np.reshape(
np.transpose(np.reshape(hls_model.predict(X_input), (n_batch, size_in, n_out)), (0, 2, 1)),
(n_batch, size_in * n_out),
)

rtol = 0
atol = 5.0e-2
np.testing.assert_allclose(hls_prediction, pytorch_prediction, rtol=rtol, atol=atol)

0 comments on commit 7982c87

Please sign in to comment.