From 4e92b7b95e111621bc4eb82ea5e50aeb67cdc4be Mon Sep 17 00:00:00 2001 From: Benjamin Ramhorst Date: Fri, 26 May 2023 17:50:22 +0100 Subject: [PATCH 01/14] Introduce unrolled implementation of Dense Resource --- .../vivado/passes/convolution_templates.py | 29 ++++++++++++++++ .../backends/vivado/passes/core_templates.py | 9 +++++ .../vivado/passes/recurrent_templates.py | 13 +++++-- hls4ml/backends/vivado/vivado_backend.py | 34 +++++++++++++++++-- hls4ml/model/graph.py | 22 ++++++++++++ .../vivado/nnet_utils/nnet_code_gen.h | 12 ++++++- .../vivado/nnet_utils/nnet_conv_stream.h | 20 +++++------ .../templates/vivado/nnet_utils/nnet_dense.h | 15 +++++++- .../vivado/nnet_utils/nnet_dense_stream.h | 2 ++ 9 files changed, 139 insertions(+), 17 deletions(-) diff --git a/hls4ml/backends/vivado/passes/convolution_templates.py b/hls4ml/backends/vivado/passes/convolution_templates.py index 4845b8f1da..dde42d97fe 100644 --- a/hls4ml/backends/vivado/passes/convolution_templates.py +++ b/hls4ml/backends/vivado/passes/convolution_templates.py @@ -9,6 +9,9 @@ static const unsigned n_out = {n_out}; static const unsigned reuse_factor = {reuse}; static const unsigned strategy = nnet::{strategy}; + static const unsigned resource_implementation = nnet::{dense_resource_implementation}; + template + using dense_unrolled = nnet::{unrolled_function}; static const unsigned n_zeros = 0; static const unsigned multiplier_limit = DIV_ROUNDUP(n_in * n_out, reuse_factor) - n_zeros / reuse_factor; typedef {accum_t.name} accum_t; @@ -86,6 +89,8 @@ def format(self, node): mult_params['product_type'] = get_backend('vivado').product_type( node.get_input_variable().type.precision, node.get_weights('weight').type.precision ) + # TODO - Extend unrolled Dense Resource to Conv1D + mult_params['unrolled_function'] = 'DenseResourceUnrolled' mult_config = self.mult_template.format(**mult_params) return mult_config + '\n' + conv_config @@ -130,6 +135,9 @@ def format(self, node): static const bool store_weights_in_bram = false; static const unsigned strategy = nnet::{strategy}; static const nnet::conv_implementation implementation = nnet::conv_implementation::{implementation}; + static const unsigned resource_implementation = nnet::{dense_resource_implementation}; + template + using dense_unrolled = nnet::{unrolled_function}; static const unsigned min_height = {min_height}; static const unsigned min_width = {min_width}; static const ap_uint pixels[min_height * min_width]; @@ -183,6 +191,12 @@ def format(self, node): params['fill_fn'] = f'fill_buffer_{node.index}' else: params['fill_fn'] = 'FillConv2DBuffer' + + if node.get_attr('dense_resource_implementation', 'standard') == 'unrolled' and node.get_attr('strategy').lower() == 'resource' and node.get_attr('reuse_factor') > 1: + # Implemented in subsequent commits + params['unrolled_function'] = 'DenseResourceUnrolled' + else: + params['unrolled_function'] = 'DenseResourceUnrolled' conv_config = self.template.format(**params) @@ -192,6 +206,11 @@ def format(self, node): mult_params['product_type'] = get_backend('vivado').product_type( node.get_input_variable().type.precision, node.get_weights('weight').type.precision ) + if node.get_attr('dense_resource_implementation', 'standard') == 'unrolled' and node.get_attr('strategy').lower() == 'resource' and node.get_attr('reuse_factor') > 1: + # Implemented in subsequent commits + mult_params['unrolled_function'] = 'DenseResourceUnrolled' + else: + mult_params['unrolled_function'] = 'DenseResourceUnrolled' mult_config = self.mult_template.format(**mult_params) return mult_config + '\n' + conv_config @@ -278,6 +297,9 @@ def format(self, node): mult_params['product_type'] = get_backend('vivado').product_type( node.get_input_variable().type.precision, node.get_weights('depthwise').type.precision ) + # TODO - Extend unrolled Dense Resource to depthwise Conv1D + mult_params['unrolled_function'] = 'DenseResourceUnrolled' + depthwise_mult_config = self.depthwise_mult_template.format(**mult_params) # Pointwise config @@ -317,6 +339,9 @@ def format(self, node): mult_params['product_type'] = get_backend('vivado').product_type( node.get_input_variable().type.precision, node.get_weights('pointwise').type.precision ) + # TODO - Extend unrolled Dense Resource to separable Conv1D + mult_params['unrolled_function'] = 'DenseResourceUnrolled' + pointwise_mult_config = self.pointwise_mult_template.format(**mult_params) return ( @@ -399,6 +424,8 @@ def format(self, node): mult_params['product_type'] = get_backend('vivado').product_type( node.get_input_variable().type.precision, node.get_weights('depthwise').type.precision ) + # TODO - Extend unrolled Dense Resource to depthwise Conv2D + mult_params['unrolled_function'] = 'DenseResourceUnrolled' depthwise_mult_config = self.depthwise_mult_template.format(**mult_params) # Pointwise config @@ -442,6 +469,8 @@ def format(self, node): mult_params['product_type'] = get_backend('vivado').product_type( node.get_input_variable().type.precision, node.get_weights('pointwise').type.precision ) + # TODO - Extend unrolled Dense Resource to separable Conv2D + mult_params['unrolled_function'] = 'DenseResourceUnrolled' pointwise_mult_config = self.pointwise_mult_template.format(**mult_params) return ( diff --git a/hls4ml/backends/vivado/passes/core_templates.py b/hls4ml/backends/vivado/passes/core_templates.py index c8119c0c2e..faabf434eb 100644 --- a/hls4ml/backends/vivado/passes/core_templates.py +++ b/hls4ml/backends/vivado/passes/core_templates.py @@ -9,6 +9,9 @@ static const unsigned n_out = {n_out}; static const unsigned io_type = nnet::{iotype}; static const unsigned strategy = nnet::{strategy}; + static const unsigned resource_implementation = nnet::{dense_resource_implementation}; + template + using dense_unrolled = nnet::{unrolled_function}; static const unsigned reuse_factor = {reuse}; static const unsigned n_zeros = {nzeros}; static const unsigned n_nonzeros = {nonzeros}; @@ -40,6 +43,12 @@ def format(self, node): node.get_input_variable().type.precision, node.get_weights('weight').type.precision ) + if node.get_attr('dense_resource_implementation', 'standard') == 'unrolled' and node.get_attr('strategy').lower() == 'resource' and node.get_attr('reuse_factor') > 1: + # Implemented in subsequent commits + params['unrolled_function'] = 'DenseResourceUnrolled' + else: + params['unrolled_function'] = 'DenseResourceUnrolled' + return self.template.format(**params) diff --git a/hls4ml/backends/vivado/passes/recurrent_templates.py b/hls4ml/backends/vivado/passes/recurrent_templates.py index aae806b35c..eb12412def 100644 --- a/hls4ml/backends/vivado/passes/recurrent_templates.py +++ b/hls4ml/backends/vivado/passes/recurrent_templates.py @@ -11,6 +11,9 @@ static const unsigned reuse_factor = {reuse}; static const unsigned n_zeros = {nzeros}; static const unsigned n_nonzeros = {nonzeros}; + static const unsigned resource_implementation = nnet::{dense_resource_implementation}; + template + using dense_unrolled = nnet::{unrolled_function}; static const unsigned multiplier_limit = DIV_ROUNDUP(n_in * n_out, reuse_factor) - n_zeros / reuse_factor; static const bool store_weights_in_bram = false; typedef {accum_t.name} accum_t; @@ -137,6 +140,10 @@ def format(self, node): mult_params1['index'] = str(node.index) + '_1' mult_params1['nzeros'] = node.get_weights('weight').nzeros mult_params1['nonzeros'] = node.get_weights('weight').nonzeros + + # TODO - Extend unrolled Dense Resource to recurrent kernels + mult_params1['unrolled_function'] = 'DenseResourceUnrolled' + if node.get_attr('return_sequences'): mult_params2['n_in'] = node.get_output_variable().dim_names[1] mult_params2['n_out'] = node.get_output_variable().dim_names[1] + ' * %i' % n_recr_mult @@ -150,13 +157,15 @@ def format(self, node): mult_params2['index'] = str(node.index) + '_2' mult_params2['nzeros'] = node.get_weights('recurrent_weight').nzeros mult_params2['nonzeros'] = node.get_weights('recurrent_weight').nonzeros - + + # TODO - Extend unrolled Dense Resource to recurrent kernels + mult_params2['unrolled_function'] = 'DenseResourceUnrolled' + mult_config1 = self.mult1_template.format(**mult_params1) mult_config2 = self.mult2_template.format(**mult_params2) return mult_config1 + '\n' + mult_config2 + '\n' + recr_act_config + '\n' + act_config + '\n' + recr_config - class RecurrentFunctionTemplate(FunctionCallTemplate): def __init__(self): super().__init__((LSTM, GRU), include_header=recr_include_list) diff --git a/hls4ml/backends/vivado/vivado_backend.py b/hls4ml/backends/vivado/vivado_backend.py index 1d4c96d982..1f71ddcdc4 100644 --- a/hls4ml/backends/vivado/vivado_backend.py +++ b/hls4ml/backends/vivado/vivado_backend.py @@ -68,12 +68,20 @@ def _register_layer_attributes(self): # Add ConvImplementation to Convolution+Pooling layers cnn_layers = [Conv1D, Conv2D, SeparableConv1D, SeparableConv2D, DepthwiseConv2D, Pooling1D, Pooling2D] - for layer in cnn_layers: attrs = self.attribute_map.get(layer, []) # attrs.append(ConfigurableAttribute('conv_implementation', value_type=str, default='LineBuffer')) attrs.append(ChoiceAttribute('conv_implementation', choices=['LineBuffer', 'Encoded'], default='LineBuffer')) self.attribute_map[layer] = attrs + + # Add implementation of Dense Resource for all layers that use Dense for matrix mult + # Handle different implementations of Resource strategy; this attribute only makes a difference if strategy == Resource + # Standard -> nnet_dense_resource.h + # Unrolled -> Code generation, ignoring zero DSPs and optimizing zero-filled BRAM blocks + for layer in [Dense] + cnn_layers + rnn_layers: + attrs = self.attribute_map.get(layer, []) + attrs.append(ChoiceAttribute('dense_resource_implementation', choices=['standard', 'unrolled'], default='standard')) + self.attribute_map[layer] = attrs def _register_flows(self): initializers = self._get_layer_initializers() @@ -240,6 +248,7 @@ def init_dense(self, layer): else: layer.set_attr('strategy', 'latency') layer.set_attr('index_t', NamedType(f'layer{layer.index}_index', index_t)) + layer.set_attr('dense_resource_implementation', layer.model.config.get_dense_resource_implementation(layer).lower()) # TODO consolidate these functions into a single `init_conv` @layer_optimizer(Conv1D) @@ -270,6 +279,9 @@ def init_conv1d(self, layer): layer.set_attr('n_partitions', out_width // closest_pf) layer.set_attr('implementation', layer.model.config.get_conv_implementation(layer).lower()) + + # TODO - Extend unrolled Dense Resource to Conv1D kernels + layer.set_attr('dense_resource_implementation', 'standard') self._validate_conv_strategy(layer) @@ -286,7 +298,10 @@ def init_sepconv1d(self, layer): 'n_partitions', 1 ) # TODO Once we have SeparableConv implementation for io_parallel this should be set properly layer.set_attr('implementation', layer.model.config.get_conv_implementation(layer).lower()) - + + # TODO - Extend unrolled Dense Resource to separable Conv1D + layer.set_attr('dense_resource_implementation', 'standard') + @layer_optimizer(Conv2D) def init_conv2d(self, layer): if len(layer.weights['weight'].data.shape) == 2: # This can happen if we assign weights of Dense layer to 1x1 Conv2D @@ -313,9 +328,10 @@ def init_conv2d(self, layer): ) else: closest_pf = chosen_pf + layer.set_attr('n_partitions', out_height * out_width // closest_pf) - layer.set_attr('implementation', layer.model.config.get_conv_implementation(layer).lower()) + layer.set_attr('dense_resource_implementation', layer.model.config.get_dense_resource_implementation(layer).lower()) self._validate_conv_strategy(layer) @@ -333,6 +349,9 @@ def init_sepconv2d(self, layer): ) # TODO Once we have SeparableConv implementation for io_parallel this should be set properly layer.set_attr('implementation', layer.model.config.get_conv_implementation(layer).lower()) + # TODO - Extend unrolled Dense Resource to separable Conv2D + layer.set_attr('dense_resource_implementation', 'standard') + @layer_optimizer(DepthwiseConv2D) def init_depconv2d(self, layer): if layer.model.config.is_resource_strategy(layer): @@ -346,6 +365,9 @@ def init_depconv2d(self, layer): 'n_partitions', 1 ) # TODO Once we have SeparableConv implementation for io_parallel this should be set properly layer.set_attr('implementation', layer.model.config.get_conv_implementation(layer).lower()) + + # TODO - Extend unrolled Dense Resource to depthwise Conv2D + layer.set_attr('dense_resource_implementation', 'standard') def _set_pooling_accum_t(self, layer, pool_size): extra_bits = ceil_log2(pool_size) @@ -404,6 +426,9 @@ def init_lstm(self, layer): layer.set_attr('strategy', 'latency') layer.set_attr('index_t', NamedType(f'layer{layer.index}_index', IntegerPrecisionType(width=1, signed=False))) + + # TODO - Extend unrolled Dense Resource to recurrent kernels + layer.set_attr('dense_resource_implementation', 'standard') @layer_optimizer(GRU) def init_gru(self, layer): @@ -419,6 +444,9 @@ def init_gru(self, layer): layer.set_attr('strategy', 'latency') layer.set_attr('index_t', NamedType(f'layer{layer.index}_index', IntegerPrecisionType(width=1, signed=False))) + + # TODO - Extend unrolled Dense Resource to recurrent kernels + layer.set_attr('dense_resource_implementation', 'standard') @layer_optimizer(GarNet) def init_garnet(self, layer): diff --git a/hls4ml/model/graph.py b/hls4ml/model/graph.py index c44fd8f02e..57fb31841d 100644 --- a/hls4ml/model/graph.py +++ b/hls4ml/model/graph.py @@ -43,6 +43,10 @@ def __init__(self, config): self.layer_type_conv_implementation = {} self.layer_name_conv_implementation = {} + self.model_dense_resource_implementation = 'Standard' + self.layer_type_dense_resource_implementation = {} + self.layer_name_dense_resource_implementation = {} + self.model_compression = False self.layer_type_compression = {} self.layer_name_compression = {} @@ -165,6 +169,15 @@ def get_conv_implementation(self, layer): return conv_implementation + def get_dense_resource_implementation(self, layer): + dense_resource_implementation = self.layer_name_dense_resource_implementation.get(layer.name.lower()) + if dense_resource_implementation is None: + dense_resource_implementation = self.layer_type_dense_resource_implementation.get(layer.__class__.__name__.lower()) + if dense_resource_implementation is None: + dense_resource_implementation = self.model_dense_resource_implementation + + return dense_resource_implementation + def is_resource_strategy(self, layer): return self.get_strategy(layer).lower() == 'resource' @@ -212,6 +225,7 @@ def _parse_hls_config(self): self.model_rf = model_cfg.get('ReuseFactor') self.model_targ_cycles = model_cfg.get('TargetCycles') self.model_conv_implementation = model_cfg.get('ConvImplementation', 'LineBuffer') + self.model_dense_resource_implementation = model_cfg.get('DenseResourceImplementation', 'Standard') self.model_strategy = model_cfg.get('Strategy', 'Latency') self.model_compression = bool(model_cfg.get('Compression', 0)) self.pipeline_style = model_cfg.get('PipelineStyle', 'pipeline') @@ -241,6 +255,10 @@ def _parse_hls_config(self): conv_implementation = layer_cfg.get('ConvImplementation') if conv_implementation is not None: self.layer_type_conv_implementation[layer_type.lower()] = conv_implementation + + dense_resource_implementation = layer_cfg.get('DenseResourceImplementation') + if conv_implementation is not None: + self.layer_type_dense_resource_implementation[layer_type.lower()] = dense_resource_implementation compression = layer_cfg.get('Compression') if compression is not None: @@ -271,6 +289,10 @@ def _parse_hls_config(self): conv_implementation = layer_cfg.get('ConvImplementation') if conv_implementation is not None: self.layer_name_conv_implementation[layer_name.lower()] = conv_implementation + + dense_resource_implementation = layer_cfg.get('DenseResourceImplementation') + if conv_implementation is not None: + self.layer_name_dense_resource_implementation[layer_name.lower()] = dense_resource_implementation compression = layer_cfg.get('Compression') if compression is not None: diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_code_gen.h b/hls4ml/templates/vivado/nnet_utils/nnet_code_gen.h index e4db43682e..553044479e 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_code_gen.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_code_gen.h @@ -2,7 +2,6 @@ #define NNET_INSTR_GEN_H_ #include "nnet_helpers.h" -#include namespace nnet { @@ -25,6 +24,17 @@ template class FillConv2DBuffer { } }; +template class DenseResourceUnrolled { + public: + static void dense_unrolled( + data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_out], + typename CONFIG_T::weight_t weights[CONFIG_T::n_in * CONFIG_T::n_out], + typename CONFIG_T::bias_t biases[CONFIG_T::n_out] + ) { + // To be implemented in subclasses + } +}; + // hls4ml insert code } // namespace nnet diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_conv_stream.h b/hls4ml/templates/vivado/nnet_utils/nnet_conv_stream.h index 7bd47442f6..509feb5f35 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_conv_stream.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_conv_stream.h @@ -291,11 +291,11 @@ void compute_output_buffer_2d( // Dense multiply // #pragma HLS INLINE recursive if (CONFIG_T::strategy == nnet::latency) { - dense_latency( - kernel_data, res_out, weights, biases); + dense_latency(kernel_data, res_out, weights, biases); + } else if (CONFIG_T::strategy == nnet::resource && CONFIG_T::resource_implementation == nnet::unrolled && CONFIG_T::reuse_factor > 1) { + CONFIG_T::template dense_unrolled::dense_unrolled(kernel_data, res_out, weights, biases); } else { - dense_resource( - kernel_data, res_out, weights, biases); + dense_resource(kernel_data, res_out, weights, biases); } // Pack output @@ -335,7 +335,7 @@ void compute_output_buffer_1d( const data_T &in_elem, hls::stream &res_stream, typename CONFIG_T::weight_t weights[CONFIG_T::kernel_size * CONFIG_T::n_chan * CONFIG_T::n_filt], typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) { - #pragma HLS INLINE + #pragma HLS INLINE OFF // Thresholds const static int lShiftX = CONFIG_T::filt_width - 1; @@ -360,13 +360,13 @@ void compute_output_buffer_1d( if ((sX - lShiftX) == 0 && pX > lShiftX - 1) { // Dense multiply - #pragma HLS INLINE recursive + // #pragma HLS INLINE recursive if (CONFIG_T::strategy == nnet::latency) { - dense_latency( - kernel_data, res_out, weights, biases); + dense_latency(kernel_data, res_out, weights, biases); + } else if (CONFIG_T::strategy == nnet::resource && CONFIG_T::resource_implementation == nnet::unrolled && CONFIG_T::reuse_factor > 1) { + CONFIG_T::template dense_unrolled::dense_unrolled(kernel_data, res_out, weights, biases); } else { - dense_resource( - kernel_data, res_out, weights, biases); + dense_resource(kernel_data, res_out, weights, biases); } // Pack output diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_dense.h b/hls4ml/templates/vivado/nnet_utils/nnet_dense.h index c5155d8485..c278606594 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_dense.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_dense.h @@ -11,6 +11,11 @@ namespace nnet { +// Different implementations of Resource strategy; this attribute only makes a difference if strategy == Resource +// Default -> nnet_dense_resource.h +// Unrolled -> Code generation, ignoring zero DSPs and optimizing BRAM +enum resource_implementation { standard, unrolled }; + struct dense_config { // Internal data type definitions typedef float bias_t; @@ -27,7 +32,13 @@ struct dense_config { static const unsigned reuse_factor = 1; static const bool store_weights_in_bram = false; static const unsigned n_zeros = 0; - // partitioning arrays cyclically to go with roll factors? + + static const unsigned resource_implementation = standard; + template + using dense_unrolled = nnet::DenseResourceUnrolled; + + // Partitioning arrays cyclically to go with roll factors? + // Product function to use template using product = nnet::product::mult; }; @@ -39,6 +50,8 @@ void dense(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_out], #pragma HLS inline if (CONFIG_T::strategy == nnet::latency) { dense_latency(data, res, weights, biases); + } else if (CONFIG_T::strategy == nnet::resource && CONFIG_T::resource_implementation == nnet::unrolled && CONFIG_T::reuse_factor > 1) { + CONFIG_T::template dense_unrolled::dense_unrolled(data, res, weights, biases); } else { dense_resource(data, res, weights, biases); } diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_dense_stream.h b/hls4ml/templates/vivado/nnet_utils/nnet_dense_stream.h index ad3a972ef6..28bdfa7fe3 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_dense_stream.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_dense_stream.h @@ -17,6 +17,8 @@ void dense_wrapper(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_out], if (CONFIG_T::strategy == nnet::latency) { #pragma HLS PIPELINE II=CONFIG_T::reuse_factor dense_latency(data, res, weights, biases); + } else if (CONFIG_T::strategy == nnet::resource && CONFIG_T::resource_implementation == nnet::unrolled and CONFIG_T::reuse_factor > 1) { + CONFIG_T::template dense_unrolled::dense_unrolled(data, res, weights, biases); } else { dense_resource(data, res, weights, biases); } From 4fa21cbecdb4add9ea68942e3b6c8f3f4a4e9ae2 Mon Sep 17 00:00:00 2001 From: Benjamin Ramhorst Date: Fri, 26 May 2023 21:28:49 +0100 Subject: [PATCH 02/14] Code generation for unrolled Dense --- hls4ml/backends/fpga/passes/codegen.py | 197 +++++++++++++++++- .../vivado/passes/convolution_templates.py | 10 +- .../backends/vivado/passes/core_templates.py | 3 +- hls4ml/backends/vivado/vivado_backend.py | 1 + .../vivado/nnet_utils/nnet_code_gen.h | 3 + .../vivado/nnet_utils/nnet_conv2d_stream.h | 5 + .../vivado/nnet_utils/nnet_conv_stream.h | 2 +- hls4ml/writer/vivado_writer.py | 19 +- 8 files changed, 225 insertions(+), 15 deletions(-) diff --git a/hls4ml/backends/fpga/passes/codegen.py b/hls4ml/backends/fpga/passes/codegen.py index f1f1080996..2936645355 100644 --- a/hls4ml/backends/fpga/passes/codegen.py +++ b/hls4ml/backends/fpga/passes/codegen.py @@ -1,7 +1,8 @@ -from hls4ml.model.layers import Conv1D, Conv2D -from hls4ml.model.optimizer import OptimizerPass +import math +import numpy as np from hls4ml.model.types import Source - +from hls4ml.model.optimizer import OptimizerPass +from hls4ml.model.layers import Dense, Conv1D, Conv2D class GenerateConvIm2col(OptimizerPass): '''Generates tcode for im2col step of 1D/2d convolution''' @@ -49,3 +50,193 @@ def _generate_im2col_2d(self, node): ) node.set_attr('line_buffer_codegen', Source(code_str)) + +class GenerateUnrolledDenseResource(OptimizerPass): + '''Generates C++ code for unrolled Dense resource''' + + def match(self, node): + # Only apply to layers use that use Dense Matrix Multiplication + # TODO - Extend (& test) for Conv1D / Separable Conv / Depthwise Conv / Recurrent layers + layers_with_dense = (Dense, Conv2D) + + # Unrolled Dense mimicks the hardware implementation of Resource strategy -> apply after Resource optimizer + weights_transposed = node.get_attr('_weights_transposed', False) + + # RF = 1 will optimize DSPs anyway, so no need to unroll code + rf_gt_one = node.get_attr('reuse_factor') > 1 + + # User requested unrolled implementation of Dense + is_unrolled = node.get_attr('dense_resource_implementation', 'standard') == 'unrolled' + + return isinstance(node, layers_with_dense) and weights_transposed and rf_gt_one and is_unrolled + + def transform(self, model, node): + code_str = self.__generate_unrolled_dense_resource(model, node) + node.set_attr('unrolled_dense_resource_codegen', Source(code_str)) + + def __generate_unrolled_dense_resource(self, model, node): + """ + Generate a C++ function that mimics the Dense Resource implementation. Similar to Dense Resource, 3 cases are considered + + The HLS compiler produces suboptimal designs for Dense Resource when the weights processed by the same DSP are zero. + Latency strategy can optimize zero mutiplications, Resource strategy, on the other hand, cannot. + Furthermore, when all the weights in the same BRAM block are zero (e.g. due to model pruning), Vivado is unable to optimize it + With this (and additional TCL scripts) zero BRAM are optimised + + Args: + node: Layer to generate code for + Returns: + generated_code: Generated C++ function (string) + """ + + # Variable instantiation and function pragmas + generated_code = ( + "template\n" + "class dense_unrolled_{index} : public DenseResourceUnrolled {{\n" + " public:\n" + " static void dense_unrolled(\n" + " data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_out],\n" + " typename CONFIG_T::weight_t weights[CONFIG_T::n_in * CONFIG_T::n_out],\n" + " typename CONFIG_T::bias_t biases[CONFIG_T::n_out]\n" + " ) {{\n" + " #pragma HLS pipeline II=CONFIG_T::reuse_factor\n" + "\n" + " constexpr int block_factor = DIV_ROUNDUP(CONFIG_T::n_in * CONFIG_T::n_out, CONFIG_T::reuse_factor);\n" + " #pragma HLS function_instantiate variable=weights,biases\n" + " #pragma HLS ARRAY_RESHAPE variable=weights block factor=block_factor\n" + " #pragma HLS RESOURCE variable=weights core=ROM_nP_BRAM\n" + " #pragma HLS ARRAY_PARTITION variable=biases complete\n" + "\n" + " typename CONFIG_T::accum_t acc[CONFIG_T::n_out];\n" + " #pragma HLS ARRAY_PARTITION variable=acc complete\n" + "\n" + " InitAccum:\n" + " for (int i = 0; i < CONFIG_T::n_out; i++) {{\n" + " #pragma HLS UNROLL\n" + " acc[i] = (typename CONFIG_T::accum_t) biases[i];\n" + " }}\n" + "\n" + ).format(index=node.index) + + # Unrolled multiplication, according to the three cases + n_in, n_out = node.model.config.backend.get_layer_mult_size(node) + reuse_factor = node.get_attr('reuse_factor') + weights = node.weights['weight'] + if reuse_factor <= n_in: + mult_code = self.__generate_unrolled_mult_code_rf_leq_nin(n_in, n_out, reuse_factor, weights) + elif reuse_factor > n_in and reuse_factor % n_in == 0: + mult_code = self.__generate_unrolled_mult_code_rf_gt_nin_rem0(n_in, n_out, reuse_factor, weights) + else: + # This case shouldn't happen if my understanding of RF is correct + # The function fpga_backend._validate_reuse_factor() has assertion rf % n_in == 0 or rf < n_in + raise Exception('Not implemented...') + + # Write output + generated_code += mult_code + "\n" + generated_code += ( + " Result:\n" + " for (int i = 0; i < CONFIG_T::n_out; i++) {\n" + " #pragma HLS UNROLL\n" + " res[i] = cast(acc[i]);\n" + " }\n" + " }\n" + "};\n" + ) + + return generated_code + + def __generate_unrolled_mult_code_rf_leq_nin(self, n_in, n_out, reuse_factor, weights): + # Function constants + mult_factor = min(n_in, reuse_factor) + block_factor = int(math.ceil(n_in * n_out / reuse_factor)) + mult_limit = int(math.ceil(n_in * n_out / mult_factor)) + mult_scale = mult_limit // n_out + + # Zero DSPs are the DSP blocks that always have zero input + # In this case, it is the number of rows in the transposed and reshaped weight matrix + # The new shape is (parallel_mult, reuse_factor) + zeros = np.sum(~weights.data.reshape(block_factor, reuse_factor).any(1)) + + # Generate unrolled multiplications + mult_code = f"\t\t#pragma HLS ALLOCATION operation instances=mul limit={mult_limit - zeros}\n" + mult_code += "\t\tMULT: {\n" + mult_code += "\t\t\t#pragma HLS protocol\n" + + for ir in range(reuse_factor): + acc_step = 0 + out_index = 0 + w_index = ir + in_index = ir + + mult_code += f"\t\t\tM{ir}: {{\n" + for _ in range(block_factor): + if weights.data.flatten()[w_index] != 0: + mult_code += f"\t\t\t\tacc[{out_index}] += static_cast(CONFIG_T::template product::product(data[{in_index}], weights[{w_index}]));\n" + + w_index += reuse_factor + in_index += reuse_factor + if in_index >= n_in: + in_index = ir + if acc_step + 1 >= mult_scale: + acc_step = 0 + out_index += 1 + else: + acc_step += 1 + + mult_code += "\t\t\t}\n" + + mult_code += "\t\t}\n" + + return mult_code + + def __generate_unrolled_mult_code_rf_gt_nin_rem0(self, n_in, n_out, reuse_factor, weights): + # Function constants + mult_factor = min(n_in, reuse_factor) + block_factor = int(math.ceil(n_in * n_out / reuse_factor)) + mult_limit = int(math.ceil(n_in * n_out / mult_factor)) + + # Zero DSPs are the DSP blocks that always have zero input + # In this case, it is the number of rows in the transposed and reshaped weight matrix + # The new shape is (parallel_mult, reuse_factor) + zeros = np.sum(~weights.data.reshape(block_factor, reuse_factor).any(1)) + + # Generate out indices + outidx = [0] * reuse_factor + outstep = 0 + outscale = reuse_factor // n_in + for ir in range(reuse_factor): + outidx[ir] = outstep + if (ir + 1) % n_in == 0: + outstep += 1 + + # Define variables + in_index = 0 + + # Generate unrolled multiplications + mult_code = f"\t\t#pragma HLS ALLOCATION operation instances=mul limit={mult_limit - zeros}\n" + mult_code += "\t\tMULT: {\n" + mult_code += "\t\t\t#pragma HLS protocol\n" + + for ir in range(reuse_factor): + w_index = ir + out_index = outidx[ir] + + mult_code += f"\t\t\tM{ir}: {{\n" + for _ in range(block_factor): + if weights.data.flatten()[w_index] != 0: + mult_code += f"\t\t\t\tacc[{int(out_index)}] += static_cast(CONFIG_T::template product::product(data[{in_index}], weights[{w_index}]));\n" + + w_index += reuse_factor + if w_index > n_in * n_out: + break + out_index += outscale + mult_code += "\t\t\t}\n" + + in_index += 1 + if in_index >= n_in: + in_index = 0 + + mult_code += "\t\t}\n" + + return mult_code + diff --git a/hls4ml/backends/vivado/passes/convolution_templates.py b/hls4ml/backends/vivado/passes/convolution_templates.py index dde42d97fe..f3e8f969af 100644 --- a/hls4ml/backends/vivado/passes/convolution_templates.py +++ b/hls4ml/backends/vivado/passes/convolution_templates.py @@ -193,10 +193,9 @@ def format(self, node): params['fill_fn'] = 'FillConv2DBuffer' if node.get_attr('dense_resource_implementation', 'standard') == 'unrolled' and node.get_attr('strategy').lower() == 'resource' and node.get_attr('reuse_factor') > 1: - # Implemented in subsequent commits - params['unrolled_function'] = 'DenseResourceUnrolled' + params['unrolled_function'] = f'dense_unrolled_{node.index}' else: - params['unrolled_function'] = 'DenseResourceUnrolled' + params['unrolled_function'] = 'DenseResourceUnrolled' conv_config = self.template.format(**params) @@ -207,8 +206,7 @@ def format(self, node): node.get_input_variable().type.precision, node.get_weights('weight').type.precision ) if node.get_attr('dense_resource_implementation', 'standard') == 'unrolled' and node.get_attr('strategy').lower() == 'resource' and node.get_attr('reuse_factor') > 1: - # Implemented in subsequent commits - mult_params['unrolled_function'] = 'DenseResourceUnrolled' + mult_params['unrolled_function'] = f'dense_unrolled_{node.index}' else: mult_params['unrolled_function'] = 'DenseResourceUnrolled' mult_config = self.mult_template.format(**mult_params) @@ -299,7 +297,7 @@ def format(self, node): ) # TODO - Extend unrolled Dense Resource to depthwise Conv1D mult_params['unrolled_function'] = 'DenseResourceUnrolled' - + depthwise_mult_config = self.depthwise_mult_template.format(**mult_params) # Pointwise config diff --git a/hls4ml/backends/vivado/passes/core_templates.py b/hls4ml/backends/vivado/passes/core_templates.py index faabf434eb..9f5353cf93 100644 --- a/hls4ml/backends/vivado/passes/core_templates.py +++ b/hls4ml/backends/vivado/passes/core_templates.py @@ -44,8 +44,7 @@ def format(self, node): ) if node.get_attr('dense_resource_implementation', 'standard') == 'unrolled' and node.get_attr('strategy').lower() == 'resource' and node.get_attr('reuse_factor') > 1: - # Implemented in subsequent commits - params['unrolled_function'] = 'DenseResourceUnrolled' + params['unrolled_function'] = f'dense_unrolled_{node.index}' else: params['unrolled_function'] = 'DenseResourceUnrolled' diff --git a/hls4ml/backends/vivado/vivado_backend.py b/hls4ml/backends/vivado/vivado_backend.py index 1f71ddcdc4..d2f793568e 100644 --- a/hls4ml/backends/vivado/vivado_backend.py +++ b/hls4ml/backends/vivado/vivado_backend.py @@ -118,6 +118,7 @@ def _register_flows(self): 'vivado:generate_conv_streaming_instructions', 'vivado:apply_resource_strategy', 'vivado:generate_conv_im2col', + 'vivado:generate_unrolled_dense_resource' ] vivado_types_flow = register_flow('specific_types', vivado_types, requires=[init_flow], backend=self.name) diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_code_gen.h b/hls4ml/templates/vivado/nnet_utils/nnet_code_gen.h index 553044479e..9687cb7b44 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_code_gen.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_code_gen.h @@ -2,6 +2,9 @@ #define NNET_INSTR_GEN_H_ #include "nnet_helpers.h" +#include "hls_stream.h" +#include "nnet_common.h" +#include "nnet_mult.h" namespace nnet { diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_conv2d_stream.h b/hls4ml/templates/vivado/nnet_utils/nnet_conv2d_stream.h index 8a4fb6be81..803fc7cc23 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_conv2d_stream.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_conv2d_stream.h @@ -74,6 +74,11 @@ void conv_2d_buffer_cl( static ap_shift_reg line_buffer[MAX(CONFIG_T::filt_height - 1, 1)] [CONFIG_T::n_chan]; #pragma HLS ARRAY_PARTITION variable = line_buffer complete dim = 2 + + if (CONFIG_T::strategy == nnet::resource && CONFIG_T::resource_implementation == nnet::unrolled && CONFIG_T::reuse_factor > 1) { + #pragma HLS allocation instances=compute_output_buffer_1d limit=1 function + #pragma HLS allocation instances=compute_output_buffer_2d limit=1 function + } ReadInputHeight: for (unsigned i_ih = 0; i_ih < CONFIG_T::in_height; i_ih++) { diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_conv_stream.h b/hls4ml/templates/vivado/nnet_utils/nnet_conv_stream.h index 509feb5f35..bb1b97dc07 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_conv_stream.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_conv_stream.h @@ -358,7 +358,7 @@ void compute_output_buffer_1d( // Check to see if we have a full kernel if ((sX - lShiftX) == 0 && pX > lShiftX - 1) { - + // Dense multiply // #pragma HLS INLINE recursive if (CONFIG_T::strategy == nnet::latency) { diff --git a/hls4ml/writer/vivado_writer.py b/hls4ml/writer/vivado_writer.py index 2fbe3d9438..1f148452ad 100644 --- a/hls4ml/writer/vivado_writer.py +++ b/hls4ml/writer/vivado_writer.py @@ -13,6 +13,14 @@ class VivadoWriter(Writer): + def __get_max_reuse_factor(self, model): + max_rf = 0 + for layer in model.get_layers(): + rf = int(layer.get_attr('reuse_factor')) + if rf > max_rf: + max_rf = rf + return max_rf + def print_array_to_cpp(self, var, odir, write_txt_file=True): """Write a weights array to C++ header files. @@ -171,10 +179,15 @@ def write_project_cpp(self, model): newline += indent + '#pragma HLS INTERFACE ap_vld port={},{} \n'.format( ','.join(all_inputs), ','.join(all_outputs) ) - if model.config.pipeline_style.lower() == 'dataflow': - newline += indent + '#pragma HLS DATAFLOW \n' + + model_cfg = model.config.get_config_value('HLSConfig')['Model'] + if 'DenseResourceImplementation' in model_cfg and model_cfg['DenseResourceImplementation'].lower() == 'unrolled': + newline += indent + f'#pragma HLS PIPELINE ii={self.__get_max_reuse_factor(model)} \n' else: - newline += indent + '#pragma HLS PIPELINE \n' + if model.config.pipeline_style.lower() == 'dataflow': + newline += indent + '#pragma HLS DATAFLOW \n' + else: + newline += indent + '#pragma HLS PIPELINE \n' if io_type == 'io_stream': newline += indent + '#pragma HLS INTERFACE axis port={},{} \n'.format( ','.join(all_inputs), ','.join(all_outputs) From 22e815b1b63a296d4ee260b50346c2f52d9c055f Mon Sep 17 00:00:00 2001 From: Benjamin Ramhorst Date: Mon, 29 May 2023 15:51:37 +0100 Subject: [PATCH 03/14] Fix incorrect BRAM reporting (#798) --- hls4ml/templates/vivado/build_prj.tcl | 6 +++--- .../vivado/nnet_utils/nnet_dense_resource.h | 15 ++++++++++++--- hls4ml/templates/vivado/vivado_synth.tcl | 4 ++-- 3 files changed, 17 insertions(+), 8 deletions(-) diff --git a/hls4ml/templates/vivado/build_prj.tcl b/hls4ml/templates/vivado/build_prj.tcl index d34337c573..2a8326aae0 100644 --- a/hls4ml/templates/vivado/build_prj.tcl +++ b/hls4ml/templates/vivado/build_prj.tcl @@ -236,15 +236,15 @@ if {$opt(export)} { if {$opt(vsynth)} { puts "***** VIVADO SYNTHESIS *****" - if {[file exist ${project_name}_prj/solution1/syn/vhdl]} { + if {[file exist ${project_name}_prj/solution1/syn/verilog]} { set time_start [clock clicks -milliseconds] exec vivado -mode batch -source vivado_synth.tcl >@ stdout set time_end [clock clicks -milliseconds] report_time "VIVADO SYNTHESIS" $time_start $time_end } else { - puts "ERROR: Cannot find generated VHDL files. Did you run C synthesis?" + puts "ERROR: Cannot find generated Verilog files. Did you run C synthesis?" exit 1 } } -exit +exit \ No newline at end of file diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_dense_resource.h b/hls4ml/templates/vivado/nnet_utils/nnet_dense_resource.h index 88de94729b..333a0e75fe 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_dense_resource.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_dense_resource.h @@ -26,10 +26,13 @@ void dense_resource_rf_leq_nin(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T:: assert((multiplier_limit == block_factor) && "This function is correct only for RF <= N_IN"); #pragma HLS function_instantiate variable=weights,biases - //#pragma HLS RESOURCE variable=weights core=RAM_2P_BRAM Commenting out the deisgnation HLS seems to choose correctly #pragma HLS ARRAY_RESHAPE variable=weights block factor=block_factor #pragma HLS ARRAY_PARTITION variable=biases complete + if (CONFIG_T::reuse_factor > 1) { + #pragma HLS RESOURCE variable=weights core=ROM_nP_BRAM + } + typename CONFIG_T::accum_t acc[CONFIG_T::n_out]; #pragma HLS ARRAY_PARTITION variable=acc complete @@ -97,10 +100,13 @@ void dense_resource_rf_gt_nin_rem0(data_T data[CONFIG_T::n_in], res_T res[CONFIG assert((rufactor > nin && rufactor % nin == 0) && "This function is correct only for RF > N_IN && RF % N_IN == 0"); #pragma HLS function_instantiate variable=weights,biases - //#pragma HLS RESOURCE variable=weights core=RAM_2P_BRAM Commenting out the deisgnation HLS seems to choose correctly #pragma HLS ARRAY_RESHAPE variable=weights block factor=block_factor #pragma HLS ARRAY_PARTITION variable=biases complete + if (CONFIG_T::reuse_factor > 1) { + #pragma HLS RESOURCE variable=weights core=ROM_nP_BRAM + } + typename CONFIG_T::accum_t acc[CONFIG_T::n_out]; #pragma HLS ARRAY_PARTITION variable=acc complete @@ -176,10 +182,13 @@ void dense_resource_rf_gt_nin(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n assert((rufactor > nin) && "This function is correct only for RF > N_IN"); #pragma HLS function_instantiate variable=weights,biases - //#pragma HLS RESOURCE variable=weights core=RAM_2P_BRAM Commenting out the deisgnation HLS seems to choose correctly #pragma HLS ARRAY_RESHAPE variable=weights block factor=block_factor #pragma HLS ARRAY_PARTITION variable=biases complete + if (CONFIG_T::reuse_factor > 1) { + #pragma HLS RESOURCE variable=weights core=ROM_nP_BRAM + } + typename CONFIG_T::accum_t acc[CONFIG_T::n_out]; #pragma HLS ARRAY_PARTITION variable=acc complete diff --git a/hls4ml/templates/vivado/vivado_synth.tcl b/hls4ml/templates/vivado/vivado_synth.tcl index 4634b166f6..96bd21c672 100644 --- a/hls4ml/templates/vivado/vivado_synth.tcl +++ b/hls4ml/templates/vivado/vivado_synth.tcl @@ -1,6 +1,6 @@ set tcldir [file dirname [info script]] source [file join $tcldir project.tcl] -add_files ${project_name}_prj/solution1/syn/vhdl +add_files ${project_name}_prj/solution1/syn/verilog synth_design -top ${project_name} -part $part -report_utilization -file vivado_synth.rpt +report_utilization -file vivado_synth.rpt \ No newline at end of file From 9cab74a2eb314f2bcb785af6f38576055f97ef5a Mon Sep 17 00:00:00 2001 From: Benjamin Ramhorst Date: Sun, 11 Jun 2023 18:58:32 +0100 Subject: [PATCH 04/14] Add post-synthesis design optimisation to remove unused BRAM --- hls4ml/templates/vivado/vivado_synth.tcl | 1 + 1 file changed, 1 insertion(+) diff --git a/hls4ml/templates/vivado/vivado_synth.tcl b/hls4ml/templates/vivado/vivado_synth.tcl index 96bd21c672..9f4119d6bd 100644 --- a/hls4ml/templates/vivado/vivado_synth.tcl +++ b/hls4ml/templates/vivado/vivado_synth.tcl @@ -3,4 +3,5 @@ source [file join $tcldir project.tcl] add_files ${project_name}_prj/solution1/syn/verilog synth_design -top ${project_name} -part $part +opt_design -retarget -propconst -sweep -bram_power_opt -shift_register_opt report_utilization -file vivado_synth.rpt \ No newline at end of file From d79f868c4dc50d666a1528d6efd076083e226d0c Mon Sep 17 00:00:00 2001 From: Benjamin Ramhorst Date: Sun, 11 Jun 2023 20:47:16 +0100 Subject: [PATCH 05/14] Tests for unrolled Dense --- test/pytest/test_dense_unrolled.py | 63 ++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 test/pytest/test_dense_unrolled.py diff --git a/test/pytest/test_dense_unrolled.py b/test/pytest/test_dense_unrolled.py new file mode 100644 index 0000000000..69daf9cd96 --- /dev/null +++ b/test/pytest/test_dense_unrolled.py @@ -0,0 +1,63 @@ +import pytest +import numpy as np +from pathlib import Path + +from tensorflow.keras.models import Sequential +from tensorflow.keras.layers import Dense, Conv2D, Flatten + +from hls4ml.utils import config_from_keras_model +from hls4ml.converters import convert_from_keras_model + +test_root_path = Path(__file__).parent + +# Tests a wide range of RF to ensure the unrolled Dense is correct +@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream']) +@pytest.mark.parametrize('reuse_factor', [1, 2, 4, 8, 16, 32, 48, 64, 96, 192]) +def test_dense_unrolled(io_type, reuse_factor): + input_shape = (16, ) + X = np.random.rand(100, *input_shape) + + model = Sequential() + model.add(Dense(12, input_shape=input_shape, kernel_initializer='lecun_uniform', bias_initializer='lecun_uniform')) + model.compile('adam', 'mse') + keras_prediction = model.predict(X) + + config = config_from_keras_model(model, default_precision='ac_fixed<32, 16>', default_reuse_factor=reuse_factor) + config['Model']['Strategy'] = 'Resource' + config['Model']['DenseResourceImplementation'] = 'Unrolled' + + output_dir = str(test_root_path / f'hls4mlprj_dense_unrolled_{io_type}_{reuse_factor}') + hls_model = convert_from_keras_model( + model, hls_config=config, output_dir=output_dir, backend='Vivado', io_type=io_type + ) + hls_model.compile() + + hls_prediction = hls_model.predict(X) + np.testing.assert_allclose(hls_prediction, keras_prediction, rtol=0, atol=1e-2) + +# Tests a wide range RF on streaming Conv2D to ensure the unrolled Dense is correct +@pytest.mark.parametrize('io_type', ['io_stream']) +@pytest.mark.parametrize('reuse_factor', [1, 3, 9, 27, 54, 108]) +def test_dense_unrolled_streaming_conv(io_type, reuse_factor): + input_shape = (8, 8, 3) + X = np.random.rand(100, *input_shape) + + model = Sequential() + model.add(Conv2D(4, (3, 3), input_shape=input_shape, kernel_initializer='lecun_uniform', bias_initializer='lecun_uniform')) + model.add(Flatten()) + model.add(Dense(1, kernel_initializer='lecun_uniform', bias_initializer='lecun_uniform')) + model.compile('adam', 'mse') + keras_prediction = model.predict(X) + + config = config_from_keras_model(model, default_precision='ac_fixed<32, 16>', default_reuse_factor=reuse_factor) + config['Model']['Strategy'] = 'Resource' + config['Model']['DenseResourceImplementation'] = 'Unrolled' + + output_dir = str(test_root_path / f'hls4mlprj_dense_unrolled_conv2d_{io_type}_{reuse_factor}') + hls_model = convert_from_keras_model( + model, hls_config=config, output_dir=output_dir, backend='Vivado', io_type=io_type + ) + hls_model.compile() + + hls_prediction = hls_model.predict(X) + np.testing.assert_allclose(hls_prediction, keras_prediction, rtol=0, atol=1e-2) From ff86c266008f963f1a16bd1f30a3100ac83a5e0d Mon Sep 17 00:00:00 2001 From: Benjamin Ramhorst Date: Fri, 16 Jun 2023 11:27:59 +0100 Subject: [PATCH 06/14] pre-commit on hls4ml Optimization pt.2 --- hls4ml/backends/fpga/passes/codegen.py | 66 +++++++++++-------- .../vivado/passes/convolution_templates.py | 16 +++-- .../backends/vivado/passes/core_templates.py | 6 +- .../vivado/passes/recurrent_templates.py | 13 ++-- hls4ml/backends/vivado/vivado_backend.py | 24 +++---- hls4ml/model/graph.py | 8 ++- hls4ml/templates/vivado/build_prj.tcl | 2 +- .../vivado/nnet_utils/nnet_code_gen.h | 11 ++-- .../vivado/nnet_utils/nnet_conv2d_stream.h | 5 +- .../vivado/nnet_utils/nnet_conv_stream.h | 28 +++++--- .../templates/vivado/nnet_utils/nnet_dense.h | 9 +-- .../vivado/nnet_utils/nnet_dense_stream.h | 3 +- hls4ml/templates/vivado/vivado_synth.tcl | 2 +- hls4ml/writer/vivado_writer.py | 7 +- test/pytest/test_dense_unrolled.py | 28 ++++---- 15 files changed, 135 insertions(+), 93 deletions(-) diff --git a/hls4ml/backends/fpga/passes/codegen.py b/hls4ml/backends/fpga/passes/codegen.py index 2936645355..32243356c3 100644 --- a/hls4ml/backends/fpga/passes/codegen.py +++ b/hls4ml/backends/fpga/passes/codegen.py @@ -1,8 +1,11 @@ import math + import numpy as np -from hls4ml.model.types import Source + +from hls4ml.model.layers import Conv1D, Conv2D, Dense from hls4ml.model.optimizer import OptimizerPass -from hls4ml.model.layers import Dense, Conv1D, Conv2D +from hls4ml.model.types import Source + class GenerateConvIm2col(OptimizerPass): '''Generates tcode for im2col step of 1D/2d convolution''' @@ -51,6 +54,7 @@ def _generate_im2col_2d(self, node): node.set_attr('line_buffer_codegen', Source(code_str)) + class GenerateUnrolledDenseResource(OptimizerPass): '''Generates C++ code for unrolled Dense resource''' @@ -73,14 +77,15 @@ def match(self, node): def transform(self, model, node): code_str = self.__generate_unrolled_dense_resource(model, node) node.set_attr('unrolled_dense_resource_codegen', Source(code_str)) - + def __generate_unrolled_dense_resource(self, model, node): """ - Generate a C++ function that mimics the Dense Resource implementation. Similar to Dense Resource, 3 cases are considered + Generate a C++ function that mimics the Dense Resource implementation. The HLS compiler produces suboptimal designs for Dense Resource when the weights processed by the same DSP are zero. - Latency strategy can optimize zero mutiplications, Resource strategy, on the other hand, cannot. - Furthermore, when all the weights in the same BRAM block are zero (e.g. due to model pruning), Vivado is unable to optimize it + Latency strategy can optimize zero mutiplications + Resource strategy, on the other hand, cannot. + When all the weights in the same BRAM block are zero, Vivado is unable to optimize it With this (and additional TCL scripts) zero BRAM are optimised Args: @@ -97,16 +102,16 @@ def __generate_unrolled_dense_resource(self, model, node): " static void dense_unrolled(\n" " data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_out],\n" " typename CONFIG_T::weight_t weights[CONFIG_T::n_in * CONFIG_T::n_out],\n" - " typename CONFIG_T::bias_t biases[CONFIG_T::n_out]\n" + " typename CONFIG_T::bias_t biases[CONFIG_T::n_out]\n" " ) {{\n" " #pragma HLS pipeline II=CONFIG_T::reuse_factor\n" "\n" " constexpr int block_factor = DIV_ROUNDUP(CONFIG_T::n_in * CONFIG_T::n_out, CONFIG_T::reuse_factor);\n" " #pragma HLS function_instantiate variable=weights,biases\n" - " #pragma HLS ARRAY_RESHAPE variable=weights block factor=block_factor\n" + " #pragma HLS ARRAY_RESHAPE variable=weights block factor=block_factor\n" " #pragma HLS RESOURCE variable=weights core=ROM_nP_BRAM\n" " #pragma HLS ARRAY_PARTITION variable=biases complete\n" - "\n" + "\n" " typename CONFIG_T::accum_t acc[CONFIG_T::n_out];\n" " #pragma HLS ARRAY_PARTITION variable=acc complete\n" "\n" @@ -117,7 +122,7 @@ def __generate_unrolled_dense_resource(self, model, node): " }}\n" "\n" ).format(index=node.index) - + # Unrolled multiplication, according to the three cases n_in, n_out = node.model.config.backend.get_layer_mult_size(node) reuse_factor = node.get_attr('reuse_factor') @@ -128,7 +133,7 @@ def __generate_unrolled_dense_resource(self, model, node): mult_code = self.__generate_unrolled_mult_code_rf_gt_nin_rem0(n_in, n_out, reuse_factor, weights) else: # This case shouldn't happen if my understanding of RF is correct - # The function fpga_backend._validate_reuse_factor() has assertion rf % n_in == 0 or rf < n_in + # The function fpga_backend._validate_reuse_factor() has assertion rf % n_in == 0 or rf < n_in raise Exception('Not implemented...') # Write output @@ -151,8 +156,8 @@ def __generate_unrolled_mult_code_rf_leq_nin(self, n_in, n_out, reuse_factor, we block_factor = int(math.ceil(n_in * n_out / reuse_factor)) mult_limit = int(math.ceil(n_in * n_out / mult_factor)) mult_scale = mult_limit // n_out - - # Zero DSPs are the DSP blocks that always have zero input + + # Zero DSPs are the DSP blocks that always have zero input # In this case, it is the number of rows in the transposed and reshaped weight matrix # The new shape is (parallel_mult, reuse_factor) zeros = np.sum(~weights.data.reshape(block_factor, reuse_factor).any(1)) @@ -161,7 +166,7 @@ def __generate_unrolled_mult_code_rf_leq_nin(self, n_in, n_out, reuse_factor, we mult_code = f"\t\t#pragma HLS ALLOCATION operation instances=mul limit={mult_limit - zeros}\n" mult_code += "\t\tMULT: {\n" mult_code += "\t\t\t#pragma HLS protocol\n" - + for ir in range(reuse_factor): acc_step = 0 out_index = 0 @@ -171,8 +176,11 @@ def __generate_unrolled_mult_code_rf_leq_nin(self, n_in, n_out, reuse_factor, we mult_code += f"\t\t\tM{ir}: {{\n" for _ in range(block_factor): if weights.data.flatten()[w_index] != 0: - mult_code += f"\t\t\t\tacc[{out_index}] += static_cast(CONFIG_T::template product::product(data[{in_index}], weights[{w_index}]));\n" - + mult_code += f"\t\t\t\tacc[{out_index}] += \ + static_cast\ + (CONFIG_T::template product::\ + product(data[{in_index}], weights[{w_index}]));\n" + w_index += reuse_factor in_index += reuse_factor if in_index >= n_in: @@ -181,10 +189,10 @@ def __generate_unrolled_mult_code_rf_leq_nin(self, n_in, n_out, reuse_factor, we acc_step = 0 out_index += 1 else: - acc_step += 1 - + acc_step += 1 + mult_code += "\t\t\t}\n" - + mult_code += "\t\t}\n" return mult_code @@ -194,13 +202,13 @@ def __generate_unrolled_mult_code_rf_gt_nin_rem0(self, n_in, n_out, reuse_factor mult_factor = min(n_in, reuse_factor) block_factor = int(math.ceil(n_in * n_out / reuse_factor)) mult_limit = int(math.ceil(n_in * n_out / mult_factor)) - - # Zero DSPs are the DSP blocks that always have zero input + + # Zero DSPs are the DSP blocks that always have zero input # In this case, it is the number of rows in the transposed and reshaped weight matrix # The new shape is (parallel_mult, reuse_factor) zeros = np.sum(~weights.data.reshape(block_factor, reuse_factor).any(1)) - - # Generate out indices + + # Generate out indices outidx = [0] * reuse_factor outstep = 0 outscale = reuse_factor // n_in @@ -216,7 +224,7 @@ def __generate_unrolled_mult_code_rf_gt_nin_rem0(self, n_in, n_out, reuse_factor mult_code = f"\t\t#pragma HLS ALLOCATION operation instances=mul limit={mult_limit - zeros}\n" mult_code += "\t\tMULT: {\n" mult_code += "\t\t\t#pragma HLS protocol\n" - + for ir in range(reuse_factor): w_index = ir out_index = outidx[ir] @@ -224,14 +232,17 @@ def __generate_unrolled_mult_code_rf_gt_nin_rem0(self, n_in, n_out, reuse_factor mult_code += f"\t\t\tM{ir}: {{\n" for _ in range(block_factor): if weights.data.flatten()[w_index] != 0: - mult_code += f"\t\t\t\tacc[{int(out_index)}] += static_cast(CONFIG_T::template product::product(data[{in_index}], weights[{w_index}]));\n" - + mult_code += f"\t\t\t\tacc[{int(out_index)}] += \ + static_cast\ + (CONFIG_T::template product::\ + product(data[{in_index}], weights[{w_index}]));\n" + w_index += reuse_factor if w_index > n_in * n_out: break out_index += outscale mult_code += "\t\t\t}\n" - + in_index += 1 if in_index >= n_in: in_index = 0 @@ -239,4 +250,3 @@ def __generate_unrolled_mult_code_rf_gt_nin_rem0(self, n_in, n_out, reuse_factor mult_code += "\t\t}\n" return mult_code - diff --git a/hls4ml/backends/vivado/passes/convolution_templates.py b/hls4ml/backends/vivado/passes/convolution_templates.py index f3e8f969af..0c5a1da729 100644 --- a/hls4ml/backends/vivado/passes/convolution_templates.py +++ b/hls4ml/backends/vivado/passes/convolution_templates.py @@ -191,11 +191,15 @@ def format(self, node): params['fill_fn'] = f'fill_buffer_{node.index}' else: params['fill_fn'] = 'FillConv2DBuffer' - - if node.get_attr('dense_resource_implementation', 'standard') == 'unrolled' and node.get_attr('strategy').lower() == 'resource' and node.get_attr('reuse_factor') > 1: + + if ( + node.get_attr('dense_resource_implementation', 'standard') == 'unrolled' + and node.get_attr('strategy').lower() == 'resource' + and node.get_attr('reuse_factor') > 1 + ): params['unrolled_function'] = f'dense_unrolled_{node.index}' else: - params['unrolled_function'] = 'DenseResourceUnrolled' + params['unrolled_function'] = 'DenseResourceUnrolled' conv_config = self.template.format(**params) @@ -205,7 +209,11 @@ def format(self, node): mult_params['product_type'] = get_backend('vivado').product_type( node.get_input_variable().type.precision, node.get_weights('weight').type.precision ) - if node.get_attr('dense_resource_implementation', 'standard') == 'unrolled' and node.get_attr('strategy').lower() == 'resource' and node.get_attr('reuse_factor') > 1: + if ( + node.get_attr('dense_resource_implementation', 'standard') == 'unrolled' + and node.get_attr('strategy').lower() == 'resource' + and node.get_attr('reuse_factor') > 1 + ): mult_params['unrolled_function'] = f'dense_unrolled_{node.index}' else: mult_params['unrolled_function'] = 'DenseResourceUnrolled' diff --git a/hls4ml/backends/vivado/passes/core_templates.py b/hls4ml/backends/vivado/passes/core_templates.py index 9f5353cf93..5f1a25e37f 100644 --- a/hls4ml/backends/vivado/passes/core_templates.py +++ b/hls4ml/backends/vivado/passes/core_templates.py @@ -43,7 +43,11 @@ def format(self, node): node.get_input_variable().type.precision, node.get_weights('weight').type.precision ) - if node.get_attr('dense_resource_implementation', 'standard') == 'unrolled' and node.get_attr('strategy').lower() == 'resource' and node.get_attr('reuse_factor') > 1: + if ( + node.get_attr('dense_resource_implementation', 'standard') == 'unrolled' + and node.get_attr('strategy').lower() == 'resource' + and node.get_attr('reuse_factor') > 1 + ): params['unrolled_function'] = f'dense_unrolled_{node.index}' else: params['unrolled_function'] = 'DenseResourceUnrolled' diff --git a/hls4ml/backends/vivado/passes/recurrent_templates.py b/hls4ml/backends/vivado/passes/recurrent_templates.py index eb12412def..e5c3937fd3 100644 --- a/hls4ml/backends/vivado/passes/recurrent_templates.py +++ b/hls4ml/backends/vivado/passes/recurrent_templates.py @@ -140,10 +140,10 @@ def format(self, node): mult_params1['index'] = str(node.index) + '_1' mult_params1['nzeros'] = node.get_weights('weight').nzeros mult_params1['nonzeros'] = node.get_weights('weight').nonzeros - + # TODO - Extend unrolled Dense Resource to recurrent kernels - mult_params1['unrolled_function'] = 'DenseResourceUnrolled' - + mult_params1['unrolled_function'] = 'DenseResourceUnrolled' + if node.get_attr('return_sequences'): mult_params2['n_in'] = node.get_output_variable().dim_names[1] mult_params2['n_out'] = node.get_output_variable().dim_names[1] + ' * %i' % n_recr_mult @@ -157,15 +157,16 @@ def format(self, node): mult_params2['index'] = str(node.index) + '_2' mult_params2['nzeros'] = node.get_weights('recurrent_weight').nzeros mult_params2['nonzeros'] = node.get_weights('recurrent_weight').nonzeros - + # TODO - Extend unrolled Dense Resource to recurrent kernels - mult_params2['unrolled_function'] = 'DenseResourceUnrolled' - + mult_params2['unrolled_function'] = 'DenseResourceUnrolled' + mult_config1 = self.mult1_template.format(**mult_params1) mult_config2 = self.mult2_template.format(**mult_params2) return mult_config1 + '\n' + mult_config2 + '\n' + recr_act_config + '\n' + act_config + '\n' + recr_config + class RecurrentFunctionTemplate(FunctionCallTemplate): def __init__(self): super().__init__((LSTM, GRU), include_header=recr_include_list) diff --git a/hls4ml/backends/vivado/vivado_backend.py b/hls4ml/backends/vivado/vivado_backend.py index d2f793568e..3300f31dc9 100644 --- a/hls4ml/backends/vivado/vivado_backend.py +++ b/hls4ml/backends/vivado/vivado_backend.py @@ -73,14 +73,16 @@ def _register_layer_attributes(self): # attrs.append(ConfigurableAttribute('conv_implementation', value_type=str, default='LineBuffer')) attrs.append(ChoiceAttribute('conv_implementation', choices=['LineBuffer', 'Encoded'], default='LineBuffer')) self.attribute_map[layer] = attrs - + # Add implementation of Dense Resource for all layers that use Dense for matrix mult - # Handle different implementations of Resource strategy; this attribute only makes a difference if strategy == Resource + # Handle different implementations of Resource strategy; only makes a difference if strategy == Resource # Standard -> nnet_dense_resource.h # Unrolled -> Code generation, ignoring zero DSPs and optimizing zero-filled BRAM blocks for layer in [Dense] + cnn_layers + rnn_layers: attrs = self.attribute_map.get(layer, []) - attrs.append(ChoiceAttribute('dense_resource_implementation', choices=['standard', 'unrolled'], default='standard')) + attrs.append( + ChoiceAttribute('dense_resource_implementation', choices=['standard', 'unrolled'], default='standard') + ) self.attribute_map[layer] = attrs def _register_flows(self): @@ -118,7 +120,7 @@ def _register_flows(self): 'vivado:generate_conv_streaming_instructions', 'vivado:apply_resource_strategy', 'vivado:generate_conv_im2col', - 'vivado:generate_unrolled_dense_resource' + 'vivado:generate_unrolled_dense_resource', ] vivado_types_flow = register_flow('specific_types', vivado_types, requires=[init_flow], backend=self.name) @@ -280,7 +282,7 @@ def init_conv1d(self, layer): layer.set_attr('n_partitions', out_width // closest_pf) layer.set_attr('implementation', layer.model.config.get_conv_implementation(layer).lower()) - + # TODO - Extend unrolled Dense Resource to Conv1D kernels layer.set_attr('dense_resource_implementation', 'standard') @@ -299,10 +301,10 @@ def init_sepconv1d(self, layer): 'n_partitions', 1 ) # TODO Once we have SeparableConv implementation for io_parallel this should be set properly layer.set_attr('implementation', layer.model.config.get_conv_implementation(layer).lower()) - + # TODO - Extend unrolled Dense Resource to separable Conv1D layer.set_attr('dense_resource_implementation', 'standard') - + @layer_optimizer(Conv2D) def init_conv2d(self, layer): if len(layer.weights['weight'].data.shape) == 2: # This can happen if we assign weights of Dense layer to 1x1 Conv2D @@ -329,7 +331,7 @@ def init_conv2d(self, layer): ) else: closest_pf = chosen_pf - + layer.set_attr('n_partitions', out_height * out_width // closest_pf) layer.set_attr('implementation', layer.model.config.get_conv_implementation(layer).lower()) layer.set_attr('dense_resource_implementation', layer.model.config.get_dense_resource_implementation(layer).lower()) @@ -366,7 +368,7 @@ def init_depconv2d(self, layer): 'n_partitions', 1 ) # TODO Once we have SeparableConv implementation for io_parallel this should be set properly layer.set_attr('implementation', layer.model.config.get_conv_implementation(layer).lower()) - + # TODO - Extend unrolled Dense Resource to depthwise Conv2D layer.set_attr('dense_resource_implementation', 'standard') @@ -427,7 +429,7 @@ def init_lstm(self, layer): layer.set_attr('strategy', 'latency') layer.set_attr('index_t', NamedType(f'layer{layer.index}_index', IntegerPrecisionType(width=1, signed=False))) - + # TODO - Extend unrolled Dense Resource to recurrent kernels layer.set_attr('dense_resource_implementation', 'standard') @@ -445,7 +447,7 @@ def init_gru(self, layer): layer.set_attr('strategy', 'latency') layer.set_attr('index_t', NamedType(f'layer{layer.index}_index', IntegerPrecisionType(width=1, signed=False))) - + # TODO - Extend unrolled Dense Resource to recurrent kernels layer.set_attr('dense_resource_implementation', 'standard') diff --git a/hls4ml/model/graph.py b/hls4ml/model/graph.py index 57fb31841d..55ec06e18a 100644 --- a/hls4ml/model/graph.py +++ b/hls4ml/model/graph.py @@ -172,7 +172,9 @@ def get_conv_implementation(self, layer): def get_dense_resource_implementation(self, layer): dense_resource_implementation = self.layer_name_dense_resource_implementation.get(layer.name.lower()) if dense_resource_implementation is None: - dense_resource_implementation = self.layer_type_dense_resource_implementation.get(layer.__class__.__name__.lower()) + dense_resource_implementation = self.layer_type_dense_resource_implementation.get( + layer.__class__.__name__.lower() + ) if dense_resource_implementation is None: dense_resource_implementation = self.model_dense_resource_implementation @@ -255,7 +257,7 @@ def _parse_hls_config(self): conv_implementation = layer_cfg.get('ConvImplementation') if conv_implementation is not None: self.layer_type_conv_implementation[layer_type.lower()] = conv_implementation - + dense_resource_implementation = layer_cfg.get('DenseResourceImplementation') if conv_implementation is not None: self.layer_type_dense_resource_implementation[layer_type.lower()] = dense_resource_implementation @@ -289,7 +291,7 @@ def _parse_hls_config(self): conv_implementation = layer_cfg.get('ConvImplementation') if conv_implementation is not None: self.layer_name_conv_implementation[layer_name.lower()] = conv_implementation - + dense_resource_implementation = layer_cfg.get('DenseResourceImplementation') if conv_implementation is not None: self.layer_name_dense_resource_implementation[layer_name.lower()] = dense_resource_implementation diff --git a/hls4ml/templates/vivado/build_prj.tcl b/hls4ml/templates/vivado/build_prj.tcl index 2a8326aae0..b6419773cb 100644 --- a/hls4ml/templates/vivado/build_prj.tcl +++ b/hls4ml/templates/vivado/build_prj.tcl @@ -247,4 +247,4 @@ if {$opt(vsynth)} { } } -exit \ No newline at end of file +exit diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_code_gen.h b/hls4ml/templates/vivado/nnet_utils/nnet_code_gen.h index 9687cb7b44..caab69663e 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_code_gen.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_code_gen.h @@ -2,6 +2,7 @@ #define NNET_INSTR_GEN_H_ #include "nnet_helpers.h" + #include "hls_stream.h" #include "nnet_common.h" #include "nnet_mult.h" @@ -29,13 +30,11 @@ template class FillConv2DBuffer { template class DenseResourceUnrolled { public: - static void dense_unrolled( - data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_out], - typename CONFIG_T::weight_t weights[CONFIG_T::n_in * CONFIG_T::n_out], - typename CONFIG_T::bias_t biases[CONFIG_T::n_out] - ) { + static void dense_unrolled(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_out], + typename CONFIG_T::weight_t weights[CONFIG_T::n_in * CONFIG_T::n_out], + typename CONFIG_T::bias_t biases[CONFIG_T::n_out]) { // To be implemented in subclasses - } + } }; // hls4ml insert code diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_conv2d_stream.h b/hls4ml/templates/vivado/nnet_utils/nnet_conv2d_stream.h index 803fc7cc23..08d06501c3 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_conv2d_stream.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_conv2d_stream.h @@ -74,8 +74,9 @@ void conv_2d_buffer_cl( static ap_shift_reg line_buffer[MAX(CONFIG_T::filt_height - 1, 1)] [CONFIG_T::n_chan]; #pragma HLS ARRAY_PARTITION variable = line_buffer complete dim = 2 - - if (CONFIG_T::strategy == nnet::resource && CONFIG_T::resource_implementation == nnet::unrolled && CONFIG_T::reuse_factor > 1) { + + if (CONFIG_T::strategy == nnet::resource && CONFIG_T::resource_implementation == nnet::unrolled && + CONFIG_T::reuse_factor > 1) { #pragma HLS allocation instances=compute_output_buffer_1d limit=1 function #pragma HLS allocation instances=compute_output_buffer_2d limit=1 function } diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_conv_stream.h b/hls4ml/templates/vivado/nnet_utils/nnet_conv_stream.h index bb1b97dc07..d95d528e46 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_conv_stream.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_conv_stream.h @@ -291,11 +291,16 @@ void compute_output_buffer_2d( // Dense multiply // #pragma HLS INLINE recursive if (CONFIG_T::strategy == nnet::latency) { - dense_latency(kernel_data, res_out, weights, biases); - } else if (CONFIG_T::strategy == nnet::resource && CONFIG_T::resource_implementation == nnet::unrolled && CONFIG_T::reuse_factor > 1) { - CONFIG_T::template dense_unrolled::dense_unrolled(kernel_data, res_out, weights, biases); + dense_latency( + kernel_data, res_out, weights, biases); + } else if (CONFIG_T::strategy == nnet::resource && CONFIG_T::resource_implementation == nnet::unrolled && + CONFIG_T::reuse_factor > 1) { + CONFIG_T::template dense_unrolled::dense_unrolled(kernel_data, res_out, weights, + biases); } else { - dense_resource(kernel_data, res_out, weights, biases); + dense_resource( + kernel_data, res_out, weights, biases); } // Pack output @@ -358,15 +363,20 @@ void compute_output_buffer_1d( // Check to see if we have a full kernel if ((sX - lShiftX) == 0 && pX > lShiftX - 1) { - + // Dense multiply // #pragma HLS INLINE recursive if (CONFIG_T::strategy == nnet::latency) { - dense_latency(kernel_data, res_out, weights, biases); - } else if (CONFIG_T::strategy == nnet::resource && CONFIG_T::resource_implementation == nnet::unrolled && CONFIG_T::reuse_factor > 1) { - CONFIG_T::template dense_unrolled::dense_unrolled(kernel_data, res_out, weights, biases); + dense_latency( + kernel_data, res_out, weights, biases); + } else if (CONFIG_T::strategy == nnet::resource && CONFIG_T::resource_implementation == nnet::unrolled && + CONFIG_T::reuse_factor > 1) { + CONFIG_T::template dense_unrolled::dense_unrolled(kernel_data, res_out, weights, + biases); } else { - dense_resource(kernel_data, res_out, weights, biases); + dense_resource( + kernel_data, res_out, weights, biases); } // Pack output diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_dense.h b/hls4ml/templates/vivado/nnet_utils/nnet_dense.h index c278606594..2037daf0b9 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_dense.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_dense.h @@ -34,11 +34,11 @@ struct dense_config { static const unsigned n_zeros = 0; static const unsigned resource_implementation = standard; - template + template using dense_unrolled = nnet::DenseResourceUnrolled; - + // Partitioning arrays cyclically to go with roll factors? - + // Product function to use template using product = nnet::product::mult; }; @@ -50,7 +50,8 @@ void dense(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_out], #pragma HLS inline if (CONFIG_T::strategy == nnet::latency) { dense_latency(data, res, weights, biases); - } else if (CONFIG_T::strategy == nnet::resource && CONFIG_T::resource_implementation == nnet::unrolled && CONFIG_T::reuse_factor > 1) { + } else if (CONFIG_T::strategy == nnet::resource && CONFIG_T::resource_implementation == nnet::unrolled && + CONFIG_T::reuse_factor > 1) { CONFIG_T::template dense_unrolled::dense_unrolled(data, res, weights, biases); } else { dense_resource(data, res, weights, biases); diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_dense_stream.h b/hls4ml/templates/vivado/nnet_utils/nnet_dense_stream.h index 28bdfa7fe3..db3039fc33 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_dense_stream.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_dense_stream.h @@ -17,7 +17,8 @@ void dense_wrapper(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_out], if (CONFIG_T::strategy == nnet::latency) { #pragma HLS PIPELINE II=CONFIG_T::reuse_factor dense_latency(data, res, weights, biases); - } else if (CONFIG_T::strategy == nnet::resource && CONFIG_T::resource_implementation == nnet::unrolled and CONFIG_T::reuse_factor > 1) { + } else if (CONFIG_T::strategy == nnet::resource && CONFIG_T::resource_implementation == nnet::unrolled && + CONFIG_T::reuse_factor > 1) { CONFIG_T::template dense_unrolled::dense_unrolled(data, res, weights, biases); } else { dense_resource(data, res, weights, biases); diff --git a/hls4ml/templates/vivado/vivado_synth.tcl b/hls4ml/templates/vivado/vivado_synth.tcl index 9f4119d6bd..342b1e6740 100644 --- a/hls4ml/templates/vivado/vivado_synth.tcl +++ b/hls4ml/templates/vivado/vivado_synth.tcl @@ -4,4 +4,4 @@ source [file join $tcldir project.tcl] add_files ${project_name}_prj/solution1/syn/verilog synth_design -top ${project_name} -part $part opt_design -retarget -propconst -sweep -bram_power_opt -shift_register_opt -report_utilization -file vivado_synth.rpt \ No newline at end of file +report_utilization -file vivado_synth.rpt diff --git a/hls4ml/writer/vivado_writer.py b/hls4ml/writer/vivado_writer.py index 1f148452ad..6509fb5e3d 100644 --- a/hls4ml/writer/vivado_writer.py +++ b/hls4ml/writer/vivado_writer.py @@ -20,7 +20,7 @@ def __get_max_reuse_factor(self, model): if rf > max_rf: max_rf = rf return max_rf - + def print_array_to_cpp(self, var, odir, write_txt_file=True): """Write a weights array to C++ header files. @@ -181,7 +181,10 @@ def write_project_cpp(self, model): ) model_cfg = model.config.get_config_value('HLSConfig')['Model'] - if 'DenseResourceImplementation' in model_cfg and model_cfg['DenseResourceImplementation'].lower() == 'unrolled': + if ( + 'DenseResourceImplementation' in model_cfg + and model_cfg['DenseResourceImplementation'].lower() == 'unrolled' + ): newline += indent + f'#pragma HLS PIPELINE ii={self.__get_max_reuse_factor(model)} \n' else: if model.config.pipeline_style.lower() == 'dataflow': diff --git a/test/pytest/test_dense_unrolled.py b/test/pytest/test_dense_unrolled.py index 69daf9cd96..a3318049be 100644 --- a/test/pytest/test_dense_unrolled.py +++ b/test/pytest/test_dense_unrolled.py @@ -1,20 +1,21 @@ -import pytest -import numpy as np from pathlib import Path +import numpy as np +import pytest +from tensorflow.keras.layers import Conv2D, Dense, Flatten from tensorflow.keras.models import Sequential -from tensorflow.keras.layers import Dense, Conv2D, Flatten -from hls4ml.utils import config_from_keras_model from hls4ml.converters import convert_from_keras_model +from hls4ml.utils import config_from_keras_model test_root_path = Path(__file__).parent + # Tests a wide range of RF to ensure the unrolled Dense is correct @pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream']) @pytest.mark.parametrize('reuse_factor', [1, 2, 4, 8, 16, 32, 48, 64, 96, 192]) def test_dense_unrolled(io_type, reuse_factor): - input_shape = (16, ) + input_shape = (16,) X = np.random.rand(100, *input_shape) model = Sequential() @@ -25,16 +26,15 @@ def test_dense_unrolled(io_type, reuse_factor): config = config_from_keras_model(model, default_precision='ac_fixed<32, 16>', default_reuse_factor=reuse_factor) config['Model']['Strategy'] = 'Resource' config['Model']['DenseResourceImplementation'] = 'Unrolled' - + output_dir = str(test_root_path / f'hls4mlprj_dense_unrolled_{io_type}_{reuse_factor}') - hls_model = convert_from_keras_model( - model, hls_config=config, output_dir=output_dir, backend='Vivado', io_type=io_type - ) + hls_model = convert_from_keras_model(model, hls_config=config, output_dir=output_dir, backend='Vivado', io_type=io_type) hls_model.compile() hls_prediction = hls_model.predict(X) np.testing.assert_allclose(hls_prediction, keras_prediction, rtol=0, atol=1e-2) + # Tests a wide range RF on streaming Conv2D to ensure the unrolled Dense is correct @pytest.mark.parametrize('io_type', ['io_stream']) @pytest.mark.parametrize('reuse_factor', [1, 3, 9, 27, 54, 108]) @@ -43,7 +43,9 @@ def test_dense_unrolled_streaming_conv(io_type, reuse_factor): X = np.random.rand(100, *input_shape) model = Sequential() - model.add(Conv2D(4, (3, 3), input_shape=input_shape, kernel_initializer='lecun_uniform', bias_initializer='lecun_uniform')) + model.add( + Conv2D(4, (3, 3), input_shape=input_shape, kernel_initializer='lecun_uniform', bias_initializer='lecun_uniform') + ) model.add(Flatten()) model.add(Dense(1, kernel_initializer='lecun_uniform', bias_initializer='lecun_uniform')) model.compile('adam', 'mse') @@ -52,11 +54,9 @@ def test_dense_unrolled_streaming_conv(io_type, reuse_factor): config = config_from_keras_model(model, default_precision='ac_fixed<32, 16>', default_reuse_factor=reuse_factor) config['Model']['Strategy'] = 'Resource' config['Model']['DenseResourceImplementation'] = 'Unrolled' - + output_dir = str(test_root_path / f'hls4mlprj_dense_unrolled_conv2d_{io_type}_{reuse_factor}') - hls_model = convert_from_keras_model( - model, hls_config=config, output_dir=output_dir, backend='Vivado', io_type=io_type - ) + hls_model = convert_from_keras_model(model, hls_config=config, output_dir=output_dir, backend='Vivado', io_type=io_type) hls_model.compile() hls_prediction = hls_model.predict(X) From 0f0adc4908e22b23d9a5bd8953d528b527d9523b Mon Sep 17 00:00:00 2001 From: Benjamin Ramhorst Date: Fri, 16 Jun 2023 12:13:13 +0100 Subject: [PATCH 07/14] Fix failing PyTests --- hls4ml/backends/fpga/passes/codegen.py | 2 +- .../backends/vivado/passes/convolution_templates.py | 13 +++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/hls4ml/backends/fpga/passes/codegen.py b/hls4ml/backends/fpga/passes/codegen.py index 32243356c3..09e600d421 100644 --- a/hls4ml/backends/fpga/passes/codegen.py +++ b/hls4ml/backends/fpga/passes/codegen.py @@ -67,7 +67,7 @@ def match(self, node): weights_transposed = node.get_attr('_weights_transposed', False) # RF = 1 will optimize DSPs anyway, so no need to unroll code - rf_gt_one = node.get_attr('reuse_factor') > 1 + rf_gt_one = node.get_attr('reuse_factor', 1) > 1 # User requested unrolled implementation of Dense is_unrolled = node.get_attr('dense_resource_implementation', 'standard') == 'unrolled' diff --git a/hls4ml/backends/vivado/passes/convolution_templates.py b/hls4ml/backends/vivado/passes/convolution_templates.py index 0c5a1da729..2b9fe13b7a 100644 --- a/hls4ml/backends/vivado/passes/convolution_templates.py +++ b/hls4ml/backends/vivado/passes/convolution_templates.py @@ -39,6 +39,9 @@ static const bool store_weights_in_bram = false; static const unsigned strategy = nnet::{strategy}; static const nnet::conv_implementation implementation = nnet::conv_implementation::{implementation}; + static const unsigned resource_implementation = nnet::{dense_resource_implementation}; + template + using dense_unrolled = nnet::{unrolled_function}; static const unsigned min_width = {min_width}; static const ap_uint pixels[min_width]; static const unsigned n_partitions = {n_partitions}; @@ -80,6 +83,8 @@ def format(self, node): params['fill_fn'] = f'fill_buffer_{node.index}' else: params['fill_fn'] = 'FillConv1DBuffer' + # TODO - Extend unrolled Dense Resource to Conv1D + params['unrolled_function'] = 'DenseResourceUnrolled' conv_config = self.template.format(**params) @@ -292,6 +297,8 @@ def format(self, node): params['scale_index_type'] = 'scale_index_regular' params['config_t'] = f'config{node.index}_depthwise_mult' + # TODO - Extend unrolled Dense Resource + params['unrolled_function'] = 'DenseResourceUnrolled' depthwise_config = self.depthwise_template.format(**params) # Depthwise mult config @@ -334,6 +341,8 @@ def format(self, node): params['scale_index_type'] = 'scale_index_regular' params['config_t'] = f'config{node.index}_pointwise_mult' + # TODO - Extend unrolled Dense Resource + params['unrolled_function'] = 'DenseResourceUnrolled' pointwise_config = self.pointwise_template.format(**params) # Pointwise mult config @@ -419,6 +428,8 @@ def format(self, node): params['scale_index_width_type'] = 'scale_index_regular' params['config_t'] = f'config{node.index}_depthwise_mult' + # TODO - Extend unrolled Dense Resource + params['unrolled_function'] = 'DenseResourceUnrolled' depthwise_config = self.depthwise_template.format(**params) # Depthwise mult config @@ -464,6 +475,8 @@ def format(self, node): else: params['scale_index_width_type'] = 'scale_index_regular' params['config_t'] = f'config{node.index}_pointwise_mult' + # TODO - Extend unrolled Dense Resource + params['unrolled_function'] = 'DenseResourceUnrolled' pointwise_config = self.pointwise_template.format(**params) # Pointwise mult config From 0ea246ce2b51de1c82ecf93cb8d5bc3a37f51a6c Mon Sep 17 00:00:00 2001 From: Vladimir Loncar Date: Mon, 15 Jul 2024 18:20:13 +0200 Subject: [PATCH 08/14] Refactor matrix-multiplication kernel as a function pointer --- hls4ml/backends/fpga/passes/codegen.py | 167 ++++++++++-------- .../vivado/passes/convolution_templates.py | 56 +++--- .../backends/vivado/passes/core_templates.py | 23 +-- .../vivado/passes/recurrent_templates.py | 52 ++++-- .../vivado/passes/resource_strategy.py | 2 +- hls4ml/backends/vivado/vivado_backend.py | 34 +++- .../vivado/nnet_utils/nnet_code_gen.h | 29 +-- .../templates/vivado/nnet_utils/nnet_common.h | 2 +- .../vivado/nnet_utils/nnet_conv1d_stream.h | 4 + .../vivado/nnet_utils/nnet_conv2d_stream.h | 3 +- .../vivado/nnet_utils/nnet_conv_stream.h | 37 +--- .../templates/vivado/nnet_utils/nnet_dense.h | 51 ++++-- .../vivado/nnet_utils/nnet_dense_stream.h | 7 +- .../vivado/nnet_utils/nnet_function_stubs.h | 42 +++++ test/pytest/test_dense_unrolled.py | 92 ++++++++-- 15 files changed, 362 insertions(+), 239 deletions(-) create mode 100644 hls4ml/templates/vivado/nnet_utils/nnet_function_stubs.h diff --git a/hls4ml/backends/fpga/passes/codegen.py b/hls4ml/backends/fpga/passes/codegen.py index 09e600d421..3667680ed5 100644 --- a/hls4ml/backends/fpga/passes/codegen.py +++ b/hls4ml/backends/fpga/passes/codegen.py @@ -2,7 +2,7 @@ import numpy as np -from hls4ml.model.layers import Conv1D, Conv2D, Dense +from hls4ml.model.layers import GRU, LSTM, Conv1D, Conv2D, Dense from hls4ml.model.optimizer import OptimizerPass from hls4ml.model.types import Source @@ -60,8 +60,8 @@ class GenerateUnrolledDenseResource(OptimizerPass): def match(self, node): # Only apply to layers use that use Dense Matrix Multiplication - # TODO - Extend (& test) for Conv1D / Separable Conv / Depthwise Conv / Recurrent layers - layers_with_dense = (Dense, Conv2D) + # TODO - Extend (& test) for Separable Conv / Depthwise Conv / Recurrent layers + layers_with_dense = (Dense, Conv1D, Conv2D, LSTM, GRU) # Unrolled Dense mimicks the hardware implementation of Resource strategy -> apply after Resource optimizer weights_transposed = node.get_attr('_weights_transposed', False) @@ -70,23 +70,43 @@ def match(self, node): rf_gt_one = node.get_attr('reuse_factor', 1) > 1 # User requested unrolled implementation of Dense - is_unrolled = node.get_attr('dense_resource_implementation', 'standard') == 'unrolled' + is_unrolled = node.get_attr('strategy', 'latency') == 'unrolled' return isinstance(node, layers_with_dense) and weights_transposed and rf_gt_one and is_unrolled def transform(self, model, node): - code_str = self.__generate_unrolled_dense_resource(model, node) - node.set_attr('unrolled_dense_resource_codegen', Source(code_str)) + if isinstance(node, (LSTM, GRU)): + n_in, n_out, n_in_recr, n_out_recr = node.model.config.backend.get_layer_mult_size(node) - def __generate_unrolled_dense_resource(self, model, node): + reuse_factor = node.get_attr('reuse_factor') + weights = node.weights['weight'] + code_str = self._generate_unrolled_function(n_in, n_out, reuse_factor, weights, str(node.index) + '_1') + node.set_attr('unrolled_dense_resource_codegen_1', Source(code_str)) + + recr_reuse_factor = node.get_attr('recurrent_reuse_factor') + recr_weights = node.weights['recurrent_weight'] + code_str = self._generate_unrolled_function( + n_in_recr, n_out_recr, recr_reuse_factor, recr_weights, str(node.index) + '_2' + ) + node.set_attr('unrolled_dense_resource_codegen_2', Source(code_str)) + + else: + n_in, n_out = node.model.config.backend.get_layer_mult_size(node) + reuse_factor = node.get_attr('reuse_factor') + weights = node.weights['weight'] + + code_str = self._generate_unrolled_function(n_in, n_out, reuse_factor, weights, node.index) + node.set_attr('unrolled_dense_resource_codegen', Source(code_str)) + + def _generate_unrolled_function(self, n_in, n_out, reuse_factor, weights, function_suffix): """ Generate a C++ function that mimics the Dense Resource implementation. The HLS compiler produces suboptimal designs for Dense Resource when the weights processed by the same DSP are zero. - Latency strategy can optimize zero mutiplications + Latency strategy can optimize zero multiplications Resource strategy, on the other hand, cannot. When all the weights in the same BRAM block are zero, Vivado is unable to optimize it - With this (and additional TCL scripts) zero BRAM are optimised + With this (and additional TCL scripts) zero BRAM are optimized Args: node: Layer to generate code for @@ -96,61 +116,58 @@ def __generate_unrolled_dense_resource(self, model, node): # Variable instantiation and function pragmas generated_code = ( - "template\n" - "class dense_unrolled_{index} : public DenseResourceUnrolled {{\n" - " public:\n" - " static void dense_unrolled(\n" - " data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_out],\n" - " typename CONFIG_T::weight_t weights[CONFIG_T::n_in * CONFIG_T::n_out],\n" - " typename CONFIG_T::bias_t biases[CONFIG_T::n_out]\n" - " ) {{\n" - " #pragma HLS pipeline II=CONFIG_T::reuse_factor\n" - "\n" - " constexpr int block_factor = DIV_ROUNDUP(CONFIG_T::n_in * CONFIG_T::n_out, CONFIG_T::reuse_factor);\n" - " #pragma HLS function_instantiate variable=weights,biases\n" - " #pragma HLS ARRAY_RESHAPE variable=weights block factor=block_factor\n" - " #pragma HLS RESOURCE variable=weights core=ROM_nP_BRAM\n" - " #pragma HLS ARRAY_PARTITION variable=biases complete\n" - "\n" - " typename CONFIG_T::accum_t acc[CONFIG_T::n_out];\n" - " #pragma HLS ARRAY_PARTITION variable=acc complete\n" - "\n" - " InitAccum:\n" - " for (int i = 0; i < CONFIG_T::n_out; i++) {{\n" - " #pragma HLS UNROLL\n" - " acc[i] = (typename CONFIG_T::accum_t) biases[i];\n" - " }}\n" - "\n" - ).format(index=node.index) + 'template\n' + 'class dense_unrolled_{suffix} : public DenseKernel {{\n' + ' public:\n' + ' static void dense(\n' + ' data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_out],\n' + ' typename CONFIG_T::weight_t weights[CONFIG_T::n_in * CONFIG_T::n_out],\n' + ' typename CONFIG_T::bias_t biases[CONFIG_T::n_out]\n' + ' ) {{\n' + ' #pragma HLS pipeline II=CONFIG_T::reuse_factor\n' + '\n' + ' constexpr int block_factor = DIV_ROUNDUP(CONFIG_T::n_in * CONFIG_T::n_out, CONFIG_T::reuse_factor);\n' + ' #pragma HLS function_instantiate variable=weights,biases\n' + ' #pragma HLS ARRAY_RESHAPE variable=weights block factor=block_factor\n' + ' #pragma HLS RESOURCE variable=weights core=ROM_nP_BRAM\n' + ' #pragma HLS ARRAY_PARTITION variable=biases complete\n' + '\n' + ' typename CONFIG_T::accum_t acc[CONFIG_T::n_out];\n' + ' #pragma HLS ARRAY_PARTITION variable=acc complete\n' + '\n' + ' InitAccum:\n' + ' for (int i = 0; i < CONFIG_T::n_out; i++) {{\n' + ' #pragma HLS UNROLL\n' + ' acc[i] = (typename CONFIG_T::accum_t) biases[i];\n' + ' }}\n' + '\n' + ).format(suffix=function_suffix) # Unrolled multiplication, according to the three cases - n_in, n_out = node.model.config.backend.get_layer_mult_size(node) - reuse_factor = node.get_attr('reuse_factor') - weights = node.weights['weight'] if reuse_factor <= n_in: - mult_code = self.__generate_unrolled_mult_code_rf_leq_nin(n_in, n_out, reuse_factor, weights) + mult_code = self._generate_unrolled_mult_code_rf_leq_nin(n_in, n_out, reuse_factor, weights) elif reuse_factor > n_in and reuse_factor % n_in == 0: - mult_code = self.__generate_unrolled_mult_code_rf_gt_nin_rem0(n_in, n_out, reuse_factor, weights) + mult_code = self._generate_unrolled_mult_code_rf_gt_nin_rem0(n_in, n_out, reuse_factor, weights) else: # This case shouldn't happen if my understanding of RF is correct # The function fpga_backend._validate_reuse_factor() has assertion rf % n_in == 0 or rf < n_in raise Exception('Not implemented...') # Write output - generated_code += mult_code + "\n" + generated_code += mult_code + '\n' generated_code += ( - " Result:\n" - " for (int i = 0; i < CONFIG_T::n_out; i++) {\n" - " #pragma HLS UNROLL\n" - " res[i] = cast(acc[i]);\n" - " }\n" - " }\n" - "};\n" + ' Result:\n' + ' for (int i = 0; i < CONFIG_T::n_out; i++) {\n' + ' #pragma HLS UNROLL\n' + ' res[i] = cast(acc[i]);\n' + ' }\n' + ' }\n' + '};\n' ) return generated_code - def __generate_unrolled_mult_code_rf_leq_nin(self, n_in, n_out, reuse_factor, weights): + def _generate_unrolled_mult_code_rf_leq_nin(self, n_in, n_out, reuse_factor, weights): # Function constants mult_factor = min(n_in, reuse_factor) block_factor = int(math.ceil(n_in * n_out / reuse_factor)) @@ -162,10 +179,13 @@ def __generate_unrolled_mult_code_rf_leq_nin(self, n_in, n_out, reuse_factor, we # The new shape is (parallel_mult, reuse_factor) zeros = np.sum(~weights.data.reshape(block_factor, reuse_factor).any(1)) + # Used to pad the code to make it human-readable + indent = ' ' + # Generate unrolled multiplications - mult_code = f"\t\t#pragma HLS ALLOCATION operation instances=mul limit={mult_limit - zeros}\n" - mult_code += "\t\tMULT: {\n" - mult_code += "\t\t\t#pragma HLS protocol\n" + mult_code = f'{indent*2}#pragma HLS ALLOCATION operation instances=mul limit={mult_limit - zeros}\n' + mult_code += f'{indent*2}MULT: {{\n' + mult_code += f'{indent*3}#pragma HLS protocol\n' for ir in range(reuse_factor): acc_step = 0 @@ -173,13 +193,15 @@ def __generate_unrolled_mult_code_rf_leq_nin(self, n_in, n_out, reuse_factor, we w_index = ir in_index = ir - mult_code += f"\t\t\tM{ir}: {{\n" + mult_code += f'{indent*3}M{ir}: {{\n' for _ in range(block_factor): if weights.data.flatten()[w_index] != 0: - mult_code += f"\t\t\t\tacc[{out_index}] += \ - static_cast\ - (CONFIG_T::template product::\ - product(data[{in_index}], weights[{w_index}]));\n" + mult_code += ( + f'{indent*4}acc[{out_index}] += ' + 'static_cast' + '(CONFIG_T::template product::' + f'product(data[{in_index}], weights[{w_index}]));\n' + ) w_index += reuse_factor in_index += reuse_factor @@ -191,13 +213,13 @@ def __generate_unrolled_mult_code_rf_leq_nin(self, n_in, n_out, reuse_factor, we else: acc_step += 1 - mult_code += "\t\t\t}\n" + mult_code += f'{indent*3}}}\n' - mult_code += "\t\t}\n" + mult_code += f'{indent*2}}}\n' return mult_code - def __generate_unrolled_mult_code_rf_gt_nin_rem0(self, n_in, n_out, reuse_factor, weights): + def _generate_unrolled_mult_code_rf_gt_nin_rem0(self, n_in, n_out, reuse_factor, weights): # Function constants mult_factor = min(n_in, reuse_factor) block_factor = int(math.ceil(n_in * n_out / reuse_factor)) @@ -208,6 +230,9 @@ def __generate_unrolled_mult_code_rf_gt_nin_rem0(self, n_in, n_out, reuse_factor # The new shape is (parallel_mult, reuse_factor) zeros = np.sum(~weights.data.reshape(block_factor, reuse_factor).any(1)) + # Used to pad the code to make it human-readable + indent = ' ' + # Generate out indices outidx = [0] * reuse_factor outstep = 0 @@ -221,32 +246,34 @@ def __generate_unrolled_mult_code_rf_gt_nin_rem0(self, n_in, n_out, reuse_factor in_index = 0 # Generate unrolled multiplications - mult_code = f"\t\t#pragma HLS ALLOCATION operation instances=mul limit={mult_limit - zeros}\n" - mult_code += "\t\tMULT: {\n" - mult_code += "\t\t\t#pragma HLS protocol\n" + mult_code = f'{indent*2}#pragma HLS ALLOCATION operation instances=mul limit={mult_limit - zeros}\n' + mult_code += f'{indent*2}MULT: {{\n' + mult_code += f'{indent*3}#pragma HLS protocol\n' for ir in range(reuse_factor): w_index = ir out_index = outidx[ir] - mult_code += f"\t\t\tM{ir}: {{\n" + mult_code += f'{indent*3}M{ir}: {{\n' for _ in range(block_factor): if weights.data.flatten()[w_index] != 0: - mult_code += f"\t\t\t\tacc[{int(out_index)}] += \ - static_cast\ - (CONFIG_T::template product::\ - product(data[{in_index}], weights[{w_index}]));\n" + mult_code += ( + f'{indent*4}acc[{int(out_index)}] += ' + 'static_cast' + '(CONFIG_T::template product::' + f'product(data[{in_index}], weights[{w_index}]));\n' + ) w_index += reuse_factor if w_index > n_in * n_out: break out_index += outscale - mult_code += "\t\t\t}\n" + mult_code += f'{indent*3}}}\n' in_index += 1 if in_index >= n_in: in_index = 0 - mult_code += "\t\t}\n" + mult_code += f'{indent*2}}}\n' return mult_code diff --git a/hls4ml/backends/vivado/passes/convolution_templates.py b/hls4ml/backends/vivado/passes/convolution_templates.py index 95a9c10cb0..9b584237a6 100644 --- a/hls4ml/backends/vivado/passes/convolution_templates.py +++ b/hls4ml/backends/vivado/passes/convolution_templates.py @@ -17,14 +17,13 @@ static const unsigned n_out = {n_out}; static const unsigned reuse_factor = {reuse}; static const unsigned strategy = nnet::{strategy}; - static const unsigned resource_implementation = nnet::{dense_resource_implementation}; - template - using dense_unrolled = nnet::{unrolled_function}; static const unsigned n_zeros = {nzeros}; static const unsigned multiplier_limit = DIV_ROUNDUP(n_in * n_out, reuse_factor) - n_zeros / reuse_factor; typedef {accum_t.name} accum_t; typedef {bias_t.name} bias_t; typedef {weight_t.name} weight_t; + template + using kernel = nnet::{dense_function}; template using product = nnet::product::{product_type}; }};\n""" @@ -49,9 +48,6 @@ static const bool store_weights_in_bram = false; static const unsigned strategy = nnet::{strategy}; static const nnet::conv_implementation implementation = nnet::conv_implementation::{implementation}; - static const unsigned resource_implementation = nnet::{dense_resource_implementation}; - template - using dense_unrolled = nnet::{unrolled_function}; static const unsigned min_width = {min_width}; static const ap_uint pixels[min_width]; static const unsigned n_partitions = {n_partitions}; @@ -96,8 +92,6 @@ def format(self, node): params['fill_fn'] = f'fill_buffer_{node.index}' else: params['fill_fn'] = 'FillConv1DBuffer' - # TODO - Extend unrolled Dense Resource to Conv1D - params['unrolled_function'] = 'DenseResourceUnrolled' conv_config = self.template.format(**params) @@ -108,8 +102,18 @@ def format(self, node): mult_params['product_type'] = get_backend('vivado').product_type( node.get_input_variable().type.precision, node.get_weights('weight').type.precision ) - # TODO - Extend unrolled Dense Resource to Conv1D - mult_params['unrolled_function'] = 'DenseResourceUnrolled' + + if node.get_attr('strategy').lower() == 'latency': + mult_params['dense_function'] = 'DenseLatency' + elif node.get_attr('strategy').lower() == 'resource': + if int(mult_params['reuse_factor']) <= int(mult_params['n_in']): + mult_params['dense_function'] = 'DenseResource_rf_leq_nin' + else: + mult_params['dense_function'] = 'DenseResource_rf_gt_nin_rem0' + # The 3rd case is never used + elif node.get_attr('strategy').lower() == 'unrolled': + mult_params['dense_function'] = f'dense_unrolled_{node.index}' + mult_config = self.mult_template.format(**mult_params) return mult_config + '\n' + conv_config @@ -160,9 +164,6 @@ def __init__(self): static const bool store_weights_in_bram = false; static const unsigned strategy = nnet::{strategy}; static const nnet::conv_implementation implementation = nnet::conv_implementation::{implementation}; - static const unsigned resource_implementation = nnet::{dense_resource_implementation}; - template - using dense_unrolled = nnet::{unrolled_function}; static const unsigned min_height = {min_height}; static const unsigned min_width = {min_width}; static const ap_uint pixels[min_height * min_width]; @@ -217,15 +218,6 @@ def format(self, node): else: params['fill_fn'] = 'FillConv2DBuffer' - if ( - node.get_attr('dense_resource_implementation', 'standard') == 'unrolled' - and node.get_attr('strategy').lower() == 'resource' - and node.get_attr('reuse_factor') > 1 - ): - params['unrolled_function'] = f'dense_unrolled_{node.index}' - else: - params['unrolled_function'] = 'DenseResourceUnrolled' - conv_config = self.template.format(**params) mult_params = self._default_config_params(node) @@ -235,14 +227,18 @@ def format(self, node): mult_params['product_type'] = get_backend('vivado').product_type( node.get_input_variable().type.precision, node.get_weights('weight').type.precision ) - if ( - node.get_attr('dense_resource_implementation', 'standard') == 'unrolled' - and node.get_attr('strategy').lower() == 'resource' - and node.get_attr('reuse_factor') > 1 - ): - mult_params['unrolled_function'] = f'dense_unrolled_{node.index}' - else: - mult_params['unrolled_function'] = 'DenseResourceUnrolled' + + if node.get_attr('strategy').lower() == 'latency': + mult_params['dense_function'] = 'DenseLatency' + elif node.get_attr('strategy').lower() == 'resource': + if int(mult_params['reuse_factor']) <= int(mult_params['n_in']): + mult_params['dense_function'] = 'DenseResource_rf_leq_nin' + else: + mult_params['dense_function'] = 'DenseResource_rf_gt_nin_rem0' + # The 3rd case is never used + elif node.get_attr('strategy').lower() == 'unrolled': + mult_params['dense_function'] = f'dense_unrolled_{node.index}' + mult_config = self.mult_template.format(**mult_params) return mult_config + '\n' + conv_config diff --git a/hls4ml/backends/vivado/passes/core_templates.py b/hls4ml/backends/vivado/passes/core_templates.py index 5f1a25e37f..16973b7fe2 100644 --- a/hls4ml/backends/vivado/passes/core_templates.py +++ b/hls4ml/backends/vivado/passes/core_templates.py @@ -9,9 +9,6 @@ static const unsigned n_out = {n_out}; static const unsigned io_type = nnet::{iotype}; static const unsigned strategy = nnet::{strategy}; - static const unsigned resource_implementation = nnet::{dense_resource_implementation}; - template - using dense_unrolled = nnet::{unrolled_function}; static const unsigned reuse_factor = {reuse}; static const unsigned n_zeros = {nzeros}; static const unsigned n_nonzeros = {nonzeros}; @@ -21,6 +18,8 @@ typedef {bias_t.name} bias_t; typedef {weight_t.name} weight_t; typedef {index_t.name} index_t; + template + using kernel = nnet::{dense_function}; template using product = nnet::product::{product_type}; }};\n""" @@ -43,14 +42,16 @@ def format(self, node): node.get_input_variable().type.precision, node.get_weights('weight').type.precision ) - if ( - node.get_attr('dense_resource_implementation', 'standard') == 'unrolled' - and node.get_attr('strategy').lower() == 'resource' - and node.get_attr('reuse_factor') > 1 - ): - params['unrolled_function'] = f'dense_unrolled_{node.index}' - else: - params['unrolled_function'] = 'DenseResourceUnrolled' + if node.get_attr('strategy').lower() == 'latency': + params['dense_function'] = 'DenseLatency' + elif node.get_attr('strategy').lower() == 'resource': + if int(params['reuse_factor']) <= int(params['n_in']): + params['dense_function'] = 'DenseResource_rf_leq_nin' + else: + params['dense_function'] = 'DenseResource_rf_gt_nin_rem0' + # The 3rd case is never used + elif node.get_attr('strategy').lower() == 'unrolled': + params['dense_function'] = f'dense_unrolled_{node.index}' return self.template.format(**params) diff --git a/hls4ml/backends/vivado/passes/recurrent_templates.py b/hls4ml/backends/vivado/passes/recurrent_templates.py index e5c3937fd3..34e3e2f9f0 100644 --- a/hls4ml/backends/vivado/passes/recurrent_templates.py +++ b/hls4ml/backends/vivado/passes/recurrent_templates.py @@ -11,15 +11,13 @@ static const unsigned reuse_factor = {reuse}; static const unsigned n_zeros = {nzeros}; static const unsigned n_nonzeros = {nonzeros}; - static const unsigned resource_implementation = nnet::{dense_resource_implementation}; - template - using dense_unrolled = nnet::{unrolled_function}; static const unsigned multiplier_limit = DIV_ROUNDUP(n_in * n_out, reuse_factor) - n_zeros / reuse_factor; static const bool store_weights_in_bram = false; typedef {accum_t.name} accum_t; typedef {bias_t.name} bias_t; typedef {weight_t.name} weight_t; - typedef {index_t.name} index_t; + template + using kernel = nnet::{dense_function}; template using product = nnet::product::{product_type}; }};\n""" @@ -116,11 +114,11 @@ def format(self, node): act_params['type'] = node.get_attr('activation') recr_act_params['type'] = node.get_attr('recurrent_activation') if node.get_attr('return_sequences'): - act_params['n_in'] = node.get_output_variable().dim_names[1] - recr_act_params['n_in'] = node.get_output_variable().dim_names[1] + ' * %i' % (n_recr_mult - 1) + act_params['n_in'] = node.get_output_variable().shape[1] + recr_act_params['n_in'] = node.get_output_variable().shape[1] * (n_recr_mult - 1) else: - act_params['n_in'] = node.get_output_variable().dim_names[0] - recr_act_params['n_in'] = node.get_output_variable().dim_names[0] + ' * %i' % (n_recr_mult - 1) + act_params['n_in'] = node.get_output_variable().shape[0] + recr_act_params['n_in'] = node.get_output_variable().shape[0] * (n_recr_mult - 1) act_config = self.act_template.format(**act_params) recr_act_config = self.recr_act_template.format(**recr_act_params) @@ -128,11 +126,11 @@ def format(self, node): mult_params1 = self._default_config_params(node) mult_params2 = self._default_config_params(node) - mult_params1['n_in'] = node.get_input_variable().dim_names[1] + mult_params1['n_in'] = node.get_input_variable().shape[1] if node.get_attr('return_sequences'): - mult_params1['n_out'] = node.get_output_variable().dim_names[1] + ' * %i' % n_recr_mult + mult_params1['n_out'] = node.get_output_variable().shape[1] * n_recr_mult else: - mult_params1['n_out'] = node.get_output_variable().dim_names[0] + ' * %i' % n_recr_mult + mult_params1['n_out'] = node.get_output_variable().shape[0] * n_recr_mult mult_params1['product_type'] = get_backend('vivado').product_type( node.get_input_variable().type.precision, node.get_weights('weight').type.precision ) @@ -141,15 +139,23 @@ def format(self, node): mult_params1['nzeros'] = node.get_weights('weight').nzeros mult_params1['nonzeros'] = node.get_weights('weight').nonzeros - # TODO - Extend unrolled Dense Resource to recurrent kernels - mult_params1['unrolled_function'] = 'DenseResourceUnrolled' + if node.get_attr('strategy').lower() == 'latency': + mult_params1['dense_function'] = 'DenseLatency' + elif node.get_attr('strategy').lower() == 'resource': + if int(mult_params1['reuse_factor']) <= int(mult_params1['n_in']): + mult_params1['dense_function'] = 'DenseResource_rf_leq_nin' + else: + mult_params1['dense_function'] = 'DenseResource_rf_gt_nin_rem0' + # The 3rd case is never used + elif node.get_attr('strategy').lower() == 'unrolled': + mult_params1['dense_function'] = f'dense_unrolled_{node.index}_1' if node.get_attr('return_sequences'): - mult_params2['n_in'] = node.get_output_variable().dim_names[1] - mult_params2['n_out'] = node.get_output_variable().dim_names[1] + ' * %i' % n_recr_mult + mult_params2['n_in'] = node.get_output_variable().shape[1] + mult_params2['n_out'] = node.get_output_variable().shape[1] * n_recr_mult else: - mult_params2['n_in'] = node.get_output_variable().dim_names[0] - mult_params2['n_out'] = node.get_output_variable().dim_names[0] + ' * %i' % n_recr_mult + mult_params2['n_in'] = node.get_output_variable().shape[0] + mult_params2['n_out'] = node.get_output_variable().shape[0] * n_recr_mult mult_params2['product_type'] = get_backend('vivado').product_type( node.get_input_variable().type.precision, node.get_weights('recurrent_weight').type.precision ) @@ -158,8 +164,16 @@ def format(self, node): mult_params2['nzeros'] = node.get_weights('recurrent_weight').nzeros mult_params2['nonzeros'] = node.get_weights('recurrent_weight').nonzeros - # TODO - Extend unrolled Dense Resource to recurrent kernels - mult_params2['unrolled_function'] = 'DenseResourceUnrolled' + if node.get_attr('strategy').lower() == 'latency': + mult_params2['dense_function'] = 'DenseLatency' + elif node.get_attr('strategy').lower() == 'resource': + if int(mult_params2['reuse_factor']) <= int(mult_params2['n_in']): + mult_params2['dense_function'] = 'DenseResource_rf_leq_nin' + else: + mult_params2['dense_function'] = 'DenseResource_rf_gt_nin_rem0' + # The 3rd case is never used + elif node.get_attr('strategy').lower() == 'unrolled': + mult_params2['dense_function'] = f'dense_unrolled_{node.index}_2' mult_config1 = self.mult1_template.format(**mult_params1) mult_config2 = self.mult2_template.format(**mult_params2) diff --git a/hls4ml/backends/vivado/passes/resource_strategy.py b/hls4ml/backends/vivado/passes/resource_strategy.py index 63e6e0b4db..d65b0dc48e 100644 --- a/hls4ml/backends/vivado/passes/resource_strategy.py +++ b/hls4ml/backends/vivado/passes/resource_strategy.py @@ -9,7 +9,7 @@ class ApplyResourceStrategy(OptimizerPass): def match(self, node): node_matches = isinstance(node, (Dense, Conv1D, SeparableConv1D, Conv2D, SeparableConv2D, LSTM, GRU)) - is_resource_strategy = node.get_attr('strategy', '').lower() == 'resource' + is_resource_strategy = node.get_attr('strategy', '').lower() in ['resource', 'unrolled'] already_transformed = node.get_attr('_weights_transposed', False) is True return node_matches and is_resource_strategy and not already_transformed diff --git a/hls4ml/backends/vivado/vivado_backend.py b/hls4ml/backends/vivado/vivado_backend.py index 35cc908ed4..6c5deccc68 100644 --- a/hls4ml/backends/vivado/vivado_backend.py +++ b/hls4ml/backends/vivado/vivado_backend.py @@ -75,15 +75,6 @@ def _register_layer_attributes(self): attrs.append(ChoiceAttribute('conv_implementation', choices=['LineBuffer', 'Encoded'], default='LineBuffer')) self.attribute_map[layer] = attrs - # Add implementation of Dense Resource for all layers that use Dense for matrix mult - # Handle different implementations of Resource strategy; only makes a difference if strategy == Resource - # Standard -> nnet_dense_resource.h - # Unrolled -> Code generation, ignoring zero DSPs and optimizing zero-filled BRAM blocks - for layer in [Dense] + cnn_layers + rnn_layers: - attrs = self.attribute_map.get(layer, []) - attrs.append( - ChoiceAttribute('dense_resource_implementation', choices=['standard', 'unrolled'], default='standard') - ) sep_conv_layers = [SeparableConv1D, SeparableConv2D] for layer in sep_conv_layers: attrs = self.attribute_map.get(layer, []) @@ -259,6 +250,11 @@ def init_dense(self, layer): index_t = layer.get_weights('weight').type.index_precision else: layer.set_attr('strategy', 'resource') + elif layer.model.config.get_strategy(layer).lower() == 'unrolled' and layer.get_attr('reuse_factor', 1) > 1: + n_in, n_out = self.get_layer_mult_size(layer) + self.set_target_reuse_factor(layer) + self.set_closest_reuse_factor(layer, n_in, n_out) + layer.set_attr('strategy', 'unrolled') else: layer.set_attr('strategy', 'latency') layer.set_attr('index_t', NamedType(f'layer{layer.index}_index', index_t)) @@ -275,6 +271,11 @@ def init_conv1d(self, layer): n_in, n_out = self.get_layer_mult_size(layer) self.set_target_reuse_factor(layer) self.set_closest_reuse_factor(layer, n_in, n_out) + elif layer.model.config.get_strategy(layer).lower() == 'unrolled' and layer.get_attr('reuse_factor', 1) > 1: + n_in, n_out = self.get_layer_mult_size(layer) + self.set_target_reuse_factor(layer) + self.set_closest_reuse_factor(layer, n_in, n_out) + layer.set_attr('strategy', 'unrolled') else: layer.set_attr('strategy', 'latency') @@ -334,6 +335,11 @@ def init_conv2d(self, layer): self.set_target_reuse_factor(layer) n_in, n_out = self.get_layer_mult_size(layer) self.set_closest_reuse_factor(layer, n_in, n_out) + elif layer.model.config.get_strategy(layer).lower() == 'unrolled' and layer.get_attr('reuse_factor', 1) > 1: + n_in, n_out = self.get_layer_mult_size(layer) + self.set_target_reuse_factor(layer) + self.set_closest_reuse_factor(layer, n_in, n_out) + layer.set_attr('strategy', 'unrolled') else: layer.set_attr('strategy', 'latency') @@ -453,6 +459,11 @@ def init_lstm(self, layer): self.set_closest_reuse_factor(layer, n_in, n_out) self.set_closest_reuse_factor(layer, n_in_recr, n_out_recr, attribute='recurrent_reuse_factor') layer.set_attr('strategy', 'resource') + elif layer.model.config.get_strategy(layer).lower() == 'unrolled' and layer.get_attr('reuse_factor', 1) > 1: + n_in, n_out, n_in_recr, n_out_recr = self.get_layer_mult_size(layer) + self.set_closest_reuse_factor(layer, n_in, n_out) + self.set_closest_reuse_factor(layer, n_in_recr, n_out_recr, attribute='recurrent_reuse_factor') + layer.set_attr('strategy', 'unrolled') else: layer.set_attr('strategy', 'latency') @@ -471,6 +482,11 @@ def init_gru(self, layer): self.set_closest_reuse_factor(layer, n_in, n_out) self.set_closest_reuse_factor(layer, n_in_recr, n_out_recr, attribute='recurrent_reuse_factor') layer.set_attr('strategy', 'resource') + elif layer.model.config.get_strategy(layer).lower() == 'unrolled' and layer.get_attr('reuse_factor', 1) > 1: + n_in, n_out, n_in_recr, n_out_recr = self.get_layer_mult_size(layer) + self.set_closest_reuse_factor(layer, n_in, n_out) + self.set_closest_reuse_factor(layer, n_in_recr, n_out_recr, attribute='recurrent_reuse_factor') + layer.set_attr('strategy', 'unrolled') else: layer.set_attr('strategy', 'latency') diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_code_gen.h b/hls4ml/templates/vivado/nnet_utils/nnet_code_gen.h index caab69663e..4a8a40cd10 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_code_gen.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_code_gen.h @@ -5,38 +5,11 @@ #include "hls_stream.h" #include "nnet_common.h" +#include "nnet_function_stubs.h" #include "nnet_mult.h" namespace nnet { -template class FillConv1DBuffer { - public: - static void fill_buffer(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan], - data_T buffer[CONFIG_T::n_pixels][CONFIG_T::filt_width * CONFIG_T::n_chan], - const unsigned partition) { - // To be implemented in subclasses - } -}; - -template class FillConv2DBuffer { - public: - static void - fill_buffer(data_T data[CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_T::n_chan], - data_T buffer[CONFIG_T::n_pixels][CONFIG_T::filt_height * CONFIG_T::filt_width * CONFIG_T::n_chan], - const unsigned partition) { - // To be implemented in subclasses - } -}; - -template class DenseResourceUnrolled { - public: - static void dense_unrolled(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_out], - typename CONFIG_T::weight_t weights[CONFIG_T::n_in * CONFIG_T::n_out], - typename CONFIG_T::bias_t biases[CONFIG_T::n_out]) { - // To be implemented in subclasses - } -}; - // hls4ml insert code } // namespace nnet diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_common.h b/hls4ml/templates/vivado/nnet_utils/nnet_common.h index fed0395a1a..fee8b7b935 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_common.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_common.h @@ -23,7 +23,7 @@ namespace nnet { // Common type definitions enum io_type { io_parallel = 0, io_stream }; -enum strategy { latency, resource }; +enum strategy { latency, resource, unrolled }; /* --- * Balanced tree reduce implementation. diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_conv1d_stream.h b/hls4ml/templates/vivado/nnet_utils/nnet_conv1d_stream.h index b23c330c78..4a55700d8d 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_conv1d_stream.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_conv1d_stream.h @@ -60,6 +60,10 @@ void conv_1d_buffer_cl(hls::stream &data, hls::stream &res, typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) { assert(CONFIG_T::pad_left == 0 && CONFIG_T::pad_right == 0); + if (CONFIG_T::strategy == nnet::unrolled && CONFIG_T::reuse_factor > 1) { + #pragma HLS allocation instances=compute_output_buffer_1d limit=1 function + } + ReadInputWidth: for (unsigned i_iw = 0; i_iw < CONFIG_T::in_width; i_iw++) { #pragma HLS LOOP_FLATTEN diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_conv2d_stream.h b/hls4ml/templates/vivado/nnet_utils/nnet_conv2d_stream.h index 08d06501c3..d5583f2669 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_conv2d_stream.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_conv2d_stream.h @@ -75,8 +75,7 @@ void conv_2d_buffer_cl( [CONFIG_T::n_chan]; #pragma HLS ARRAY_PARTITION variable = line_buffer complete dim = 2 - if (CONFIG_T::strategy == nnet::resource && CONFIG_T::resource_implementation == nnet::unrolled && - CONFIG_T::reuse_factor > 1) { + if (CONFIG_T::strategy == nnet::unrolled && CONFIG_T::reuse_factor > 1) { #pragma HLS allocation instances=compute_output_buffer_1d limit=1 function #pragma HLS allocation instances=compute_output_buffer_2d limit=1 function } diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_conv_stream.h b/hls4ml/templates/vivado/nnet_utils/nnet_conv_stream.h index d95d528e46..dcd914dffe 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_conv_stream.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_conv_stream.h @@ -95,13 +95,8 @@ void mult_buffer(hls::stream data_window[CONFIG_T:: } #pragma HLS INLINE recursive - if (CONFIG_T::strategy == nnet::latency) { - dense_latency( - data, res, weights, biases); - } else { - dense_resource( - data, res, weights, biases); - } + CONFIG_T::mult_config::template kernel::dense(data, res, weights, biases); CastLoop: for (unsigned jj = 0; jj < CONFIG_T::n_filt; jj++) { @@ -290,18 +285,8 @@ void compute_output_buffer_2d( // Dense multiply // #pragma HLS INLINE recursive - if (CONFIG_T::strategy == nnet::latency) { - dense_latency( - kernel_data, res_out, weights, biases); - } else if (CONFIG_T::strategy == nnet::resource && CONFIG_T::resource_implementation == nnet::unrolled && - CONFIG_T::reuse_factor > 1) { - CONFIG_T::template dense_unrolled::dense_unrolled(kernel_data, res_out, weights, - biases); - } else { - dense_resource( - kernel_data, res_out, weights, biases); - } + CONFIG_T::mult_config::template kernel::dense(kernel_data, res_out, weights, biases); // Pack output CastLoop: @@ -366,18 +351,8 @@ void compute_output_buffer_1d( // Dense multiply // #pragma HLS INLINE recursive - if (CONFIG_T::strategy == nnet::latency) { - dense_latency( - kernel_data, res_out, weights, biases); - } else if (CONFIG_T::strategy == nnet::resource && CONFIG_T::resource_implementation == nnet::unrolled && - CONFIG_T::reuse_factor > 1) { - CONFIG_T::template dense_unrolled::dense_unrolled(kernel_data, res_out, weights, - biases); - } else { - dense_resource( - kernel_data, res_out, weights, biases); - } + CONFIG_T::mult_config::template kernel::dense(kernel_data, res_out, weights, biases); // Pack output CastLoop: diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_dense.h b/hls4ml/templates/vivado/nnet_utils/nnet_dense.h index 2037daf0b9..d6c7beb70e 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_dense.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_dense.h @@ -5,17 +5,13 @@ #include "nnet_common.h" #include "nnet_dense_latency.h" #include "nnet_dense_resource.h" +#include "nnet_function_stubs.h" #include "nnet_helpers.h" #include "nnet_mult.h" #include namespace nnet { -// Different implementations of Resource strategy; this attribute only makes a difference if strategy == Resource -// Default -> nnet_dense_resource.h -// Unrolled -> Code generation, ignoring zero DSPs and optimizing BRAM -enum resource_implementation { standard, unrolled }; - struct dense_config { // Internal data type definitions typedef float bias_t; @@ -33,9 +29,7 @@ struct dense_config { static const bool store_weights_in_bram = false; static const unsigned n_zeros = 0; - static const unsigned resource_implementation = standard; - template - using dense_unrolled = nnet::DenseResourceUnrolled; + template using kernel = nnet::DenseKernel; // Partitioning arrays cyclically to go with roll factors? @@ -47,16 +41,41 @@ template void dense(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_out], typename CONFIG_T::weight_t weights[CONFIG_T::n_in * CONFIG_T::n_out], typename CONFIG_T::bias_t biases[CONFIG_T::n_out]) { - #pragma HLS inline - if (CONFIG_T::strategy == nnet::latency) { + #pragma HLS INLINE + CONFIG_T::template kernel::dense(data, res, weights, biases); +} + +template class DenseLatency : public DenseKernel { + public: + static void dense(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_out], + typename CONFIG_T::weight_t weights[CONFIG_T::n_in * CONFIG_T::n_out], + typename CONFIG_T::bias_t biases[CONFIG_T::n_out]) { + #pragma HLS INLINE dense_latency(data, res, weights, biases); - } else if (CONFIG_T::strategy == nnet::resource && CONFIG_T::resource_implementation == nnet::unrolled && - CONFIG_T::reuse_factor > 1) { - CONFIG_T::template dense_unrolled::dense_unrolled(data, res, weights, biases); - } else { - dense_resource(data, res, weights, biases); } -} +}; + +template +class DenseResource_rf_leq_nin : public DenseKernel { + public: + static void dense(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_out], + typename CONFIG_T::weight_t weights[CONFIG_T::n_in * CONFIG_T::n_out], + typename CONFIG_T::bias_t biases[CONFIG_T::n_out]) { + #pragma HLS INLINE + dense_resource_rf_leq_nin(data, res, weights, biases); + } +}; + +template +class DenseResource_rf_gt_nin_rem0 : public DenseKernel { + public: + static void dense(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_out], + typename CONFIG_T::weight_t weights[CONFIG_T::n_in * CONFIG_T::n_out], + typename CONFIG_T::bias_t biases[CONFIG_T::n_out]) { + #pragma HLS INLINE + dense_resource_rf_gt_nin_rem0(data, res, weights, biases); + } +}; } // namespace nnet diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_dense_stream.h b/hls4ml/templates/vivado/nnet_utils/nnet_dense_stream.h index db3039fc33..3e3183480e 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_dense_stream.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_dense_stream.h @@ -16,13 +16,8 @@ void dense_wrapper(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_out], #pragma HLS INLINE recursive if (CONFIG_T::strategy == nnet::latency) { #pragma HLS PIPELINE II=CONFIG_T::reuse_factor - dense_latency(data, res, weights, biases); - } else if (CONFIG_T::strategy == nnet::resource && CONFIG_T::resource_implementation == nnet::unrolled && - CONFIG_T::reuse_factor > 1) { - CONFIG_T::template dense_unrolled::dense_unrolled(data, res, weights, biases); - } else { - dense_resource(data, res, weights, biases); } + CONFIG_T::template kernel::dense(data, res, weights, biases); } template diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_function_stubs.h b/hls4ml/templates/vivado/nnet_utils/nnet_function_stubs.h new file mode 100644 index 0000000000..1316bbe776 --- /dev/null +++ b/hls4ml/templates/vivado/nnet_utils/nnet_function_stubs.h @@ -0,0 +1,42 @@ +#ifndef NNET_FUNCTION_STUBS_H_ +#define NNET_FUNCTION_STUBS_H_ + +#include "nnet_helpers.h" + +#include "hls_stream.h" +#include "nnet_common.h" +#include "nnet_mult.h" + +namespace nnet { + +template class FillConv1DBuffer { + public: + static void fill_buffer(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan], + data_T buffer[CONFIG_T::n_pixels][CONFIG_T::filt_width * CONFIG_T::n_chan], + const unsigned partition) { + // To be implemented in subclasses + } +}; + +template class FillConv2DBuffer { + public: + static void + fill_buffer(data_T data[CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_T::n_chan], + data_T buffer[CONFIG_T::n_pixels][CONFIG_T::filt_height * CONFIG_T::filt_width * CONFIG_T::n_chan], + const unsigned partition) { + // To be implemented in subclasses + } +}; + +template class DenseKernel { + public: + static void dense(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_out], + typename CONFIG_T::weight_t weights[CONFIG_T::n_in * CONFIG_T::n_out], + typename CONFIG_T::bias_t biases[CONFIG_T::n_out]) { + // To be implemented in subclasses + } +}; + +} // namespace nnet + +#endif diff --git a/test/pytest/test_dense_unrolled.py b/test/pytest/test_dense_unrolled.py index a3318049be..6b7503c543 100644 --- a/test/pytest/test_dense_unrolled.py +++ b/test/pytest/test_dense_unrolled.py @@ -2,7 +2,7 @@ import numpy as np import pytest -from tensorflow.keras.layers import Conv2D, Dense, Flatten +from tensorflow.keras.layers import GRU, LSTM, Conv1D, Conv2D, Dense, Flatten from tensorflow.keras.models import Sequential from hls4ml.converters import convert_from_keras_model @@ -14,37 +14,51 @@ # Tests a wide range of RF to ensure the unrolled Dense is correct @pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream']) @pytest.mark.parametrize('reuse_factor', [1, 2, 4, 8, 16, 32, 48, 64, 96, 192]) -def test_dense_unrolled(io_type, reuse_factor): +@pytest.mark.parametrize('backend', ['Vitis', 'Vivado']) +def test_dense_unrolled(io_type, reuse_factor, backend): input_shape = (16,) X = np.random.rand(100, *input_shape) model = Sequential() - model.add(Dense(12, input_shape=input_shape, kernel_initializer='lecun_uniform', bias_initializer='lecun_uniform')) + model.add( + Dense( + 12, input_shape=input_shape, kernel_initializer='lecun_uniform', bias_initializer='lecun_uniform', name='dense' + ) + ) model.compile('adam', 'mse') keras_prediction = model.predict(X) - config = config_from_keras_model(model, default_precision='ac_fixed<32, 16>', default_reuse_factor=reuse_factor) - config['Model']['Strategy'] = 'Resource' - config['Model']['DenseResourceImplementation'] = 'Unrolled' + config = config_from_keras_model( + model, default_precision='ac_fixed<32, 16>', backend=backend, default_reuse_factor=reuse_factor + ) + config['Model']['Strategy'] = 'Unrolled' + + output_dir = str(test_root_path / f'hls4mlprj_dense_unrolled_{io_type}_{reuse_factor}_{backend}') + hls_model = convert_from_keras_model(model, hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type) + + # Check if strategy was not overridden + assert list(hls_model.get_layers())[1].get_attr('strategy') == 'unrolled' if reuse_factor > 1 else 'latency' - output_dir = str(test_root_path / f'hls4mlprj_dense_unrolled_{io_type}_{reuse_factor}') - hls_model = convert_from_keras_model(model, hls_config=config, output_dir=output_dir, backend='Vivado', io_type=io_type) hls_model.compile() hls_prediction = hls_model.predict(X) np.testing.assert_allclose(hls_prediction, keras_prediction, rtol=0, atol=1e-2) -# Tests a wide range RF on streaming Conv2D to ensure the unrolled Dense is correct +# Tests a wide range RF on streaming Conv1D/2D to ensure the unrolled Dense is correct +@pytest.mark.parametrize('dim', [1, 2]) @pytest.mark.parametrize('io_type', ['io_stream']) @pytest.mark.parametrize('reuse_factor', [1, 3, 9, 27, 54, 108]) -def test_dense_unrolled_streaming_conv(io_type, reuse_factor): - input_shape = (8, 8, 3) +def test_dense_unrolled_streaming_conv(dim, io_type, reuse_factor): + input_shape = (8,) * dim + (3,) X = np.random.rand(100, *input_shape) + conv_class = Conv1D if dim == 1 else Conv2D model = Sequential() model.add( - Conv2D(4, (3, 3), input_shape=input_shape, kernel_initializer='lecun_uniform', bias_initializer='lecun_uniform') + conv_class( + 4, (3,) * dim, input_shape=input_shape, kernel_initializer='lecun_uniform', bias_initializer='lecun_uniform' + ) ) model.add(Flatten()) model.add(Dense(1, kernel_initializer='lecun_uniform', bias_initializer='lecun_uniform')) @@ -52,12 +66,60 @@ def test_dense_unrolled_streaming_conv(io_type, reuse_factor): keras_prediction = model.predict(X) config = config_from_keras_model(model, default_precision='ac_fixed<32, 16>', default_reuse_factor=reuse_factor) - config['Model']['Strategy'] = 'Resource' - config['Model']['DenseResourceImplementation'] = 'Unrolled' + config['Model']['Strategy'] = 'Unrolled' - output_dir = str(test_root_path / f'hls4mlprj_dense_unrolled_conv2d_{io_type}_{reuse_factor}') + output_dir = str(test_root_path / f'hls4mlprj_dense_unrolled_conv{dim}d_{io_type}_{reuse_factor}') hls_model = convert_from_keras_model(model, hls_config=config, output_dir=output_dir, backend='Vivado', io_type=io_type) + + # Check if strategy was not overridden + assert list(hls_model.get_layers())[1].get_attr('strategy') == 'unrolled' if reuse_factor > 1 else 'latency' + hls_model.compile() hls_prediction = hls_model.predict(X) np.testing.assert_allclose(hls_prediction, keras_prediction, rtol=0, atol=1e-2) + + +@pytest.mark.parametrize('rnn_layer', [LSTM, GRU]) +@pytest.mark.parametrize('backend', ['Vitis', 'Vivado']) +@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream']) +@pytest.mark.parametrize('static', [True, False]) +@pytest.mark.parametrize('reuse_factor', [1, 4, 32, 128]) # These should be enough +def test_rnn_unrolled(rnn_layer, backend, io_type, static, reuse_factor): + # Subtract 0.5 to include negative values + input_shape = (12, 8) + X = np.random.rand(50, *input_shape) - 0.5 + + layer_name = rnn_layer.__name__.lower() + keras_model = Sequential() + keras_model.add( + rnn_layer( + units=8, + input_shape=input_shape, + kernel_initializer='lecun_uniform', + recurrent_initializer='lecun_uniform', + bias_initializer='lecun_uniform', + return_sequences=False, + name=layer_name, + ) + ) + keras_model.compile() + + default_precision = 'ap_fixed<32, 16>' if backend in ['Vivado', 'Vitis'] else 'ac_fixed<32, 16, true>' + hls_config = config_from_keras_model( + keras_model, granularity='name', default_precision=default_precision, backend=backend + ) + hls_config['LayerName'][layer_name]['static'] = static + hls_config['LayerName'][layer_name]['Strategy'] = 'Unrolled' + hls_config['LayerName'][layer_name]['ReuseFactor'] = reuse_factor + prj_name = f'hls4mlprj_rnn_unrolled_{layer_name}_static_{int(static)}_{io_type}_{reuse_factor}_{backend}' + output_dir = str(test_root_path / prj_name) + + hls_model = convert_from_keras_model( + keras_model, hls_config=hls_config, output_dir=output_dir, backend=backend, io_type=io_type + ) + hls_model.compile() + + keras_prediction = keras_model.predict(X) + hls_prediction = hls_model.predict(X) + np.testing.assert_allclose(hls_prediction.flatten(), keras_prediction.flatten(), rtol=0.0, atol=5e-2) From 2ed0865032b3decea6af3a246c876ecc5cd2aa81 Mon Sep 17 00:00:00 2001 From: Vladimir Loncar Date: Thu, 22 Aug 2024 19:04:57 +0200 Subject: [PATCH 09/14] Reorganize codegen of unrolled implementation --- hls4ml/backends/fpga/fpga_backend.py | 4 +- hls4ml/backends/fpga/passes/codegen.py | 230 +---------------- hls4ml/backends/vitis/passes/feature_check.py | 23 +- hls4ml/backends/vitis/vitis_backend.py | 3 +- .../vivado/passes/unrolled_codegen.py | 243 ++++++++++++++++++ hls4ml/backends/vivado/vivado_backend.py | 103 ++++++-- hls4ml/writer/vivado_writer.py | 17 +- 7 files changed, 359 insertions(+), 264 deletions(-) create mode 100644 hls4ml/backends/vivado/passes/unrolled_codegen.py diff --git a/hls4ml/backends/fpga/fpga_backend.py b/hls4ml/backends/fpga/fpga_backend.py index 8d0ed64aad..ad8e917dd8 100644 --- a/hls4ml/backends/fpga/fpga_backend.py +++ b/hls4ml/backends/fpga/fpga_backend.py @@ -227,10 +227,12 @@ def get_closest_reuse_factor(self, valid_rf, chosen_rf): else: return before - def set_closest_reuse_factor(self, layer, n_in, n_out, attribute='reuse_factor'): + def set_closest_reuse_factor(self, layer, n_in, n_out, attribute='reuse_factor', include_max_rf=True): assert attribute is not None, 'Reuse factor attribute cannot be None' valid_rf = self.get_valid_reuse_factors(n_in, n_out) + if not include_max_rf: + valid_rf.pop() chosen_rf = layer.get_attr(attribute) if chosen_rf not in valid_rf: closest_rf = self.get_closest_reuse_factor(valid_rf, chosen_rf) diff --git a/hls4ml/backends/fpga/passes/codegen.py b/hls4ml/backends/fpga/passes/codegen.py index 3667680ed5..f1f1080996 100644 --- a/hls4ml/backends/fpga/passes/codegen.py +++ b/hls4ml/backends/fpga/passes/codegen.py @@ -1,8 +1,4 @@ -import math - -import numpy as np - -from hls4ml.model.layers import GRU, LSTM, Conv1D, Conv2D, Dense +from hls4ml.model.layers import Conv1D, Conv2D from hls4ml.model.optimizer import OptimizerPass from hls4ml.model.types import Source @@ -53,227 +49,3 @@ def _generate_im2col_2d(self, node): ) node.set_attr('line_buffer_codegen', Source(code_str)) - - -class GenerateUnrolledDenseResource(OptimizerPass): - '''Generates C++ code for unrolled Dense resource''' - - def match(self, node): - # Only apply to layers use that use Dense Matrix Multiplication - # TODO - Extend (& test) for Separable Conv / Depthwise Conv / Recurrent layers - layers_with_dense = (Dense, Conv1D, Conv2D, LSTM, GRU) - - # Unrolled Dense mimicks the hardware implementation of Resource strategy -> apply after Resource optimizer - weights_transposed = node.get_attr('_weights_transposed', False) - - # RF = 1 will optimize DSPs anyway, so no need to unroll code - rf_gt_one = node.get_attr('reuse_factor', 1) > 1 - - # User requested unrolled implementation of Dense - is_unrolled = node.get_attr('strategy', 'latency') == 'unrolled' - - return isinstance(node, layers_with_dense) and weights_transposed and rf_gt_one and is_unrolled - - def transform(self, model, node): - if isinstance(node, (LSTM, GRU)): - n_in, n_out, n_in_recr, n_out_recr = node.model.config.backend.get_layer_mult_size(node) - - reuse_factor = node.get_attr('reuse_factor') - weights = node.weights['weight'] - code_str = self._generate_unrolled_function(n_in, n_out, reuse_factor, weights, str(node.index) + '_1') - node.set_attr('unrolled_dense_resource_codegen_1', Source(code_str)) - - recr_reuse_factor = node.get_attr('recurrent_reuse_factor') - recr_weights = node.weights['recurrent_weight'] - code_str = self._generate_unrolled_function( - n_in_recr, n_out_recr, recr_reuse_factor, recr_weights, str(node.index) + '_2' - ) - node.set_attr('unrolled_dense_resource_codegen_2', Source(code_str)) - - else: - n_in, n_out = node.model.config.backend.get_layer_mult_size(node) - reuse_factor = node.get_attr('reuse_factor') - weights = node.weights['weight'] - - code_str = self._generate_unrolled_function(n_in, n_out, reuse_factor, weights, node.index) - node.set_attr('unrolled_dense_resource_codegen', Source(code_str)) - - def _generate_unrolled_function(self, n_in, n_out, reuse_factor, weights, function_suffix): - """ - Generate a C++ function that mimics the Dense Resource implementation. - - The HLS compiler produces suboptimal designs for Dense Resource when the weights processed by the same DSP are zero. - Latency strategy can optimize zero multiplications - Resource strategy, on the other hand, cannot. - When all the weights in the same BRAM block are zero, Vivado is unable to optimize it - With this (and additional TCL scripts) zero BRAM are optimized - - Args: - node: Layer to generate code for - Returns: - generated_code: Generated C++ function (string) - """ - - # Variable instantiation and function pragmas - generated_code = ( - 'template\n' - 'class dense_unrolled_{suffix} : public DenseKernel {{\n' - ' public:\n' - ' static void dense(\n' - ' data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_out],\n' - ' typename CONFIG_T::weight_t weights[CONFIG_T::n_in * CONFIG_T::n_out],\n' - ' typename CONFIG_T::bias_t biases[CONFIG_T::n_out]\n' - ' ) {{\n' - ' #pragma HLS pipeline II=CONFIG_T::reuse_factor\n' - '\n' - ' constexpr int block_factor = DIV_ROUNDUP(CONFIG_T::n_in * CONFIG_T::n_out, CONFIG_T::reuse_factor);\n' - ' #pragma HLS function_instantiate variable=weights,biases\n' - ' #pragma HLS ARRAY_RESHAPE variable=weights block factor=block_factor\n' - ' #pragma HLS RESOURCE variable=weights core=ROM_nP_BRAM\n' - ' #pragma HLS ARRAY_PARTITION variable=biases complete\n' - '\n' - ' typename CONFIG_T::accum_t acc[CONFIG_T::n_out];\n' - ' #pragma HLS ARRAY_PARTITION variable=acc complete\n' - '\n' - ' InitAccum:\n' - ' for (int i = 0; i < CONFIG_T::n_out; i++) {{\n' - ' #pragma HLS UNROLL\n' - ' acc[i] = (typename CONFIG_T::accum_t) biases[i];\n' - ' }}\n' - '\n' - ).format(suffix=function_suffix) - - # Unrolled multiplication, according to the three cases - if reuse_factor <= n_in: - mult_code = self._generate_unrolled_mult_code_rf_leq_nin(n_in, n_out, reuse_factor, weights) - elif reuse_factor > n_in and reuse_factor % n_in == 0: - mult_code = self._generate_unrolled_mult_code_rf_gt_nin_rem0(n_in, n_out, reuse_factor, weights) - else: - # This case shouldn't happen if my understanding of RF is correct - # The function fpga_backend._validate_reuse_factor() has assertion rf % n_in == 0 or rf < n_in - raise Exception('Not implemented...') - - # Write output - generated_code += mult_code + '\n' - generated_code += ( - ' Result:\n' - ' for (int i = 0; i < CONFIG_T::n_out; i++) {\n' - ' #pragma HLS UNROLL\n' - ' res[i] = cast(acc[i]);\n' - ' }\n' - ' }\n' - '};\n' - ) - - return generated_code - - def _generate_unrolled_mult_code_rf_leq_nin(self, n_in, n_out, reuse_factor, weights): - # Function constants - mult_factor = min(n_in, reuse_factor) - block_factor = int(math.ceil(n_in * n_out / reuse_factor)) - mult_limit = int(math.ceil(n_in * n_out / mult_factor)) - mult_scale = mult_limit // n_out - - # Zero DSPs are the DSP blocks that always have zero input - # In this case, it is the number of rows in the transposed and reshaped weight matrix - # The new shape is (parallel_mult, reuse_factor) - zeros = np.sum(~weights.data.reshape(block_factor, reuse_factor).any(1)) - - # Used to pad the code to make it human-readable - indent = ' ' - - # Generate unrolled multiplications - mult_code = f'{indent*2}#pragma HLS ALLOCATION operation instances=mul limit={mult_limit - zeros}\n' - mult_code += f'{indent*2}MULT: {{\n' - mult_code += f'{indent*3}#pragma HLS protocol\n' - - for ir in range(reuse_factor): - acc_step = 0 - out_index = 0 - w_index = ir - in_index = ir - - mult_code += f'{indent*3}M{ir}: {{\n' - for _ in range(block_factor): - if weights.data.flatten()[w_index] != 0: - mult_code += ( - f'{indent*4}acc[{out_index}] += ' - 'static_cast' - '(CONFIG_T::template product::' - f'product(data[{in_index}], weights[{w_index}]));\n' - ) - - w_index += reuse_factor - in_index += reuse_factor - if in_index >= n_in: - in_index = ir - if acc_step + 1 >= mult_scale: - acc_step = 0 - out_index += 1 - else: - acc_step += 1 - - mult_code += f'{indent*3}}}\n' - - mult_code += f'{indent*2}}}\n' - - return mult_code - - def _generate_unrolled_mult_code_rf_gt_nin_rem0(self, n_in, n_out, reuse_factor, weights): - # Function constants - mult_factor = min(n_in, reuse_factor) - block_factor = int(math.ceil(n_in * n_out / reuse_factor)) - mult_limit = int(math.ceil(n_in * n_out / mult_factor)) - - # Zero DSPs are the DSP blocks that always have zero input - # In this case, it is the number of rows in the transposed and reshaped weight matrix - # The new shape is (parallel_mult, reuse_factor) - zeros = np.sum(~weights.data.reshape(block_factor, reuse_factor).any(1)) - - # Used to pad the code to make it human-readable - indent = ' ' - - # Generate out indices - outidx = [0] * reuse_factor - outstep = 0 - outscale = reuse_factor // n_in - for ir in range(reuse_factor): - outidx[ir] = outstep - if (ir + 1) % n_in == 0: - outstep += 1 - - # Define variables - in_index = 0 - - # Generate unrolled multiplications - mult_code = f'{indent*2}#pragma HLS ALLOCATION operation instances=mul limit={mult_limit - zeros}\n' - mult_code += f'{indent*2}MULT: {{\n' - mult_code += f'{indent*3}#pragma HLS protocol\n' - - for ir in range(reuse_factor): - w_index = ir - out_index = outidx[ir] - - mult_code += f'{indent*3}M{ir}: {{\n' - for _ in range(block_factor): - if weights.data.flatten()[w_index] != 0: - mult_code += ( - f'{indent*4}acc[{int(out_index)}] += ' - 'static_cast' - '(CONFIG_T::template product::' - f'product(data[{in_index}], weights[{w_index}]));\n' - ) - - w_index += reuse_factor - if w_index > n_in * n_out: - break - out_index += outscale - mult_code += f'{indent*3}}}\n' - - in_index += 1 - if in_index >= n_in: - in_index = 0 - - mult_code += f'{indent*2}}}\n' - - return mult_code diff --git a/hls4ml/backends/vitis/passes/feature_check.py b/hls4ml/backends/vitis/passes/feature_check.py index d7f9c2a7f5..7f0b832765 100644 --- a/hls4ml/backends/vitis/passes/feature_check.py +++ b/hls4ml/backends/vitis/passes/feature_check.py @@ -14,7 +14,7 @@ def transform(self, model, node): node.set_attr('implementation', 'linebuffer') -class ValidateStrategy(OptimizerPass): +class ValidateResourceStrategy(OptimizerPass): _resource_layer_cls = ['Conv1D', 'Conv2D', 'Dense'] def match(self, node): @@ -29,6 +29,23 @@ def transform(self, model, node): if rf > n_in and rf % n_in > 0: print( f'WARNING: "Resource" strategy in "{node.name}" ({node.class_name}) may have suboptimal QoR in Vitis ' - 'backend due to use of "urem" cores.\n' - 'Consider using a different ReuseFactor or switching to "Latency" strategy.' + 'backend due to use of "urem" cores in Vitis HLS <= 2022.1.\n' + 'Consider using a different ReuseFactor or switching to "Latency" strategy if using older versions ' + 'of Vitis HLS.' ) + + +class ValidateUnrolledStrategy(OptimizerPass): + _unrolled_layer_cls = ['Conv1D', 'Conv2D', 'Dense', 'GRU', 'LSTM'] + + def match(self, node): + is_unrolled_layer = len([layer_cls for layer_cls in self._unrolled_layer_cls if layer_cls in node.class_name]) > 0 + is_unrolled_strategy = node.get_attr('strategy', 'latency').lower() == 'unrolled' + + return is_unrolled_layer and is_unrolled_strategy + + def transform(self, model, node): + print( + f'WARNING: "Unrolled" strategy in "{node.name}" ({node.class_name}) may have unexpected II in Vitis backend.\n' + 'Verify that the final design satisfies the latency/II constraints.' + ) diff --git a/hls4ml/backends/vitis/vitis_backend.py b/hls4ml/backends/vitis/vitis_backend.py index 2a0616a198..6e9cbbb10c 100644 --- a/hls4ml/backends/vitis/vitis_backend.py +++ b/hls4ml/backends/vitis/vitis_backend.py @@ -15,7 +15,8 @@ def __init__(self): def _register_flows(self): validation_passes = [ 'vitis:validate_conv_implementation', - 'vitis:validate_strategy', + 'vitis:validate_resource_strategy', + 'vitis:validate_unrolled_strategy', ] validation_flow = register_flow('validation', validation_passes, requires=['vivado:init_layers'], backend=self.name) diff --git a/hls4ml/backends/vivado/passes/unrolled_codegen.py b/hls4ml/backends/vivado/passes/unrolled_codegen.py new file mode 100644 index 0000000000..6fd6c584af --- /dev/null +++ b/hls4ml/backends/vivado/passes/unrolled_codegen.py @@ -0,0 +1,243 @@ +import math + +import numpy as np + +from hls4ml.model.layers import GRU, LSTM, Conv1D, Conv2D, Dense +from hls4ml.model.optimizer import OptimizerPass +from hls4ml.model.types import Source + + +class GenerateUnrolledDenseResource(OptimizerPass): + '''Generates C++ code for unrolled Dense resource''' + + def match(self, node): + # Only apply to layers use that use Dense Matrix Multiplication + # TODO - Extend (& test) for Separable Conv / Depthwise Conv / Recurrent layers + layers_with_dense = (Dense, Conv1D, Conv2D, LSTM, GRU) + + # Unrolled Dense mimicks the hardware implementation of Resource strategy -> apply after Resource optimizer + weights_transposed = node.get_attr('_weights_transposed', False) + + # RF = 1 will optimize DSPs anyway, so no need to unroll code + rf_gt_one = node.get_attr('reuse_factor', 1) > 1 + + # User requested unrolled implementation of Dense + is_unrolled = node.get_attr('strategy', 'latency') == 'unrolled' + + return isinstance(node, layers_with_dense) and weights_transposed and rf_gt_one and is_unrolled + + def transform(self, model, node): + if isinstance(node, (LSTM, GRU)): + n_in, n_out, n_in_recr, n_out_recr = node.model.config.backend.get_layer_mult_size(node) + + reuse_factor = node.get_attr('reuse_factor') + weights = node.weights['weight'] + code_str = self._generate_unrolled_function(n_in, n_out, reuse_factor, weights, str(node.index) + '_1') + code_str = self._add_backend_specific_pragmas_to_generated_code(code_str, model.config.backend) + node.set_attr('unrolled_dense_resource_codegen_1', Source(code_str)) + + recr_reuse_factor = node.get_attr('recurrent_reuse_factor') + recr_weights = node.weights['recurrent_weight'] + code_str = self._generate_unrolled_function( + n_in_recr, n_out_recr, recr_reuse_factor, recr_weights, str(node.index) + '_2' + ) + code_str = self._add_backend_specific_pragmas_to_generated_code(code_str, model.config.backend) + node.set_attr('unrolled_dense_resource_codegen_2', Source(code_str)) + + else: + n_in, n_out = node.model.config.backend.get_layer_mult_size(node) + reuse_factor = node.get_attr('reuse_factor') + weights = node.weights['weight'] + + code_str = self._generate_unrolled_function(n_in, n_out, reuse_factor, weights, node.index) + code_str = self._add_backend_specific_pragmas_to_generated_code(code_str, model.config.backend) + node.set_attr('unrolled_dense_resource_codegen', Source(code_str)) + + def _generate_unrolled_function(self, n_in, n_out, reuse_factor, weights, function_suffix): + """ + Generate a C++ function that mimics the Dense Resource implementation. + + The HLS compiler produces suboptimal designs for Dense Resource when the weights processed by the same DSP are zero. + Latency strategy can optimize zero multiplications + Resource strategy, on the other hand, cannot. + When all the weights in the same BRAM block are zero, Vivado is unable to optimize it + With this (and additional TCL scripts) zero BRAM are optimized + + Args: + node: Layer to generate code for + Returns: + generated_code: Generated C++ function (string) + """ + + # Variable instantiation and function pragmas + generated_code = ( + 'template\n' + 'class dense_unrolled_{suffix} : public DenseKernel {{{{\n' + ' public:\n' + ' static void dense(\n' + ' data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_out],\n' + ' typename CONFIG_T::weight_t weights[CONFIG_T::n_in * CONFIG_T::n_out],\n' + ' typename CONFIG_T::bias_t biases[CONFIG_T::n_out]\n' + ' ) {{{{\n' + ' #pragma HLS pipeline II=CONFIG_T::reuse_factor\n' + '\n' + ' constexpr int block_factor = DIV_ROUNDUP(CONFIG_T::n_in * CONFIG_T::n_out, CONFIG_T::reuse_factor);\n' + ' #pragma HLS ARRAY_RESHAPE variable=weights block factor=block_factor\n' + ' {{weights_resource_pragma}}\n' + ' #pragma HLS ARRAY_PARTITION variable=biases complete\n' + '\n' + ' typename CONFIG_T::accum_t acc[CONFIG_T::n_out];\n' + ' #pragma HLS ARRAY_PARTITION variable=acc complete\n' + '\n' + ' InitAccum:\n' + ' for (int i = 0; i < CONFIG_T::n_out; i++) {{{{\n' + ' #pragma HLS UNROLL\n' + ' acc[i] = (typename CONFIG_T::accum_t) biases[i];\n' + ' }}}}\n' + '\n' + ).format(suffix=function_suffix) + + # Unrolled multiplication, according to the three cases + if reuse_factor <= n_in: + mult_code = self._generate_unrolled_mult_code_rf_leq_nin(n_in, n_out, reuse_factor, weights) + elif reuse_factor > n_in and reuse_factor % n_in == 0: + mult_code = self._generate_unrolled_mult_code_rf_gt_nin_rem0(n_in, n_out, reuse_factor, weights) + else: + # This case shouldn't happen if my understanding of RF is correct + # The function fpga_backend._validate_reuse_factor() has assertion rf % n_in == 0 or rf < n_in + raise Exception('Not implemented...') + + # Write output + generated_code += mult_code + '\n' + generated_code += ( + ' Result:\n' + ' for (int i = 0; i < CONFIG_T::n_out; i++) {{\n' + ' #pragma HLS UNROLL\n' + ' res[i] = cast(acc[i]);\n' + ' }}\n' + ' }}\n' + '}};\n' + ) + + return generated_code + + def _generate_unrolled_mult_code_rf_leq_nin(self, n_in, n_out, reuse_factor, weights): + # Function constants + mult_factor = min(n_in, reuse_factor) + block_factor = int(math.ceil(n_in * n_out / reuse_factor)) + mult_limit = int(math.ceil(n_in * n_out / mult_factor)) + mult_scale = mult_limit // n_out + + # Zero DSPs are the DSP blocks that always have zero input + # In this case, it is the number of rows in the transposed and reshaped weight matrix + # The new shape is (parallel_mult, reuse_factor) + zeros = np.sum(~weights.data.reshape(block_factor, reuse_factor).any(1)) + + # Used to pad the code to make it human-readable + indent = ' ' + + # Generate unrolled multiplications + mult_code = f'{indent*2}#pragma HLS ALLOCATION operation instances=mul limit={mult_limit - zeros}\n' + mult_code += f'{indent*2}MULT: {{{{\n' + + for ir in range(reuse_factor): + acc_step = 0 + out_index = 0 + w_index = ir + in_index = ir + + mult_code += f'{indent*3}M{ir}: {{{{\n' + for _ in range(block_factor): + if weights.data.flatten()[w_index] != 0: + mult_code += ( + f'{indent*4}acc[{out_index}] += ' + 'static_cast' + '(CONFIG_T::template product::' + f'product(data[{in_index}], weights[{w_index}]));\n' + ) + + w_index += reuse_factor + in_index += reuse_factor + if in_index >= n_in: + in_index = ir + if acc_step + 1 >= mult_scale: + acc_step = 0 + out_index += 1 + else: + acc_step += 1 + + mult_code += f'{indent*3}}}}}\n' + + mult_code += f'{indent*2}}}}}\n' + + return mult_code + + def _generate_unrolled_mult_code_rf_gt_nin_rem0(self, n_in, n_out, reuse_factor, weights): + # Function constants + mult_factor = min(n_in, reuse_factor) + block_factor = int(math.ceil(n_in * n_out / reuse_factor)) + mult_limit = int(math.ceil(n_in * n_out / mult_factor)) + + # Zero DSPs are the DSP blocks that always have zero input + # In this case, it is the number of rows in the transposed and reshaped weight matrix + # The new shape is (parallel_mult, reuse_factor) + zeros = np.sum(~weights.data.reshape(block_factor, reuse_factor).any(1)) + + # Used to pad the code to make it human-readable + indent = ' ' + + # Generate out indices + outidx = [0] * reuse_factor + outstep = 0 + outscale = reuse_factor // n_in + for ir in range(reuse_factor): + outidx[ir] = outstep + if (ir + 1) % n_in == 0: + outstep += 1 + + # Define variables + in_index = 0 + + # Generate unrolled multiplications + mult_code = f'{indent*2}#pragma HLS ALLOCATION operation instances=mul limit={mult_limit - zeros}\n' + mult_code += f'{indent*2}MULT: {{{{\n' + + for ir in range(reuse_factor): + w_index = ir + out_index = outidx[ir] + + mult_code += f'{indent*3}M{ir}: {{{{\n' + for _ in range(block_factor): + if weights.data.flatten()[w_index] != 0: + mult_code += ( + f'{indent*4}acc[{int(out_index)}] += ' + 'static_cast' + '(CONFIG_T::template product::' + f'product(data[{in_index}], weights[{w_index}]));\n' + ) + + w_index += reuse_factor + if w_index > n_in * n_out: + break + out_index += outscale + mult_code += f'{indent*3}}}}}\n' + + in_index += 1 + if in_index >= n_in: + in_index = 0 + + mult_code += f'{indent*2}}}}}\n' + + return mult_code + + def _add_backend_specific_pragmas_to_generated_code(self, code, backend): + if backend.name == 'Vivado': + weights_resource_pragma = '#pragma HLS RESOURCE variable=weights core=ROM_nP_BRAM' + elif backend.name == 'Vitis': + weights_resource_pragma = '#pragma HLS BIND_STORAGE variable=weights type=ROM_NP impl=BRAM' + else: + raise Exception(f'Unexpected backend {backend.name} in GenerateUnrolledDenseResource optimizer.') + + code = code.format(weights_resource_pragma=weights_resource_pragma) + + return code diff --git a/hls4ml/backends/vivado/vivado_backend.py b/hls4ml/backends/vivado/vivado_backend.py index 6c5deccc68..834dec9d5e 100644 --- a/hls4ml/backends/vivado/vivado_backend.py +++ b/hls4ml/backends/vivado/vivado_backend.py @@ -250,11 +250,22 @@ def init_dense(self, layer): index_t = layer.get_weights('weight').type.index_precision else: layer.set_attr('strategy', 'resource') - elif layer.model.config.get_strategy(layer).lower() == 'unrolled' and layer.get_attr('reuse_factor', 1) > 1: + elif layer.model.config.get_strategy(layer).lower() == 'unrolled': + use_resource_instead = False + if layer.get_attr('reuse_factor', 1) == 1: + print( + f'Unrolled strategy cannot be combined with reuse factor 1 in layer "{layer.name}". ' + 'Using "resource" strategy instead.' + ) + use_resource_instead = True n_in, n_out = self.get_layer_mult_size(layer) self.set_target_reuse_factor(layer) - self.set_closest_reuse_factor(layer, n_in, n_out) - layer.set_attr('strategy', 'unrolled') + if use_resource_instead: + self.set_closest_reuse_factor(layer, n_in, n_out) + layer.set_attr('strategy', 'resource') + else: + self.set_closest_reuse_factor(layer, n_in, n_out, include_max_rf=False) + layer.set_attr('strategy', 'unrolled') else: layer.set_attr('strategy', 'latency') layer.set_attr('index_t', NamedType(f'layer{layer.index}_index', index_t)) @@ -271,11 +282,28 @@ def init_conv1d(self, layer): n_in, n_out = self.get_layer_mult_size(layer) self.set_target_reuse_factor(layer) self.set_closest_reuse_factor(layer, n_in, n_out) - elif layer.model.config.get_strategy(layer).lower() == 'unrolled' and layer.get_attr('reuse_factor', 1) > 1: + elif layer.model.config.get_strategy(layer).lower() == 'unrolled': + use_resource_instead = False + if layer.get_attr('reuse_factor', 1) == 1: + print( + f'Unrolled strategy cannot be combined with reuse factor 1 in layer "{layer.name}".' + 'Using "resource" strategy instead.' + ) + use_resource_instead = True + elif layer.model.config.get_config_value('IOType') == 'io_parallel': + print( + f'Unrolled strategy cannot be combined with io_parallel in layer "{layer.name}". ' + 'Using "resource" strategy instead.' + ) + use_resource_instead = True n_in, n_out = self.get_layer_mult_size(layer) self.set_target_reuse_factor(layer) - self.set_closest_reuse_factor(layer, n_in, n_out) - layer.set_attr('strategy', 'unrolled') + if use_resource_instead: + self.set_closest_reuse_factor(layer, n_in, n_out) + layer.set_attr('strategy', 'resource') + else: + self.set_closest_reuse_factor(layer, n_in, n_out, include_max_rf=False) + layer.set_attr('strategy', 'unrolled') else: layer.set_attr('strategy', 'latency') @@ -335,11 +363,28 @@ def init_conv2d(self, layer): self.set_target_reuse_factor(layer) n_in, n_out = self.get_layer_mult_size(layer) self.set_closest_reuse_factor(layer, n_in, n_out) - elif layer.model.config.get_strategy(layer).lower() == 'unrolled' and layer.get_attr('reuse_factor', 1) > 1: + elif layer.model.config.get_strategy(layer).lower() == 'unrolled': + use_resource_instead = False + if layer.get_attr('reuse_factor', 1) == 1: + print( + f'Unrolled strategy cannot be combined with reuse factor 1 in layer "{layer.name}". ' + 'Using "resource" strategy instead.' + ) + use_resource_instead = True + elif layer.model.config.get_config_value('IOType') == 'io_parallel': + print( + f'Unrolled strategy cannot be combined with io_parallel in layer "{layer.name}". ' + 'Using "resource" strategy instead.' + ) + use_resource_instead = True n_in, n_out = self.get_layer_mult_size(layer) self.set_target_reuse_factor(layer) - self.set_closest_reuse_factor(layer, n_in, n_out) - layer.set_attr('strategy', 'unrolled') + if use_resource_instead: + self.set_closest_reuse_factor(layer, n_in, n_out) + layer.set_attr('strategy', 'resource') + else: + self.set_closest_reuse_factor(layer, n_in, n_out, include_max_rf=False) + layer.set_attr('strategy', 'unrolled') else: layer.set_attr('strategy', 'latency') @@ -459,11 +504,23 @@ def init_lstm(self, layer): self.set_closest_reuse_factor(layer, n_in, n_out) self.set_closest_reuse_factor(layer, n_in_recr, n_out_recr, attribute='recurrent_reuse_factor') layer.set_attr('strategy', 'resource') - elif layer.model.config.get_strategy(layer).lower() == 'unrolled' and layer.get_attr('reuse_factor', 1) > 1: + elif layer.model.config.get_strategy(layer).lower() == 'unrolled': + use_resource_instead = False + if layer.get_attr('reuse_factor', 1) == 1: + print( + f'Unrolled strategy cannot be combined with reuse factor 1 in layer "{layer.name}". ' + 'Using "resource" strategy instead.' + ) + use_resource_instead = True n_in, n_out, n_in_recr, n_out_recr = self.get_layer_mult_size(layer) - self.set_closest_reuse_factor(layer, n_in, n_out) - self.set_closest_reuse_factor(layer, n_in_recr, n_out_recr, attribute='recurrent_reuse_factor') - layer.set_attr('strategy', 'unrolled') + if use_resource_instead: + self.set_closest_reuse_factor(layer, n_in, n_out) + layer.set_attr('strategy', 'resource') + else: + self.set_closest_reuse_factor( + layer, n_in_recr, n_out_recr, attribute='recurrent_reuse_factor', include_max_rf=False + ) + layer.set_attr('strategy', 'unrolled') else: layer.set_attr('strategy', 'latency') @@ -482,11 +539,23 @@ def init_gru(self, layer): self.set_closest_reuse_factor(layer, n_in, n_out) self.set_closest_reuse_factor(layer, n_in_recr, n_out_recr, attribute='recurrent_reuse_factor') layer.set_attr('strategy', 'resource') - elif layer.model.config.get_strategy(layer).lower() == 'unrolled' and layer.get_attr('reuse_factor', 1) > 1: + elif layer.model.config.get_strategy(layer).lower() == 'unrolled': + use_resource_instead = False + if layer.get_attr('reuse_factor', 1) == 1: + print( + f'Unrolled strategy cannot be combined with reuse factor 1 in layer "{layer.name}". ' + 'Using "resource" strategy instead.' + ) + use_resource_instead = True n_in, n_out, n_in_recr, n_out_recr = self.get_layer_mult_size(layer) - self.set_closest_reuse_factor(layer, n_in, n_out) - self.set_closest_reuse_factor(layer, n_in_recr, n_out_recr, attribute='recurrent_reuse_factor') - layer.set_attr('strategy', 'unrolled') + if use_resource_instead: + self.set_closest_reuse_factor(layer, n_in, n_out) + layer.set_attr('strategy', 'resource') + else: + self.set_closest_reuse_factor( + layer, n_in_recr, n_out_recr, attribute='recurrent_reuse_factor', include_max_rf=False + ) + layer.set_attr('strategy', 'unrolled') else: layer.set_attr('strategy', 'latency') diff --git a/hls4ml/writer/vivado_writer.py b/hls4ml/writer/vivado_writer.py index 4202ba9700..ab691912be 100644 --- a/hls4ml/writer/vivado_writer.py +++ b/hls4ml/writer/vivado_writer.py @@ -13,13 +13,6 @@ class VivadoWriter(Writer): - def __get_max_reuse_factor(self, model): - max_rf = 0 - for layer in model.get_layers(): - rf = int(layer.get_attr('reuse_factor')) - if rf > max_rf: - max_rf = rf - return max_rf def print_array_to_cpp(self, var, odir, write_txt_file=True): """Write a weights array to C++ header files. @@ -181,11 +174,9 @@ def write_project_cpp(self, model): ) model_cfg = model.config.get_config_value('HLSConfig')['Model'] - if ( - 'DenseResourceImplementation' in model_cfg - and model_cfg['DenseResourceImplementation'].lower() == 'unrolled' - ): - newline += indent + f'#pragma HLS PIPELINE ii={self.__get_max_reuse_factor(model)} \n' + if model_cfg.get('Strategy', 'latency').lower() == 'unrolled': + max_rf = max([int(layer.get_attr('reuse_factor')) for layer in model.get_layers()]) + newline += indent + f'#pragma HLS PIPELINE II={max_rf} \n' else: if model.config.pipeline_style.lower() == 'dataflow': newline += indent + '#pragma HLS DATAFLOW \n' @@ -724,7 +715,7 @@ def write_tar(self, model): """ with tarfile.open(model.config.get_output_dir() + '.tar.gz', mode='w:gz') as archive: - archive.add(model.config.get_output_dir(), recursive=True) + archive.add(model.config.get_output_dir(), recursive=True, arcname='') def write_hls(self, model): print('Writing HLS project') From fbc4107948892d2332aac8e05349dc623a0d3fca Mon Sep 17 00:00:00 2001 From: Vladimir Loncar Date: Mon, 26 Aug 2024 00:54:10 +0200 Subject: [PATCH 10/14] Remove mentions of dense_resource_implementation --- hls4ml/backends/vivado/vivado_backend.py | 1 - hls4ml/model/graph.py | 20 -------------------- 2 files changed, 21 deletions(-) diff --git a/hls4ml/backends/vivado/vivado_backend.py b/hls4ml/backends/vivado/vivado_backend.py index fb7377a655..b1e2ffddd8 100644 --- a/hls4ml/backends/vivado/vivado_backend.py +++ b/hls4ml/backends/vivado/vivado_backend.py @@ -292,7 +292,6 @@ def init_dense(self, layer): else: layer.set_attr('strategy', 'latency') layer.set_attr('index_t', NamedType(f'layer{layer.index}_index', index_t)) - layer.set_attr('dense_resource_implementation', layer.model.config.get_dense_resource_implementation(layer).lower()) # TODO consolidate these functions into a single `init_conv` @layer_optimizer(Conv1D) diff --git a/hls4ml/model/graph.py b/hls4ml/model/graph.py index 0e2c3d33ae..d0a1fdf7fc 100644 --- a/hls4ml/model/graph.py +++ b/hls4ml/model/graph.py @@ -43,10 +43,6 @@ def __init__(self, config): self.layer_type_conv_implementation = {} self.layer_name_conv_implementation = {} - self.model_dense_resource_implementation = 'Standard' - self.layer_type_dense_resource_implementation = {} - self.layer_name_dense_resource_implementation = {} - self.model_compression = False self.layer_type_compression = {} self.layer_name_compression = {} @@ -190,17 +186,6 @@ def get_conv_implementation(self, layer): return conv_implementation - def get_dense_resource_implementation(self, layer): - dense_resource_implementation = self.layer_name_dense_resource_implementation.get(layer.name.lower()) - if dense_resource_implementation is None: - dense_resource_implementation = self.layer_type_dense_resource_implementation.get( - layer.__class__.__name__.lower() - ) - if dense_resource_implementation is None: - dense_resource_implementation = self.model_dense_resource_implementation - - return dense_resource_implementation - def is_resource_strategy(self, layer): return self.get_strategy(layer).lower() == 'resource' @@ -280,7 +265,6 @@ def _parse_hls_config(self): self.model_rf = model_cfg.get('ReuseFactor') self.model_targ_cycles = model_cfg.get('TargetCycles') self.model_conv_implementation = model_cfg.get('ConvImplementation', 'LineBuffer') - self.model_dense_resource_implementation = model_cfg.get('DenseResourceImplementation', 'Standard') self.model_strategy = model_cfg.get('Strategy', 'Latency') self.model_compression = bool(model_cfg.get('Compression', 0)) self.pipeline_style = model_cfg.get('PipelineStyle', 'pipeline') @@ -311,10 +295,6 @@ def _parse_hls_config(self): if conv_implementation is not None: self.layer_type_conv_implementation[layer_type.lower()] = conv_implementation - dense_resource_implementation = layer_cfg.get('DenseResourceImplementation') - if conv_implementation is not None: - self.layer_type_dense_resource_implementation[layer_type.lower()] = dense_resource_implementation - compression = layer_cfg.get('Compression') if compression is not None: self.layer_type_compression[layer_type.lower()] = bool(compression) From ecda5c946e6757b68579b590d09e9a9e6e0f3ac5 Mon Sep 17 00:00:00 2001 From: Vladimir Loncar Date: Mon, 26 Aug 2024 01:02:13 +0200 Subject: [PATCH 11/14] Default to 'auto' for pipeline style and move check to an optimizer --- .../backends/vivado/passes/pipeline_style.py | 131 ++++++++++++++++++ hls4ml/backends/vivado/vivado_backend.py | 14 +- hls4ml/model/graph.py | 51 +------ hls4ml/writer/vivado_writer.py | 22 +-- test/pytest/test_dense_unrolled.py | 6 +- test/pytest/test_pipeline_style.py | 99 +++++++++++++ 6 files changed, 255 insertions(+), 68 deletions(-) create mode 100644 hls4ml/backends/vivado/passes/pipeline_style.py create mode 100755 test/pytest/test_pipeline_style.py diff --git a/hls4ml/backends/vivado/passes/pipeline_style.py b/hls4ml/backends/vivado/passes/pipeline_style.py new file mode 100644 index 0000000000..326745e455 --- /dev/null +++ b/hls4ml/backends/vivado/passes/pipeline_style.py @@ -0,0 +1,131 @@ +from hls4ml.model.layers import Conv1D, Conv2D +from hls4ml.model.optimizer import ModelOptimizerPass + + +class SetPipelineStyle(ModelOptimizerPass): + def __init__(self): + pass + + def transform(self, model): + if model.config.pipeline_style not in ['auto', 'pipeline', 'dataflow']: + print( + f'WARNING: Pipeline style set to {model.config.pipeline_style}, valid values: auto, pipeline, dataflow. ' + 'Using "auto".' + ) + self._set_pipeline_style(model, 'auto') + + if model.config.pipeline_style is None or model.config.pipeline_style == 'auto': + + if self._maybe_set_dataflow_io_stream(model): + return True + + if self._maybe_set_dataflow_conv_layers(model): + return True + + if self._maybe_set_dataflow_resource_strategy(model): + return True + + if self._maybe_set_pipeline_unrolled_strategy(model): + return True + + if self._maybe_set_pipeline_io_parallel(model): + return True + + self._set_safe_default_dataflow(model) + return True + else: + self._validate_hls_config(model) + + return False # No model changes made + + def _set_pipeline_style(self, model, pipeline_style): + # Could add logging here + model.config.pipeline_style = pipeline_style + + def _maybe_set_dataflow_io_stream(self, model): + if model.config.get_config_value('IOType') == 'io_stream': + self._set_pipeline_style(model, 'dataflow') + return True + + return False + + def _maybe_set_dataflow_conv_layers(self, model): + for layer in model.get_layers(): + if isinstance(layer, (Conv1D, Conv2D)): + self._set_pipeline_style(model, 'dataflow') + return True + + return False + + def _maybe_set_dataflow_resource_strategy(self, model): + for layer in model.get_layers(): + if model.config.is_resource_strategy(layer): + self._set_pipeline_style(model, 'dataflow') + return True + + return False + + def _maybe_set_pipeline_unrolled_strategy(self, model): + have_unrolled = False + for layer in model.get_layers(): + if model.config.get_strategy(layer).lower() == 'unrolled': + self._set_pipeline_style(model, 'pipeline') + have_unrolled = True + break + + if have_unrolled: + model.config.pipeline_ii = max([int(layer.get_attr('reuse_factor')) for layer in model.get_layers()]) + + return have_unrolled + + def _maybe_set_pipeline_io_parallel(self, model): + if model.config.get_config_value('IOType') == 'io_parallel': + self._set_pipeline_style(model, 'pipeline') + return True + + return False + + def _set_safe_default_dataflow(self, model): + print( + 'WARNING: Couldn\'t determine best pipeline style, defaulting to "DATAFLOW". ' + 'Use "PipelineStyle" property to override.' + ) + self._set_pipeline_style(model, 'dataflow') + + def _validate_hls_config(self, model): + if model.config.pipeline_style.lower() == 'pipeline': + if model.config.model_compression: + print('WARNING: Compression enabled while pipeline style set to "pipeline".') + if model.config.model_strategy.lower() == 'resource': + print( + 'WARNING: Model strategy "Resource" will lead to bad QoR in combination ' + 'with pipeline style set to "pipeline".' + ) + if any(isinstance(layer, (Conv1D, Conv2D)) for layer in model.get_layers()): + print('WARNING: Convolution layers require "dataflow" pipeline style.') + for layer_type, strategy in model.config.layer_type_strategy.items(): + if strategy.lower() == 'resource' and model.config.pipeline_style.lower() == 'pipeline': + print( + f'WARNING: Strategy for layer type {layer_type} set to "Resource", while pipeline style set to ' + '"pipeline". This will lead to bad QoR.' + ) + + for layer_name, strategy in model.config.layer_name_strategy.items(): + if strategy.lower() == 'resource' and model.config.pipeline_style.lower() == 'pipeline': + print( + 'WARNING: Strategy for layer {} set to "Resource", while pipeline style set to "pipeline".'.format( + layer_name + ) + ) + + for layer_type, compression in model.config.layer_type_compression.items(): + if compression and model.config.pipeline_style.lower() == 'pipeline': + print( + 'WARNING: Compression enabled for layer type {}, while pipeline style set to "pipeline".'.format( + layer_type + ) + ) + + for layer_name, compression in model.config.layer_name_compression.items(): + if compression and model.config.pipeline_style.lower() == 'pipeline': + print(f'WARNING: Compression enabled for layer {layer_name}, while pipeline style set to "pipeline".') diff --git a/hls4ml/backends/vivado/vivado_backend.py b/hls4ml/backends/vivado/vivado_backend.py index b1e2ffddd8..17fd994598 100644 --- a/hls4ml/backends/vivado/vivado_backend.py +++ b/hls4ml/backends/vivado/vivado_backend.py @@ -114,6 +114,7 @@ def _register_flows(self): 'vivado:apply_resource_strategy', 'vivado:generate_conv_im2col', 'vivado:generate_unrolled_dense_resource', + 'vivado:set_pipeline_style', ] vivado_types_flow = register_flow('specific_types', vivado_types, requires=[init_flow], backend=self.name) @@ -247,11 +248,6 @@ def build( return parse_vivado_report(model.config.get_output_dir()) - def _validate_conv_strategy(self, layer): - if layer.model.config.pipeline_style.lower() != 'dataflow': - print(f'WARNING: Layer {layer.name} requires "dataflow" pipeline style. Switching to "dataflow" pipeline style.') - layer.model.config.pipeline_style = 'dataflow' - @layer_optimizer(Layer) def init_base_layer(self, layer): reuse_factor = layer.model.config.get_reuse_factor(layer) @@ -356,8 +352,6 @@ def init_conv1d(self, layer): layer.set_attr('implementation', layer.model.config.get_conv_implementation(layer).lower()) - self._validate_conv_strategy(layer) - @layer_optimizer(SeparableConv1D) def init_sepconv1d(self, layer): if layer.model.config.is_resource_strategy(layer): @@ -480,8 +474,6 @@ def init_conv2d(self, layer): layer.set_attr('implementation', layer.model.config.get_conv_implementation(layer).lower()) - self._validate_conv_strategy(layer) - @layer_optimizer(SeparableConv2D) def init_sepconv2d(self, layer): if layer.model.config.is_resource_strategy(layer): @@ -585,8 +577,10 @@ def init_lstm(self, layer): n_in, n_out, n_in_recr, n_out_recr = self.get_layer_mult_size(layer) if use_resource_instead: self.set_closest_reuse_factor(layer, n_in, n_out) + self.set_closest_reuse_factor(layer, n_in_recr, n_out_recr, attribute='recurrent_reuse_factor') layer.set_attr('strategy', 'resource') else: + self.set_closest_reuse_factor(layer, n_in, n_out, include_max_rf=False) self.set_closest_reuse_factor( layer, n_in_recr, n_out_recr, attribute='recurrent_reuse_factor', include_max_rf=False ) @@ -617,8 +611,10 @@ def init_gru(self, layer): n_in, n_out, n_in_recr, n_out_recr = self.get_layer_mult_size(layer) if use_resource_instead: self.set_closest_reuse_factor(layer, n_in, n_out) + self.set_closest_reuse_factor(layer, n_in_recr, n_out_recr, attribute='recurrent_reuse_factor') layer.set_attr('strategy', 'resource') else: + self.set_closest_reuse_factor(layer, n_in, n_out, include_max_rf=False) self.set_closest_reuse_factor( layer, n_in_recr, n_out_recr, attribute='recurrent_reuse_factor', include_max_rf=False ) diff --git a/hls4ml/model/graph.py b/hls4ml/model/graph.py index d0a1fdf7fc..609417f94a 100644 --- a/hls4ml/model/graph.py +++ b/hls4ml/model/graph.py @@ -49,7 +49,8 @@ def __init__(self, config): self.trace_output = self.get_config_value('TraceOutput', False) - self.pipeline_style = 'pipeline' + self.pipeline_style = 'auto' + self.pipeline_ii = None if 'WriterConfig' in self.config: self.writer_config = self.config['WriterConfig'] @@ -61,7 +62,6 @@ def __init__(self, config): } self._parse_hls_config() - self._validate_hls_config() def get_config_value(self, key, default=None): return self.config.get(key, default) @@ -267,7 +267,8 @@ def _parse_hls_config(self): self.model_conv_implementation = model_cfg.get('ConvImplementation', 'LineBuffer') self.model_strategy = model_cfg.get('Strategy', 'Latency') self.model_compression = bool(model_cfg.get('Compression', 0)) - self.pipeline_style = model_cfg.get('PipelineStyle', 'pipeline') + self.pipeline_style = model_cfg.get('PipelineStyle', 'auto') + self.pipeline_ii = model_cfg.get('PipelineInterval', None) layer_type_cfg = hls_config.get('LayerType') if layer_type_cfg is not None: @@ -304,50 +305,6 @@ def _parse_hls_config(self): for layer_name, layer_cfg in layer_name_cfg.items(): self.parse_name_config(layer_name, layer_cfg) - def _validate_hls_config(self): - use_dataflow = False - if self.pipeline_style.lower() == 'pipeline' and self.model_compression: - print('WARNING: Compression enabled while pipeline style set to "pipeline".') - use_dataflow = True - for layer_type, strategy in self.layer_type_strategy.items(): - if strategy.lower() == 'resource' and self.pipeline_style.lower() == 'pipeline': - print( - 'WARNING: Strategy for layer type {} set to "Resource", while pipeline style set to "pipeline".'.format( - layer_type - ) - ) - use_dataflow = True - - for layer_name, strategy in self.layer_name_strategy.items(): - if strategy.lower() == 'resource' and self.pipeline_style.lower() == 'pipeline': - print( - 'WARNING: Strategy for layer {} set to "Resource", while pipeline style set to "pipeline".'.format( - layer_name - ) - ) - use_dataflow = True - - for layer_type, compression in self.layer_type_compression.items(): - if compression and self.pipeline_style.lower() == 'pipeline': - print( - 'WARNING: Compression enabled for layer type {}, while pipeline style set to "pipeline".'.format( - layer_type - ) - ) - use_dataflow = True - - for layer_name, compression in self.layer_name_compression.items(): - if compression and self.pipeline_style.lower() == 'pipeline': - print(f'WARNING: Compression enabled for layer {layer_name}, while pipeline style set to "pipeline".') - use_dataflow = True - - if self.model_strategy.lower() == 'resource': - use_dataflow = True - - if use_dataflow: - print('WARNING: Changing pipeline style to "dataflow".') - self.pipeline_style = 'dataflow' - class ModelGraph: """The ModelGraph represents the network that is being processed by hls4ml. diff --git a/hls4ml/writer/vivado_writer.py b/hls4ml/writer/vivado_writer.py index fae3984e5f..1adee08093 100644 --- a/hls4ml/writer/vivado_writer.py +++ b/hls4ml/writer/vivado_writer.py @@ -199,7 +199,15 @@ def write_project_cpp(self, model): all_inputs = [i.name for i in model_inputs] all_outputs = [o.name for o in model_outputs] all_brams = [b.name for b in model_brams] - io_type = model.config.get_config_value("IOType") + io_type = model.config.get_config_value('IOType') + + pipeline_style = model.config.pipeline_style + pipeline_ii = model.config.pipeline_ii + pipeline_pragma = indent + f'#pragma HLS {pipeline_style.upper()}' + if pipeline_style == 'pipeline' and pipeline_ii is not None: + pipeline_pragma += f' II={pipeline_ii}\n' + else: + pipeline_pragma += '\n' if io_type == 'io_parallel': for i in model_inputs: @@ -211,23 +219,15 @@ def write_project_cpp(self, model): newline += indent + '#pragma HLS INTERFACE ap_vld port={},{} \n'.format( ','.join(all_inputs), ','.join(all_outputs) ) + newline += pipeline_pragma - model_cfg = model.config.get_config_value('HLSConfig')['Model'] - if model_cfg.get('Strategy', 'latency').lower() == 'unrolled': - max_rf = max([int(layer.get_attr('reuse_factor')) for layer in model.get_layers()]) - newline += indent + f'#pragma HLS PIPELINE II={max_rf} \n' - else: - if model.config.pipeline_style.lower() == 'dataflow': - newline += indent + '#pragma HLS DATAFLOW \n' - else: - newline += indent + '#pragma HLS PIPELINE \n' if io_type == 'io_stream': newline += indent + '#pragma HLS INTERFACE axis port={},{} \n'.format( ','.join(all_inputs), ','.join(all_outputs) ) if all_brams: newline += indent + '#pragma HLS INTERFACE bram port={} \n'.format(','.join(all_brams)) - newline += indent + '#pragma HLS DATAFLOW \n' + newline += pipeline_pragma elif '// hls-fpga-machine-learning insert layers' in line: newline = line + '\n' diff --git a/test/pytest/test_dense_unrolled.py b/test/pytest/test_dense_unrolled.py index 6b7503c543..69d024b84f 100644 --- a/test/pytest/test_dense_unrolled.py +++ b/test/pytest/test_dense_unrolled.py @@ -84,7 +84,7 @@ def test_dense_unrolled_streaming_conv(dim, io_type, reuse_factor): @pytest.mark.parametrize('backend', ['Vitis', 'Vivado']) @pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream']) @pytest.mark.parametrize('static', [True, False]) -@pytest.mark.parametrize('reuse_factor', [1, 4, 32, 128]) # These should be enough +@pytest.mark.parametrize('reuse_factor', [1, 4, 32, 128]) # RF=128 also tests if setting closest RF works well def test_rnn_unrolled(rnn_layer, backend, io_type, static, reuse_factor): # Subtract 0.5 to include negative values input_shape = (12, 8) @@ -118,6 +118,10 @@ def test_rnn_unrolled(rnn_layer, backend, io_type, static, reuse_factor): hls_model = convert_from_keras_model( keras_model, hls_config=hls_config, output_dir=output_dir, backend=backend, io_type=io_type ) + + # Check if strategy was not overridden + assert list(hls_model.get_layers())[1].get_attr('strategy') == 'unrolled' if reuse_factor > 1 else 'latency' + hls_model.compile() keras_prediction = keras_model.predict(X) diff --git a/test/pytest/test_pipeline_style.py b/test/pytest/test_pipeline_style.py new file mode 100755 index 0000000000..f8706fa52c --- /dev/null +++ b/test/pytest/test_pipeline_style.py @@ -0,0 +1,99 @@ +""" Test that pipeline style is properly handled by optimizers (respected if user-defined, correctly set if 'auto'). """ + +from pathlib import Path + +import pytest +import tensorflow as tf + +import hls4ml + +test_root_path = Path(__file__).parent + + +@pytest.mark.parametrize('backend', ['Vivado', 'Vitis']) +@pytest.mark.parametrize( + 'param_group, pipeline_style, io_type, strategy, ii', + [ + (1, 'auto', 'io_stream', 'resource', None), # io_stream should result in DATAFLOW pragma regardless of other params + (2, 'auto', 'io_stream', 'latency', None), + (3, None, 'io_stream', 'unrolled', None), # None should be interpreted as 'auto' + (4, 'auto', 'io_parallel', 'resource', None), # Should end up with DATAFLOW pragma + (5, 'auto', 'io_parallel', 'latency', None), # Should end up with PIPELINE pragma + (6, 'auto', 'io_parallel', 'unrolled', None), # Should end up with PIPELINE pragma and II + (7, 'pipeline', 'io_stream', 'resource', None), # Should result in a warning + (8, 'pipeline', 'io_parallel', 'resource', None), # Should result in a warning + (9, 'pipeline', 'io_parallel', 'latency', None), # No warning + (10, 'pipeline', 'io_parallel', 'latency', 10), # No warning, should include II=10 + (11, 'dataflow', 'io_stream', 'latency', None), # No warning + (12, 'dataflow', 'io_parallel', 'latency', None), # No warning + (13, 'dataflow', 'io_parallel', 'latency', None), # No warning + (14, 'wrong', 'io_parallel', 'latency', None), # Incorrect settings should issue a warning and switch to 'auto' + (15, 'auto', 'io_parallel', 'resource', None), # Special case to test Conv layer. No warning + (16, 'pipeline', 'io_parallel', 'resource', None), # Special case to test Conv layer. Should result in two warnings + ], +) +def test_pipeline_style(capfd, backend, param_group, pipeline_style, io_type, strategy, ii): + def _check_top_hls_pragma(model, pragma, ii=None): + assert model.config.pipeline_style == pragma + + pragma_to_check = f'#pragma HLS {pragma.upper()}' + if ii is not None: + pragma_to_check += f' II={ii}' + + with open(model.config.get_output_dir() + '/firmware/myproject.cpp') as main_file: + contents = main_file.readlines() + for line in contents: + if pragma_to_check in line: + return True + + return False + + if param_group in [15, 16]: + model = tf.keras.models.Sequential([tf.keras.layers.Conv1D(8, 2, input_shape=(10, 4))]) + else: + model = tf.keras.models.Sequential([tf.keras.layers.Dense(8, input_shape=(10,))]) + + config = hls4ml.utils.config_from_keras_model(model) + if pipeline_style is not None: + config['Model']['PipelineStyle'] = pipeline_style + if ii is not None: + config['Model']['PipelineInterval'] = ii + config['Model']['Strategy'] = strategy + config['Model']['ReuseFactor'] = 2 + + prj_name = f'hls4mlprj_pipeline_style_{backend}_{param_group}' + output_dir = str(test_root_path / prj_name) + hls_model = hls4ml.converters.convert_from_keras_model( + model, hls_config=config, output_dir=output_dir, io_type=io_type, backend=backend + ) + hls_model.write() + + captured_warnings = [line for line in capfd.readouterr().out.split('\n') if line.startswith('WARNING')] + + if param_group in [1, 2, 3, 4]: + assert _check_top_hls_pragma(hls_model, 'dataflow') + elif param_group == 5: + assert _check_top_hls_pragma(hls_model, 'pipeline') + elif param_group == 6: + assert _check_top_hls_pragma(hls_model, 'pipeline', ii=2) + elif param_group in [7, 8]: + assert _check_top_hls_pragma(hls_model, 'pipeline') + assert any('bad QoR' in warning for warning in captured_warnings) + elif param_group == 9: + assert _check_top_hls_pragma(hls_model, 'pipeline') + assert len(captured_warnings) == 0 + elif param_group == 10: + assert _check_top_hls_pragma(hls_model, 'pipeline', ii=ii) + assert len(captured_warnings) == 0 + elif param_group in [11, 12, 13]: + assert _check_top_hls_pragma(hls_model, 'dataflow') + assert len(captured_warnings) == 0 + elif param_group == 14: + assert _check_top_hls_pragma(hls_model, 'pipeline') + assert any('Using "auto"' in warning for warning in captured_warnings) + elif param_group == 15: + assert _check_top_hls_pragma(hls_model, 'dataflow') + elif param_group == 16: + assert _check_top_hls_pragma(hls_model, 'pipeline') + assert any('bad QoR' in warning for warning in captured_warnings) + assert any('Convolution' in warning for warning in captured_warnings) From ce8431d51a8d96307cd1016c2eafd1150a33d498 Mon Sep 17 00:00:00 2001 From: Vladimir Loncar Date: Mon, 26 Aug 2024 01:27:03 +0200 Subject: [PATCH 12/14] Pimp the docs a bit --- docs/advanced/model_optimization.rst | 4 ++-- docs/api/configuration.rst | 5 ++++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/docs/advanced/model_optimization.rst b/docs/advanced/model_optimization.rst index a75224b8cc..c1396b3d20 100644 --- a/docs/advanced/model_optimization.rst +++ b/docs/advanced/model_optimization.rst @@ -130,5 +130,5 @@ Note, to ensure DSPs are optimized, "unrolled" Dense multiplication must be used .. code-block:: Python hls_config = config_from_keras_model(optimized_model) - hls_config['Model']['DenseResourceImplementation'] = 'Unrolled' - # Any addition hls4ml config, such as strategy, reuse factor etc... + hls_config['Model']['Strategy'] = 'Unrolled' + # Any addition hls4ml config, reuse factor etc... diff --git a/docs/api/configuration.rst b/docs/api/configuration.rst index 091f88e619..9303e652c9 100644 --- a/docs/api/configuration.rst +++ b/docs/api/configuration.rst @@ -103,7 +103,10 @@ For Vivado backend the options are: * **IOType**\ : your options are ``io_parallel`` or ``io_stream`` which defines the type of data structure used for inputs, intermediate activations between layers, and outputs. For ``io_parallel``, arrays are used that, in principle, can be fully unrolled and are typically implemented in RAMs. For ``io_stream``, HLS streams are used, which are a more efficient/scalable mechanism to represent data that are produced and consumed in a sequential manner. Typically, HLS streams are implemented with FIFOs instead of RAMs. For more information see `here `__. * **HLSConfig**\: the detailed configuration of precision and parallelism, including: * **ReuseFactor**\ : in the case that you are pipelining, this defines the pipeline interval or initiation interval - * **Strategy**\ : Optimization strategy on FPGA, either "Latency" or "Resource". If none is supplied then hl4ml uses "Latency" as default. Note that a reuse factor larger than 1 should be specified when using "resource" strategy. An example of using larger reuse factor can be found `here. `__ + * **ParallelizationFactor**\ : The number of output "pixels" to compute in parallel in convolutional layers. Increasing this parameter results in significant increase in resources required on the FPGA. + * **Strategy**\ : Optimization strategy on FPGA, either "Latency", "Resource" or "Unrolled". If none is supplied then hl4ml uses "Latency" as default. Note that a reuse factor larger than 1 should be specified when using "resource" or "unrolled" strategy. An example of using larger reuse factor can be found `here. `__ + * **PipelineStyle**\ : Set the top level pipeline style. Valid options are "auto", "pipeline" and "dataflow". If unspecified, it defaults to "auto". + * **PipelineInterval**\ : Optionally override the desired initiation interval of the design. Only valid in combination with "pipeline" style. If unspecified, it is left to the compiler to decide, ideally matching the largest reuse factor of the network. * **Precision**\ : this defines the precsion of your inputs, outputs, weights and biases. It is denoted by ``ap_fixed``\ , where ``Y`` is the number of bits representing the signed number above the binary point (i.e. the integer part), and ``X`` is the total number of bits. Additionally, integers in fixed precision data type (\ ``ap_int``\ , where ``N`` is a bit-size from 1 to 1024) can also be used. You have a chance to further configure this more finely with per-layer configuration described below. From c4af46af9835b0d2f9c791f2c539e2a30e04f87c Mon Sep 17 00:00:00 2001 From: Vladimir Loncar Date: Mon, 7 Oct 2024 10:41:34 +0200 Subject: [PATCH 13/14] Rename "unrolled" -> "resource_unrolled" --- hls4ml/backends/vitis/passes/feature_check.py | 8 ++-- hls4ml/backends/vitis/vitis_backend.py | 2 +- .../vivado/passes/convolution_templates.py | 8 ++-- .../backends/vivado/passes/core_templates.py | 4 +- .../backends/vivado/passes/pipeline_style.py | 6 +-- .../vivado/passes/recurrent_templates.py | 8 ++-- .../vivado/passes/resource_strategy.py | 2 +- .../vivado/passes/unrolled_codegen.py | 12 ++--- hls4ml/backends/vivado/vivado_backend.py | 34 +++++++------- hls4ml/model/graph.py | 9 ++-- .../templates/vivado/nnet_utils/nnet_common.h | 2 +- .../vivado/nnet_utils/nnet_conv1d_stream.h | 2 +- .../vivado/nnet_utils/nnet_conv2d_stream.h | 2 +- hls4ml/utils/string_utils.py | 3 +- test/pytest/test_dense_unrolled.py | 46 +++++++++++++------ test/pytest/test_pipeline_style.py | 4 +- 16 files changed, 86 insertions(+), 66 deletions(-) mode change 100755 => 100644 test/pytest/test_pipeline_style.py diff --git a/hls4ml/backends/vitis/passes/feature_check.py b/hls4ml/backends/vitis/passes/feature_check.py index 7f0b832765..a38f6581f6 100644 --- a/hls4ml/backends/vitis/passes/feature_check.py +++ b/hls4ml/backends/vitis/passes/feature_check.py @@ -35,17 +35,17 @@ def transform(self, model, node): ) -class ValidateUnrolledStrategy(OptimizerPass): +class ValidateResourceUnrolledStrategy(OptimizerPass): _unrolled_layer_cls = ['Conv1D', 'Conv2D', 'Dense', 'GRU', 'LSTM'] def match(self, node): is_unrolled_layer = len([layer_cls for layer_cls in self._unrolled_layer_cls if layer_cls in node.class_name]) > 0 - is_unrolled_strategy = node.get_attr('strategy', 'latency').lower() == 'unrolled' + is_unrolled_strategy = node.get_attr('strategy', 'latency').lower() == 'resource_unrolled' return is_unrolled_layer and is_unrolled_strategy def transform(self, model, node): print( - f'WARNING: "Unrolled" strategy in "{node.name}" ({node.class_name}) may have unexpected II in Vitis backend.\n' - 'Verify that the final design satisfies the latency/II constraints.' + f'WARNING: "ResourceUnrolled" strategy in "{node.name}" ({node.class_name}) may have unexpected II in' + 'Vitis backend.\nVerify that the final design satisfies the latency/II constraints.' ) diff --git a/hls4ml/backends/vitis/vitis_backend.py b/hls4ml/backends/vitis/vitis_backend.py index c9fd452619..0110f78313 100644 --- a/hls4ml/backends/vitis/vitis_backend.py +++ b/hls4ml/backends/vitis/vitis_backend.py @@ -16,7 +16,7 @@ def _register_flows(self): validation_passes = [ 'vitis:validate_conv_implementation', 'vitis:validate_resource_strategy', - 'vitis:validate_unrolled_strategy', + 'vitis:validate_resource_unrolled_strategy', ] validation_flow = register_flow('validation', validation_passes, requires=['vivado:init_layers'], backend=self.name) diff --git a/hls4ml/backends/vivado/passes/convolution_templates.py b/hls4ml/backends/vivado/passes/convolution_templates.py index 6b13319174..dd77bee85e 100644 --- a/hls4ml/backends/vivado/passes/convolution_templates.py +++ b/hls4ml/backends/vivado/passes/convolution_templates.py @@ -111,8 +111,8 @@ def format(self, node): else: mult_params['dense_function'] = 'DenseResource_rf_gt_nin_rem0' # The 3rd case is never used - elif node.get_attr('strategy').lower() == 'unrolled': - mult_params['dense_function'] = f'dense_unrolled_{node.index}' + elif node.get_attr('strategy').lower() == 'resource_unrolled': + mult_params['dense_function'] = f'dense_resource_unrolled_{node.index}' mult_config = self.mult_template.format(**mult_params) @@ -236,8 +236,8 @@ def format(self, node): else: mult_params['dense_function'] = 'DenseResource_rf_gt_nin_rem0' # The 3rd case is never used - elif node.get_attr('strategy').lower() == 'unrolled': - mult_params['dense_function'] = f'dense_unrolled_{node.index}' + elif node.get_attr('strategy').lower() == 'resource_unrolled': + mult_params['dense_function'] = f'dense_resource_unrolled_{node.index}' mult_config = self.mult_template.format(**mult_params) diff --git a/hls4ml/backends/vivado/passes/core_templates.py b/hls4ml/backends/vivado/passes/core_templates.py index 118ba41335..836da6e68a 100644 --- a/hls4ml/backends/vivado/passes/core_templates.py +++ b/hls4ml/backends/vivado/passes/core_templates.py @@ -51,8 +51,8 @@ def format(self, node): else: params['dense_function'] = 'DenseResource_rf_gt_nin_rem0' # The 3rd case is never used - elif node.get_attr('strategy').lower() == 'unrolled': - params['dense_function'] = f'dense_unrolled_{node.index}' + elif node.get_attr('strategy').lower() == 'resource_unrolled': + params['dense_function'] = f'dense_resource_unrolled_{node.index}' return self.template.format(**params) diff --git a/hls4ml/backends/vivado/passes/pipeline_style.py b/hls4ml/backends/vivado/passes/pipeline_style.py index 326745e455..66c2bbe71e 100644 --- a/hls4ml/backends/vivado/passes/pipeline_style.py +++ b/hls4ml/backends/vivado/passes/pipeline_style.py @@ -25,7 +25,7 @@ def transform(self, model): if self._maybe_set_dataflow_resource_strategy(model): return True - if self._maybe_set_pipeline_unrolled_strategy(model): + if self._maybe_set_pipeline_resource_unrolled_strategy(model): return True if self._maybe_set_pipeline_io_parallel(model): @@ -65,10 +65,10 @@ def _maybe_set_dataflow_resource_strategy(self, model): return False - def _maybe_set_pipeline_unrolled_strategy(self, model): + def _maybe_set_pipeline_resource_unrolled_strategy(self, model): have_unrolled = False for layer in model.get_layers(): - if model.config.get_strategy(layer).lower() == 'unrolled': + if model.config.get_strategy(layer).lower() == 'resource_unrolled': self._set_pipeline_style(model, 'pipeline') have_unrolled = True break diff --git a/hls4ml/backends/vivado/passes/recurrent_templates.py b/hls4ml/backends/vivado/passes/recurrent_templates.py index 6c4ee51cdb..939713af22 100644 --- a/hls4ml/backends/vivado/passes/recurrent_templates.py +++ b/hls4ml/backends/vivado/passes/recurrent_templates.py @@ -149,8 +149,8 @@ def format(self, node): else: mult_params1['dense_function'] = 'DenseResource_rf_gt_nin_rem0' # The 3rd case is never used - elif node.get_attr('strategy').lower() == 'unrolled': - mult_params1['dense_function'] = f'dense_unrolled_{node.index}_1' + elif node.get_attr('strategy').lower() == 'resource_unrolled': + mult_params1['dense_function'] = f'dense_resource_unrolled_{node.index}_1' if node.get_attr('return_sequences'): mult_params2['n_in'] = node.get_output_variable().shape[1] @@ -174,8 +174,8 @@ def format(self, node): else: mult_params2['dense_function'] = 'DenseResource_rf_gt_nin_rem0' # The 3rd case is never used - elif node.get_attr('strategy').lower() == 'unrolled': - mult_params2['dense_function'] = f'dense_unrolled_{node.index}_2' + elif node.get_attr('strategy').lower() == 'resource_unrolled': + mult_params2['dense_function'] = f'dense_resource_unrolled_{node.index}_2' mult_config1 = self.mult1_template.format(**mult_params1) mult_config2 = self.mult2_template.format(**mult_params2) diff --git a/hls4ml/backends/vivado/passes/resource_strategy.py b/hls4ml/backends/vivado/passes/resource_strategy.py index d65b0dc48e..0c06190f30 100644 --- a/hls4ml/backends/vivado/passes/resource_strategy.py +++ b/hls4ml/backends/vivado/passes/resource_strategy.py @@ -9,7 +9,7 @@ class ApplyResourceStrategy(OptimizerPass): def match(self, node): node_matches = isinstance(node, (Dense, Conv1D, SeparableConv1D, Conv2D, SeparableConv2D, LSTM, GRU)) - is_resource_strategy = node.get_attr('strategy', '').lower() in ['resource', 'unrolled'] + is_resource_strategy = node.get_attr('strategy', '').lower() in ['resource', 'resource_unrolled'] already_transformed = node.get_attr('_weights_transposed', False) is True return node_matches and is_resource_strategy and not already_transformed diff --git a/hls4ml/backends/vivado/passes/unrolled_codegen.py b/hls4ml/backends/vivado/passes/unrolled_codegen.py index 6fd6c584af..d901c77008 100644 --- a/hls4ml/backends/vivado/passes/unrolled_codegen.py +++ b/hls4ml/backends/vivado/passes/unrolled_codegen.py @@ -15,14 +15,14 @@ def match(self, node): # TODO - Extend (& test) for Separable Conv / Depthwise Conv / Recurrent layers layers_with_dense = (Dense, Conv1D, Conv2D, LSTM, GRU) - # Unrolled Dense mimicks the hardware implementation of Resource strategy -> apply after Resource optimizer + # Unrolled Dense mimics the hardware implementation of Resource strategy -> apply after Resource optimizer weights_transposed = node.get_attr('_weights_transposed', False) # RF = 1 will optimize DSPs anyway, so no need to unroll code rf_gt_one = node.get_attr('reuse_factor', 1) > 1 # User requested unrolled implementation of Dense - is_unrolled = node.get_attr('strategy', 'latency') == 'unrolled' + is_unrolled = node.get_attr('strategy', 'latency') == 'resource_unrolled' return isinstance(node, layers_with_dense) and weights_transposed and rf_gt_one and is_unrolled @@ -34,7 +34,7 @@ def transform(self, model, node): weights = node.weights['weight'] code_str = self._generate_unrolled_function(n_in, n_out, reuse_factor, weights, str(node.index) + '_1') code_str = self._add_backend_specific_pragmas_to_generated_code(code_str, model.config.backend) - node.set_attr('unrolled_dense_resource_codegen_1', Source(code_str)) + node.set_attr('resource_unrolled_dense_codegen_1', Source(code_str)) recr_reuse_factor = node.get_attr('recurrent_reuse_factor') recr_weights = node.weights['recurrent_weight'] @@ -42,7 +42,7 @@ def transform(self, model, node): n_in_recr, n_out_recr, recr_reuse_factor, recr_weights, str(node.index) + '_2' ) code_str = self._add_backend_specific_pragmas_to_generated_code(code_str, model.config.backend) - node.set_attr('unrolled_dense_resource_codegen_2', Source(code_str)) + node.set_attr('resource_unrolled_dense_codegen_2', Source(code_str)) else: n_in, n_out = node.model.config.backend.get_layer_mult_size(node) @@ -51,7 +51,7 @@ def transform(self, model, node): code_str = self._generate_unrolled_function(n_in, n_out, reuse_factor, weights, node.index) code_str = self._add_backend_specific_pragmas_to_generated_code(code_str, model.config.backend) - node.set_attr('unrolled_dense_resource_codegen', Source(code_str)) + node.set_attr('resource_unrolled_dense_codegen', Source(code_str)) def _generate_unrolled_function(self, n_in, n_out, reuse_factor, weights, function_suffix): """ @@ -72,7 +72,7 @@ def _generate_unrolled_function(self, n_in, n_out, reuse_factor, weights, functi # Variable instantiation and function pragmas generated_code = ( 'template\n' - 'class dense_unrolled_{suffix} : public DenseKernel {{{{\n' + 'class dense_resource_unrolled_{suffix} : public DenseKernel {{{{\n' ' public:\n' ' static void dense(\n' ' data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_out],\n' diff --git a/hls4ml/backends/vivado/vivado_backend.py b/hls4ml/backends/vivado/vivado_backend.py index 8fdc862287..9f8a5171d3 100644 --- a/hls4ml/backends/vivado/vivado_backend.py +++ b/hls4ml/backends/vivado/vivado_backend.py @@ -266,11 +266,11 @@ def init_dense(self, layer): index_t = layer.get_weights('weight').type.index_precision else: layer.set_attr('strategy', 'resource') - elif layer.model.config.get_strategy(layer).lower() == 'unrolled': + elif layer.model.config.get_strategy(layer).lower() == 'resource_unrolled': use_resource_instead = False if layer.get_attr('reuse_factor', 1) == 1: print( - f'Unrolled strategy cannot be combined with reuse factor 1 in layer "{layer.name}". ' + f'Unrolled resource strategy cannot be combined with reuse factor 1 in layer "{layer.name}". ' 'Using "resource" strategy instead.' ) use_resource_instead = True @@ -281,7 +281,7 @@ def init_dense(self, layer): layer.set_attr('strategy', 'resource') else: self.set_closest_reuse_factor(layer, n_in, n_out, include_max_rf=False) - layer.set_attr('strategy', 'unrolled') + layer.set_attr('strategy', 'resource_unrolled') else: layer.set_attr('strategy', 'latency') layer.set_attr('index_t', NamedType(f'layer{layer.index}_index', index_t)) @@ -297,17 +297,17 @@ def init_conv1d(self, layer): n_in, n_out = self.get_layer_mult_size(layer) self.set_target_reuse_factor(layer) self.set_closest_reuse_factor(layer, n_in, n_out) - elif layer.model.config.get_strategy(layer).lower() == 'unrolled': + elif layer.model.config.get_strategy(layer).lower() == 'resource_unrolled': use_resource_instead = False if layer.get_attr('reuse_factor', 1) == 1: print( - f'Unrolled strategy cannot be combined with reuse factor 1 in layer "{layer.name}".' + f'Unrolled resource strategy cannot be combined with reuse factor 1 in layer "{layer.name}".' 'Using "resource" strategy instead.' ) use_resource_instead = True elif layer.model.config.get_config_value('IOType') == 'io_parallel': print( - f'Unrolled strategy cannot be combined with io_parallel in layer "{layer.name}". ' + f'Unrolled resource strategy cannot be combined with io_parallel in layer "{layer.name}". ' 'Using "resource" strategy instead.' ) use_resource_instead = True @@ -318,7 +318,7 @@ def init_conv1d(self, layer): layer.set_attr('strategy', 'resource') else: self.set_closest_reuse_factor(layer, n_in, n_out, include_max_rf=False) - layer.set_attr('strategy', 'unrolled') + layer.set_attr('strategy', 'resource_unrolled') else: layer.set_attr('strategy', 'latency') @@ -418,17 +418,17 @@ def init_conv2d(self, layer): self.set_target_reuse_factor(layer) n_in, n_out = self.get_layer_mult_size(layer) self.set_closest_reuse_factor(layer, n_in, n_out) - elif layer.model.config.get_strategy(layer).lower() == 'unrolled': + elif layer.model.config.get_strategy(layer).lower() == 'resource_unrolled': use_resource_instead = False if layer.get_attr('reuse_factor', 1) == 1: print( - f'Unrolled strategy cannot be combined with reuse factor 1 in layer "{layer.name}". ' + f'Unrolled resource strategy cannot be combined with reuse factor 1 in layer "{layer.name}". ' 'Using "resource" strategy instead.' ) use_resource_instead = True elif layer.model.config.get_config_value('IOType') == 'io_parallel': print( - f'Unrolled strategy cannot be combined with io_parallel in layer "{layer.name}". ' + f'Unrolled resource strategy cannot be combined with io_parallel in layer "{layer.name}". ' 'Using "resource" strategy instead.' ) use_resource_instead = True @@ -439,7 +439,7 @@ def init_conv2d(self, layer): layer.set_attr('strategy', 'resource') else: self.set_closest_reuse_factor(layer, n_in, n_out, include_max_rf=False) - layer.set_attr('strategy', 'unrolled') + layer.set_attr('strategy', 'resource_unrolled') else: layer.set_attr('strategy', 'latency') @@ -563,11 +563,11 @@ def init_lstm(self, layer): self.set_closest_reuse_factor(layer, n_in, n_out) self.set_closest_reuse_factor(layer, n_in_recr, n_out_recr, attribute='recurrent_reuse_factor') layer.set_attr('strategy', 'resource') - elif layer.model.config.get_strategy(layer).lower() == 'unrolled': + elif layer.model.config.get_strategy(layer).lower() == 'resource_unrolled': use_resource_instead = False if layer.get_attr('reuse_factor', 1) == 1: print( - f'Unrolled strategy cannot be combined with reuse factor 1 in layer "{layer.name}". ' + f'Unrolled resource strategy cannot be combined with reuse factor 1 in layer "{layer.name}". ' 'Using "resource" strategy instead.' ) use_resource_instead = True @@ -581,7 +581,7 @@ def init_lstm(self, layer): self.set_closest_reuse_factor( layer, n_in_recr, n_out_recr, attribute='recurrent_reuse_factor', include_max_rf=False ) - layer.set_attr('strategy', 'unrolled') + layer.set_attr('strategy', 'resource_unrolled') else: layer.set_attr('strategy', 'latency') @@ -597,11 +597,11 @@ def init_gru(self, layer): self.set_closest_reuse_factor(layer, n_in, n_out) self.set_closest_reuse_factor(layer, n_in_recr, n_out_recr, attribute='recurrent_reuse_factor') layer.set_attr('strategy', 'resource') - elif layer.model.config.get_strategy(layer).lower() == 'unrolled': + elif layer.model.config.get_strategy(layer).lower() == 'resource_unrolled': use_resource_instead = False if layer.get_attr('reuse_factor', 1) == 1: print( - f'Unrolled strategy cannot be combined with reuse factor 1 in layer "{layer.name}". ' + f'Unrolled resource strategy cannot be combined with reuse factor 1 in layer "{layer.name}". ' 'Using "resource" strategy instead.' ) use_resource_instead = True @@ -615,7 +615,7 @@ def init_gru(self, layer): self.set_closest_reuse_factor( layer, n_in_recr, n_out_recr, attribute='recurrent_reuse_factor', include_max_rf=False ) - layer.set_attr('strategy', 'unrolled') + layer.set_attr('strategy', 'resource_unrolled') else: layer.set_attr('strategy', 'latency') diff --git a/hls4ml/model/graph.py b/hls4ml/model/graph.py index 609417f94a..678e6d49af 100644 --- a/hls4ml/model/graph.py +++ b/hls4ml/model/graph.py @@ -10,6 +10,7 @@ from hls4ml.model.flow import get_flow from hls4ml.model.layers import layer_map from hls4ml.model.optimizer import get_available_passes, optimize_model +from hls4ml.utils.string_utils import convert_to_snake_case class HLSConfig: @@ -35,7 +36,7 @@ def __init__(self, config): self.layer_type_targ_cycles = {} self.layer_name_targ_cycles = {} - self.model_strategy = 'Latency' + self.model_strategy = convert_to_snake_case('Latency') self.layer_type_strategy = {} self.layer_name_strategy = {} @@ -217,7 +218,7 @@ def parse_name_config(self, layer_name, layer_cfg): strategy = layer_cfg.get('Strategy') if strategy is not None: - self.layer_name_strategy[layer_name.lower()] = strategy + self.layer_name_strategy[layer_name.lower()] = convert_to_snake_case(strategy) conv_implementation = layer_cfg.get('ConvImplementation') if conv_implementation is not None: @@ -265,7 +266,7 @@ def _parse_hls_config(self): self.model_rf = model_cfg.get('ReuseFactor') self.model_targ_cycles = model_cfg.get('TargetCycles') self.model_conv_implementation = model_cfg.get('ConvImplementation', 'LineBuffer') - self.model_strategy = model_cfg.get('Strategy', 'Latency') + self.model_strategy = convert_to_snake_case(model_cfg.get('Strategy', 'Latency')) self.model_compression = bool(model_cfg.get('Compression', 0)) self.pipeline_style = model_cfg.get('PipelineStyle', 'auto') self.pipeline_ii = model_cfg.get('PipelineInterval', None) @@ -290,7 +291,7 @@ def _parse_hls_config(self): strategy = layer_cfg.get('Strategy') if strategy is not None: - self.layer_type_strategy[layer_type.lower()] = strategy + self.layer_type_strategy[layer_type.lower()] = convert_to_snake_case(strategy) conv_implementation = layer_cfg.get('ConvImplementation') if conv_implementation is not None: diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_common.h b/hls4ml/templates/vivado/nnet_utils/nnet_common.h index fee8b7b935..a14517df5b 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_common.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_common.h @@ -23,7 +23,7 @@ namespace nnet { // Common type definitions enum io_type { io_parallel = 0, io_stream }; -enum strategy { latency, resource, unrolled }; +enum strategy { latency, resource, resource_unrolled }; /* --- * Balanced tree reduce implementation. diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_conv1d_stream.h b/hls4ml/templates/vivado/nnet_utils/nnet_conv1d_stream.h index 4a55700d8d..2b481930b7 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_conv1d_stream.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_conv1d_stream.h @@ -60,7 +60,7 @@ void conv_1d_buffer_cl(hls::stream &data, hls::stream &res, typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) { assert(CONFIG_T::pad_left == 0 && CONFIG_T::pad_right == 0); - if (CONFIG_T::strategy == nnet::unrolled && CONFIG_T::reuse_factor > 1) { + if (CONFIG_T::strategy == nnet::resource_unrolled && CONFIG_T::reuse_factor > 1) { #pragma HLS allocation instances=compute_output_buffer_1d limit=1 function } diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_conv2d_stream.h b/hls4ml/templates/vivado/nnet_utils/nnet_conv2d_stream.h index d5583f2669..1408b0db13 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_conv2d_stream.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_conv2d_stream.h @@ -75,7 +75,7 @@ void conv_2d_buffer_cl( [CONFIG_T::n_chan]; #pragma HLS ARRAY_PARTITION variable = line_buffer complete dim = 2 - if (CONFIG_T::strategy == nnet::unrolled && CONFIG_T::reuse_factor > 1) { + if (CONFIG_T::strategy == nnet::resource_unrolled && CONFIG_T::reuse_factor > 1) { #pragma HLS allocation instances=compute_output_buffer_1d limit=1 function #pragma HLS allocation instances=compute_output_buffer_2d limit=1 function } diff --git a/hls4ml/utils/string_utils.py b/hls4ml/utils/string_utils.py index fa341cd8af..a08c4c52a7 100644 --- a/hls4ml/utils/string_utils.py +++ b/hls4ml/utils/string_utils.py @@ -10,7 +10,8 @@ def convert_to_snake_case(pascal_case): Returns: str: converted string """ - return re.sub(r'(?', backend='Vitis', default_reuse_factor=8) + config['Model']['Strategy'] = strategy + + output_dir = str(test_root_path / f'hls4mlprj_resource_unrolled_parsing_{strategy}') + hls_model = convert_from_keras_model(model, hls_config=config, output_dir=output_dir, backend='Vitis') + + # Check if strategy was not overridden + assert list(hls_model.get_layers())[1].get_attr('strategy') == 'resource_unrolled' + + +# Tests a wide range of RF to ensure the unrolled resource kernel is correct @pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream']) @pytest.mark.parametrize('reuse_factor', [1, 2, 4, 8, 16, 32, 48, 64, 96, 192]) @pytest.mark.parametrize('backend', ['Vitis', 'Vivado']) -def test_dense_unrolled(io_type, reuse_factor, backend): +def test_resource_unrolled_dense(io_type, reuse_factor, backend): input_shape = (16,) X = np.random.rand(100, *input_shape) @@ -31,13 +49,13 @@ def test_dense_unrolled(io_type, reuse_factor, backend): config = config_from_keras_model( model, default_precision='ac_fixed<32, 16>', backend=backend, default_reuse_factor=reuse_factor ) - config['Model']['Strategy'] = 'Unrolled' + config['Model']['Strategy'] = 'ResourceUnrolled' - output_dir = str(test_root_path / f'hls4mlprj_dense_unrolled_{io_type}_{reuse_factor}_{backend}') + output_dir = str(test_root_path / f'hls4mlprj_resource_unrolled_dense_{io_type}_{reuse_factor}_{backend}') hls_model = convert_from_keras_model(model, hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type) # Check if strategy was not overridden - assert list(hls_model.get_layers())[1].get_attr('strategy') == 'unrolled' if reuse_factor > 1 else 'latency' + assert list(hls_model.get_layers())[1].get_attr('strategy') == 'resource_unrolled' if reuse_factor > 1 else 'latency' hls_model.compile() @@ -45,11 +63,11 @@ def test_dense_unrolled(io_type, reuse_factor, backend): np.testing.assert_allclose(hls_prediction, keras_prediction, rtol=0, atol=1e-2) -# Tests a wide range RF on streaming Conv1D/2D to ensure the unrolled Dense is correct +# Tests a wide range RF on streaming Conv1D/2D to ensure the unrolled resource kernel is correct @pytest.mark.parametrize('dim', [1, 2]) @pytest.mark.parametrize('io_type', ['io_stream']) @pytest.mark.parametrize('reuse_factor', [1, 3, 9, 27, 54, 108]) -def test_dense_unrolled_streaming_conv(dim, io_type, reuse_factor): +def test_resource_unrolled_streaming_conv(dim, io_type, reuse_factor): input_shape = (8,) * dim + (3,) X = np.random.rand(100, *input_shape) conv_class = Conv1D if dim == 1 else Conv2D @@ -66,13 +84,13 @@ def test_dense_unrolled_streaming_conv(dim, io_type, reuse_factor): keras_prediction = model.predict(X) config = config_from_keras_model(model, default_precision='ac_fixed<32, 16>', default_reuse_factor=reuse_factor) - config['Model']['Strategy'] = 'Unrolled' + config['Model']['Strategy'] = 'ResourceUnrolled' - output_dir = str(test_root_path / f'hls4mlprj_dense_unrolled_conv{dim}d_{io_type}_{reuse_factor}') + output_dir = str(test_root_path / f'hls4mlprj_resource_unrolled_conv{dim}d_{io_type}_{reuse_factor}') hls_model = convert_from_keras_model(model, hls_config=config, output_dir=output_dir, backend='Vivado', io_type=io_type) # Check if strategy was not overridden - assert list(hls_model.get_layers())[1].get_attr('strategy') == 'unrolled' if reuse_factor > 1 else 'latency' + assert list(hls_model.get_layers())[1].get_attr('strategy') == 'resource_unrolled' if reuse_factor > 1 else 'latency' hls_model.compile() @@ -85,7 +103,7 @@ def test_dense_unrolled_streaming_conv(dim, io_type, reuse_factor): @pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream']) @pytest.mark.parametrize('static', [True, False]) @pytest.mark.parametrize('reuse_factor', [1, 4, 32, 128]) # RF=128 also tests if setting closest RF works well -def test_rnn_unrolled(rnn_layer, backend, io_type, static, reuse_factor): +def test_resource_unrolled_rnn(rnn_layer, backend, io_type, static, reuse_factor): # Subtract 0.5 to include negative values input_shape = (12, 8) X = np.random.rand(50, *input_shape) - 0.5 @@ -110,9 +128,9 @@ def test_rnn_unrolled(rnn_layer, backend, io_type, static, reuse_factor): keras_model, granularity='name', default_precision=default_precision, backend=backend ) hls_config['LayerName'][layer_name]['static'] = static - hls_config['LayerName'][layer_name]['Strategy'] = 'Unrolled' + hls_config['LayerName'][layer_name]['Strategy'] = 'ResourceUnrolled' hls_config['LayerName'][layer_name]['ReuseFactor'] = reuse_factor - prj_name = f'hls4mlprj_rnn_unrolled_{layer_name}_static_{int(static)}_{io_type}_{reuse_factor}_{backend}' + prj_name = f'hls4mlprj_resource_unrolled_rnn_{layer_name}_static_{int(static)}_{io_type}_{reuse_factor}_{backend}' output_dir = str(test_root_path / prj_name) hls_model = convert_from_keras_model( @@ -120,7 +138,7 @@ def test_rnn_unrolled(rnn_layer, backend, io_type, static, reuse_factor): ) # Check if strategy was not overridden - assert list(hls_model.get_layers())[1].get_attr('strategy') == 'unrolled' if reuse_factor > 1 else 'latency' + assert list(hls_model.get_layers())[1].get_attr('strategy') == 'resource_unrolled' if reuse_factor > 1 else 'latency' hls_model.compile() diff --git a/test/pytest/test_pipeline_style.py b/test/pytest/test_pipeline_style.py old mode 100755 new mode 100644 index f8706fa52c..17d180d487 --- a/test/pytest/test_pipeline_style.py +++ b/test/pytest/test_pipeline_style.py @@ -16,10 +16,10 @@ [ (1, 'auto', 'io_stream', 'resource', None), # io_stream should result in DATAFLOW pragma regardless of other params (2, 'auto', 'io_stream', 'latency', None), - (3, None, 'io_stream', 'unrolled', None), # None should be interpreted as 'auto' + (3, None, 'io_stream', 'resource_unrolled', None), # None should be interpreted as 'auto' (4, 'auto', 'io_parallel', 'resource', None), # Should end up with DATAFLOW pragma (5, 'auto', 'io_parallel', 'latency', None), # Should end up with PIPELINE pragma - (6, 'auto', 'io_parallel', 'unrolled', None), # Should end up with PIPELINE pragma and II + (6, 'auto', 'io_parallel', 'resource_unrolled', None), # Should end up with PIPELINE pragma and II (7, 'pipeline', 'io_stream', 'resource', None), # Should result in a warning (8, 'pipeline', 'io_parallel', 'resource', None), # Should result in a warning (9, 'pipeline', 'io_parallel', 'latency', None), # No warning From 97c5347eb2e8cb7cf68a243810e40bce2f0ce24a Mon Sep 17 00:00:00 2001 From: Vladimir Loncar Date: Mon, 7 Oct 2024 18:33:36 +0200 Subject: [PATCH 14/14] Move optimization API to "dsp_aware_pruning" module (new optimization tools coming) --- docs/advanced/model_optimization.rst | 14 +-- hls4ml/optimization/__init__.py | 111 +----------------- .../dsp_aware_pruning/__init__.py | 108 +++++++++++++++++ .../{ => dsp_aware_pruning}/attributes.py | 4 +- .../{ => dsp_aware_pruning}/config.py | 0 .../{ => dsp_aware_pruning}/keras/__init__.py | 14 +-- .../{ => dsp_aware_pruning}/keras/builder.py | 4 +- .../{ => dsp_aware_pruning}/keras/config.py | 0 .../{ => dsp_aware_pruning}/keras/masking.py | 6 +- .../keras/reduction.py | 2 +- .../keras/regularizers.py | 2 +- .../{ => dsp_aware_pruning}/keras/utils.py | 0 .../{ => dsp_aware_pruning}/knapsack.py | 0 .../objectives/__init__.py | 4 +- .../objectives/gpu_objectives.py | 6 +- .../objectives/vivado_objectives.py | 6 +- .../{ => dsp_aware_pruning}/scheduler.py | 0 .../test_optimization/test_attributes.py | 2 +- .../test_keras/test_masking.py | 8 +- .../test_keras/test_reduction.py | 4 +- .../test_keras/test_regularizers.py | 6 +- .../test_keras/test_weight_sharing.py | 8 +- .../pytest/test_optimization/test_knapsack.py | 2 +- .../test_optimization/test_objectives.py | 4 +- .../test_optimization/test_scheduler.py | 2 +- test/pytest/test_pipeline_style.py | 0 26 files changed, 160 insertions(+), 157 deletions(-) create mode 100644 hls4ml/optimization/dsp_aware_pruning/__init__.py rename hls4ml/optimization/{ => dsp_aware_pruning}/attributes.py (98%) rename hls4ml/optimization/{ => dsp_aware_pruning}/config.py (100%) rename hls4ml/optimization/{ => dsp_aware_pruning}/keras/__init__.py (96%) rename hls4ml/optimization/{ => dsp_aware_pruning}/keras/builder.py (98%) rename hls4ml/optimization/{ => dsp_aware_pruning}/keras/config.py (100%) rename hls4ml/optimization/{ => dsp_aware_pruning}/keras/masking.py (99%) rename hls4ml/optimization/{ => dsp_aware_pruning}/keras/reduction.py (96%) rename hls4ml/optimization/{ => dsp_aware_pruning}/keras/regularizers.py (99%) rename hls4ml/optimization/{ => dsp_aware_pruning}/keras/utils.py (100%) rename hls4ml/optimization/{ => dsp_aware_pruning}/knapsack.py (100%) rename hls4ml/optimization/{ => dsp_aware_pruning}/objectives/__init__.py (97%) rename hls4ml/optimization/{ => dsp_aware_pruning}/objectives/gpu_objectives.py (92%) rename hls4ml/optimization/{ => dsp_aware_pruning}/objectives/vivado_objectives.py (98%) rename hls4ml/optimization/{ => dsp_aware_pruning}/scheduler.py (100%) mode change 100644 => 100755 test/pytest/test_pipeline_style.py diff --git a/docs/advanced/model_optimization.rst b/docs/advanced/model_optimization.rst index c1396b3d20..41132ab619 100644 --- a/docs/advanced/model_optimization.rst +++ b/docs/advanced/model_optimization.rst @@ -13,11 +13,11 @@ The code block below showcases three use cases of the hls4ml Optimization API - from tensorflow.keras.optimizers import Adam from tensorflow.keras.metrics import CategoricalAccuracy from tensorflow.keras.losses import CategoricalCrossentropy - from hls4ml.optimization.keras import optimize_model - from hls4ml.optimization.keras.utils import get_model_sparsity - from hls4ml.optimization.attributes import get_attributes_from_keras_model - from hls4ml.optimization.objectives import ParameterEstimator - from hls4ml.optimization.scheduler import PolynomialScheduler + from hls4ml.optimization.dsp_aware_pruning.keras import optimize_model + from hls4ml.optimization.dsp_aware_pruning.keras.utils import get_model_sparsity + from hls4ml.optimization.dsp_aware_pruning.attributes import get_attributes_from_keras_model + from hls4ml.optimization.dsp_aware_pruning.objectives import ParameterEstimator + from hls4ml.optimization.dsp_aware_pruning.scheduler import PolynomialScheduler # Define baseline model and load data # X_train, y_train = ... # X_val, y_val = ... @@ -75,7 +75,7 @@ To optimize GPU FLOPs, the code is similar to above: .. code-block:: Python - from hls4ml.optimization.objectives.gpu_objectives import GPUFLOPEstimator + from hls4ml.optimization.dsp_aware_pruning.objectives.gpu_objectives import GPUFLOPEstimator # Optimize model # Note the change from ParameterEstimator to GPUFLOPEstimator @@ -98,7 +98,7 @@ Finally, optimizing Vivado DSPs is possible, given a hls4ml config: .. code-block:: Python from hls4ml.utils.config import config_from_keras_model - from hls4ml.optimization.objectives.vivado_objectives import VivadoDSPEstimator + from hls4ml.optimization.dsp_aware_pruning.objectives.vivado_objectives import VivadoDSPEstimator # Note the change from optimize_model to optimize_keras_model_for_hls4ml # The function optimize_keras_model_for_hls4ml acts as a wrapper for the function, parsing hls4ml config to model attributes diff --git a/hls4ml/optimization/__init__.py b/hls4ml/optimization/__init__.py index ab51ce1eb3..c626b70c2b 100644 --- a/hls4ml/optimization/__init__.py +++ b/hls4ml/optimization/__init__.py @@ -1,108 +1,3 @@ -import numpy as np - -from hls4ml.optimization.attributes import get_attributes_from_keras_model_and_hls4ml_config -from hls4ml.optimization.keras import optimize_model - -default_regularization_range = np.logspace(-6, -2, num=16).tolist() - - -def optimize_keras_model_for_hls4ml( - keras_model, - hls_config, - objective, - scheduler, - X_train, - y_train, - X_val, - y_val, - batch_size, - epochs, - optimizer, - loss_fn, - validation_metric, - increasing, - rtol, - callbacks=None, - ranking_metric='l1', - local=False, - verbose=False, - rewinding_epochs=1, - cutoff_bad_trials=3, - directory='hls4ml-optimization', - tuner='Bayesian', - knapsack_solver='CBC_MIP', - regularization_range=default_regularization_range, -): - ''' - Top-level function for optimizing a Keras model, given hls4ml config and a hardware objective(s) - - Args: - keras_model (keras.Model): Model to be optimized - hls_config (dict): hls4ml configuration, obtained from hls4ml.utils.config.config_from_keras_model(...) - objective (hls4ml.optimization.objectives.ObjectiveEstimator): - Parameter, hardware or user-defined objective of optimization - scheduler (hls4ml.optimization.scheduler.OptimizationScheduler): - Sparsity scheduler, choose between constant, polynomial and binary - X_train (np.array): Training inputs - y_train (np.array): Training labels - X_val (np.array): Validation inputs - y_val (np.array): Validation labels - batch_size (int): Batch size during training - epochs (int): Maximum number of epochs to fine-tune model, in one iteration of pruning - optimizer (keras.optimizers.Optimizer or equivalent-string description): Optimizer used during training - loss_fn (keras.losses.Loss or equivalent loss description): Loss function used during training - validation_metric (keras.metrics.Metric or equivalent loss description): Validation metric, used as a baseline - increasing (boolean): If the metric improves with increased values; - e.g. accuracy -> increasing = True, MSE -> increasing = False - rtol (float): Relative tolerance; - pruning stops when pruned_validation_metric < (or >) rtol * baseline_validation_metric - callbacks (list of keras.callbacks.Callback) Currently not supported, developed in future versions - ranking_metric (string): Metric used for ranking weights and structures; - currently supported l1, l2, saliency and Oracle - local (boolean): Layer-wise or global pruning - verbose (boolean): Display debug logs during model optimization - rewinding_epochs (int): Number of epochs to retrain model without weight freezing, - allows regrowth of previously pruned weights - cutoff_bad_trials (int): After how many bad trials (performance below threshold), - should model pruning / weight sharing stop - directory (string): Directory to store temporary results - tuner (str): Tuning algorithm, choose between Bayesian, Hyperband and None - knapsack_solver (str): Algorithm to solve Knapsack problem when optimizing; - default usually works well; for very large networks, greedy algorithm might be more suitable - regularization_range (list): List of suitable hyperparameters for weight decay - - Returns: - keras.Model: Optimized model - ''' - - # Extract model attributes - model_attributes = get_attributes_from_keras_model_and_hls4ml_config(keras_model, hls_config) - - # Optimize model - return optimize_model( - keras_model, - model_attributes, - objective, - scheduler, - X_train, - y_train, - X_val, - y_val, - batch_size, - epochs, - optimizer, - loss_fn, - validation_metric, - increasing, - rtol, - callbacks=callbacks, - ranking_metric=ranking_metric, - local=local, - verbose=verbose, - rewinding_epochs=rewinding_epochs, - cutoff_bad_trials=cutoff_bad_trials, - directory=directory, - tuner=tuner, - knapsack_solver=knapsack_solver, - regularization_range=regularization_range, - ) +from .dsp_aware_pruning import optimize_keras_model_for_hls4ml # noqa: F401 +from .dsp_aware_pruning.attributes import get_attributes_from_keras_model_and_hls4ml_config # noqa: F401 +from .dsp_aware_pruning.keras import optimize_model # noqa: F401 diff --git a/hls4ml/optimization/dsp_aware_pruning/__init__.py b/hls4ml/optimization/dsp_aware_pruning/__init__.py new file mode 100644 index 0000000000..69e2029e0e --- /dev/null +++ b/hls4ml/optimization/dsp_aware_pruning/__init__.py @@ -0,0 +1,108 @@ +import numpy as np + +from hls4ml.optimization.dsp_aware_pruning.attributes import get_attributes_from_keras_model_and_hls4ml_config +from hls4ml.optimization.dsp_aware_pruning.keras import optimize_model + +default_regularization_range = np.logspace(-6, -2, num=16).tolist() + + +def optimize_keras_model_for_hls4ml( + keras_model, + hls_config, + objective, + scheduler, + X_train, + y_train, + X_val, + y_val, + batch_size, + epochs, + optimizer, + loss_fn, + validation_metric, + increasing, + rtol, + callbacks=None, + ranking_metric='l1', + local=False, + verbose=False, + rewinding_epochs=1, + cutoff_bad_trials=3, + directory='hls4ml-optimization', + tuner='Bayesian', + knapsack_solver='CBC_MIP', + regularization_range=default_regularization_range, +): + ''' + Top-level function for optimizing a Keras model, given hls4ml config and a hardware objective(s) + + Args: + keras_model (keras.Model): Model to be optimized + hls_config (dict): hls4ml configuration, obtained from hls4ml.utils.config.config_from_keras_model(...) + objective (hls4ml.optimization.objectives.ObjectiveEstimator): + Parameter, hardware or user-defined objective of optimization + scheduler (hls4ml.optimization.scheduler.OptimizationScheduler): + Sparsity scheduler, choose between constant, polynomial and binary + X_train (np.array): Training inputs + y_train (np.array): Training labels + X_val (np.array): Validation inputs + y_val (np.array): Validation labels + batch_size (int): Batch size during training + epochs (int): Maximum number of epochs to fine-tune model, in one iteration of pruning + optimizer (keras.optimizers.Optimizer or equivalent-string description): Optimizer used during training + loss_fn (keras.losses.Loss or equivalent loss description): Loss function used during training + validation_metric (keras.metrics.Metric or equivalent loss description): Validation metric, used as a baseline + increasing (boolean): If the metric improves with increased values; + e.g. accuracy -> increasing = True, MSE -> increasing = False + rtol (float): Relative tolerance; + pruning stops when pruned_validation_metric < (or >) rtol * baseline_validation_metric + callbacks (list of keras.callbacks.Callback) Currently not supported, developed in future versions + ranking_metric (string): Metric used for ranking weights and structures; + currently supported l1, l2, saliency and Oracle + local (boolean): Layer-wise or global pruning + verbose (boolean): Display debug logs during model optimization + rewinding_epochs (int): Number of epochs to retrain model without weight freezing, + allows regrowth of previously pruned weights + cutoff_bad_trials (int): After how many bad trials (performance below threshold), + should model pruning / weight sharing stop + directory (string): Directory to store temporary results + tuner (str): Tuning algorithm, choose between Bayesian, Hyperband and None + knapsack_solver (str): Algorithm to solve Knapsack problem when optimizing; + default usually works well; for very large networks, greedy algorithm might be more suitable + regularization_range (list): List of suitable hyperparameters for weight decay + + Returns: + keras.Model: Optimized model + ''' + + # Extract model attributes + model_attributes = get_attributes_from_keras_model_and_hls4ml_config(keras_model, hls_config) + + # Optimize model + return optimize_model( + keras_model, + model_attributes, + objective, + scheduler, + X_train, + y_train, + X_val, + y_val, + batch_size, + epochs, + optimizer, + loss_fn, + validation_metric, + increasing, + rtol, + callbacks=callbacks, + ranking_metric=ranking_metric, + local=local, + verbose=verbose, + rewinding_epochs=rewinding_epochs, + cutoff_bad_trials=cutoff_bad_trials, + directory=directory, + tuner=tuner, + knapsack_solver=knapsack_solver, + regularization_range=regularization_range, + ) diff --git a/hls4ml/optimization/attributes.py b/hls4ml/optimization/dsp_aware_pruning/attributes.py similarity index 98% rename from hls4ml/optimization/attributes.py rename to hls4ml/optimization/dsp_aware_pruning/attributes.py index a7b6d74135..f652f27d50 100644 --- a/hls4ml/optimization/attributes.py +++ b/hls4ml/optimization/dsp_aware_pruning/attributes.py @@ -2,8 +2,8 @@ import hls4ml from hls4ml.model.types import FixedPrecisionType, IntegerPrecisionType -from hls4ml.optimization.config import SUPPORTED_STRUCTURES -from hls4ml.optimization.keras.config import SUPPORTED_LAYERS +from hls4ml.optimization.dsp_aware_pruning.config import SUPPORTED_STRUCTURES +from hls4ml.optimization.dsp_aware_pruning.keras.config import SUPPORTED_LAYERS class hls4mlAttributes: diff --git a/hls4ml/optimization/config.py b/hls4ml/optimization/dsp_aware_pruning/config.py similarity index 100% rename from hls4ml/optimization/config.py rename to hls4ml/optimization/dsp_aware_pruning/config.py diff --git a/hls4ml/optimization/keras/__init__.py b/hls4ml/optimization/dsp_aware_pruning/keras/__init__.py similarity index 96% rename from hls4ml/optimization/keras/__init__.py rename to hls4ml/optimization/dsp_aware_pruning/keras/__init__.py index d67ddd5d26..29012bd39e 100644 --- a/hls4ml/optimization/keras/__init__.py +++ b/hls4ml/optimization/dsp_aware_pruning/keras/__init__.py @@ -7,13 +7,13 @@ # Enables printing of loss tensors during custom training loop from tensorflow.python.ops.numpy_ops import np_config -import hls4ml.optimization.keras.utils as utils -from hls4ml.optimization.config import SUPPORTED_STRUCTURES -from hls4ml.optimization.keras.builder import build_optimizable_model, remove_custom_regularizers -from hls4ml.optimization.keras.config import SUPPORTED_LAYERS, SUPPORTED_METRICS, TMP_DIRECTORY -from hls4ml.optimization.keras.masking import get_model_masks -from hls4ml.optimization.keras.reduction import reduce_model -from hls4ml.optimization.scheduler import OptimizationScheduler +import hls4ml.optimization.dsp_aware_pruning.keras.utils as utils +from hls4ml.optimization.dsp_aware_pruning.config import SUPPORTED_STRUCTURES +from hls4ml.optimization.dsp_aware_pruning.keras.builder import build_optimizable_model, remove_custom_regularizers +from hls4ml.optimization.dsp_aware_pruning.keras.config import SUPPORTED_LAYERS, SUPPORTED_METRICS, TMP_DIRECTORY +from hls4ml.optimization.dsp_aware_pruning.keras.masking import get_model_masks +from hls4ml.optimization.dsp_aware_pruning.keras.reduction import reduce_model +from hls4ml.optimization.dsp_aware_pruning.scheduler import OptimizationScheduler np_config.enable_numpy_behavior() default_regularization_range = np.logspace(-6, -2, num=16).tolist() diff --git a/hls4ml/optimization/keras/builder.py b/hls4ml/optimization/dsp_aware_pruning/keras/builder.py similarity index 98% rename from hls4ml/optimization/keras/builder.py rename to hls4ml/optimization/dsp_aware_pruning/keras/builder.py index f265ccdf48..4ba39e4f7b 100644 --- a/hls4ml/optimization/keras/builder.py +++ b/hls4ml/optimization/dsp_aware_pruning/keras/builder.py @@ -8,8 +8,8 @@ from tensorflow.keras.callbacks import EarlyStopping from tensorflow.keras.layers import Conv2D, Dense -from hls4ml.optimization.keras.config import SUPPORTED_LAYERS, TMP_DIRECTORY -from hls4ml.optimization.keras.regularizers import Conv2DRegularizer, DenseRegularizer +from hls4ml.optimization.dsp_aware_pruning.keras.config import SUPPORTED_LAYERS, TMP_DIRECTORY +from hls4ml.optimization.dsp_aware_pruning.keras.regularizers import Conv2DRegularizer, DenseRegularizer co = {} _add_supported_quantized_objects(co) diff --git a/hls4ml/optimization/keras/config.py b/hls4ml/optimization/dsp_aware_pruning/keras/config.py similarity index 100% rename from hls4ml/optimization/keras/config.py rename to hls4ml/optimization/dsp_aware_pruning/keras/config.py diff --git a/hls4ml/optimization/keras/masking.py b/hls4ml/optimization/dsp_aware_pruning/keras/masking.py similarity index 99% rename from hls4ml/optimization/keras/masking.py rename to hls4ml/optimization/dsp_aware_pruning/keras/masking.py index 0e74997be8..dddeddf6f7 100644 --- a/hls4ml/optimization/keras/masking.py +++ b/hls4ml/optimization/dsp_aware_pruning/keras/masking.py @@ -6,9 +6,9 @@ from qkeras import QConv2D, QDense from tensorflow.keras.layers import Conv2D, Dense -from hls4ml.optimization.config import SUPPORTED_STRUCTURES -from hls4ml.optimization.keras.config import SUPPORTED_LAYERS, SUPPORTED_METRICS -from hls4ml.optimization.knapsack import solve_knapsack +from hls4ml.optimization.dsp_aware_pruning.config import SUPPORTED_STRUCTURES +from hls4ml.optimization.dsp_aware_pruning.keras.config import SUPPORTED_LAYERS, SUPPORTED_METRICS +from hls4ml.optimization.dsp_aware_pruning.knapsack import solve_knapsack def get_model_masks( diff --git a/hls4ml/optimization/keras/reduction.py b/hls4ml/optimization/dsp_aware_pruning/keras/reduction.py similarity index 96% rename from hls4ml/optimization/keras/reduction.py rename to hls4ml/optimization/dsp_aware_pruning/keras/reduction.py index 4ea8855aa8..12fb534799 100644 --- a/hls4ml/optimization/keras/reduction.py +++ b/hls4ml/optimization/dsp_aware_pruning/keras/reduction.py @@ -2,7 +2,7 @@ from tensorflow.keras.layers import Conv2D, Dense from tensorflow.keras.models import Sequential -from hls4ml.optimization.keras.utils import get_last_layer_with_weights +from hls4ml.optimization.dsp_aware_pruning.keras.utils import get_last_layer_with_weights def reduce_model(model): diff --git a/hls4ml/optimization/keras/regularizers.py b/hls4ml/optimization/dsp_aware_pruning/keras/regularizers.py similarity index 99% rename from hls4ml/optimization/keras/regularizers.py rename to hls4ml/optimization/dsp_aware_pruning/keras/regularizers.py index 1e885963c2..b42eb3f056 100644 --- a/hls4ml/optimization/keras/regularizers.py +++ b/hls4ml/optimization/dsp_aware_pruning/keras/regularizers.py @@ -1,7 +1,7 @@ import numpy as np import tensorflow as tf -from hls4ml.optimization.config import SUPPORTED_STRUCTURES +from hls4ml.optimization.dsp_aware_pruning.config import SUPPORTED_STRUCTURES @tf.keras.utils.register_keras_serializable(name='DenseRegularizer') diff --git a/hls4ml/optimization/keras/utils.py b/hls4ml/optimization/dsp_aware_pruning/keras/utils.py similarity index 100% rename from hls4ml/optimization/keras/utils.py rename to hls4ml/optimization/dsp_aware_pruning/keras/utils.py diff --git a/hls4ml/optimization/knapsack.py b/hls4ml/optimization/dsp_aware_pruning/knapsack.py similarity index 100% rename from hls4ml/optimization/knapsack.py rename to hls4ml/optimization/dsp_aware_pruning/knapsack.py diff --git a/hls4ml/optimization/objectives/__init__.py b/hls4ml/optimization/dsp_aware_pruning/objectives/__init__.py similarity index 97% rename from hls4ml/optimization/objectives/__init__.py rename to hls4ml/optimization/dsp_aware_pruning/objectives/__init__.py index fcbef305b6..45204aaf73 100644 --- a/hls4ml/optimization/objectives/__init__.py +++ b/hls4ml/optimization/dsp_aware_pruning/objectives/__init__.py @@ -3,8 +3,8 @@ import numpy as np -from hls4ml.optimization.attributes import OptimizationAttributes -from hls4ml.optimization.config import SUPPORTED_STRUCTURES +from hls4ml.optimization.dsp_aware_pruning.attributes import OptimizationAttributes +from hls4ml.optimization.dsp_aware_pruning.config import SUPPORTED_STRUCTURES ''' Pruning & weight sharing are formulated as an optimization problem, with the aim of minimizing some metric diff --git a/hls4ml/optimization/objectives/gpu_objectives.py b/hls4ml/optimization/dsp_aware_pruning/objectives/gpu_objectives.py similarity index 92% rename from hls4ml/optimization/objectives/gpu_objectives.py rename to hls4ml/optimization/dsp_aware_pruning/objectives/gpu_objectives.py index 8528a31839..bb3afc6397 100644 --- a/hls4ml/optimization/objectives/gpu_objectives.py +++ b/hls4ml/optimization/dsp_aware_pruning/objectives/gpu_objectives.py @@ -2,9 +2,9 @@ import numpy as np -from hls4ml.optimization.attributes import OptimizationAttributes -from hls4ml.optimization.config import SUPPORTED_STRUCTURES -from hls4ml.optimization.objectives import ObjectiveEstimator +from hls4ml.optimization.dsp_aware_pruning.attributes import OptimizationAttributes +from hls4ml.optimization.dsp_aware_pruning.config import SUPPORTED_STRUCTURES +from hls4ml.optimization.dsp_aware_pruning.objectives import ObjectiveEstimator class GPUFLOPEstimator(ObjectiveEstimator): diff --git a/hls4ml/optimization/objectives/vivado_objectives.py b/hls4ml/optimization/dsp_aware_pruning/objectives/vivado_objectives.py similarity index 98% rename from hls4ml/optimization/objectives/vivado_objectives.py rename to hls4ml/optimization/dsp_aware_pruning/objectives/vivado_objectives.py index c0c0c33e09..798542cfc0 100644 --- a/hls4ml/optimization/objectives/vivado_objectives.py +++ b/hls4ml/optimization/dsp_aware_pruning/objectives/vivado_objectives.py @@ -3,9 +3,9 @@ import numpy as np -from hls4ml.optimization.attributes import OptimizationAttributes -from hls4ml.optimization.config import SUPPORTED_STRUCTURES -from hls4ml.optimization.objectives import ObjectiveEstimator +from hls4ml.optimization.dsp_aware_pruning.attributes import OptimizationAttributes +from hls4ml.optimization.dsp_aware_pruning.config import SUPPORTED_STRUCTURES +from hls4ml.optimization.dsp_aware_pruning.objectives import ObjectiveEstimator # Optimizes DSP utilisation for Vivado backend diff --git a/hls4ml/optimization/scheduler.py b/hls4ml/optimization/dsp_aware_pruning/scheduler.py similarity index 100% rename from hls4ml/optimization/scheduler.py rename to hls4ml/optimization/dsp_aware_pruning/scheduler.py diff --git a/test/pytest/test_optimization/test_attributes.py b/test/pytest/test_optimization/test_attributes.py index 3ba8d08d14..a42d3a6751 100644 --- a/test/pytest/test_optimization/test_attributes.py +++ b/test/pytest/test_optimization/test_attributes.py @@ -1,7 +1,7 @@ from tensorflow.keras.layers import Conv2D, Dense, Flatten, ReLU from tensorflow.keras.models import Sequential -from hls4ml.optimization.attributes import get_attributes_from_keras_model_and_hls4ml_config +from hls4ml.optimization import get_attributes_from_keras_model_and_hls4ml_config from hls4ml.utils.config import config_from_keras_model diff --git a/test/pytest/test_optimization/test_keras/test_masking.py b/test/pytest/test_optimization/test_keras/test_masking.py index 5c5e60aca7..8b465d8d7e 100644 --- a/test/pytest/test_optimization/test_keras/test_masking.py +++ b/test/pytest/test_optimization/test_keras/test_masking.py @@ -4,10 +4,10 @@ from tensorflow.keras.layers import Conv2D, Dense, Flatten from tensorflow.keras.models import Sequential -from hls4ml.optimization.attributes import get_attributes_from_keras_model -from hls4ml.optimization.config import SUPPORTED_STRUCTURES -from hls4ml.optimization.keras.masking import get_model_masks -from hls4ml.optimization.objectives import ParameterEstimator +from hls4ml.optimization.dsp_aware_pruning.attributes import get_attributes_from_keras_model +from hls4ml.optimization.dsp_aware_pruning.config import SUPPORTED_STRUCTURES +from hls4ml.optimization.dsp_aware_pruning.keras.masking import get_model_masks +from hls4ml.optimization.dsp_aware_pruning.objectives import ParameterEstimator ''' In all the tests, an artifical network with one Dense/Conv2D layer and pre-determined weights is created diff --git a/test/pytest/test_optimization/test_keras/test_reduction.py b/test/pytest/test_optimization/test_keras/test_reduction.py index 7243a9123f..4bf93f7301 100644 --- a/test/pytest/test_optimization/test_keras/test_reduction.py +++ b/test/pytest/test_optimization/test_keras/test_reduction.py @@ -6,8 +6,8 @@ from tensorflow.keras.layers import AveragePooling2D, BatchNormalization, Conv2D, Dense, Flatten, MaxPooling2D, ReLU, Softmax from tensorflow.keras.models import Sequential -from hls4ml.optimization.keras.reduction import reduce_model -from hls4ml.optimization.keras.utils import get_model_sparsity +from hls4ml.optimization.dsp_aware_pruning.keras.reduction import reduce_model +from hls4ml.optimization.dsp_aware_pruning.keras.utils import get_model_sparsity pytest.skip(allow_module_level=True) diff --git a/test/pytest/test_optimization/test_keras/test_regularizers.py b/test/pytest/test_optimization/test_keras/test_regularizers.py index 9fe518caae..f643f3a79a 100644 --- a/test/pytest/test_optimization/test_keras/test_regularizers.py +++ b/test/pytest/test_optimization/test_keras/test_regularizers.py @@ -6,9 +6,9 @@ from tensorflow.keras.models import Sequential from tensorflow.keras.optimizers import Adam -from hls4ml.optimization.config import SUPPORTED_STRUCTURES -from hls4ml.optimization.keras.builder import remove_custom_regularizers -from hls4ml.optimization.keras.regularizers import Conv2DRegularizer, DenseRegularizer +from hls4ml.optimization.dsp_aware_pruning.config import SUPPORTED_STRUCTURES +from hls4ml.optimization.dsp_aware_pruning.keras.builder import remove_custom_regularizers +from hls4ml.optimization.dsp_aware_pruning.keras.regularizers import Conv2DRegularizer, DenseRegularizer # Constants pattern_offset = 4 diff --git a/test/pytest/test_optimization/test_keras/test_weight_sharing.py b/test/pytest/test_optimization/test_keras/test_weight_sharing.py index c274a84da8..be1d3a957f 100644 --- a/test/pytest/test_optimization/test_keras/test_weight_sharing.py +++ b/test/pytest/test_optimization/test_keras/test_weight_sharing.py @@ -4,10 +4,10 @@ from tensorflow.keras.layers import Dense from tensorflow.keras.models import Sequential -from hls4ml.optimization.attributes import get_attributes_from_keras_model -from hls4ml.optimization.config import SUPPORTED_STRUCTURES -from hls4ml.optimization.keras.masking import get_model_masks -from hls4ml.optimization.objectives import ObjectiveEstimator +from hls4ml.optimization.dsp_aware_pruning.attributes import get_attributes_from_keras_model +from hls4ml.optimization.dsp_aware_pruning.config import SUPPORTED_STRUCTURES +from hls4ml.optimization.dsp_aware_pruning.keras.masking import get_model_masks +from hls4ml.optimization.dsp_aware_pruning.objectives import ObjectiveEstimator # Similar tests in test_masking.py, weight sharing instead of pruning sparsity = 0.33 diff --git a/test/pytest/test_optimization/test_knapsack.py b/test/pytest/test_optimization/test_knapsack.py index a4145c00d0..804081c8e8 100644 --- a/test/pytest/test_optimization/test_knapsack.py +++ b/test/pytest/test_optimization/test_knapsack.py @@ -1,7 +1,7 @@ import numpy as np import pytest -from hls4ml.optimization.knapsack import solve_knapsack +from hls4ml.optimization.dsp_aware_pruning.knapsack import solve_knapsack # In the simple case below, both implementations give the optimal answer diff --git a/test/pytest/test_optimization/test_objectives.py b/test/pytest/test_optimization/test_objectives.py index a7d81befe6..2f8a6414da 100644 --- a/test/pytest/test_optimization/test_objectives.py +++ b/test/pytest/test_optimization/test_objectives.py @@ -2,8 +2,8 @@ from tensorflow.keras.layers import Conv2D, Dense, Flatten from tensorflow.keras.models import Sequential -from hls4ml.optimization.attributes import get_attributes_from_keras_model -from hls4ml.optimization.objectives import ParameterEstimator +from hls4ml.optimization.dsp_aware_pruning.attributes import get_attributes_from_keras_model +from hls4ml.optimization.dsp_aware_pruning.objectives import ParameterEstimator # Test attempts to verify one of the estimators (parameter) is correctly declared, the functions are static etc. diff --git a/test/pytest/test_optimization/test_scheduler.py b/test/pytest/test_optimization/test_scheduler.py index 2dc7642bf6..2182d1cb46 100644 --- a/test/pytest/test_optimization/test_scheduler.py +++ b/test/pytest/test_optimization/test_scheduler.py @@ -1,6 +1,6 @@ import numpy as np # Use np.testing.assert_allclose due to floating point rounding errors -from hls4ml.optimization.scheduler import BinaryScheduler, ConstantScheduler, PolynomialScheduler +from hls4ml.optimization.dsp_aware_pruning.scheduler import BinaryScheduler, ConstantScheduler, PolynomialScheduler def test_constant_scheduler(): diff --git a/test/pytest/test_pipeline_style.py b/test/pytest/test_pipeline_style.py old mode 100644 new mode 100755