From d804b2de8501dff8fe7b448fc4a17f932d25b619 Mon Sep 17 00:00:00 2001 From: Jovan Mitrevski Date: Sat, 28 Sep 2024 18:33:48 -0500 Subject: [PATCH 1/5] update parametrized activations for Xilinx --- example-models | 2 +- .../backends/vivado/passes/core_templates.py | 29 +++++++++++++++++-- hls4ml/converters/keras/core.py | 2 +- hls4ml/converters/pytorch/core.py | 2 +- hls4ml/model/layers.py | 20 ++++++++++++- .../model/optimizer/passes/infer_precision.py | 14 +++++++++ .../vivado/nnet_utils/nnet_activation.h | 21 +++++++------- .../nnet_utils/nnet_activation_stream.h | 21 +++++++------- test/pytest/test_activations.py | 2 +- 9 files changed, 85 insertions(+), 28 deletions(-) diff --git a/example-models b/example-models index ff74f73dbc..3cfbcfd062 160000 --- a/example-models +++ b/example-models @@ -1 +1 @@ -Subproject commit ff74f73dbc253d1aa7de1603ee10ede551919548 +Subproject commit 3cfbcfd062f60492507d21ff0e91559b3bdd6550 diff --git a/hls4ml/backends/vivado/passes/core_templates.py b/hls4ml/backends/vivado/passes/core_templates.py index 268293dd1e..b20a89f9ad 100644 --- a/hls4ml/backends/vivado/passes/core_templates.py +++ b/hls4ml/backends/vivado/passes/core_templates.py @@ -116,6 +116,15 @@ def format(self, node): typedef {table_t.name} table_t; }};\n""" +param_activ_config_template = """struct {type}_config{index} : nnet::activ_config {{ + static const unsigned n_in = {n_in}; + static const unsigned table_size = {table_size}; + static const unsigned io_type = nnet::{iotype}; + static const unsigned reuse_factor = {reuse}; + typedef {table_t.name} table_t; + typedef {param_t.name} param_t; +}};\n""" + hard_activ_config_template = """struct {type}_config{index} {{ static const unsigned n_in = {n_in}; static const {slope_t.name} slope; @@ -138,14 +147,16 @@ def format(self, node): }};\n""" activ_function_template = 'nnet::{activation}<{input_t}, {output_t}, {config}>({input}, {output});' -param_activ_function_template = 'nnet::{activation}<{input_t}, {output_t}, {config}>({input}, {param}, {output});' +param_activ_function_template = ( + 'nnet::{activation}<{input_t}, {param_t.name}, {output_t}, {config}>({input}, {param}, {output});' +) activ_include_list = ['nnet_utils/nnet_activation.h', 'nnet_utils/nnet_activation_stream.h'] class ActivationConfigTemplate(LayerConfigTemplate): def __init__(self): - super().__init__((Activation, ParametrizedActivation, PReLU, UnaryLUT)) + super().__init__((Activation, UnaryLUT)) self.template = activ_config_template def format(self, node): @@ -155,6 +166,18 @@ def format(self, node): return self.template.format(**params) +class ParamActivationConfigTemplate(LayerConfigTemplate): + def __init__(self): + super().__init__((ParametrizedActivation, PReLU)) + self.template = param_activ_config_template + + def format(self, node): + params = self._default_config_params(node) + params['type'] = node.get_attr('activation') + + return self.template.format(**params) + + class HardActivationConfigTemplate(LayerConfigTemplate): def __init__(self): super().__init__(HardActivation) @@ -208,7 +231,7 @@ def __init__(self): def format(self, node): params = self._default_function_params(node) params['activation'] = node.get_attr('activation').lower() - params['param'] = node.get_weights('alpha').name + params['param'] = node.get_weights('param').name params['config'] = '{}_config{}'.format(node.get_attr('activation'), node.index) return self.template.format(**params) diff --git a/hls4ml/converters/keras/core.py b/hls4ml/converters/keras/core.py index ca7d0b3541..aff15808ad 100644 --- a/hls4ml/converters/keras/core.py +++ b/hls4ml/converters/keras/core.py @@ -71,7 +71,7 @@ def parse_activation_layer(keras_layer, input_names, input_shapes, data_reader): elif layer['class_name'] == 'ReLU': layer['class_name'] = 'Activation' elif layer['class_name'] == 'PReLU': - layer['alpha_data'] = get_weights_data(data_reader, layer['name'], 'alpha') + layer['param_data'] = get_weights_data(data_reader, layer['name'], 'alpha') if layer['class_name'] == 'Activation' and layer['activation'] == 'softmax': layer['class_name'] = 'Softmax' diff --git a/hls4ml/converters/pytorch/core.py b/hls4ml/converters/pytorch/core.py index d3ba470bf5..c56857715a 100644 --- a/hls4ml/converters/pytorch/core.py +++ b/hls4ml/converters/pytorch/core.py @@ -55,7 +55,7 @@ def parse_activation_layer(operation, layer_name, input_names, input_shapes, nod if layer['class_name'] == 'ELU': layer['activ_param'] = class_object.alpha if layer['class_name'] == 'PReLU': - layer['alpha_data'] = class_object.weight.data.numpy() + layer['param_data'] = class_object.weight.data.numpy() if layer['class_name'] == 'Threshold': layer['activ_param'] = class_object.threshold layer['class_name'] = 'ThresholdedReLU' diff --git a/hls4ml/model/layers.py b/hls4ml/model/layers.py index d8d1fb9c8f..1ceb6456b8 100644 --- a/hls4ml/model/layers.py +++ b/hls4ml/model/layers.py @@ -845,6 +845,17 @@ def initialize(self): class ParametrizedActivation(Activation): + _expected_attributes = [ + Attribute('n_in'), + Attribute('activation', value_type=str), + TypeAttribute('param'), + ] + + def initialize(self): + super().initialize() + param_t = NamedType(*reversed(self.model.config.get_precision(self, 'param'))) + self.set_attr('param_t', param_t) + def _get_act_function_name(self): act = self.get_attr('activation').lower() if act == 'leakyrelu': @@ -882,9 +893,16 @@ def initialize(self): class PReLU(Activation): + _expected_attributes = [ + Attribute('n_in'), + Attribute('activation', value_type=str), + WeightAttribute('param'), + TypeAttribute('param'), + ] + def initialize(self): super().initialize() - self.add_weights_variable(name='alpha', var_name='a{index}') + self.add_weights_variable(name='param', var_name='a{index}') class Softmax(Activation): diff --git a/hls4ml/model/optimizer/passes/infer_precision.py b/hls4ml/model/optimizer/passes/infer_precision.py index bb24f2206e..0e68cc346d 100644 --- a/hls4ml/model/optimizer/passes/infer_precision.py +++ b/hls4ml/model/optimizer/passes/infer_precision.py @@ -84,6 +84,9 @@ def _infer_precision(self, node, types_to_infer): if node_class in ['SimpleRNN', 'LSTM', 'GRU']: return self._infer_rnn_precision(node, types_to_infer) + if node_class in ['ParametrizedActivation']: + return self._infer_par_act_precision(node, types_to_infer) + # What about quantized activation layer? Setting it to 'auto' manually will break it here. We should prevent # this in config_from_* functions @@ -557,3 +560,14 @@ def _infer_rnn_precision(self, node, types_to_infer): inferred_types.append(f'{weightvar}_t') return inferred_types + + def _infer_par_act_precision(self, node, types_to_infer): + inferred_types = [] + + # for now, only set if for threshold relu + if 'param_t' in inferred_types and self.get_attr('activation').lower() == 'thresholdedrelu': + in_type = node.get_input_variable().type.precision + node.attributes['param_t'].type = in_type + inferred_types.append('param_t') + + return inferred_types diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_activation.h b/hls4ml/templates/vivado/nnet_utils/nnet_activation.h index da13998e38..59a4ef2bed 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_activation.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_activation.h @@ -499,8 +499,8 @@ void hard_tanh(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) { // ************************************************* // Leaky RELU Activation // ************************************************* -template -void leaky_relu(data_T data[CONFIG_T::n_in], data_T alpha, res_T res[CONFIG_T::n_in]) { +template +void leaky_relu(data_T data[CONFIG_T::n_in], param_T alpha, res_T res[CONFIG_T::n_in]) { #pragma HLS PIPELINE data_T datareg; @@ -516,8 +516,8 @@ void leaky_relu(data_T data[CONFIG_T::n_in], data_T alpha, res_T res[CONFIG_T::n // ************************************************* // Thresholded RELU Activation // ************************************************* -template -void thresholded_relu(data_T data[CONFIG_T::n_in], data_T theta, res_T res[CONFIG_T::n_in]) { +template +void thresholded_relu(data_T data[CONFIG_T::n_in], param_T theta, res_T res[CONFIG_T::n_in]) { #pragma HLS PIPELINE data_T datareg; @@ -646,8 +646,8 @@ template void init_elu_table(typename CONFIG_T: } } -template -void elu(data_T data[CONFIG_T::n_in], const res_T alpha, res_T res[CONFIG_T::n_in]) { +template +void elu(data_T data[CONFIG_T::n_in], const param_T alpha, res_T res[CONFIG_T::n_in]) { // Initialize the lookup table #ifdef __HLS_SYN__ bool initialized = false; @@ -679,8 +679,9 @@ void elu(data_T data[CONFIG_T::n_in], const res_T alpha, res_T res[CONFIG_T::n_i } } -template void elu(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) { - elu(data, 1.0, res); +template +void elu(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) { + elu(data, 1.0, res); } // ************************************************* @@ -738,8 +739,8 @@ template void selu(data_T data[CO // ************************************************* // PReLU Activation // ************************************************* -template -void prelu(data_T data[CONFIG_T::n_in], data_T alpha[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) { +template +void prelu(data_T data[CONFIG_T::n_in], param_T alpha[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) { #pragma HLS PIPELINE data_T datareg; diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_activation_stream.h b/hls4ml/templates/vivado/nnet_utils/nnet_activation_stream.h index 4f12ee5cb4..2513fb54b3 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_activation_stream.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_activation_stream.h @@ -499,8 +499,8 @@ template void hard_tanh(hls::stre // Leaky RELU Activation // ************************************************* -template -void leaky_relu(hls::stream &data, typename data_T::value_type alpha, hls::stream &res) { +template +void leaky_relu(hls::stream &data, param_T alpha, hls::stream &res) { LeakyReLUActLoop: for (int i = 0; i < CONFIG_T::n_in / res_T::size; i++) { #pragma HLS PIPELINE @@ -525,8 +525,8 @@ void leaky_relu(hls::stream &data, typename data_T::value_type alpha, hl // Thresholded RELU Activation // ************************************************* -template -void thresholded_relu(hls::stream &data, typename data_T::value_type theta, hls::stream &res) { +template +void thresholded_relu(hls::stream &data, param_T theta, hls::stream &res) { ThresholdedReLUActLoop: for (int i = 0; i < CONFIG_T::n_in / res_T::size; i++) { #pragma HLS PIPELINE @@ -633,8 +633,8 @@ template void softsign(hls::strea // ************************************************* // ELU Activation // ************************************************* -template -void elu(hls::stream &data, typename data_T::value_type alpha, hls::stream &res) { +template +void elu(hls::stream &data, param_T alpha, hls::stream &res) { // Initialize the lookup table #ifdef __HLS_SYN__ bool initialized = false; @@ -674,8 +674,9 @@ void elu(hls::stream &data, typename data_T::value_type alpha, hls::stre } } -template void elu(hls::stream &data, hls::stream &res) { - elu(data, 1.0, res); +template +void elu(hls::stream &data, hls::stream &res) { + elu(data, 1.0, res); } // ************************************************* @@ -726,8 +727,8 @@ template void selu(hls::stream -void prelu(hls::stream &data, typename data_T::value_type alpha[CONFIG_T::n_in], hls::stream &res) { +template +void prelu(hls::stream &data, param_T alpha[CONFIG_T::n_in], hls::stream &res) { PReLUActLoop: for (int i = 0; i < CONFIG_T::n_in / res_T::size; i++) { #pragma HLS PIPELINE diff --git a/test/pytest/test_activations.py b/test/pytest/test_activations.py index 5ab9481e1a..f156b1cdc3 100644 --- a/test/pytest/test_activations.py +++ b/test/pytest/test_activations.py @@ -41,7 +41,7 @@ def test_activations(backend, activation, name, shape, io_type): activation = activation(input) keras_model = Model(inputs=input, outputs=activation) - hls_config = hls4ml.utils.config_from_keras_model(keras_model) + hls_config = hls4ml.utils.config_from_keras_model(keras_model, granularity='name', backend=backend) output_dir = str(test_root_path / 'hls4mlprj_activations_{}_{}_{}_{}').format(backend, io_type, str(shape), name) hls_model = hls4ml.converters.convert_from_keras_model( From b19368c0dbd1e74da17dd59bd00bef9c8957ef6f Mon Sep 17 00:00:00 2001 From: Jovan Mitrevski Date: Sat, 28 Sep 2024 21:45:05 -0500 Subject: [PATCH 2/5] update quartus and catapult --- .../catapult/passes/core_templates.py | 29 +++++++++++++++++-- .../backends/quartus/passes/core_templates.py | 29 +++++++++++++++++-- .../catapult/nnet_utils/nnet_activation.h | 20 ++++++------- .../nnet_utils/nnet_activation_stream.h | 20 ++++++------- .../firmware/nnet_utils/nnet_activation.h | 20 ++++++------- .../nnet_utils/nnet_activation_stream.h | 20 ++++++------- .../nnet_utils/nnet_activation_stream.h | 2 +- 7 files changed, 93 insertions(+), 47 deletions(-) diff --git a/hls4ml/backends/catapult/passes/core_templates.py b/hls4ml/backends/catapult/passes/core_templates.py index 2088923428..77c3b85524 100755 --- a/hls4ml/backends/catapult/passes/core_templates.py +++ b/hls4ml/backends/catapult/passes/core_templates.py @@ -115,6 +115,15 @@ def format(self, node): typedef {table_t.name} table_t; }};\n""" +param_activ_config_template = """struct {type}_config{index} : nnet::activ_config {{ + static const unsigned n_in = {n_in}; + static const unsigned table_size = {table_size}; + static const unsigned io_type = nnet::{iotype}; + static const unsigned reuse_factor = {reuse}; + typedef {table_t.name} table_t; + typedef {param_t.name} param_t; +}};\n""" + hard_activ_config_template = """struct {type}_config{index} {{ static const unsigned n_in = {n_in}; static const {slope_t.name} slope; @@ -140,14 +149,16 @@ def format(self, node): }};\n""" activ_function_template = 'nnet::{activation}<{input_t}, {output_t}, {config}>({input}, {output});' -param_activ_function_template = 'nnet::{activation}<{input_t}, {output_t}, {config}>({input}, {param}, {output});' +param_activ_function_template = ( + 'nnet::{activation}<{input_t}, {param_t.name}, {output_t}, {config}>({input}, {param}, {output});' +) activ_include_list = ['nnet_utils/nnet_activation.h', 'nnet_utils/nnet_activation_stream.h'] class ActivationConfigTemplate(LayerConfigTemplate): def __init__(self): - super().__init__((Activation, ParametrizedActivation, PReLU)) + super().__init__(Activation) self.template = activ_config_template def format(self, node): @@ -157,6 +168,18 @@ def format(self, node): return self.template.format(**params) +class ParamActivationConfigTemplate(LayerConfigTemplate): + def __init__(self): + super().__init__((ParametrizedActivation, PReLU)) + self.template = param_activ_config_template + + def format(self, node): + params = self._default_config_params(node) + params['type'] = node.get_attr('activation') + + return self.template.format(**params) + + class HardActivationConfigTemplate(LayerConfigTemplate): def __init__(self): super().__init__(HardActivation) @@ -210,7 +233,7 @@ def __init__(self): def format(self, node): params = self._default_function_params(node) params['activation'] = node.get_attr('activation').lower() - params['param'] = node.get_weights('alpha').name + params['param'] = node.get_weights('param').name params['config'] = '{}_config{}'.format(node.get_attr('activation'), node.index) return self.template.format(**params) diff --git a/hls4ml/backends/quartus/passes/core_templates.py b/hls4ml/backends/quartus/passes/core_templates.py index d6998c9ab2..b474e14df5 100644 --- a/hls4ml/backends/quartus/passes/core_templates.py +++ b/hls4ml/backends/quartus/passes/core_templates.py @@ -125,6 +125,15 @@ def format(self, node): typedef {table_t.name} table_t; }};\n""" +param_activ_config_template = """struct {type}_config{index} : nnet::activ_config {{ + static const unsigned n_in = {n_in}; + static const unsigned table_size = {table_size}; + static const unsigned io_type = nnet::{iotype}; + static const unsigned reuse_factor = {reuse}; + typedef {table_t.name} table_t; + typedef {param_t.name} param_t; +}};\n""" + hard_activ_config_template = """struct {type}_config{index} {{ static const unsigned n_in = {n_in}; static const {slope_t.name} slope; @@ -146,14 +155,16 @@ def format(self, node): }};\n""" activ_function_template = 'nnet::{activation}<{input_t}, {output_t}, {config}>({input}, {output});' -param_activ_function_template = 'nnet::{activation}<{input_t}, {output_t}, {config}>({input}, {param}, {output});' +param_activ_function_template = ( + 'nnet::{activation}<{input_t}, {param_t.name}, {output_t}, {config}>({input}, {param}, {output});' +) activ_include_list = ['nnet_utils/nnet_activation.h', 'nnet_utils/nnet_activation_stream.h'] class ActivationConfigTemplate(LayerConfigTemplate): def __init__(self): - super().__init__((Activation, ParametrizedActivation, PReLU, UnaryLUT)) + super().__init__((Activation, UnaryLUT)) self.template = activ_config_template def format(self, node): @@ -163,6 +174,18 @@ def format(self, node): return self.template.format(**params) +class ParamActivationConfigTemplate(LayerConfigTemplate): + def __init__(self): + super().__init__((ParametrizedActivation, PReLU)) + self.template = param_activ_config_template + + def format(self, node): + params = self._default_config_params(node) + params['type'] = node.get_attr('activation') + + return self.template.format(**params) + + class HardActivationConfigTemplate(LayerConfigTemplate): def __init__(self): super().__init__(HardActivation) @@ -216,7 +239,7 @@ def __init__(self): def format(self, node): params = self._default_function_params(node) params['activation'] = node.get_attr('activation').lower() - params['param'] = node.get_weights('alpha').name + params['param'] = node.get_weights('param').name params['config'] = '{}_config{}'.format(node.get_attr('activation'), node.index) return self.template.format(**params) diff --git a/hls4ml/templates/catapult/nnet_utils/nnet_activation.h b/hls4ml/templates/catapult/nnet_utils/nnet_activation.h index f08e75a0d6..fb72460b96 100644 --- a/hls4ml/templates/catapult/nnet_utils/nnet_activation.h +++ b/hls4ml/templates/catapult/nnet_utils/nnet_activation.h @@ -686,8 +686,8 @@ void hard_tanh(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) { // ************************************************* // Leaky RELU Activation // ************************************************* -template -void leaky_relu(data_T data[CONFIG_T::n_in], data_T alpha, res_T res[CONFIG_T::n_in]) { +template +void leaky_relu(data_T data[CONFIG_T::n_in], param_T alpha, res_T res[CONFIG_T::n_in]) { //#pragma HLS PIPELINE data_T datareg; @@ -703,8 +703,8 @@ void leaky_relu(data_T data[CONFIG_T::n_in], data_T alpha, res_T res[CONFIG_T::n // ************************************************* // Thresholded RELU Activation // ************************************************* -template -void thresholded_relu(data_T data[CONFIG_T::n_in], data_T theta, res_T res[CONFIG_T::n_in]) { +template +void thresholded_relu(data_T data[CONFIG_T::n_in], param_T theta, res_T res[CONFIG_T::n_in]) { //#pragma HLS PIPELINE data_T datareg; @@ -917,8 +917,8 @@ template void init_elu_table(typename CONFIG_T: #ifndef USE_AC_MATH -template -void elu(data_T data[CONFIG_T::n_in], const res_T alpha, res_T res[CONFIG_T::n_in]) { +template +void elu(data_T data[CONFIG_T::n_in], const param_T alpha, res_T res[CONFIG_T::n_in]) { // Initialize the lookup table #ifdef __HLS_SYN__ bool initialized = false; @@ -953,8 +953,8 @@ void elu(data_T data[CONFIG_T::n_in], const res_T alpha, res_T res[CONFIG_T::n_i #else -template -void elu(data_T data[CONFIG_T::n_in], const res_T alpha, res_T res[CONFIG_T::n_in]) { +template +void elu(data_T data[CONFIG_T::n_in], const param_T alpha, res_T res[CONFIG_T::n_in]) { for (int ii = 0; ii < CONFIG_T::n_in; ii++) { ac_math::ac_elu_pwl(data[ii], res[ii], alpha); } @@ -1045,8 +1045,8 @@ template void selu(data_T data[CO // ************************************************* // PReLU Activation // ************************************************* -template -void prelu(data_T data[CONFIG_T::n_in], data_T alpha[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) { +template +void prelu(data_T data[CONFIG_T::n_in], param_T alpha[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) { //#pragma HLS PIPELINE data_T datareg; diff --git a/hls4ml/templates/catapult/nnet_utils/nnet_activation_stream.h b/hls4ml/templates/catapult/nnet_utils/nnet_activation_stream.h index 509560bc2b..82570dbe51 100644 --- a/hls4ml/templates/catapult/nnet_utils/nnet_activation_stream.h +++ b/hls4ml/templates/catapult/nnet_utils/nnet_activation_stream.h @@ -545,8 +545,8 @@ template void hard_tanh(ac_channe // ************************************************* // Leaky RELU Activation // ************************************************* -template -void leaky_relu(ac_channel &data, typename data_T::value_type alpha, ac_channel &res) { +template +void leaky_relu(ac_channel &data, param_T alpha, ac_channel &res) { LeakyReLUActLoop: for (int i = 0; i < CONFIG_T::n_in / res_T::size; i++) { //#pragma HLS PIPELINE @@ -571,8 +571,8 @@ void leaky_relu(ac_channel &data, typename data_T::value_type alpha, ac_ // Thresholded RELU Activation // ************************************************* -template -void thresholded_relu(ac_channel &data, typename data_T::value_type theta, ac_channel &res) { +template +void thresholded_relu(ac_channel &data, param_T theta, ac_channel &res) { ThresholdedReLUActLoop: for (int i = 0; i < CONFIG_T::n_in / res_T::size; i++) { //#pragma HLS PIPELINE @@ -720,8 +720,8 @@ template void softsign(ac_channel #ifndef USE_AC_MATH -template -void elu(ac_channel &data, typename data_T::value_type alpha, ac_channel &res) { +template +void elu(ac_channel &data, param_T alpha, ac_channel &res) { // Initialize the lookup table #ifdef __HLS_SYN__ bool initialized = false; @@ -763,8 +763,8 @@ void elu(ac_channel &data, typename data_T::value_type alpha, ac_channel } #else -template -void elu(ac_channel &data, typename data_T::value_type alpha, ac_channel &res) { +template +void elu(ac_channel &data, param_T alpha, ac_channel &res) { EluActLoop: for (int i = 0; i < CONFIG_T::n_in / res_T::size; i++) { data_T in_data = data.read(); @@ -845,8 +845,8 @@ template void selu(ac_channel -void prelu(ac_channel &data, typename data_T::value_type alpha[CONFIG_T::n_in], ac_channel &res) { +template +void prelu(ac_channel &data, const param_T alpha[CONFIG_T::n_in], ac_channel &res) { PReLUActLoop: for (int i = 0; i < CONFIG_T::n_in / res_T::size; i++) { //#pragma HLS PIPELINE diff --git a/hls4ml/templates/quartus/firmware/nnet_utils/nnet_activation.h b/hls4ml/templates/quartus/firmware/nnet_utils/nnet_activation.h index a70096e2f5..859583a577 100644 --- a/hls4ml/templates/quartus/firmware/nnet_utils/nnet_activation.h +++ b/hls4ml/templates/quartus/firmware/nnet_utils/nnet_activation.h @@ -333,8 +333,8 @@ void hard_tanh(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) { // ************************************************* // Leaky RELU Activation // ************************************************* -template -void leaky_relu(data_T data[CONFIG_T::n_in], data_T alpha, res_T res[CONFIG_T::n_in]) { +template +void leaky_relu(data_T data[CONFIG_T::n_in], param_T alpha, res_T res[CONFIG_T::n_in]) { #pragma unroll for (int ii = 0; ii < CONFIG_T::n_in; ii++) { data_T datareg = data[ii]; @@ -348,8 +348,8 @@ void leaky_relu(data_T data[CONFIG_T::n_in], data_T alpha, res_T res[CONFIG_T::n // ************************************************* // Thresholded RELU Activation // ************************************************* -template -void thresholded_relu(data_T data[CONFIG_T::n_in], data_T theta, res_T res[CONFIG_T::n_in]) { +template +void thresholded_relu(data_T data[CONFIG_T::n_in], param_T theta, res_T res[CONFIG_T::n_in]) { #pragma unroll for (int ii = 0; ii < CONFIG_T::n_in; ii++) { data_T datareg = data[ii]; @@ -414,8 +414,8 @@ void softsign(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) { // ************************************************* // ELU Activation // ************************************************* -template -void elu(data_T data[CONFIG_T::n_in], const res_T alpha, res_T res[CONFIG_T::n_in]) { +template +void elu(data_T data[CONFIG_T::n_in], const param_T alpha, res_T res[CONFIG_T::n_in]) { // Initialize the lookup table #include "activation_tables/elu_table.tb" // Index into the lookup table based on data @@ -433,8 +433,8 @@ void elu(data_T data[CONFIG_T::n_in], const res_T alpha, res_T res[CONFIG_T::n_i } } -template void elu(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) { - elu(data, 1.0, res); +template void elu(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) { + elu(data, 1.0, res); } // ************************************************* @@ -461,8 +461,8 @@ template void selu(data_T data[CO // ************************************************* // PReLU Activation // ************************************************* -template -void prelu(data_T data[CONFIG_T::n_in], const data_T alpha[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) { +template +void prelu(data_T data[CONFIG_T::n_in], const param_T alpha[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) { #pragma unroll for (int ii = 0; ii < CONFIG_T::n_in; ii++) { data_T datareg = data[ii]; diff --git a/hls4ml/templates/quartus/firmware/nnet_utils/nnet_activation_stream.h b/hls4ml/templates/quartus/firmware/nnet_utils/nnet_activation_stream.h index f0562a9b22..d409bc482a 100644 --- a/hls4ml/templates/quartus/firmware/nnet_utils/nnet_activation_stream.h +++ b/hls4ml/templates/quartus/firmware/nnet_utils/nnet_activation_stream.h @@ -52,8 +52,8 @@ template void relu(stream // ************************************************* // Leaky RELU Activation // ************************************************* -template -void leaky_relu(stream &data, const typename data_T::value_type alpha, stream &res) { +template +void leaky_relu(stream &data, param_T alpha, stream &res) { constexpr unsigned multiplier_limit = DIV_ROUNDUP(data_T::size, CONFIG_T::reuse_factor); constexpr unsigned pipeline = data_T::size / multiplier_limit; @@ -79,8 +79,8 @@ void leaky_relu(stream &data, const typename data_T::value_type alpha, s // ************************************************* // Thresholded RELU Activation // ************************************************* -template -void thresholded_relu(stream &data, const typename data_T::value_type theta, stream &res) { +template +void thresholded_relu(stream &data, param_T theta, stream &res) { ThresholdedReLUActLoop: #pragma ii 1 for (int i = 0; i < CONFIG_T::n_in / res_T::size; i++) { @@ -103,8 +103,8 @@ void thresholded_relu(stream &data, const typename data_T::value_type th // ************************************************* // ELU Activation // ************************************************* -template -void elu(stream &data, const typename data_T::value_type alpha, stream &res) { +template +void elu(stream &data, param_T alpha, stream &res) { #include "activation_tables/elu_table.tb" constexpr unsigned multiplier_limit = DIV_ROUNDUP(data_T::size, CONFIG_T::reuse_factor); @@ -134,8 +134,8 @@ void elu(stream &data, const typename data_T::value_type alpha, stream void elu(stream &data, stream &res) { - elu(data, 1.0, res); +template void elu(stream &data, stream &res) { + elu(data, 1.0, res); } // ************************************************* @@ -171,8 +171,8 @@ template void selu(stream // ************************************************* // PReLU Activation // ************************************************* -template -void prelu(stream &data, const typename data_T::value_type alpha[CONFIG_T::n_in], stream &res) { +template +void prelu(stream &data, const param_T alpha[CONFIG_T::n_in], stream &res) { constexpr unsigned multiplier_limit = DIV_ROUNDUP(data_T::size, CONFIG_T::reuse_factor); constexpr unsigned pipeline = data_T::size / multiplier_limit; diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_activation_stream.h b/hls4ml/templates/vivado/nnet_utils/nnet_activation_stream.h index 2513fb54b3..810e985600 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_activation_stream.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_activation_stream.h @@ -728,7 +728,7 @@ template void selu(hls::stream -void prelu(hls::stream &data, param_T alpha[CONFIG_T::n_in], hls::stream &res) { +void prelu(hls::stream &data, const param_T alpha[CONFIG_T::n_in], hls::stream &res) { PReLUActLoop: for (int i = 0; i < CONFIG_T::n_in / res_T::size; i++) { #pragma HLS PIPELINE From 9c8c1dd27387e271543ec3309cee3b96937ff748 Mon Sep 17 00:00:00 2001 From: Jovan Mitrevski Date: Sat, 28 Sep 2024 21:55:29 -0500 Subject: [PATCH 3/5] fix pre-commit --- hls4ml/templates/quartus/firmware/nnet_utils/nnet_activation.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/hls4ml/templates/quartus/firmware/nnet_utils/nnet_activation.h b/hls4ml/templates/quartus/firmware/nnet_utils/nnet_activation.h index 859583a577..d60ff58cd0 100644 --- a/hls4ml/templates/quartus/firmware/nnet_utils/nnet_activation.h +++ b/hls4ml/templates/quartus/firmware/nnet_utils/nnet_activation.h @@ -433,7 +433,8 @@ void elu(data_T data[CONFIG_T::n_in], const param_T alpha, res_T res[CONFIG_T::n } } -template void elu(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) { +template +void elu(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) { elu(data, 1.0, res); } From 22f380027ab315e1e1cfc4ad0698a5e4fe66a733 Mon Sep 17 00:00:00 2001 From: Jovan Mitrevski Date: Sun, 29 Sep 2024 11:00:12 -0500 Subject: [PATCH 4/5] fix non-parametrized version of elu --- .../templates/quartus/firmware/nnet_utils/nnet_activation.h | 5 ++--- .../quartus/firmware/nnet_utils/nnet_activation_stream.h | 4 ++-- hls4ml/templates/vivado/nnet_utils/nnet_activation.h | 5 ++--- hls4ml/templates/vivado/nnet_utils/nnet_activation_stream.h | 5 ++--- 4 files changed, 8 insertions(+), 11 deletions(-) diff --git a/hls4ml/templates/quartus/firmware/nnet_utils/nnet_activation.h b/hls4ml/templates/quartus/firmware/nnet_utils/nnet_activation.h index d60ff58cd0..1dea511c10 100644 --- a/hls4ml/templates/quartus/firmware/nnet_utils/nnet_activation.h +++ b/hls4ml/templates/quartus/firmware/nnet_utils/nnet_activation.h @@ -433,9 +433,8 @@ void elu(data_T data[CONFIG_T::n_in], const param_T alpha, res_T res[CONFIG_T::n } } -template -void elu(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) { - elu(data, 1.0, res); +template void elu(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) { + elu, res_T, CONFIG_T>(data, 1.0, res); } // ************************************************* diff --git a/hls4ml/templates/quartus/firmware/nnet_utils/nnet_activation_stream.h b/hls4ml/templates/quartus/firmware/nnet_utils/nnet_activation_stream.h index d409bc482a..e29592d1e1 100644 --- a/hls4ml/templates/quartus/firmware/nnet_utils/nnet_activation_stream.h +++ b/hls4ml/templates/quartus/firmware/nnet_utils/nnet_activation_stream.h @@ -134,8 +134,8 @@ void elu(stream &data, param_T alpha, stream &res) { } } -template void elu(stream &data, stream &res) { - elu(data, 1.0, res); +template void elu(stream &data, stream &res) { + elu, res_T, CONFIG_T>(data, 1.0, res); } // ************************************************* diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_activation.h b/hls4ml/templates/vivado/nnet_utils/nnet_activation.h index 59a4ef2bed..4683239d85 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_activation.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_activation.h @@ -679,9 +679,8 @@ void elu(data_T data[CONFIG_T::n_in], const param_T alpha, res_T res[CONFIG_T::n } } -template -void elu(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) { - elu(data, 1.0, res); +template void elu(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) { + elu, res_T, CONFIG_T>(data, 1.0, res); } // ************************************************* diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_activation_stream.h b/hls4ml/templates/vivado/nnet_utils/nnet_activation_stream.h index 810e985600..ef687243bf 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_activation_stream.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_activation_stream.h @@ -674,9 +674,8 @@ void elu(hls::stream &data, param_T alpha, hls::stream &res) { } } -template -void elu(hls::stream &data, hls::stream &res) { - elu(data, 1.0, res); +template void elu(hls::stream &data, hls::stream &res) { + elu, res_T, CONFIG_T>(data, 1.0, res); } // ************************************************* From e68b770e0abc087b8066d6f2c5ee5e013e2dcab0 Mon Sep 17 00:00:00 2001 From: Jovan Mitrevski Date: Tue, 1 Oct 2024 11:36:36 -0500 Subject: [PATCH 5/5] update comment on parametriced activation precision --- hls4ml/model/optimizer/passes/infer_precision.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/hls4ml/model/optimizer/passes/infer_precision.py b/hls4ml/model/optimizer/passes/infer_precision.py index 0e68cc346d..bd439e4a0f 100644 --- a/hls4ml/model/optimizer/passes/infer_precision.py +++ b/hls4ml/model/optimizer/passes/infer_precision.py @@ -564,7 +564,9 @@ def _infer_rnn_precision(self, node, types_to_infer): def _infer_par_act_precision(self, node, types_to_infer): inferred_types = [] - # for now, only set if for threshold relu + # For threshold relu, set the parameter precision to be the input precision by default; + # for other parametrized activations, just allow the default precision to be used. + # Can override these values in the configuration by explicitly setting them. if 'param_t' in inferred_types and self.get_attr('activation').lower() == 'thresholdedrelu': in_type = node.get_input_variable().type.precision node.attributes['param_t'].type = in_type