Skip to content

Commit

Permalink
Merge branch 'main' into oneapi_backend
Browse files Browse the repository at this point in the history
  • Loading branch information
jmitrevs authored Sep 13, 2024
2 parents a4f4bd9 + 5241109 commit f1c0301
Showing 1 changed file with 4 additions and 8 deletions.
12 changes: 4 additions & 8 deletions test/pytest/test_keras_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,6 @@ def test_conv1d(padds, backend, io_type):
assert list(hls_model.get_layers())[1].attributes['n_chan'] == model.layers[0].input_shape[2]
assert list(hls_model.get_layers())[1].attributes['n_filt'] == model.layers[0].filters
assert list(hls_model.get_layers())[1].attributes['stride_width'] == model.layers[0].strides[0]
assert list(hls_model.get_layers())[1].attributes['padding'] == model.layers[0].padding
assert list(hls_model.get_layers())[1].attributes['data_format'] == model.layers[0].data_format
assert list(hls_model.get_layers())[1].attributes["out_width"] == list(model.layers[0].output_shape)[1]

Expand Down Expand Up @@ -235,7 +234,6 @@ def test_conv2d(chans, padds, backend, io_type):
assert list(hls_model.get_layers())[1].attributes['n_filt'] == model.layers[0].filters
assert list(hls_model.get_layers())[1].attributes['stride_width'] == model.layers[0].strides[1]
assert list(hls_model.get_layers())[1].attributes['stride_height'] == model.layers[0].strides[0]
assert list(hls_model.get_layers())[1].attributes['padding'] == model.layers[0].padding
assert list(hls_model.get_layers())[1].attributes['data_format'] == model.layers[0].data_format

if model.layers[0].data_format == 'channels_first':
Expand Down Expand Up @@ -392,7 +390,6 @@ def test_pooling(pooling, padds, chans, backend):
assert hls_pool.attributes['stride_width'] == ker_pool.strides[1]
assert hls_pool.attributes['pool_height'] == ker_pool.pool_size[1]
assert hls_pool.attributes['pool_width'] == ker_pool.pool_size[0]
assert hls_pool.attributes['padding'] == ker_pool.padding

if hls_pool.attributes['data_format'] == 'channels_last':
assert hls_pool.attributes['in_height'] == ker_pool.input_shape[1]
Expand All @@ -403,7 +400,7 @@ def test_pooling(pooling, padds, chans, backend):
assert hls_pool.attributes['in_width'] == ker_pool.input_shape[3]
assert hls_pool.attributes['n_filt'] == ker_pool.input_shape[1]

if hls_pool.attributes['padding'] == 'same':
if ker_pool.padding == 'same':
# Height
in_height = ker_pool.input_shape[1]
if ker_pool.data_format == 'channels_first':
Expand Down Expand Up @@ -434,7 +431,7 @@ def test_pooling(pooling, padds, chans, backend):
assert pad_left == hls_pool.attributes['pad_left']
assert pad_right == hls_pool.attributes['pad_right']

elif hls_pool.attributes['padding'] == 'valid':
elif ker_pool.padding == 'valid':
if hls_pool.attributes['data_format'] == 'channels_first':
in_height = ker_pool.input_shape[2]
in_width = ker_pool.input_shape[3]
Expand All @@ -459,12 +456,11 @@ def test_pooling(pooling, padds, chans, backend):
assert hls_pool.attributes['n_filt'] == ker_pool.input_shape[2]
assert hls_pool.attributes['pool_width'] == ker_pool.pool_size[0]
assert hls_pool.attributes['stride_width'] == ker_pool.strides[0]
assert hls_pool.attributes['padding'] == ker_pool.padding

out_same = math.ceil(float(ker_pool.input_shape[1]) / float(ker_pool.strides[0]))
out_valid = math.ceil(float(ker_pool.input_shape[1] - ker_pool.pool_size[0] + 1) / ker_pool.strides[0])

if hls_pool.attributes['padding'] == 'same':
if ker_pool.padding == 'same':
assert hls_pool.attributes['n_out'] == out_same
if ker_pool.input_shape[1] % ker_pool.strides[0] == 0:
pad_along_width = max(ker_pool.pool_size[0] - ker_pool.strides[0], 0)
Expand All @@ -473,7 +469,7 @@ def test_pooling(pooling, padds, chans, backend):
assert hls_pool.attributes['pad_left'] == pad_along_width // 2
assert hls_pool.attributes['pad_right'] == pad_along_width - pad_along_width // 2

elif hls_pool.attributes['padding'] == 'valid':
elif ker_pool.padding == 'valid':
assert hls_pool.attributes['n_out'] == out_valid
assert hls_pool.attributes['pad_left'] == 0
assert hls_pool.attributes['pad_right'] == 0

0 comments on commit f1c0301

Please sign in to comment.