Skip to content

Commit

Permalink
remove pointwise conv implementation option; make it default
Browse files Browse the repository at this point in the history
  • Loading branch information
jmduarte committed Nov 1, 2024
1 parent 1a93246 commit d37a843
Show file tree
Hide file tree
Showing 7 changed files with 27 additions and 89 deletions.
46 changes: 2 additions & 44 deletions hls4ml/backends/vivado/passes/pointwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
Conv1DFunctionTemplate,
Conv2DConfigTemplate,
Conv2DFunctionTemplate,
conv2d_config_template,
conv_mult_config_template,
)
from hls4ml.model.layers import register_layer
Expand Down Expand Up @@ -42,49 +43,6 @@
}};
const ap_uint<config{index}::filt_width> config{index}::pixels[] = {{{instructions}}};\n"""

pointwise_conv2d_config_template = """struct config{index} : nnet::conv2d_config {{
static const unsigned pad_top = {pad_top};
static const unsigned pad_bottom = {pad_bottom};
static const unsigned pad_left = {pad_left};
static const unsigned pad_right = {pad_right};
static const unsigned in_height = {in_height};
static const unsigned in_width = {in_width};
static const unsigned n_chan = {n_chan};
static const unsigned filt_height = {filt_height};
static const unsigned filt_width = {filt_width};
static const unsigned kernel_size = filt_height * filt_width;
static const unsigned n_filt = {n_filt};
static const unsigned stride_height = {stride_height};
static const unsigned stride_width = {stride_width};
static const unsigned out_height = {out_height};
static const unsigned out_width = {out_width};
static const unsigned reuse_factor = {reuse};
static const unsigned n_zeros = {nzeros};
static const unsigned multiplier_limit =
DIV_ROUNDUP(kernel_size * n_chan * n_filt, reuse_factor) - n_zeros / reuse_factor;
static const bool store_weights_in_bram = false;
static const unsigned strategy = nnet::{strategy};
static const nnet::conv_implementation implementation = nnet::conv_implementation::{implementation};
static const unsigned min_height = {min_height};
static const unsigned min_width = {min_width};
static const ap_uint<filt_height * filt_width> pixels[min_height * min_width];
static const unsigned n_partitions = {n_partitions};
static const unsigned n_pixels = out_height * out_width / n_partitions;
template<class data_T, class CONFIG_T>
using fill_buffer = nnet::{fill_fn}<data_T, CONFIG_T>;
typedef {accum_t.name} accum_t;
typedef {bias_t.name} bias_t;
typedef {weight_t.name} weight_t;
typedef {config_t} mult_config;
template<unsigned K, unsigned S, unsigned W>
using scale_index_height = nnet::{scale_index_height_type}<K, S, W>;
template<unsigned K, unsigned S, unsigned W>
using scale_index_width = nnet::{scale_index_width_type}<K, S, W>;
template<class data_T, class res_T, class CONFIG_T>
using pointwise_conv = nnet::{pointwise_fn}<data_T, res_T, CONFIG_T>;
}};
const ap_uint<config{index}::filt_height * config{index}::filt_width> config{index}::pixels[] = {{{instructions}}};\n"""

pointwise_conv1d_function_template = (
'nnet::pointwise_conv_1d_{data_format}<{input_t}, {output_t}, {config}>({input}, {output}, {w}, {b});'
)
Expand Down Expand Up @@ -112,7 +70,7 @@ def __init__(self):
class PointwiseConv2DConfigTemplate(Conv2DConfigTemplate):
def __init__(self):
super(Conv2DConfigTemplate, self).__init__(PointwiseConv2D)
self.template = pointwise_conv2d_config_template
self.template = conv2d_config_template
self.mult_template = conv_mult_config_template


Expand Down
4 changes: 1 addition & 3 deletions hls4ml/backends/vivado/vivado_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,7 @@ def _register_layer_attributes(self):
cnn_layers = [Conv1D, Conv2D, SeparableConv1D, SeparableConv2D, DepthwiseConv2D, Pooling1D, Pooling2D]
for layer in cnn_layers:
attrs = self.attribute_map.get(layer, [])
attrs.append(
ChoiceAttribute('conv_implementation', choices=['LineBuffer', 'Encoded', 'Pointwise'], default='LineBuffer')
)
attrs.append(ChoiceAttribute('conv_implementation', choices=['LineBuffer', 'Encoded'], default='LineBuffer'))
self.attribute_map[layer] = attrs

def _register_flows(self):
Expand Down
10 changes: 0 additions & 10 deletions hls4ml/templates/vivado/nnet_utils/nnet_code_gen.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,6 @@ template <class data_T, class res_T, typename CONFIG_T> class PointwiseConv1D {
}
};

template <class data_T, class res_T, typename CONFIG_T> class PointwiseConv2D {
public:
static void pointwise_conv(data_T data[CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_T::n_chan],
res_T res[CONFIG_T::out_height * CONFIG_T::out_width * CONFIG_T::n_filt],
typename CONFIG_T::weight_t weights[CONFIG_T::n_chan * CONFIG_T::n_filt],
typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) {
// To be implemented in subclasses
}
};

// hls4ml insert code

} // namespace nnet
Expand Down
2 changes: 0 additions & 2 deletions hls4ml/templates/vivado/nnet_utils/nnet_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
#define NNET_COMMON_H_

#include "ap_fixed.h"
#include "nnet_helpers.h"

// This is a substitute for "ceil(n/(float)d)".
#define DIV_ROUNDUP(n, d) ((n + d - 1) / d)
Expand All @@ -25,7 +24,6 @@ namespace nnet {
// Common type definitions
enum io_type { io_parallel = 0, io_stream };
enum strategy { latency, resource, resource_unrolled };
enum class conv_implementation { linebuffer = 0, encoded = 1, pointwise = 2 };

/* ---
* Balanced tree reduce implementation.
Expand Down
9 changes: 2 additions & 7 deletions hls4ml/templates/vivado/nnet_utils/nnet_conv1d.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,8 @@ void pointwise_conv_1d_cl(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan],
#pragma HLS INLINE region

if (CONFIG_T::strategy == nnet::latency) {
if (CONFIG_T::implementation == conv_implementation::pointwise) {
// Use pointwise unrolled implementation
CONFIG_T::template pointwise_conv<data_T, res_T, CONFIG_T>::pointwise_conv(data, res, weights, biases);
} else {
// Use standard unrolled implementation
conv_1d_latency_cl<data_T, res_T, CONFIG_T>(data, res, weights, biases);
}
// Use pointwise unrolled implementation
CONFIG_T::template pointwise_conv<data_T, res_T, CONFIG_T>::pointwise_conv(data, res, weights, biases);
} else {
conv_1d_resource_cl<data_T, res_T, CONFIG_T>(data, res, weights, biases);
}
Expand Down
4 changes: 3 additions & 1 deletion hls4ml/templates/vivado/nnet_utils/nnet_conv_stream.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

namespace nnet {

enum class conv_implementation { linebuffer = 0, encoded = 1 };

// *************************************************
// Encoded Implementation (Vlad's)
// *************************************************
Expand Down Expand Up @@ -56,7 +58,7 @@ template <unsigned K, unsigned S, unsigned W> unsigned scale_index_K_lt_S(const
template <unsigned K, unsigned S, unsigned W> class scale_index_regular {
public:
static unsigned scale_index(const unsigned idx) {
#pragma HLS INLINE
#pragma HLS INLINE

if (K >= S) {
return scale_index_K_gte_S<K, S, W>(idx);
Expand Down
41 changes: 19 additions & 22 deletions test/pytest/test_pointwiseconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,29 +19,27 @@
@pytest.mark.parametrize('padds', padds_options)
@pytest.mark.parametrize('strides', strides1d_options)
@pytest.mark.parametrize(
'backend, io_type, strategy, conv_impl, rf',
'backend, io_type, strategy, rf',
[
('Quartus', 'io_parallel', 'resource', 'LineBuffer', 1),
('Quartus', 'io_stream', 'resource', 'LineBuffer', 1),
('oneAPI', 'io_parallel', 'resource', 'LineBuffer', 1),
('oneAPI', 'io_stream', 'resource', 'LineBuffer', 1),
('Vivado', 'io_parallel', 'resource', 'LineBuffer', 1),
('Vitis', 'io_parallel', 'resource', 'LineBuffer', 1),
('Vivado', 'io_parallel', 'latency', 'LineBuffer', 1),
('Vitis', 'io_parallel', 'latency', 'LineBuffer', 1),
('Vivado', 'io_parallel', 'latency', 'Pointwise', 1),
('Vivado', 'io_parallel', 'latency', 'Pointwise', 14),
('Vitis', 'io_parallel', 'latency', 'Pointwise', 1),
('Vitis', 'io_parallel', 'latency', 'Pointwise', 14),
('Vivado', 'io_stream', 'latency', 'LineBuffer', 1),
('Vivado', 'io_stream', 'resource', 'LineBuffer', 1),
('Vitis', 'io_stream', 'latency', 'LineBuffer', 1),
('Vitis', 'io_stream', 'resource', 'LineBuffer', 1),
('Catapult', 'io_stream', 'latency', 'LineBuffer', 1),
('Catapult', 'io_stream', 'resource', 'LineBuffer', 1),
('Quartus', 'io_parallel', 'resource', 1),
('Quartus', 'io_stream', 'resource', 1),
('oneAPI', 'io_parallel', 'resource', 1),
('oneAPI', 'io_stream', 'resource', 1),
('Vivado', 'io_parallel', 'resource', 1),
('Vitis', 'io_parallel', 'resource', 1),
('Vivado', 'io_parallel', 'latency', 1),
('Vitis', 'io_parallel', 'latency', 1),
('Vivado', 'io_parallel', 'latency', 14),
('Vitis', 'io_parallel', 'latency', 14),
('Vivado', 'io_stream', 'latency', 1),
('Vivado', 'io_stream', 'resource', 1),
('Vitis', 'io_stream', 'latency', 1),
('Vitis', 'io_stream', 'resource', 1),
('Catapult', 'io_stream', 'latency', 1),
('Catapult', 'io_stream', 'resource', 1),
],
)
def test_pointwiseconv1d(chans, padds, strides, backend, io_type, strategy, conv_impl, rf):
def test_pointwiseconv1d(chans, padds, strides, backend, io_type, strategy, rf):
model = tf.keras.models.Sequential()
input_shape = (28, 3)
model.add(
Expand All @@ -65,12 +63,11 @@ def test_pointwiseconv1d(chans, padds, strides, backend, io_type, strategy, conv
default_precision = 'fixed<32,16>'
config = hls4ml.utils.config_from_keras_model(model, default_precision=default_precision, granularity='name')
config['Model']['Strategy'] = strategy
config['LayerName']['pointwise1d']['ConvImplementation'] = conv_impl
config['LayerName']['pointwise1d']['ReuseFactor'] = rf

output_dir = str(
test_root_path
/ f'hls4mlprj_pointwise1d_{chans}_{strides[0]}_{padds}_{backend}_{io_type}_{strategy}_{conv_impl}_rf{rf}'
/ f'hls4mlprj_pointwise1d_{chans}_{strides[0]}_{padds}_{backend}_{io_type}_{strategy}_rf{rf}'
)
hls_model = hls4ml.converters.convert_from_keras_model(
model, hls_config=config, output_dir=output_dir, io_type=io_type, backend=backend
Expand Down

0 comments on commit d37a843

Please sign in to comment.