diff --git a/hls4ml/converters/pytorch_to_hls.py b/hls4ml/converters/pytorch_to_hls.py index afa87365fb..7874b98d09 100644 --- a/hls4ml/converters/pytorch_to_hls.py +++ b/hls4ml/converters/pytorch_to_hls.py @@ -95,6 +95,7 @@ def decorator(function): 'avg_pool1d': 'AvgPool1d', 'avg_pool2d': 'AvgPool2d', 'flatten': 'Flatten', + 'view': 'View', } diff --git a/test/pytest/test_pytorch_api.py b/test/pytest/test_pytorch_api.py index 9d67c2867d..ae87caa54c 100644 --- a/test/pytest/test_pytorch_api.py +++ b/test/pytest/test_pytorch_api.py @@ -810,3 +810,65 @@ def forward(self, x): hls_prediction = hls_model.predict(hls_input).flatten() np.testing.assert_allclose(hls_prediction, pytorch_prediction, rtol=0, atol=5e-2) + + +@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus']) +@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream']) +def test_view(backend, io_type): + + class TestModel(nn.Module): + def __init__(self, n_in, n_out, size_in): + super().__init__() + self.view_mult = n_out * size_in + + self.conv1 = nn.Conv1d( + n_in, + n_out, + kernel_size=3, + padding=1, + bias=False, + ) + + def forward(self, x): + z = self.conv1(x) + z = z.view(-1, self.view_mult) + return z + + n_in = 2 + n_out = 4 + size_in = 128 + n_batch = 100 + + model = TestModel(n_in, n_out, size_in) + model = model.to(memory_format=torch.channels_last) + model.eval() + + X_input = np.random.rand(n_batch, n_in, size_in) + pytorch_prediction = model(torch.Tensor(X_input)).detach().numpy() + + # X_input is channels last + X_input = np.ascontiguousarray(X_input.transpose(0, 2, 1)) + config = config_from_pytorch_model(model, inputs_channel_last=True, transpose_outputs=False) + + output_dir = str(test_root_path / f'hls4mlprj_pytorch_view_{backend}_{io_type}') + hls_model = convert_from_pytorch_model( + model, + (None, n_in, size_in), + hls_config=config, + output_dir=output_dir, + backend=backend, + io_type=io_type, + ) + + hls_model.compile() + + # reshape hls prediction to channels last, then transpose, then reshape + # to match .view + hls_prediction = np.reshape( + np.transpose(np.reshape(hls_model.predict(X_input), (n_batch, size_in, n_out)), (0, 2, 1)), + (n_batch, size_in * n_out), + ) + + rtol = 0 + atol = 5.0e-2 + np.testing.assert_allclose(hls_prediction, pytorch_prediction, rtol=rtol, atol=atol)