diff --git a/docs/api/configuration.rst b/docs/api/configuration.rst index 091f88e619..72d677d196 100644 --- a/docs/api/configuration.rst +++ b/docs/api/configuration.rst @@ -9,6 +9,7 @@ We currently support two ways of setting hls4ml's model configuration. This page .. contents:: \ +The Python API approach is recommended for most users as there are more utilities to help create the configuration dictionaries. **NOTE:** @@ -16,8 +17,10 @@ We currently support two ways of setting hls4ml's model configuration. This page * One important part of ``hls4ml`` to remember is that the user is responsible for the format of the inputs. There is no automatic formatting or normalization so this must be done in the training. -* +.. + * For developers, you might also want to checkout this section: `Detailed configuration in converted hls codes <#detailed-configuration-in-converted-hls-codes>`_. + *Broken link* ---- @@ -31,11 +34,26 @@ Using hls4ml, you can quickly generate a simple configuration dictionary from a import hls4ml config = hls4ml.utils.config_from_keras_model(model, granularity='model') -For more advanced and detailed configuration, you can also set them through the created dictionary. For example, to change the reuse factor: +This python dictionary can be edited as needed. A more advanced configuration can be generated by, for example: + +.. code-block:: python + + import hls4ml + config = hls4ml.utils.config_from_keras_model( + model, + granularity='name', + default_precision='fixed<16,6>', + backend='Vitis') + +This will include per-layer configuration based on the model. Including the backend is recommended because some configation options depend on the backend. Note, the precisions at the +higher granularites usually default to 'auto', which means that ``hls4ml`` will try to set it automatically. Note that higher granularity settings take precendence +over model-level settings. See :py:class:`~hls4ml.utils.config.config_from_keras_model` for more information on the various options. + +One can override specific values before using the configuration: .. code-block:: python - config['Model']['ReuseFactor'] = 2 + config['LayerName']['fc1']['ReuseFactor'] = 2 Or to set the precision of a specific layer's weight: @@ -45,6 +63,20 @@ Or to set the precision of a specific layer's weight: To better understand how the configuration hierachy works, refer to the next section for more details. +Finally, one then uses the configuration to create an hls model: + +.. code-block:: python + + hls_model = hls4ml.converters.convert_from_keras_model( + model, + hls_config=config, + output_dir="my_project_dir", + io_type='io_stream', + backend='Vitis' + ) + +See :py:class:`~hls4ml.converters.convert_from_keras_model` for more information on the various options. + ---- 2. YAML Configuration file diff --git a/docs/setup.rst b/docs/setup.rst index f99b2f2dcb..a735281c3f 100644 --- a/docs/setup.rst +++ b/docs/setup.rst @@ -57,7 +57,7 @@ To run FPGA synthesis, installation of following tools is required: * Xilinx Vivado HLS 2018.2 to 2020.1 for synthesis for Xilinx FPGAs - * Vitis HLS 2022.1 or newer is required for synthesis for Xilinx FPGAs using the experimental ``Vitis`` backend. + * Vitis HLS 2022.2 or newer is required for synthesis for Xilinx FPGAs using the ``Vitis`` backend. * Intel Quartus 20.1 to 21.4 for the synthesis for Intel FPGAs diff --git a/docs/status.rst b/docs/status.rst index e4cac5e735..4ff4d33282 100644 --- a/docs/status.rst +++ b/docs/status.rst @@ -81,7 +81,7 @@ Other feature notes: * ``hls4ml`` is tested on Linux, and supports * Vivado HLS versions 2018.2 to 2020.1 * Intel HLS versions 20.1 to 21.4 - * Vitis HLS versions 2020.2 to 2022.2 (experimentally) + * Vitis HLS versions 2022.2 to 2024.1 * Windows and macOS are not supported * BDT support has moved to the `Conifer `__ package diff --git a/hls4ml/backends/catapult/catapult_backend.py b/hls4ml/backends/catapult/catapult_backend.py index 0583e80dab..d939e1f30b 100644 --- a/hls4ml/backends/catapult/catapult_backend.py +++ b/hls4ml/backends/catapult/catapult_backend.py @@ -110,6 +110,7 @@ def _register_flows(self): 'catapult:inplace_stream_flatten', 'catapult:skip_softmax', 'catapult:fix_softmax_table_size', + 'infer_precision_types', ] optimization_flow = register_flow('optimize', optimization_passes, requires=[init_flow], backend=self.name) diff --git a/hls4ml/backends/catapult/passes/conv_same_pad.py b/hls4ml/backends/catapult/passes/conv_same_pad.py index bb8354a3d0..8946e493fc 100755 --- a/hls4ml/backends/catapult/passes/conv_same_pad.py +++ b/hls4ml/backends/catapult/passes/conv_same_pad.py @@ -6,10 +6,8 @@ class InsertZeroPaddingBeforeConv1D(OptimizerPass): name = 'insert_zero_padding_before_conv1d' def match(self, node): - is_match = ( - isinstance(node, (Conv1D, SeparableConv1D)) - and ((node.get_attr('padding') == 'same') or (node.get_attr('padding') == 'causal')) - and node.get_attr('filt_width') != 1 + is_match = isinstance(node, (Conv1D, SeparableConv1D)) and ( + (node.get_attr('pad_left') != 0) or (node.get_attr('pad_right') != 0) ) return is_match @@ -37,7 +35,6 @@ def transform(self, model, node): } # Switch Conv1D layer padding to 'valid' - node.set_attr('padding', 'valid') node.set_attr('pad_left', 0) node.set_attr('pad_right', 0) node.set_attr('in_width', out_width) @@ -54,11 +51,11 @@ class InsertZeroPaddingBeforeConv2D(OptimizerPass): name = 'insert_zero_padding_before_conv2d' def match(self, node): - is_match = ( - isinstance(node, (Conv2D, SeparableConv2D)) - and node.get_attr('padding') == 'same' - and node.get_attr('filt_height') != 1 - and node.get_attr('filt_width') != 1 + is_match = isinstance(node, (Conv2D, SeparableConv2D)) and ( + (node.get_attr('pad_left') != 0) + or (node.get_attr('pad_right') != 0) + or (node.get_attr('pad_top') != 0) + or (node.get_attr('pad_bottom') != 0) ) return is_match @@ -93,7 +90,6 @@ def transform(self, model, node): } # Switch Conv2D layer padding to 'valid' - node.set_attr('padding', 'valid') node.set_attr('pad_top', 0) node.set_attr('pad_bottom', 0) node.set_attr('pad_left', 0) diff --git a/hls4ml/backends/vivado/passes/conv_same_pad.py b/hls4ml/backends/vivado/passes/conv_same_pad.py index bb8354a3d0..8946e493fc 100644 --- a/hls4ml/backends/vivado/passes/conv_same_pad.py +++ b/hls4ml/backends/vivado/passes/conv_same_pad.py @@ -6,10 +6,8 @@ class InsertZeroPaddingBeforeConv1D(OptimizerPass): name = 'insert_zero_padding_before_conv1d' def match(self, node): - is_match = ( - isinstance(node, (Conv1D, SeparableConv1D)) - and ((node.get_attr('padding') == 'same') or (node.get_attr('padding') == 'causal')) - and node.get_attr('filt_width') != 1 + is_match = isinstance(node, (Conv1D, SeparableConv1D)) and ( + (node.get_attr('pad_left') != 0) or (node.get_attr('pad_right') != 0) ) return is_match @@ -37,7 +35,6 @@ def transform(self, model, node): } # Switch Conv1D layer padding to 'valid' - node.set_attr('padding', 'valid') node.set_attr('pad_left', 0) node.set_attr('pad_right', 0) node.set_attr('in_width', out_width) @@ -54,11 +51,11 @@ class InsertZeroPaddingBeforeConv2D(OptimizerPass): name = 'insert_zero_padding_before_conv2d' def match(self, node): - is_match = ( - isinstance(node, (Conv2D, SeparableConv2D)) - and node.get_attr('padding') == 'same' - and node.get_attr('filt_height') != 1 - and node.get_attr('filt_width') != 1 + is_match = isinstance(node, (Conv2D, SeparableConv2D)) and ( + (node.get_attr('pad_left') != 0) + or (node.get_attr('pad_right') != 0) + or (node.get_attr('pad_top') != 0) + or (node.get_attr('pad_bottom') != 0) ) return is_match @@ -93,7 +90,6 @@ def transform(self, model, node): } # Switch Conv2D layer padding to 'valid' - node.set_attr('padding', 'valid') node.set_attr('pad_top', 0) node.set_attr('pad_bottom', 0) node.set_attr('pad_left', 0) diff --git a/hls4ml/converters/__init__.py b/hls4ml/converters/__init__.py index 3bd6d06c3b..092e53b3d3 100644 --- a/hls4ml/converters/__init__.py +++ b/hls4ml/converters/__init__.py @@ -10,6 +10,8 @@ from hls4ml.converters.keras_to_hls import get_supported_keras_layers # noqa: F401 from hls4ml.converters.keras_to_hls import parse_keras_model # noqa: F401 from hls4ml.converters.keras_to_hls import keras_to_hls, register_keras_layer_handler + +# from hls4ml.converters.pytorch_to_hls import parse_pytorch_model # noqa: F401 from hls4ml.model import ModelGraph from hls4ml.utils.config import create_config from hls4ml.utils.symbolic_utils import LUTFunction @@ -238,7 +240,6 @@ def convert_from_keras_model( def convert_from_pytorch_model( model, - input_shape, output_dir='my-hls-test', project_name='myproject', input_data_tb=None, @@ -251,7 +252,6 @@ def convert_from_pytorch_model( Args: model: PyTorch model to convert. - input_shape (list): The shape of the input tensor. First element is the batch size, needs to be None output_dir (str, optional): Output directory of the generated HLS project. Defaults to 'my-hls-test'. project_name (str, optional): Name of the HLS project. Defaults to 'myproject'. input_data_tb (str, optional): String representing the path of input data in .npy or .dat format that will be @@ -293,7 +293,6 @@ def convert_from_pytorch_model( config = create_config(output_dir=output_dir, project_name=project_name, backend=backend, **kwargs) config['PytorchModel'] = model - config['InputShape'] = input_shape config['InputData'] = input_data_tb config['OutputPredictions'] = output_data_tb config['HLSConfig'] = {} @@ -301,9 +300,9 @@ def convert_from_pytorch_model( if hls_config is None: hls_config = {} - model_config = hls_config.get('Model', None) + model_config = hls_config.get('Model') config['HLSConfig']['Model'] = _check_model_config(model_config) - + config['InputShape'] = hls_config.get('InputShape') _check_hls_config(config, hls_config) return pytorch_to_hls(config) diff --git a/hls4ml/converters/keras/convolution.py b/hls4ml/converters/keras/convolution.py index d223d55dfb..950a672692 100644 --- a/hls4ml/converters/keras/convolution.py +++ b/hls4ml/converters/keras/convolution.py @@ -30,10 +30,9 @@ def parse_conv1d_layer(keras_layer, input_names, input_shapes, data_reader): layer['n_filt'] = layer['n_chan'] * layer.get('depth_multiplier') layer['filt_width'] = keras_layer['config']['kernel_size'][0] layer['stride_width'] = keras_layer['config']['strides'][0] - layer['padding'] = keras_layer['config']['padding'] (layer['out_width'], layer['pad_left'], layer['pad_right']) = compute_padding_1d( - layer['padding'], layer['in_width'], layer['stride_width'], layer['filt_width'] + keras_layer['config']['padding'], layer['in_width'], layer['stride_width'], layer['filt_width'] ) if layer['data_format'] == 'channels_last': @@ -74,7 +73,6 @@ def parse_conv2d_layer(keras_layer, input_names, input_shapes, data_reader): layer['filt_width'] = keras_layer['config']['kernel_size'][1] layer['stride_height'] = keras_layer['config']['strides'][0] layer['stride_width'] = keras_layer['config']['strides'][1] - layer['padding'] = keras_layer['config']['padding'] ( layer['out_height'], @@ -84,7 +82,7 @@ def parse_conv2d_layer(keras_layer, input_names, input_shapes, data_reader): layer['pad_left'], layer['pad_right'], ) = compute_padding_2d( - layer['padding'], + keras_layer['config']['padding'], layer['in_height'], layer['in_width'], layer['stride_height'], diff --git a/hls4ml/converters/keras/pooling.py b/hls4ml/converters/keras/pooling.py index f0e00242b0..14d6a9236a 100644 --- a/hls4ml/converters/keras/pooling.py +++ b/hls4ml/converters/keras/pooling.py @@ -15,10 +15,9 @@ def parse_pooling_layer(keras_layer, input_names, input_shapes, data_reader): layer['pool_width'] = keras_layer['config']['pool_size'][0] layer['stride_width'] = keras_layer['config']['strides'][0] - layer['padding'] = keras_layer['config']['padding'] (layer['n_out'], layer['pad_left'], layer['pad_right']) = compute_padding_1d( - layer['padding'], layer['n_in'], layer['stride_width'], layer['pool_width'] + keras_layer['config']['padding'], layer['n_in'], layer['stride_width'], layer['pool_width'] ) if layer['data_format'] == 'channels_last': @@ -32,7 +31,6 @@ def parse_pooling_layer(keras_layer, input_names, input_shapes, data_reader): layer['stride_width'] = keras_layer['config']['strides'][1] layer['pool_height'] = keras_layer['config']['pool_size'][0] layer['pool_width'] = keras_layer['config']['pool_size'][1] - layer['padding'] = keras_layer['config']['padding'] ( layer['out_height'], @@ -42,7 +40,7 @@ def parse_pooling_layer(keras_layer, input_names, input_shapes, data_reader): layer['pad_left'], layer['pad_right'], ) = compute_padding_2d( - layer['padding'], + keras_layer['config']['padding'], layer['in_height'], layer['in_width'], layer['stride_height'], diff --git a/hls4ml/converters/pytorch/convolution.py b/hls4ml/converters/pytorch/convolution.py index 5c0d4d2d4c..40295e0865 100644 --- a/hls4ml/converters/pytorch/convolution.py +++ b/hls4ml/converters/pytorch/convolution.py @@ -35,11 +35,6 @@ def parse_conv1d_layer(operation, layer_name, input_names, input_shapes, node, c else: padding = class_object.padding - if padding == 0: # No padding, i.e., 'VALID' padding in Keras/Tensorflow - layer['padding'] = 'valid' - else: # Only 'valid' and 'same' padding are available in Keras - layer['padding'] = 'same' - # Ouput info (layer['out_width'], pad_left, pad_right) = compute_padding_1d_pytorch( padding, layer['in_width'], layer['stride_width'], layer['filt_width'], layer['dilation'] @@ -84,11 +79,6 @@ def parse_conv2d_layer(operation, layer_name, input_names, input_shapes, node, c layer['pad_top'] = layer['pad_bottom'] = class_object.padding[0] layer['pad_left'] = layer['pad_right'] = class_object.padding[1] - if all(x == 0 for x in class_object.padding): # No padding, i.e., 'VALID' padding in Keras/Tensorflow - layer['padding'] = 'valid' - else: # Only 'valid' and 'same' padding are available in Keras - layer['padding'] = 'same' - # Ouput info (layer['out_height'], layer['out_width'], _, _, _, _) = compute_padding_2d_pytorch( class_object.padding, diff --git a/hls4ml/converters/pytorch/core.py b/hls4ml/converters/pytorch/core.py index 0262fdab03..d3ba470bf5 100644 --- a/hls4ml/converters/pytorch/core.py +++ b/hls4ml/converters/pytorch/core.py @@ -43,14 +43,12 @@ def parse_activation_layer(operation, layer_name, input_names, input_shapes, nod layer = {} layer['class_name'] = operation - layer['activation'] = layer['class_name'] + layer['activation'] = layer['class_name'].lower() layer['name'] = layer_name layer['inputs'] = input_names - # if layer['class_name'] != 'Activation': - # layer['activation'] = layer['class_name'] if node.op == 'call_module': - if layer['class_name'] == 'ReLU' or layer['class_name'] == 'Sigmoid': + if layer['class_name'] in ['ReLU', 'Sigmoid', 'Tanh']: layer['class_name'] = 'Activation' if layer['class_name'] == 'LeakyReLU': layer['activ_param'] = class_object.negative_slope @@ -68,7 +66,7 @@ def parse_activation_layer(operation, layer_name, input_names, input_shapes, nod if hasattr(node, 'dim'): layer['axis'] = class_object.dim else: - if layer['class_name'] == 'ReLU' or layer['class_name'] == 'Sigmoid': + if layer['class_name'] in ['ReLU', 'Sigmoid', 'Tanh']: layer['class_name'] = 'Activation' if layer['class_name'] == 'LeakyReLU': layer['activ_param'] = node.kwargs['negative_slope'] diff --git a/hls4ml/converters/pytorch_to_hls.py b/hls4ml/converters/pytorch_to_hls.py index bd483b3690..79ca1fa5c6 100644 --- a/hls4ml/converters/pytorch_to_hls.py +++ b/hls4ml/converters/pytorch_to_hls.py @@ -84,6 +84,7 @@ def decorator(function): # map names of operations between toch.nn and torch.nn.functionals layer_name_map = { 'relu': 'ReLU', + 'tanh': 'Tanh', 'leaky_relu': 'LeakyReLU', 'elu': 'ELU', 'prelu': 'PReLU', @@ -102,7 +103,7 @@ def decorator(function): # ---------------------------------------------------------------- -def pytorch_to_hls(config): +def parse_pytorch_model(config, verbose=True): """Convert PyTorch model to hls4ml ModelGraph. Args: @@ -118,14 +119,15 @@ def pytorch_to_hls(config): # This is a list of dictionaries to hold all the layer info we need to generate HLS layer_list = [] - print('Interpreting Model ...') - + if verbose: + print('Interpreting Model ...') reader = PyTorchFileReader(config) if isinstance(config['PytorchModel'], str) else PyTorchModelReader(config) if type(reader.input_shape) is tuple: input_shapes = [list(reader.input_shape)] else: input_shapes = list(reader.input_shape) - input_shapes = [list(shape) for shape in input_shapes] + # first element needs to 'None' as placeholder for the batch size, insert it if not present + input_shapes = [[None] + list(shape) if shape[0] is not None else list(shape) for shape in input_shapes] model = reader.torch_model @@ -151,7 +153,8 @@ def pytorch_to_hls(config): output_shape = None # Loop through layers - print('Topology:') + if verbose: + print('Topology:') layer_counter = 0 n_inputs = 0 @@ -226,13 +229,14 @@ def pytorch_to_hls(config): pytorch_class, layer_name, input_names, input_shapes, node, class_object, reader, config ) - print( - 'Layer name: {}, layer type: {}, input shape: {}'.format( - layer['name'], - layer['class_name'], - input_shapes, + if verbose: + print( + 'Layer name: {}, layer type: {}, input shape: {}'.format( + layer['name'], + layer['class_name'], + input_shapes, + ) ) - ) layer_list.append(layer) assert output_shape is not None @@ -288,7 +292,12 @@ def pytorch_to_hls(config): operation, layer_name, input_names, input_shapes, node, None, reader, config ) - print('Layer name: {}, layer type: {}, input shape: {}'.format(layer['name'], layer['class_name'], input_shapes)) + if verbose: + print( + 'Layer name: {}, layer type: {}, input shape: {}'.format( + layer['name'], layer['class_name'], input_shapes + ) + ) layer_list.append(layer) assert output_shape is not None @@ -342,7 +351,12 @@ def pytorch_to_hls(config): operation, layer_name, input_names, input_shapes, node, None, reader, config ) - print('Layer name: {}, layer type: {}, input shape: {}'.format(layer['name'], layer['class_name'], input_shapes)) + if verbose: + print( + 'Layer name: {}, layer type: {}, input shape: {}'.format( + layer['name'], layer['class_name'], input_shapes + ) + ) layer_list.append(layer) assert output_shape is not None @@ -351,6 +365,11 @@ def pytorch_to_hls(config): if len(input_layers) == 0: input_layers = None + return layer_list, input_layers + + +def pytorch_to_hls(config): + layer_list, input_layers = parse_pytorch_model(config) print('Creating HLS model') hls_model = ModelGraph(config, layer_list, inputs=input_layers) return hls_model diff --git a/hls4ml/model/optimizer/__init__.py b/hls4ml/model/optimizer/__init__.py index de3dffb46c..77e38b0c5b 100644 --- a/hls4ml/model/optimizer/__init__.py +++ b/hls4ml/model/optimizer/__init__.py @@ -33,9 +33,8 @@ register_flow( 'convert', [ - 'seperable_to_depthwise_and_conv', # has to be before precision inference - 'infer_precision_types', 'channels_last_converter', + 'seperable_to_depthwise_and_conv', 'remove_transpose_before_flatten', 'remove_nop_transpose', 'remove_single_channel_transpose', @@ -45,19 +44,17 @@ 'qkeras_factorize_alpha', 'extract_ternary_threshold', 'fuse_consecutive_batch_normalization', + 'fuse_batch_normalization', 'replace_multidimensional_dense_with_conv', 'enforce_proxy_model_embedded_config', + 'eliminate_linear_activation', + # many of the above optimzers need to be done before this + 'infer_precision_types', ], ) # TODO Maybe not all QKeras optmizers belong here? register_flow( 'optimize', - [ - 'eliminate_linear_activation', - 'fuse_consecutive_batch_normalization', - 'fuse_batch_normalization', - 'infer_precision_types', - 'set_precision_concat', - ], + [], requires=['convert'], ) diff --git a/hls4ml/model/optimizer/passes/infer_precision.py b/hls4ml/model/optimizer/passes/infer_precision.py index 256e8a8152..bb24f2206e 100644 --- a/hls4ml/model/optimizer/passes/infer_precision.py +++ b/hls4ml/model/optimizer/passes/infer_precision.py @@ -1,9 +1,17 @@ import math +from typing import Iterable import numpy as np from hls4ml.model.optimizer import ConfigurableOptimizerPass -from hls4ml.model.types import FixedPrecisionType, UnspecifiedPrecisionType +from hls4ml.model.types import ( + FixedPrecisionType, + IntegerPrecisionType, + PrecisionType, + RoundingMode, + SaturationMode, + UnspecifiedPrecisionType, +) # TODO: The code assumes everything is Fixed or Integer precision. Need to add checks @@ -70,6 +78,12 @@ def _infer_precision(self, node, types_to_infer): if node_class in ['Dot']: return self._infer_dot_precision(node, types_to_infer) + if node_class in ['Embedding']: + return self._infer_embedding_precision(node, types_to_infer) + + if node_class in ['SimpleRNN', 'LSTM', 'GRU']: + return self._infer_rnn_precision(node, types_to_infer) + # What about quantized activation layer? Setting it to 'auto' manually will break it here. We should prevent # this in config_from_* functions @@ -79,6 +93,20 @@ def _get_default_precision(self, node): model_config = node.model.config return model_config.backend.convert_precision_string(model_config.model_precision['default']) + def _get_maximum_precision(self, node): + model_config = node.model.config + if 'maximum' in model_config.model_precision: + return model_config.backend.convert_precision_string(model_config.model_precision['maximum']) + else: + return None + + def _all_supported_types(self, types: Iterable[PrecisionType]): + """Are all the types supported for inference--currently Integer or Fixed""" + for tp in types: + if not isinstance(tp, (IntegerPrecisionType, FixedPrecisionType)): + return False + return True + def _infer_default_type(self, node, type_name): model_config = node.model.config default_precision = model_config.backend.convert_precision_string(model_config.model_precision['default']) @@ -99,9 +127,6 @@ def _infer_common_precision(self, node, types_to_infer, n_ops): inferred_types = [] input_precision = node.get_input_variable().type.precision - input_width = input_precision.width - input_integers = input_precision.integer - input_signed = input_precision.signed if 'weight_t' in types_to_infer: weight_quantizer = node.get_attr('weight_quantizer', None) @@ -113,10 +138,6 @@ def _infer_common_precision(self, node, types_to_infer, n_ops): node.weights['weight'].update_precision(node.types['weight_t'].precision) inferred_types.append('weight_t') - weight_width = node.types['weight_t'].precision.width - weight_integers = node.types['weight_t'].precision.integer - weight_signed = node.types['weight_t'].precision.signed - if 'bias_t' in types_to_infer: bias_quantizer = node.get_attr('bias_quantizer', None) if bias_quantizer is not None: @@ -127,25 +148,42 @@ def _infer_common_precision(self, node, types_to_infer, n_ops): node.weights['bias'].update_precision(node.types['bias_t'].precision) inferred_types.append('bias_t') - bias_width = node.types['bias_t'].precision.width - bias_integers = node.types['bias_t'].precision.integer - bias_signed = node.types['bias_t'].precision.signed - no_bias = node.weights['bias'].nonzeros == 0 and self.infer_no_bias # no bias + if self._all_supported_types((input_precision, node.types['weight_t'].precision, node.types['bias_t'].precision)): + input_width = input_precision.width + input_integers = input_precision.integer + input_signed = input_precision.signed - # using math.ceil instead of np.ceil because it returns an int - bitwidth = weight_width + input_width + math.ceil(np.log2(n_ops)) - integers = weight_integers + input_integers + math.ceil(np.log2(n_ops)) - signed = weight_signed or input_signed + weight_width = node.types['weight_t'].precision.width + weight_integers = node.types['weight_t'].precision.integer + weight_signed = node.types['weight_t'].precision.signed - frac = bitwidth - integers + bias_width = node.types['bias_t'].precision.width + bias_integers = node.types['bias_t'].precision.integer + bias_signed = node.types['bias_t'].precision.signed + no_bias = node.weights['bias'].nonzeros == 0 and self.infer_no_bias # no bias + + # using math.ceil instead of np.ceil because it returns an int + bitwidth = weight_width + input_width + math.ceil(np.log2(n_ops)) + integers = weight_integers + input_integers + math.ceil(np.log2(n_ops)) + signed = weight_signed or input_signed + + frac = bitwidth - integers - if not no_bias: - integers = max(integers + (bias_signed and not signed), bias_integers + (signed and not bias_signed)) + 1 - bitwidth = integers + max(frac, bias_width - bias_integers) - signed = signed or bias_signed + if not no_bias: + integers = max(integers + (bias_signed and not signed), bias_integers + (signed and not bias_signed)) + 1 + bitwidth = integers + max(frac, bias_width - bias_integers) + signed = signed or bias_signed - # Note: this is guaranteed to not overflow or need rounding, so it's sufficient to use the simpler form. - new_type = FixedPrecisionType(bitwidth, integers, signed) + # if max_precision is specified, limit the size to be less than max precisoin + max_precision = self._get_maximum_precision(node) + if max_precision is not None: + bitwidth = min(bitwidth, max_precision.width) + integers = min(integers, max_precision.integer) + + # Note: this is guaranteed to not overflow or need rounding, so it's sufficient to use the simpler form. + new_type = FixedPrecisionType(bitwidth, integers, signed) + else: + new_type = self._get_default_precision(node) if 'accum_t' in types_to_infer: node.types['accum_t'].name = node.name + '_accum_t' @@ -173,6 +211,8 @@ def _infer_depthconv_precision(self, node, types_to_infer): n_ops = node.get_attr('filt_height', 1) * node.get_attr('filt_width') return self._infer_common_precision(node, types_to_infer, n_ops) + # This function should generally not be called because we split sepconv to depthwise and regular (pointwise). + # It has not been updated. def _infer_sepconv_precision(self, node, types_to_infer): inferred_types = [] @@ -272,24 +312,35 @@ def _infer_bn_precision(self, node, types_to_infer): scale_precision = node.types['scale_t'].precision bias_precision = node.types['bias_t'].precision - after_scale_signed = scale_precision.signed or input_precision.signed - after_scale_width = input_precision.width + scale_precision.width - after_scale_integer = input_precision.integer + scale_precision.integer + if self._all_supported_types((input_precision, scale_precision, bias_precision)): + + after_scale_signed = scale_precision.signed or input_precision.signed + after_scale_width = input_precision.width + scale_precision.width + after_scale_integer = input_precision.integer + scale_precision.integer - out_precision_signed = after_scale_signed or bias_precision.signed - out_precision_integer = ( - max( - after_scale_integer + (bias_precision.signed and not after_scale_signed), - bias_precision.integer + (after_scale_signed and not bias_precision.signed), + out_precision_signed = after_scale_signed or bias_precision.signed + out_precision_integer = ( + max( + after_scale_integer + (bias_precision.signed and not after_scale_signed), + bias_precision.integer + (after_scale_signed and not bias_precision.signed), + ) + + 1 + ) + out_precision_width = out_precision_integer + max( + after_scale_width - after_scale_integer, bias_precision.fractional ) - + 1 - ) - out_precision_width = out_precision_integer + max( - after_scale_width - after_scale_integer, bias_precision.fractional - ) - # Note: this is guaranteed to not overflow or need rounding, so it's sufficient to use the simpler form. - out_precision = FixedPrecisionType(out_precision_width, out_precision_integer, out_precision_signed) + # if max_precision is specified, limit the size to be less than max precisoin + max_precision = self._get_maximum_precision(node) + if max_precision is not None: + out_precision_width = min(out_precision_width, max_precision.width) + out_precision_integer = min(out_precision_integer, max_precision.integer) + + # Note: this is guaranteed to not overflow or need rounding, so it's sufficient to use the simpler form. + out_precision = FixedPrecisionType(out_precision_width, out_precision_integer, out_precision_signed) + + else: + out_precision = self._get_default_precision(node) node.types['result_t'].name = node.name + '_result_t' node.types['result_t'].precision = out_precision @@ -305,20 +356,29 @@ def _infer_pooling_precision(self, node, types_to_infer): input_precision = node.get_input_variable().type.precision pool_op = node.attributes['pool_op'].lower() - width = input_precision.width - integer = input_precision.integer - signed = input_precision.signed + if pool_op == 'max': + # This has the benefit of working for xnor types. I don't think "copy" is needed + accum_type = input_precision + + elif pool_op == 'average': + if self._all_supported_types((input_precision,)): + width = input_precision.width + integer = input_precision.integer + signed = input_precision.signed + + pool_size = node.get_attr('pool_height', 1) * node.get_attr('pool_width') + extra_bits = int(np.ceil(np.log2(pool_size))) + + # for now ignore max precision in this case + accum_type = FixedPrecisionType( + width=width + extra_bits * 2, integer=integer + extra_bits, signed=signed + ) + else: + accum_type = self._get_default_precision(node) - pool_size = node.get_attr('pool_height', 1) * node.get_attr('pool_width') - if pool_op == 'average': - extra_bits = int(np.ceil(np.log2(pool_size))) - elif pool_op == 'max': - extra_bits = 0 else: raise ValueError(f'Unknown pooling operation: {pool_op}') - accum_type = FixedPrecisionType(width=width + extra_bits * 2, integer=integer + extra_bits, signed=signed) - node.types['accum_t'].name = node.name + '_accum_t' node.types['accum_t'].precision = accum_type @@ -338,22 +398,76 @@ def _infer_merge_precision(self, node, types_to_infer): op = node.get_attr('op').lower() if op in ('add', 'subtract', 'average'): - new_signed = input_1.signed or input_2.signed or op == 'subtract' - new_int = ( - max( - input_1.integer + (input_2.signed and not input_1.signed), - input_2.integer + (input_1.signed and not input_2.signed), + if self._all_supported_types((input_1, input_2)): + new_signed = input_1.signed or input_2.signed or op == 'subtract' + new_int = ( + max( + input_1.integer + (input_2.signed and not input_1.signed), + input_2.integer + (input_1.signed and not input_2.signed), + ) + + 1 ) - + 1 - ) - new_width = new_int + max(input_1.fractional, input_2.fractional) - out_precision = FixedPrecisionType(new_width, new_int, new_signed) + new_width = new_int + max(input_1.fractional, input_2.fractional) + max_precision = self._get_maximum_precision(node) + if max_precision is not None: + new_width = min(new_width, max_precision.width) + new_int = min(new_int, max_precision.integer) + out_precision = FixedPrecisionType(new_width, new_int, new_signed) + else: + out_precision = self._get_default_precision(node) elif op == 'multiply': - new_signed = input_1.signed or input_2.signed - new_int = input_1.integer + input_2.integer - new_width = input_1.width + input_2.width - out_precision = FixedPrecisionType(new_width, new_int, new_signed) + if self._all_supported_types((input_1, input_2)): + new_signed = input_1.signed or input_2.signed + new_int = input_1.integer + input_2.integer + new_width = input_1.width + input_2.width + # if max_precision is specified, limit the size to be less than max precisoin + max_precision = self._get_maximum_precision(node) + if max_precision is not None: + new_width = min(new_width, max_precision.width) + new_int = min(new_int, max_precision.integer) + out_precision = FixedPrecisionType(new_width, new_int, new_signed) + else: + out_precision = self._get_default_precision(node) elif op in ('maximum', 'minimum'): + if input_1 == input_2: + # can handle binary and potentially others + out_precision = input_1 # I assume copy is not necessary + elif self._all_supported_types((input_1, input_2)): + new_signed = input_1.signed or input_2.signed + + input_1_integer = input_1.integer + input_2_integer = input_2.integer + + # add one to integer if unsigned while new is signed + if new_signed and not input_1.signed: + input_1_integer += 1 + if new_signed and not input_2.signed: + input_2_integer += 1 + + new_width = max(input_1.fractional, input_2.fractional) + max(input_1_integer, input_2_integer) + new_int = max(input_1_integer, input_2_integer) + out_precision = FixedPrecisionType(new_width, new_int, new_signed) + else: + out_precision = self._get_default_precision(node) + else: + print(f'Warning: not propagating weights for type {op}') + out_precision = self._get_default_precision(node) + + node.types['result_t'].name = node.name + '_result_t' + node.types['result_t'].precision = out_precision + + return ['result_t'] + + def _infer_cat_precision(self, node, types_to_infer): + assert 'result_t' in types_to_infer and len(types_to_infer) == 1 + + input_1 = node.get_input_variable(node.inputs[0]).type.precision + input_2 = node.get_input_variable(node.inputs[1]).type.precision + + if input_1 == input_2: + # can handle binary and potentially others + out_precision = input_1 # I assume copy is not necessary + elif self._all_supported_types((input_1, input_2)): new_signed = input_1.signed or input_2.signed input_1_integer = input_1.integer @@ -367,9 +481,20 @@ def _infer_merge_precision(self, node, types_to_infer): new_width = max(input_1.fractional, input_2.fractional) + max(input_1_integer, input_2_integer) new_int = max(input_1_integer, input_2_integer) - out_precision = FixedPrecisionType(new_width, new_int, new_signed) + + # if max_precision is specified, limit the size to be less than max precisoin + max_precision = self._get_maximum_precision(node) + if max_precision is not None: + new_width = min(new_width, max_precision.width) + new_int = min(new_int, max_precision.integer) + + # some logic copied from former SetPrecisionConcat optimizer + newrmode = input_1.rounding_mode if input_1.rounding_mode != RoundingMode.TRN else input_2.rounding_mode + newsmode = input_1.saturation_mode if input_1.saturation_mode != SaturationMode.WRAP else input_2.saturation_mode + newsbits = input_1.saturation_bits if input_1.saturation_bits != 0 else input_2.saturation_bits + + out_precision = FixedPrecisionType(new_width, new_int, new_signed, newrmode, newsmode, newsbits) else: - print(f'Warning: not propagating weights for type {op}') out_precision = self._get_default_precision(node) node.types['result_t'].name = node.name + '_result_t' @@ -377,46 +502,58 @@ def _infer_merge_precision(self, node, types_to_infer): return ['result_t'] - def _infer_cat_precision(self, node, types_to_infer): + def _infer_dot_precision(self, node, types_to_infer): assert 'result_t' in types_to_infer and len(types_to_infer) == 1 input_1 = node.get_input_variable(node.inputs[0]).type.precision input_2 = node.get_input_variable(node.inputs[1]).type.precision - new_signed = input_1.signed or input_2.signed - - input_1_integer = input_1.integer - input_2_integer = input_2.integer + if self._all_supported_types((input_1, input_2)): + n_in = node.get_input_variable(node.inputs[0]).shape[0] - # add one to integer if unsigned while new is signed - if new_signed and not input_1.signed: - input_1_integer += 1 - if new_signed and not input_2.signed: - input_2_integer += 1 + new_signed = input_1.signed or input_2.signed + new_width = input_1.width + input_2.width + math.ceil(np.log2(n_in)) + new_int = input_1.integer + input_2.integer + math.ceil(np.log2(n_in)) - new_width = max(input_1.fractional, input_2.fractional) + max(input_1_integer, input_2_integer) - new_int = max(input_1_integer, input_2_integer) + # if max_precision is specified, limit the size to be less than max precisoin + max_precision = self._get_maximum_precision(node) + if max_precision is not None: + new_width = min(new_width, max_precision.width) + new_int = min(new_int, max_precision.integer) - out_precision = FixedPrecisionType(new_width, new_int, new_signed) + out_precision = FixedPrecisionType(new_width, new_int, new_signed) + else: + out_precision = self._get_default_precision(node) node.types['result_t'].name = node.name + '_result_t' node.types['result_t'].precision = out_precision return ['result_t'] - def _infer_dot_precision(self, node, types_to_infer): - assert 'result_t' in types_to_infer and len(types_to_infer) == 1 + def _infer_embedding_precision(self, node, types_to_infer): + inferred_types = [] - input_1 = node.get_input_variable(node.inputs[0]).type.precision - input_2 = node.get_input_variable(node.inputs[1]).type.precision + if 'embeddings_t' in types_to_infer: + self._infer_default_type(node, 'embeddings_t') + node.weights['embeddings'].update_precision(node.types['embeddings_t'].precision) + inferred_types.append('embeddings_t') - n_in = node.get_input_variable(node.inputs[0]).shape[0] + if 'result_t' in types_to_infer: + out_precision = self._get_default_precision(node) + node.types['result_t'].name = node.name + '_result_t' + node.types['result_t'].precision = out_precision + inferred_types.append('result_t') - new_signed = input_1.signed or input_2.signed - new_width = input_1.width + input_2.width + math.ceil(np.log2(n_in)) - new_int = input_1.integer + input_2.integer + math.ceil(np.log2(n_in)) + return inferred_types - out_precision = FixedPrecisionType(new_width, new_int, new_signed) - node.types['result_t'].name = node.name + '_result_t' - node.types['result_t'].precision = out_precision + # TODO: This is just a placeholder + def _infer_rnn_precision(self, node, types_to_infer): + inferred_types = [] - return ['result_t'] + # for now just do the weights and leave the rest for the default catch + for weightvar in ('weight', 'bias', 'recurrent_weight', 'recurrent_bias'): + if f'{weightvar}_t' in types_to_infer: + self._infer_default_type(node, f'{weightvar}_t') + node.weights[weightvar].update_precision(node.types[f'{weightvar}_t'].precision) + inferred_types.append(f'{weightvar}_t') + + return inferred_types diff --git a/hls4ml/model/optimizer/passes/multi_dense.py b/hls4ml/model/optimizer/passes/multi_dense.py index 008011bde2..4419abf9c8 100644 --- a/hls4ml/model/optimizer/passes/multi_dense.py +++ b/hls4ml/model/optimizer/passes/multi_dense.py @@ -20,7 +20,6 @@ def transform(self, model, node): conv_attrs = { 'data_format': 'channels_last', - 'padding': 'valid', 'n_chan': input_shape[-1], 'n_filt': node.get_attr('n_out'), 'weight_data': np.expand_dims(node.get_attr('weight_data'), axis=tuple(range(dim))), diff --git a/hls4ml/model/optimizer/passes/precision_merge.py b/hls4ml/model/optimizer/passes/precision_merge.py deleted file mode 100644 index 9e79b11000..0000000000 --- a/hls4ml/model/optimizer/passes/precision_merge.py +++ /dev/null @@ -1,40 +0,0 @@ -from hls4ml.model.optimizer import OptimizerPass -from hls4ml.model.types import FixedPrecisionType, RoundingMode, SaturationMode - - -def get_concat_type(itype1, itype2): - newwidth = max(itype1.width, itype2.width) - newint = max(itype1.integer, itype2.integer) - if itype1.signed ^ itype2.signed: # XOR - newint += 1 - newwidth += 1 - newrmode = itype1.rounding_mode if itype1.rounding_mode != RoundingMode.TRN else itype2.rounding_mode - newsmode = itype1.saturation_mode if itype1.saturation_mode != SaturationMode.WRAP else itype2.saturation_mode - newsbits = itype1.saturation_bits if itype1.saturation_bits != 0 else itype2.saturation_bits - - newtype = FixedPrecisionType(newwidth, newint, itype1.signed or itype2.signed, newrmode, newsmode, newsbits) - return newtype - - -class SetPrecisionConcat(OptimizerPass): - def match(self, node): - if node.__class__.__name__ == 'Concatenate': - otype = node.get_output_variable().type.precision - itype1 = node.get_input_variable(node.inputs[0]).type.precision - itype2 = node.get_input_variable(node.inputs[1]).type.precision - if isinstance(otype, FixedPrecisionType) and otype != get_concat_type(itype1, itype2): - return True - return False - - def transform(self, model, node): - """ - Set concat output precision - """ - otype = node.get_output_variable().type.precision - itype1 = node.get_input_variable(node.inputs[0]).type.precision - itype2 = node.get_input_variable(node.inputs[1]).type.precision - newtype = get_concat_type(itype1, itype2) - print(f"Found {node.name} in the model, optimizing {otype} to {newtype}...") - node.get_output_variable().type.precision = newtype - - return True diff --git a/hls4ml/model/optimizer/passes/seperable_to_dw_conv.py b/hls4ml/model/optimizer/passes/seperable_to_dw_conv.py index 7d3b71dc96..38eef1e7d0 100644 --- a/hls4ml/model/optimizer/passes/seperable_to_dw_conv.py +++ b/hls4ml/model/optimizer/passes/seperable_to_dw_conv.py @@ -33,7 +33,6 @@ class SeperableToDepthwiseAndConv(OptimizerPass): 'data_format', 'depthwise_data', 'depthwise_quantizer', - 'padding', ) _pw_attributes = ('out_width', 'n_filt', 'dilation_width', 'out_height', 'dilation_height', 'data_format', 'use_bias') diff --git a/hls4ml/templates/catapult/nnet_utils/nnet_pooling.h b/hls4ml/templates/catapult/nnet_utils/nnet_pooling.h index 82e281023b..d6ab38a960 100644 --- a/hls4ml/templates/catapult/nnet_utils/nnet_pooling.h +++ b/hls4ml/templates/catapult/nnet_utils/nnet_pooling.h @@ -107,22 +107,20 @@ void pooling1d_cl(data_T data[CONFIG_T::n_in * CONFIG_T::n_filt], res_T res[CONF // TODO partition the arrays according to the reuse factor const int limit = pool_op_limit_1d(); #pragma HLS ALLOCATION function instances=CONFIG_T::pool_op limit=limit - // Add any necessary padding - unsigned padded_width = CONFIG_T::n_in + CONFIG_T::pad_left + CONFIG_T::pad_right; - if (CONFIG_T::pad_left == 0 && CONFIG_T::pad_right == 0) { - padded_width -= padded_width - (padded_width / CONFIG_T::stride_width * CONFIG_T::stride_width); - } + // Add padding and reduce input width to area covered by pooling function + static constexpr int full_padded_width = CONFIG_T::n_in + CONFIG_T::pad_left + CONFIG_T::pad_right; + static constexpr int restricted_padded_width = full_padded_width / CONFIG_T::stride_width * CONFIG_T::stride_width; for (int ff = 0; ff < CONFIG_T::n_filt; ff++) { // Loop over input image x in steps of stride - for (int ii = 0; ii < padded_width; ii += CONFIG_T::stride_width) { + for (int ii = 0; ii < restricted_padded_width; ii += CONFIG_T::stride_width) { data_T pool[CONFIG_T::pool_width]; #pragma HLS ARRAY_PARTITION variable=pool complete dim=0 // Keep track of number of pixels in image vs padding region unsigned img_overlap = 0; // Loop over pool window x for (int jj = 0; jj < CONFIG_T::stride_width; jj++) { - if (ii + jj < CONFIG_T::pad_left || ii + jj >= (padded_width - CONFIG_T::pad_right)) { + if (ii + jj < CONFIG_T::pad_left || ii + jj >= (full_padded_width - CONFIG_T::pad_right)) { // Add padding pool[jj] = pad_val(); if (CONFIG_T::count_pad) { @@ -211,19 +209,17 @@ void pooling2d_cl(data_T data[CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_ // TODO partition the arrays according to the reuse factor const int limit = pool_op_limit(); #pragma HLS ALLOCATION function instances=CONFIG_T::pool_op limit=limit - // Add any necessary padding - unsigned padded_height = CONFIG_T::in_height + CONFIG_T::pad_top + CONFIG_T::pad_bottom; - unsigned padded_width = CONFIG_T::in_width + CONFIG_T::pad_left + CONFIG_T::pad_right; - if (CONFIG_T::pad_top == 0 && CONFIG_T::pad_bottom == 0 && CONFIG_T::pad_left == 0 && CONFIG_T::pad_right == 0) { - padded_height -= padded_height - (padded_height / CONFIG_T::stride_height * CONFIG_T::stride_height); - padded_width -= padded_width - (padded_width / CONFIG_T::stride_width * CONFIG_T::stride_width); - } + // Add padding and reduce input width to area covered by pooling function + static constexpr int full_padded_width = CONFIG_T::in_width + CONFIG_T::pad_left + CONFIG_T::pad_right; + static constexpr int full_padded_height = CONFIG_T::in_height + CONFIG_T::pad_top + CONFIG_T::pad_bottom; + static constexpr int restricted_padded_width = full_padded_width / CONFIG_T::stride_width * CONFIG_T::stride_width; + static constexpr int restricted_padded_height = full_padded_height / CONFIG_T::stride_height * CONFIG_T::stride_height; for (int ff = 0; ff < CONFIG_T::n_filt; ff++) { // Loop over input image y in steps of stride - for (int ii = 0; ii < padded_height; ii += CONFIG_T::stride_height) { + for (int ii = 0; ii < restricted_padded_height; ii += CONFIG_T::stride_height) { // Loop over input image x in steps of stride - for (int jj = 0; jj < padded_width; jj += CONFIG_T::stride_width) { + for (int jj = 0; jj < restricted_padded_width; jj += CONFIG_T::stride_width) { data_T pool[CONFIG_T::pool_height * CONFIG_T::pool_width]; #pragma HLS ARRAY_PARTITION variable=pool complete dim=0 // Keep track of number of pixels in image vs padding region @@ -232,8 +228,8 @@ void pooling2d_cl(data_T data[CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_ for (int kk = 0; kk < CONFIG_T::stride_height; kk++) { // Loop over pool window x for (int ll = 0; ll < CONFIG_T::stride_width; ll++) { - if (ii + kk < CONFIG_T::pad_top || ii + kk >= (padded_height - CONFIG_T::pad_bottom) || - jj + ll < CONFIG_T::pad_left || jj + ll >= (padded_width - CONFIG_T::pad_right)) { + if (ii + kk < CONFIG_T::pad_top || ii + kk >= (full_padded_height - CONFIG_T::pad_bottom) || + jj + ll < CONFIG_T::pad_left || jj + ll >= (full_padded_width - CONFIG_T::pad_right)) { // Add padding pool[kk * CONFIG_T::stride_width + ll] = pad_val(); if (CONFIG_T::count_pad) { @@ -275,19 +271,17 @@ void pooling2d_cf(data_T data[CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_ // TODO partition the arrays according to the reuse factor const int limit = pool_op_limit(); #pragma HLS ALLOCATION function instances=CONFIG_T::pool_op limit=limit - // Add any necessary padding - unsigned padded_height = CONFIG_T::in_height + CONFIG_T::pad_top + CONFIG_T::pad_bottom; - unsigned padded_width = CONFIG_T::in_width + CONFIG_T::pad_left + CONFIG_T::pad_right; - if (CONFIG_T::pad_top == 0 && CONFIG_T::pad_bottom == 0 && CONFIG_T::pad_left == 0 && CONFIG_T::pad_right == 0) { - padded_height -= padded_height - (padded_height / CONFIG_T::stride_height * CONFIG_T::stride_height); - padded_width -= padded_width - (padded_width / CONFIG_T::stride_width * CONFIG_T::stride_width); - } + // Add padding and reduce input width to area covered by pooling function + static constexpr int full_padded_width = CONFIG_T::in_width + CONFIG_T::pad_left + CONFIG_T::pad_right; + static constexpr int full_padded_height = CONFIG_T::in_height + CONFIG_T::pad_top + CONFIG_T::pad_bottom; + static constexpr int restricted_padded_width = full_padded_width / CONFIG_T::stride_width * CONFIG_T::stride_width; + static constexpr int restricted_padded_height = full_padded_height / CONFIG_T::stride_height * CONFIG_T::stride_height; for (int ff = 0; ff < CONFIG_T::n_filt; ff++) { // Loop over input image y in steps of stride - for (int ii = 0; ii < padded_height; ii += CONFIG_T::stride_height) { + for (int ii = 0; ii < restricted_padded_height; ii += CONFIG_T::stride_height) { // Loop over input image x in steps of stride - for (int jj = 0; jj < padded_width; jj += CONFIG_T::stride_width) { + for (int jj = 0; jj < restricted_padded_width; jj += CONFIG_T::stride_width) { data_T pool[CONFIG_T::pool_height * CONFIG_T::pool_width]; #pragma HLS ARRAY_PARTITION variable=pool complete dim=0 // Keep track of number of pixels in image vs padding region @@ -296,8 +290,8 @@ void pooling2d_cf(data_T data[CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_ for (int kk = 0; kk < CONFIG_T::stride_height; kk++) { // Loop over pool window x for (int ll = 0; ll < CONFIG_T::stride_width; ll++) { - if (ii + kk < CONFIG_T::pad_top || ii + kk >= (padded_height - CONFIG_T::pad_bottom) || - jj + ll < CONFIG_T::pad_left || jj + ll >= (padded_width - CONFIG_T::pad_right)) { + if (ii + kk < CONFIG_T::pad_top || ii + kk >= (full_padded_height - CONFIG_T::pad_bottom) || + jj + ll < CONFIG_T::pad_left || jj + ll >= (full_padded_width - CONFIG_T::pad_right)) { // Add padding pool[kk * CONFIG_T::stride_width + ll] = pad_val(); if (CONFIG_T::count_pad) { diff --git a/hls4ml/templates/vitis/nnet_utils/nnet_pooling.h b/hls4ml/templates/vitis/nnet_utils/nnet_pooling.h index d8ac60a839..93d23d2689 100644 --- a/hls4ml/templates/vitis/nnet_utils/nnet_pooling.h +++ b/hls4ml/templates/vitis/nnet_utils/nnet_pooling.h @@ -70,6 +70,7 @@ struct pooling1d_config { static const unsigned n_out = (n_in - pool_width) / stride_width + 1; static const unsigned pad_left = 0; static const unsigned pad_right = 0; + static const bool count_pad = false; // Pooling function static const Pool_Op pool_op = Max; }; @@ -88,14 +89,13 @@ void pooling1d_cl(data_T data[CONFIG_T::n_in * CONFIG_T::n_filt], res_T res[CONF CONFIG_T::pool_op, typename CONFIG_T::accum_t> limit=limit // Add any necessary padding - unsigned padded_width = CONFIG_T::n_in + CONFIG_T::pad_left + CONFIG_T::pad_right; - if (CONFIG_T::pad_left == 0 && CONFIG_T::pad_right == 0) { - padded_width -= padded_width - (padded_width / CONFIG_T::stride_width * CONFIG_T::stride_width); - } + // Add padding and reduce input width to area covered by pooling function + static constexpr int full_padded_width = CONFIG_T::n_in + CONFIG_T::pad_left + CONFIG_T::pad_right; + static constexpr int restricted_padded_width = full_padded_width / CONFIG_T::stride_width * CONFIG_T::stride_width; for (int ff = 0; ff < CONFIG_T::n_filt; ff++) { // Loop over input image x in steps of stride - for (int ii = 0; ii < padded_width; ii += CONFIG_T::stride_width) { + for (int ii = 0; ii < restricted_padded_width; ii += CONFIG_T::stride_width) { unsigned overlap_pixel = 0; data_T pool[CONFIG_T::pool_width]; #pragma HLS ARRAY_PARTITION variable=pool complete dim=0 @@ -130,6 +130,7 @@ void global_pooling1d_cl(data_T data[CONFIG_T::n_in * CONFIG_T::n_filt], res_T r for (int ff = 0; ff < CONFIG_T::n_filt; ff++) { data_T pool[CONFIG_T::n_in]; + #pragma HLS ARRAY_PARTITION variable=pool complete dim=0 for (int jj = 0; jj < CONFIG_T::n_in; jj++) { pool[jj] = data[jj * CONFIG_T::n_filt + ff]; } @@ -154,6 +155,7 @@ struct pooling2d_config { static const unsigned pad_bottom = 0; static const unsigned pad_left = 0; static const unsigned pad_right = 0; + static const bool count_pad = false; // Pooling function static const Pool_Op pool_op = Max; // Reuse factor @@ -176,18 +178,17 @@ void pooling2d_cl(data_T data[CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_ const int limit = pool_op_limit(); #pragma HLS ALLOCATION function instances=pool_op limit=limit - unsigned padded_height = CONFIG_T::in_height + CONFIG_T::pad_top + CONFIG_T::pad_bottom; - unsigned padded_width = CONFIG_T::in_width + CONFIG_T::pad_left + CONFIG_T::pad_right; - if (CONFIG_T::pad_top == 0 && CONFIG_T::pad_bottom == 0 && CONFIG_T::pad_left == 0 && CONFIG_T::pad_right == 0) { - padded_height -= padded_height - (padded_height / CONFIG_T::stride_height * CONFIG_T::stride_height); - padded_width -= padded_width - (padded_width / CONFIG_T::stride_width * CONFIG_T::stride_width); - } + // Add padding and reduce input width to area covered by pooling function + static constexpr int full_padded_width = CONFIG_T::in_width + CONFIG_T::pad_left + CONFIG_T::pad_right; + static constexpr int full_padded_height = CONFIG_T::in_height + CONFIG_T::pad_top + CONFIG_T::pad_bottom; + static constexpr int restricted_padded_width = full_padded_width / CONFIG_T::stride_width * CONFIG_T::stride_width; + static constexpr int restricted_padded_height = full_padded_height / CONFIG_T::stride_height * CONFIG_T::stride_height; for (int ff = 0; ff < CONFIG_T::n_filt; ff++) { // Loop over input image y in steps of stride - for (int ii = 0; ii < padded_height; ii += CONFIG_T::stride_height) { + for (int ii = 0; ii < restricted_padded_height; ii += CONFIG_T::stride_height) { // Loop over input image x in steps of stride - for (int jj = 0; jj < padded_width; jj += CONFIG_T::stride_width) { + for (int jj = 0; jj < restricted_padded_width; jj += CONFIG_T::stride_width) { data_T pool[CONFIG_T::pool_height * CONFIG_T::pool_width]; #pragma HLS ARRAY_PARTITION variable=pool complete dim=0 @@ -231,34 +232,35 @@ void pooling2d_cf(data_T data[CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_ const int limit = pool_op_limit(); #pragma HLS ALLOCATION function instances=pool_op limit=limit - // Add any necessary padding - unsigned padded_height = CONFIG_T::in_height + CONFIG_T::pad_top + CONFIG_T::pad_bottom; - unsigned padded_width = CONFIG_T::in_width + CONFIG_T::pad_left + CONFIG_T::pad_right; - if (CONFIG_T::pad_top == 0 && CONFIG_T::pad_bottom == 0 && CONFIG_T::pad_left == 0 && CONFIG_T::pad_right == 0) { - padded_height -= padded_height - (padded_height / CONFIG_T::stride_height * CONFIG_T::stride_height); - padded_width -= padded_width - (padded_width / CONFIG_T::stride_width * CONFIG_T::stride_width); - } + // Add padding and reduce input width to area covered by pooling function + static constexpr int full_padded_width = CONFIG_T::in_width + CONFIG_T::pad_left + CONFIG_T::pad_right; + static constexpr int full_padded_height = CONFIG_T::in_height + CONFIG_T::pad_top + CONFIG_T::pad_bottom; + static constexpr int restricted_padded_width = full_padded_width / CONFIG_T::stride_width * CONFIG_T::stride_width; + static constexpr int restricted_padded_height = full_padded_height / CONFIG_T::stride_height * CONFIG_T::stride_height; for (int ff = 0; ff < CONFIG_T::n_filt; ff++) { // Loop over input image y in steps of stride - for (int ii = 0; ii < padded_height; ii += CONFIG_T::stride_height) { + for (int ii = 0; ii < restricted_padded_height; ii += CONFIG_T::stride_height) { // Loop over input image x in steps of stride - for (int jj = 0; jj < padded_width; jj += CONFIG_T::stride_width) { + for (int jj = 0; jj < restricted_padded_width; jj += CONFIG_T::stride_width) { data_T pool[CONFIG_T::pool_height * CONFIG_T::pool_width]; + #pragma HLS ARRAY_PARTITION variable=pool complete dim=0 // Keep track of number of pixels in image vs padding region unsigned img_overlap = 0; // Loop over pool window y for (int kk = 0; kk < CONFIG_T::stride_height; kk++) { // Loop over pool window x for (int ll = 0; ll < CONFIG_T::stride_width; ll++) { - if (ii + kk < CONFIG_T::pad_top || ii + kk >= (padded_height - CONFIG_T::pad_bottom) || - jj + ll < CONFIG_T::pad_left || jj + ll >= (padded_width - CONFIG_T::pad_right)) { + if (ii + kk < CONFIG_T::pad_top || ii + kk >= (full_padded_height - CONFIG_T::pad_bottom) || + jj + ll < CONFIG_T::pad_left || jj + ll >= (full_padded_width - CONFIG_T::pad_right)) { // Add padding pool[kk * CONFIG_T::stride_width + ll] = pad_val(); + if (CONFIG_T::count_pad) + img_overlap++; } else { pool[kk * CONFIG_T::stride_width + ll] = - data[(ii + kk) * CONFIG_T::in_width + ff * CONFIG_T::in_width * CONFIG_T::in_height + ll + - jj]; + data[(ii + kk - CONFIG_T::pad_top) * CONFIG_T::in_width + + ff * CONFIG_T::in_width * CONFIG_T::in_height + ll + jj - CONFIG_T::pad_left]; img_overlap++; } } diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_pooling.h b/hls4ml/templates/vivado/nnet_utils/nnet_pooling.h index e6182d20db..bb9f0b3f05 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_pooling.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_pooling.h @@ -87,14 +87,13 @@ void pooling1d_cl(data_T data[CONFIG_T::n_in * CONFIG_T::n_filt], res_T res[CONF #pragma HLS ALLOCATION function instances=CONFIG_T::pool_op limit=limit // Add any necessary padding - unsigned padded_width = CONFIG_T::n_in + CONFIG_T::pad_left + CONFIG_T::pad_right; - if (CONFIG_T::pad_left == 0 && CONFIG_T::pad_right == 0) { - padded_width -= padded_width - (padded_width / CONFIG_T::stride_width * CONFIG_T::stride_width); - } + // Add padding and reduce input width to area covered by pooling function + static constexpr int full_padded_width = CONFIG_T::n_in + CONFIG_T::pad_left + CONFIG_T::pad_right; + static constexpr int restricted_padded_width = full_padded_width / CONFIG_T::stride_width * CONFIG_T::stride_width; for (int ff = 0; ff < CONFIG_T::n_filt; ff++) { // Loop over input image x in steps of stride - for (int ii = 0; ii < padded_width; ii += CONFIG_T::stride_width) { + for (int ii = 0; ii < restricted_padded_width; ii += CONFIG_T::stride_width) { unsigned overlap_pixel = 0; data_T pool[CONFIG_T::pool_width]; #pragma HLS ARRAY_PARTITION variable=pool complete dim=0 @@ -176,19 +175,18 @@ void pooling2d_cl(data_T data[CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_ const int limit = pool_op_limit(); #pragma HLS ALLOCATION function instances=CONFIG_T::pool_op limit=limit - unsigned padded_height = CONFIG_T::in_height + CONFIG_T::pad_top + CONFIG_T::pad_bottom; - unsigned padded_width = CONFIG_T::in_width + CONFIG_T::pad_left + CONFIG_T::pad_right; - if (CONFIG_T::pad_top == 0 && CONFIG_T::pad_bottom == 0 && CONFIG_T::pad_left == 0 && CONFIG_T::pad_right == 0) { - padded_height -= padded_height - (padded_height / CONFIG_T::stride_height * CONFIG_T::stride_height); - padded_width -= padded_width - (padded_width / CONFIG_T::stride_width * CONFIG_T::stride_width); - } + // Add padding and reduce input width to area covered by pooling function + static constexpr int full_padded_width = CONFIG_T::in_width + CONFIG_T::pad_left + CONFIG_T::pad_right; + static constexpr int full_padded_height = CONFIG_T::in_height + CONFIG_T::pad_top + CONFIG_T::pad_bottom; + static constexpr int restricted_padded_width = full_padded_width / CONFIG_T::stride_width * CONFIG_T::stride_width; + static constexpr int restricted_padded_height = full_padded_height / CONFIG_T::stride_height * CONFIG_T::stride_height; for (int ff = 0; ff < CONFIG_T::n_filt; ff++) { // Loop over input image y in steps of stride - for (int ii = 0; ii < padded_height; ii += CONFIG_T::stride_height) { + for (int ii = 0; ii < restricted_padded_height; ii += CONFIG_T::stride_height) { // Loop over input image x in steps of stride - for (int jj = 0; jj < padded_width; jj += CONFIG_T::stride_width) { + for (int jj = 0; jj < restricted_padded_width; jj += CONFIG_T::stride_width) { data_T pool[CONFIG_T::pool_height * CONFIG_T::pool_width]; #pragma HLS ARRAY_PARTITION variable=pool complete dim=0 @@ -231,19 +229,17 @@ void pooling2d_cf(data_T data[CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_ // TODO partition the arrays according to the reuse factor const int limit = pool_op_limit(); #pragma HLS ALLOCATION function instances=CONFIG_T::pool_op limit=limit - // Add any necessary padding - unsigned padded_height = CONFIG_T::in_height + CONFIG_T::pad_top + CONFIG_T::pad_bottom; - unsigned padded_width = CONFIG_T::in_width + CONFIG_T::pad_left + CONFIG_T::pad_right; - if (CONFIG_T::pad_top == 0 && CONFIG_T::pad_bottom == 0 && CONFIG_T::pad_left == 0 && CONFIG_T::pad_right == 0) { - padded_height -= padded_height - (padded_height / CONFIG_T::stride_height * CONFIG_T::stride_height); - padded_width -= padded_width - (padded_width / CONFIG_T::stride_width * CONFIG_T::stride_width); - } + // Add padding and reduce input width to area covered by pooling function + static constexpr int full_padded_width = CONFIG_T::in_width + CONFIG_T::pad_left + CONFIG_T::pad_right; + static constexpr int full_padded_height = CONFIG_T::in_height + CONFIG_T::pad_top + CONFIG_T::pad_bottom; + static constexpr int restricted_padded_width = full_padded_width / CONFIG_T::stride_width * CONFIG_T::stride_width; + static constexpr int restricted_padded_height = full_padded_height / CONFIG_T::stride_height * CONFIG_T::stride_height; for (int ff = 0; ff < CONFIG_T::n_filt; ff++) { // Loop over input image y in steps of stride - for (int ii = 0; ii < padded_height; ii += CONFIG_T::stride_height) { + for (int ii = 0; ii < restricted_padded_height; ii += CONFIG_T::stride_height) { // Loop over input image x in steps of stride - for (int jj = 0; jj < padded_width; jj += CONFIG_T::stride_width) { + for (int jj = 0; jj < restricted_padded_width; jj += CONFIG_T::stride_width) { data_T pool[CONFIG_T::pool_height * CONFIG_T::pool_width]; #pragma HLS ARRAY_PARTITION variable=pool complete dim=0 // Keep track of number of pixels in image vs padding region @@ -252,8 +248,8 @@ void pooling2d_cf(data_T data[CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_ for (int kk = 0; kk < CONFIG_T::stride_height; kk++) { // Loop over pool window x for (int ll = 0; ll < CONFIG_T::stride_width; ll++) { - if (ii + kk < CONFIG_T::pad_top || ii + kk >= (padded_height - CONFIG_T::pad_bottom) || - jj + ll < CONFIG_T::pad_left || jj + ll >= (padded_width - CONFIG_T::pad_right)) { + if (ii + kk < CONFIG_T::pad_top || ii + kk >= (full_padded_height - CONFIG_T::pad_bottom) || + jj + ll < CONFIG_T::pad_left || jj + ll >= (full_padded_width - CONFIG_T::pad_right)) { // Add padding pool[kk * CONFIG_T::stride_width + ll] = pad_val(); if (CONFIG_T::count_pad) diff --git a/hls4ml/utils/config.py b/hls4ml/utils/config.py index 1a297787d6..5cd17d02e9 100644 --- a/hls4ml/utils/config.py +++ b/hls4ml/utils/config.py @@ -112,7 +112,7 @@ def _get_precision_from_quantizer(quantizer): def config_from_keras_model( - model, granularity='model', backend=None, default_precision='fixed<16,6>', default_reuse_factor=1 + model, granularity='model', backend=None, default_precision='fixed<16,6>', default_reuse_factor=1, max_precision=None ): """Create an HLS conversion config given the Keras model. @@ -132,8 +132,11 @@ def config_from_keras_model( will generate config keys for every layer separately, allowing for highly specific configuration tweaks. backend(str, optional): Name of the backend to use - default_precision (str, optional): Default precision to use. Defaults to 'fixed<16,6>'. + default_precision (str, optional): Default precision to use. Defaults to 'fixed<16,6>'. Note, this must + be an explicit precision: 'auto' is not allowed. default_reuse_factor (int, optional): Default reuse factor. Defaults to 1. + max_precision (str or None, optional): Maximum width precision to use. Defaults to None, meaning no maximum. + Note: Only integer and fixed precisions are supported Raises: Exception: If Keras model has layers not supported by hls4ml. @@ -182,9 +185,11 @@ def make_layer_config(layer): if name.endswith('_t'): name = name[:-2] if attr.default is None: - precision_cfg[name] = default_precision + precision_cfg[name] = 'auto' else: precision_cfg[name] = str(attr.default) + elif attr.name == 'reuse_factor': + layer_config[attr.config_name] = default_reuse_factor else: if attr.default is not None: layer_config[attr.config_name] = attr.default @@ -238,7 +243,10 @@ def make_layer_config(layer): config = {} model_config = {} - model_config['Precision'] = default_precision + model_config['Precision'] = {} + model_config['Precision']['default'] = default_precision + if max_precision is not None: + model_config['Precision']['maximum'] = max_precision model_config['ReuseFactor'] = default_reuse_factor model_config['Strategy'] = 'Latency' model_config['BramFactor'] = 1_000_000_000 @@ -269,6 +277,7 @@ def make_layer_config(layer): def config_from_pytorch_model( model, + input_shape, granularity='model', backend=None, default_precision='ap_fixed<16,6>', @@ -284,6 +293,7 @@ def config_from_pytorch_model( Args: model: PyTorch model + input_shape (tuple or list of tuples): The shape of the input tensor, excluding the batch size. granularity (str, optional): Granularity of the created config. Defaults to 'model'. Can be set to 'model', 'type' and 'layer'. @@ -321,6 +331,83 @@ def config_from_pytorch_model( model_config['Strategy'] = 'Latency' config['Model'] = model_config + config['PytorchModel'] = model + if not (isinstance(input_shape, tuple) or (isinstance(input_shape, list) and isinstance(input_shape[0], tuple))): + raise Exception('Input shape must be tuple (single input) or list of tuples (multiple inputs)') + config['InputShape'] = input_shape + + if granularity.lower() not in ['model', 'type', 'name']: + raise Exception( + f'Invalid configuration granularity specified, expected "model", "type" or "name" got "{granularity}"' + ) + + if backend is not None: + backend = hls4ml.backends.get_backend(backend) + + from hls4ml.converters.pytorch_to_hls import parse_pytorch_model + + ( + layer_list, + _, + ) = parse_pytorch_model(config, verbose=False) + + def make_layer_config(layer): + cls_name = layer['class_name'] + if 'config' in layer.keys(): + if 'activation' in layer['config'].keys(): + if layer['config']['activation'] == 'softmax': + cls_name = 'Softmax' + + layer_cls = hls4ml.model.layers.layer_map[cls_name] + if backend is not None: + layer_cls = backend.create_layer_class(layer_cls) + + layer_config = {} + + config_attrs = [a for a in layer_cls.expected_attributes if a.configurable] + for attr in config_attrs: + if isinstance(attr, hls4ml.model.attributes.TypeAttribute): + precision_cfg = layer_config.setdefault('Precision', {}) + name = attr.name + if name.endswith('_t'): + name = name[:-2] + if attr.default is None: + precision_cfg[name] = default_precision + else: + precision_cfg[name] = str(attr.default) + elif attr.name == 'reuse_factor': + layer_config[attr.config_name] = default_reuse_factor + else: + if attr.default is not None: + layer_config[attr.config_name] = attr.default + + if layer['class_name'] == 'Input': + dtype = layer['config']['dtype'] + if dtype.startswith('int') or dtype.startswith('uint'): + typename = dtype[: dtype.index('int') + 3] + width = int(dtype[dtype.index('int') + 3 :]) + layer_config['Precision']['result'] = f'ap_{typename}<{width}>' + # elif bool, q[u]int, ... + + return layer_config + + if granularity.lower() == 'type': + type_config = {} + for layer in layer_list: + if layer['class_name'] in type_config: + continue + layer_config = make_layer_config(layer) + type_config[layer['class_name']] = layer_config + + config['LayerType'] = type_config + + elif granularity.lower() == 'name': + name_config = {} + for layer in layer_list: + layer_config = make_layer_config(layer) + name_config[layer['name']] = layer_config + + config['LayerName'] = name_config return config diff --git a/test/pytest/generate_ci_yaml.py b/test/pytest/generate_ci_yaml.py index f62d752ade..b130b43cef 100644 --- a/test/pytest/generate_ci_yaml.py +++ b/test/pytest/generate_ci_yaml.py @@ -20,8 +20,12 @@ n_test_files_per_yml = int(os.environ.get('N_TESTS_PER_YAML', 4)) +# Blacklisted tests will be skipped BLACKLIST = {'test_reduction'} +# Long-running tests will not be bundled with other tests +LONGLIST = {'test_hgq_layers'} + def path_to_name(test_path): path = Path(test_path) @@ -43,9 +47,7 @@ def uses_example_model(test_filename): def generate_test_yaml(test_root='.'): test_root = Path(test_root) - test_paths = [path for path in test_root.glob('**/test_*.py') if path.stem not in BLACKLIST] - for path in test_paths: - print(path.name) + test_paths = [path for path in test_root.glob('**/test_*.py') if path.stem not in (BLACKLIST | LONGLIST)] need_example_models = [uses_example_model(path) for path in test_paths] idxs = list(range(len(need_example_models))) @@ -63,6 +65,15 @@ def generate_test_yaml(test_root='.'): yml = diff_yml else: yml.update(diff_yml) + + test_paths = [path for path in test_root.glob('**/test_*.py') if path.stem in LONGLIST] + for path in test_paths: + name = path.stem.replace('test_', '') + test_file = str(path.relative_to(test_root)) + needs_examples = uses_example_model(path) + diff_yml = yaml.safe_load(template.format(name, test_file, needs_examples)) + yml.update(diff_yml) + return yml diff --git a/test/pytest/test_backend_config.py b/test/pytest/test_backend_config.py index 346402de13..c43a7c7680 100644 --- a/test/pytest/test_backend_config.py +++ b/test/pytest/test_backend_config.py @@ -31,7 +31,7 @@ def test_backend_config(framework, backend, part, clock_period, clock_unc): convert_fn = hls4ml.converters.convert_from_keras_model else: model = torch.nn.Sequential(torch.nn.Linear(1, 2), torch.nn.ReLU()) - config = hls4ml.utils.config_from_pytorch_model(model) + config = hls4ml.utils.config_from_pytorch_model(model, input_shape=(None, 1)) convert_fn = hls4ml.converters.convert_from_pytorch_model if clock_unc is not None: @@ -42,16 +42,27 @@ def test_backend_config(framework, backend, part, clock_period, clock_unc): test_dir = f'hls4mlprj_backend_config_{framework}_{backend}_part_{part}_period_{clock_period}_unc_{unc_str}' output_dir = test_root_path / test_dir - hls_model = convert_fn( - model, - input_shape=(None, 1), # This serves as a test of handling unexpected values by the backend in keras converer - hls_config=config, - output_dir=str(output_dir), - backend=backend, - part=part, - clock_period=clock_period, - clock_uncertainty=clock_unc, - ) + if framework == "keras": + hls_model = convert_fn( + model, + input_shape=(None, 1), # This serves as a test of handling unexpected values by the backend in keras converer + hls_config=config, + output_dir=str(output_dir), + backend=backend, + part=part, + clock_period=clock_period, + clock_uncertainty=clock_unc, + ) + else: + hls_model = convert_fn( + model, + hls_config=config, + output_dir=str(output_dir), + backend=backend, + part=part, + clock_period=clock_period, + clock_uncertainty=clock_unc, + ) hls_model.write() diff --git a/test/pytest/test_batchnorm.py b/test/pytest/test_batchnorm.py index f3812cd753..34ff84695b 100644 --- a/test/pytest/test_batchnorm.py +++ b/test/pytest/test_batchnorm.py @@ -36,7 +36,9 @@ def test_batchnorm(model, data, backend, io_type): center = model.layers[0].center scale = model.layers[0].scale - config = hls4ml.utils.config_from_keras_model(model, default_precision=default_precision, granularity='name') + config = hls4ml.utils.config_from_keras_model( + model, default_precision=default_precision, granularity='name', backend=backend + ) output_dir = str(test_root_path / f'hls4mlprj_batchnorm_{backend}_{io_type}_center{center}_scale{scale}') hls_model = hls4ml.converters.convert_from_keras_model( model, backend=backend, hls_config=config, io_type=io_type, output_dir=output_dir diff --git a/test/pytest/test_batchnorm_pytorch.py b/test/pytest/test_batchnorm_pytorch.py index 1e45e7ae0f..137aee8a1e 100644 --- a/test/pytest/test_batchnorm_pytorch.py +++ b/test/pytest/test_batchnorm_pytorch.py @@ -39,10 +39,12 @@ def test_batchnorm(data, backend, io_type): default_precision = 'ac_fixed<32, 1, true>' if backend == 'Quartus' else 'ac_fixed<32, 1>' - config = hls4ml.utils.config_from_pytorch_model(model, default_precision=default_precision, granularity='name') + config = hls4ml.utils.config_from_pytorch_model( + model, (in_shape,), default_precision=default_precision, granularity='name', backend=backend + ) output_dir = str(test_root_path / f'hls4mlprj_batchnorm_{backend}_{io_type}') hls_model = hls4ml.converters.convert_from_pytorch_model( - model, (None, in_shape), backend=backend, hls_config=config, io_type=io_type, output_dir=output_dir + model, backend=backend, hls_config=config, io_type=io_type, output_dir=output_dir ) hls_model.compile() @@ -94,9 +96,13 @@ def test_batchnorm_fusion(fusion_data, backend, io_type): # We do not have an implementation of a transpose for io_stream, need to transpose inputs and outputs outside of hls4ml if io_type == 'io_stream': fusion_data = np.ascontiguousarray(fusion_data.transpose(0, 2, 1)) - config = hls4ml.utils.config_from_pytorch_model(model, channels_last_conversion='internal', transpose_outputs=False) + config = hls4ml.utils.config_from_pytorch_model( + model, (n_in, size_in_height), channels_last_conversion='internal', transpose_outputs=False + ) else: - config = hls4ml.utils.config_from_pytorch_model(model, channels_last_conversion='full', transpose_outputs=True) + config = hls4ml.utils.config_from_pytorch_model( + model, (n_in, size_in_height), channels_last_conversion='full', transpose_outputs=True + ) config['Model']['Strategy'] = 'Resource' @@ -104,7 +110,6 @@ def test_batchnorm_fusion(fusion_data, backend, io_type): output_dir = str(test_root_path / f'hls4mlprj_block_{backend}_{io_type}') hls_model = hls4ml.converters.convert_from_pytorch_model( model, - (None, n_in, size_in_height), hls_config=config, output_dir=output_dir, backend=backend, diff --git a/test/pytest/test_binary_cnn.py b/test/pytest/test_binary_cnn.py index 40af056df9..c1fa1b1551 100644 --- a/test/pytest/test_binary_cnn.py +++ b/test/pytest/test_binary_cnn.py @@ -66,7 +66,9 @@ def test_binary_cnn(backend, io_type, strategy): model2.summary() - hls_config = hls4ml.utils.config_from_keras_model(model2, granularity='name', default_precision='fixed<32,12>') + hls_config = hls4ml.utils.config_from_keras_model( + model2, granularity='name', default_precision='fixed<32,12>', backend=backend + ) hls_config['Model']['Strategy'] = strategy # hls_config['LayerName']['q_dense_7_softmax']['Implementation'] = 'legacy' diff --git a/test/pytest/test_causalpadding.py b/test/pytest/test_causalpadding.py index c076c99987..d91da35fac 100644 --- a/test/pytest/test_causalpadding.py +++ b/test/pytest/test_causalpadding.py @@ -23,7 +23,9 @@ def test_causalpadding(io_type, backend): data = np.expand_dims(data, axis=0) data = np.expand_dims(data, axis=-1) - config = hls4ml.utils.config_from_keras_model(model, default_precision='ap_fixed<32,16>', granularity='name') + config = hls4ml.utils.config_from_keras_model( + model, default_precision='ap_fixed<32,16>', granularity='name', backend=backend + ) odir = str(test_root_path / f'hls4mlprj_validpadding_{backend}_{io_type}') hls_model = hls4ml.converters.convert_from_keras_model( model, hls_config=config, io_type=io_type, output_dir=odir, backend=backend diff --git a/test/pytest/test_clone_flatten.py b/test/pytest/test_clone_flatten.py index 5f631d027f..d819af54e7 100644 --- a/test/pytest/test_clone_flatten.py +++ b/test/pytest/test_clone_flatten.py @@ -31,9 +31,7 @@ def keras_model(): @pytest.mark.parametrize('backend', ['Vivado', 'Quartus', 'Catapult']) def hls_model(keras_model, backend, io_type): hls_config = hls4ml.utils.config_from_keras_model( - keras_model, - default_precision='ap_int<6>', - granularity='name', + keras_model, default_precision='ap_int<6>', granularity='name', backend=backend ) output_dir = str(test_root_path / f'hls4mlprj_clone_flatten_{backend}_{io_type}') hls_model = hls4ml.converters.convert_from_keras_model( diff --git a/test/pytest/test_cnn_mnist_qkeras.py b/test/pytest/test_cnn_mnist_qkeras.py index b4c28c70d1..38489b5865 100644 --- a/test/pytest/test_cnn_mnist_qkeras.py +++ b/test/pytest/test_cnn_mnist_qkeras.py @@ -58,7 +58,7 @@ def mnist_model(): ) def hls_model(mnist_model, backend, io_type, strategy): keras_model = mnist_model - hls_config = hls4ml.utils.config_from_keras_model(keras_model, granularity='name') + hls_config = hls4ml.utils.config_from_keras_model(keras_model, granularity='name', backend=backend) hls_config['Model']['Strategy'] = strategy hls_config['LayerName']['softmax']['Strategy'] = 'Stable' output_dir = str(test_root_path / f'hls4mlprj_cnn_mnist_qkeras_{backend}_{io_type}_{strategy}') diff --git a/test/pytest/test_embed.py b/test/pytest/test_embed.py index 9950060406..632f8e75f1 100644 --- a/test/pytest/test_embed.py +++ b/test/pytest/test_embed.py @@ -28,7 +28,9 @@ def keras_model(): @pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus', 'Catapult', 'oneAPI']) @pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream']) def hls_model(keras_model, backend, io_type): - hls_config = hls4ml.utils.config_from_keras_model(keras_model, default_precision='ap_fixed<16,6>', granularity='name') + hls_config = hls4ml.utils.config_from_keras_model( + keras_model, default_precision='ap_fixed<16,6>', granularity='name', backend=backend + ) hls_config['LayerName']['embedding_input']['Precision']['result'] = 'ap_uint<4>' out_dir = str(test_root_path / 'hls4mlprj_embed_{}_{}').format(backend, io_type) hls_model = hls4ml.converters.convert_from_keras_model( diff --git a/test/pytest/test_garnet.py b/test/pytest/test_garnet.py index 67ddf77182..62bc82a8c0 100644 --- a/test/pytest/test_garnet.py +++ b/test/pytest/test_garnet.py @@ -33,7 +33,7 @@ def garnet_models(): model = Model(inputs=inputs, outputs=outputs) model.summary() - config = hls4ml.utils.config_from_keras_model(model, granularity='name') + config = hls4ml.utils.config_from_keras_model(model, granularity='name', backend='Vivado') config['Model'] = {} config['Model']['ReuseFactor'] = 1 config['Model']['Strategy'] = 'Latency' @@ -68,7 +68,7 @@ def garnet_stack_models(): model = Model(inputs=inputs, outputs=outputs) model.summary() - config = hls4ml.utils.config_from_keras_model(model, granularity='name') + config = hls4ml.utils.config_from_keras_model(model, granularity='name', backend='Vivado') config['Model'] = {} config['Model']['ReuseFactor'] = 1 config['Model']['Strategy'] = 'Latency' diff --git a/test/pytest/test_globalpooling.py b/test/pytest/test_globalpooling.py index e2135e41e7..57f06dbdd4 100644 --- a/test/pytest/test_globalpooling.py +++ b/test/pytest/test_globalpooling.py @@ -53,7 +53,9 @@ def keras_model_1d(request): def test_global_pool1d(backend, keras_model_1d, data_1d, io_type): model, model_type, keepdims = keras_model_1d - config = hls4ml.utils.config_from_keras_model(model, default_precision='ap_fixed<32,9>', granularity='name') + config = hls4ml.utils.config_from_keras_model( + model, default_precision='ap_fixed<32,9>', granularity='name', backend=backend + ) hls_model = hls4ml.converters.convert_from_keras_model( model, @@ -108,7 +110,9 @@ def keras_model_2d(request): def test_global_pool2d(backend, keras_model_2d, data_2d, io_type): model, model_type, keepdims = keras_model_2d - config = hls4ml.utils.config_from_keras_model(model, default_precision='ap_fixed<32,9>', granularity='name') + config = hls4ml.utils.config_from_keras_model( + model, default_precision='ap_fixed<32,9>', granularity='name', backend=backend + ) hls_model = hls4ml.converters.convert_from_keras_model( model, diff --git a/test/pytest/test_keras_api.py b/test/pytest/test_keras_api.py index 64320a6641..eaf36e66fc 100644 --- a/test/pytest/test_keras_api.py +++ b/test/pytest/test_keras_api.py @@ -310,7 +310,9 @@ def test_depthwise2d(backend, io_type): model.add(DepthwiseConv2D(kernel_size=(3, 3), input_shape=(32, 32, 3))) model.compile() - config = hls4ml.utils.config_from_keras_model(model, granularity='name', default_precision='fixed<32,12>') + config = hls4ml.utils.config_from_keras_model( + model, granularity='name', default_precision='fixed<32,12>', backend=backend + ) output_dir = str(test_root_path / f'hls4mlprj_keras_api_depthwiseconv2d_{backend}_{io_type}') hls_model = hls4ml.converters.convert_from_keras_model( model, hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type @@ -336,7 +338,7 @@ def test_depthwise1d(backend, io_type): model.add(DepthwiseConv1D(kernel_size=3, input_shape=(32, 3))) model.compile() - config = hls4ml.utils.config_from_keras_model(model, granularity='name') + config = hls4ml.utils.config_from_keras_model(model, granularity='name', backend=backend) output_dir = str(test_root_path / f'hls4mlprj_keras_api_depthwiseconv1d_{backend}_{io_type}') hls_model = hls4ml.converters.convert_from_keras_model( model, hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type diff --git a/test/pytest/test_merge_pytorch.py b/test/pytest/test_merge_pytorch.py index ac42a7bb42..1dc461e532 100644 --- a/test/pytest/test_merge_pytorch.py +++ b/test/pytest/test_merge_pytorch.py @@ -41,14 +41,16 @@ def test_merge(merge_op, io_type, backend): model = MergeModule(merge_op) model.eval() - batch_input_shape = (None,) + input_shape config = hls4ml.utils.config_from_pytorch_model( - model, default_precision='ap_fixed<32,16>', channels_last_conversion="internal", transpose_outputs=False + model, + [input_shape, input_shape], + default_precision='ap_fixed<32,16>', + channels_last_conversion="internal", + transpose_outputs=False, ) output_dir = str(test_root_path / f'hls4mlprj_merge_pytorch_{merge_op}_{backend}_{io_type}') hls_model = hls4ml.converters.convert_from_pytorch_model( model, - [batch_input_shape, batch_input_shape], hls_config=config, output_dir=output_dir, io_type=io_type, diff --git a/test/pytest/test_optimization/test_attributes.py b/test/pytest/test_optimization/test_attributes.py index 2669321e09..3ba8d08d14 100644 --- a/test/pytest/test_optimization/test_attributes.py +++ b/test/pytest/test_optimization/test_attributes.py @@ -38,6 +38,12 @@ def test_attributes(): cfg['Model']['Strategy'] = strategy cfg['LayerName']['dense']['ReuseFactor'] = 1 + # optimization doesn't yet support auto precision + for layer in cfg['LayerName'].values(): + for key, prec in layer['Precision'].items(): + if prec == 'auto': + layer['Precision'][key] = default_precision + # Verify correct information for every layer model_attributes = get_attributes_from_keras_model_and_hls4ml_config(model, cfg) assert len(model_attributes) == 4 diff --git a/test/pytest/test_pointwiseconv.py b/test/pytest/test_pointwiseconv.py index 2217a0f3e8..678b22bfeb 100644 --- a/test/pytest/test_pointwiseconv.py +++ b/test/pytest/test_pointwiseconv.py @@ -158,7 +158,7 @@ def test_pointwise_config(strategy): model.compile(optimizer='adam', loss='mse') - config = hls4ml.utils.config_from_keras_model(model, granularity='name') + config = hls4ml.utils.config_from_keras_model(model, granularity='name', backend='Vivado') config['Model']['Strategy'] = strategy config['LayerName']['conv2d_1x1']['Strategy'] = strategy # Will fail if the strategy is not lowercase output_dir = str(test_root_path / f'hls4mlprj_pointwise2d_config_{strategy}') diff --git a/test/pytest/test_pooling.py b/test/pytest/test_pooling.py index 353ac4c432..1486ee33fe 100644 --- a/test/pytest/test_pooling.py +++ b/test/pytest/test_pooling.py @@ -53,7 +53,9 @@ def keras_model_1d(request): def test_pool1d(backend, keras_model_1d, data_1d, io_type): model, model_type, padding = keras_model_1d - config = hls4ml.utils.config_from_keras_model(model, default_precision='ap_fixed<32,9>', granularity='name') + config = hls4ml.utils.config_from_keras_model( + model, default_precision='ap_fixed<32,9>', granularity='name', backend=backend + ) hls_model = hls4ml.converters.convert_from_keras_model( model, @@ -141,7 +143,9 @@ def keras_model_2d(request): def test_pool2d(backend, keras_model_2d, data_2d, io_type): model, model_type, padding = keras_model_2d - config = hls4ml.utils.config_from_keras_model(model, default_precision='ap_fixed<32,9>', granularity='name') + config = hls4ml.utils.config_from_keras_model( + model, default_precision='ap_fixed<32,9>', granularity='name', backend=backend + ) hls_model = hls4ml.converters.convert_from_keras_model( model, diff --git a/test/pytest/test_pytorch_api.py b/test/pytest/test_pytorch_api.py index 5bc0351067..b8cce4259f 100644 --- a/test/pytest/test_pytorch_api.py +++ b/test/pytest/test_pytorch_api.py @@ -32,12 +32,10 @@ def test_linear(backend, io_type): pytorch_prediction = model(torch.Tensor(X_input)).detach().numpy() - config = config_from_pytorch_model(model) + config = config_from_pytorch_model(model, (1,)) output_dir = str(test_root_path / f'hls4mlprj_pytorch_api_linear_{backend}_{io_type}') - hls_model = convert_from_pytorch_model( - model, (None, 1), hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type - ) + hls_model = convert_from_pytorch_model(model, hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type) hls_model.compile() @@ -66,6 +64,7 @@ def test_linear(backend, io_type): "activation_function", [ nn.ReLU(), + nn.Tanh(), nn.LeakyReLU(negative_slope=1.0), nn.ELU(alpha=1.0), nn.PReLU(init=0.25), @@ -83,13 +82,11 @@ def test_activations(activation_function, backend, io_type): pytorch_prediction = model(torch.Tensor(X_input)).detach().numpy() - config = config_from_pytorch_model(model) + config = config_from_pytorch_model(model, (1,)) output_dir = str( test_root_path / f'hls4mlprj_pytorch_api_activations_{activation_function.__class__.__name__}_{backend}_{io_type}' ) - hls_model = convert_from_pytorch_model( - model, (None, 1), hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type - ) + hls_model = convert_from_pytorch_model(model, hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type) hls_model.compile() hls_prediction = hls_model.predict(X_input) @@ -106,7 +103,7 @@ def test_activations(activation_function, backend, io_type): assert nNodes - 1 == len(hls_model.get_layers()) - if activation_function.__class__.__name__ == 'ReLU' or activation_function.__class__.__name__ == 'Sigmoid': + if activation_function.__class__.__name__ in ['ReLU', 'Sigmoid', 'Tanh']: assert list(hls_model.get_layers())[2].attributes['class_name'] == 'Activation' elif activation_function.__class__.__name__ == 'Threshold': assert list(hls_model.get_layers())[2].attributes['class_name'] == 'ThresholdedReLU' @@ -122,6 +119,14 @@ def forward(self, x): return nn.functional.relu(x) +class TanHModel(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return nn.functional.tanh(x) + + class LeakyReLuModel(nn.Module): def __init__(self): super().__init__() @@ -158,6 +163,7 @@ def forward(self, x): "activation_function", [ ReLuModel(), + TanHModel(), LeakyReLuModel(), EluModel(), SigmoidModel(), @@ -174,12 +180,10 @@ def test_activation_functionals(activation_function, backend, io_type): pytorch_prediction = model(torch.Tensor(X_input)).detach().numpy() - config = config_from_pytorch_model(model) + config = config_from_pytorch_model(model, (1,)) fn_name = activation_function.__class__.__name__ - output_dir = str(test_root_path / f'hls4mlprj_pytorch_api_activations_functional_relu_{backend}_{io_type}_{fn_name}') - hls_model = convert_from_pytorch_model( - model, (None, 1), hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type - ) + output_dir = str(test_root_path / f'hls4mlprj_pytorch_api_activations_functional_{fn_name}_{backend}_{io_type}') + hls_model = convert_from_pytorch_model(model, hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type) hls_model.compile() hls_prediction = hls_model.predict(X_input) @@ -217,14 +221,14 @@ def test_conv1d(padds, backend, io_type): if io_type == 'io_stream': X_input = np.ascontiguousarray(X_input.transpose(0, 2, 1)) - config = config_from_pytorch_model(model, channels_last_conversion="internal", transpose_outputs=False) + config = config_from_pytorch_model( + model, (n_in, size_in), channels_last_conversion="internal", transpose_outputs=False + ) else: - config = config_from_pytorch_model(model, channels_last_conversion="full", transpose_outputs=True) + config = config_from_pytorch_model(model, (n_in, size_in), channels_last_conversion="full", transpose_outputs=True) output_dir = str(test_root_path / f'hls4mlprj_pytorch_api_conv1d_{padds}_{backend}_{io_type}') - hls_model = convert_from_pytorch_model( - model, (None, n_in, size_in), hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type - ) + hls_model = convert_from_pytorch_model(model, hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type) hls_model.compile() from torch.fx import symbolic_trace @@ -274,7 +278,7 @@ def test_conv1d(padds, backend, io_type): act_index = 2 assert list(hls_model.get_layers())[conv_index].attributes['name'] == convNode.name assert list(hls_model.get_layers())[conv_index].attributes['class_name'] == 'Conv1D' - assert list(hls_model.get_layers())[act_index].attributes['activation'] == class_object_relu.__class__.__name__ + assert list(hls_model.get_layers())[act_index].attributes['activation'] == class_object_relu.__class__.__name__.lower() if io_type == "io_stream" and (backend == "Vivado" or backend == "Vitis") and padds == 1: assert list(hls_model.get_layers())[conv_index].attributes["in_width"] == size_in + 2 else: @@ -283,10 +287,7 @@ def test_conv1d(padds, backend, io_type): assert list(hls_model.get_layers())[conv_index].attributes['n_chan'] == class_object_conv.in_channels assert list(hls_model.get_layers())[conv_index].attributes['n_filt'] == class_object_conv.out_channels assert list(hls_model.get_layers())[conv_index].attributes['stride_width'] == class_object_conv.stride[0] - if list(hls_model.get_layers())[conv_index].attributes['padding'] == 'valid': - padding = 0 - else: - padding = 1 + padding = padds if io_type == "io_stream" and (backend == "Vivado" or backend == "Vitis") and padds == 1: padding = 1 padds = 0 @@ -328,14 +329,17 @@ def test_conv2d(padds, backend, io_type): if io_type == 'io_stream': X_input = np.ascontiguousarray(X_input.transpose(0, 2, 3, 1)) - config = config_from_pytorch_model(model, channels_last_conversion="internal", transpose_outputs=False) + config = config_from_pytorch_model( + model, (n_in, size_in_height, size_in_width), channels_last_conversion="internal", transpose_outputs=False + ) else: - config = config_from_pytorch_model(model, channels_last_conversion="full", transpose_outputs=True) + config = config_from_pytorch_model( + model, (n_in, size_in_height, size_in_width), channels_last_conversion="full", transpose_outputs=True + ) output_dir = str(test_root_path / f'hls4mlprj_pytorch_api_conv2d_{padds}_{backend}_{io_type}') hls_model = convert_from_pytorch_model( model, - (None, n_in, size_in_height, size_in_width), hls_config=config, output_dir=output_dir, backend=backend, @@ -418,7 +422,9 @@ def test_conv2d(padds, backend, io_type): act_index = 2 assert list(hls_model.get_layers())[conv_index].attributes['name'] == convNode.name assert list(hls_model.get_layers())[conv_index].attributes['class_name'] == 'Conv2D' - assert list(hls_model.get_layers())[act_index].attributes['activation'] == class_object_relu.__class__.__name__ + assert ( + list(hls_model.get_layers())[act_index].attributes['activation'] == class_object_relu.__class__.__name__.lower() + ) assert list(hls_model.get_layers())[conv_index].attributes["in_width"] == size_in_width assert list(hls_model.get_layers())[conv_index].attributes["in_height"] == size_in_height assert list(hls_model.get_layers())[conv_index].attributes['filt_width'] == class_object_conv.kernel_size[1] @@ -427,10 +433,7 @@ def test_conv2d(padds, backend, io_type): assert list(hls_model.get_layers())[conv_index].attributes['n_filt'] == class_object_conv.out_channels assert list(hls_model.get_layers())[conv_index].attributes['stride_width'] == class_object_conv.stride[1] assert list(hls_model.get_layers())[conv_index].attributes['stride_height'] == class_object_conv.stride[0] - if list(hls_model.get_layers())[conv_index].attributes['padding'] == 'valid': - padding = 0 - else: - padding = 1 + padding = padds assert padding == class_object_conv.padding[0] assert list(hls_model.get_layers())[conv_index].attributes['data_format'] == 'channels_last' @@ -478,20 +481,16 @@ def test_pooling(pooling, padds, backend): size_in_height = 0 input_shape = (1, n_in, size_in_height, size_in_width) if '2d' in pooling.__name__ else (1, n_in, size_in_width) - input_shape_forHLS = ( - (None, n_in, size_in_height, size_in_width) if '2d' in pooling.__name__ else (None, n_in, size_in_width) - ) + input_shape_forHLS = (n_in, size_in_height, size_in_width) if '2d' in pooling.__name__ else (n_in, size_in_width) X_input = np.random.rand(*input_shape) model = torch.nn.Sequential(pooling(2, padding=padds)).to() model.eval() pytorch_prediction = model(torch.Tensor(X_input)).detach().numpy() - config = config_from_pytorch_model(model) + config = config_from_pytorch_model(model, input_shape_forHLS) output_dir = str(test_root_path / f'hls4mlprj_pytorch_api_pooling_{pooling.__name__}_padds_{padds}_backend_{backend}') - hls_model = convert_from_pytorch_model( - model, input_shape_forHLS, hls_config=config, output_dir=output_dir, backend=backend - ) + hls_model = convert_from_pytorch_model(model, hls_config=config, output_dir=output_dir, backend=backend) hls_model.compile() from torch.fx import symbolic_trace @@ -598,12 +597,10 @@ def test_bn(backend, io_type): pytorch_prediction = model(torch.Tensor(X_input)).detach().numpy().flatten() - config = config_from_pytorch_model(model) + config = config_from_pytorch_model(model, (5,)) output_dir = str(test_root_path / f'hls4mlprj_pytorch_api_bn_{backend}_{io_type}') - hls_model = convert_from_pytorch_model( - model, (None, 5), hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type - ) + hls_model = convert_from_pytorch_model(model, hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type) hls_model.compile() @@ -641,13 +638,11 @@ def test_squeeze(backend, io_type): pytorch_prediction = model(torch.Tensor(X_input)).detach().numpy().flatten() - config = config_from_pytorch_model(model) + config = config_from_pytorch_model(model, (5,)) del config['Model']['ChannelsLastConversion'] # We don't want anything touched for this test output_dir = str(test_root_path / f'hls4mlprj_pytorch_api_squeeze_{backend}_{io_type}') - hls_model = convert_from_pytorch_model( - model, (None, 5), hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type - ) + hls_model = convert_from_pytorch_model(model, hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type) hls_model.compile() @@ -673,11 +668,11 @@ def test_flatten(backend): input = torch.randn(1, 1, 5, 5) model = nn.Sequential(nn.Conv2d(1, 32, 5, 1, 1), nn.Flatten(), nn.ReLU()) pytorch_prediction = model(input).detach().numpy() - input_shape = (None, 1, 5, 5) + input_shape = (1, 5, 5) - config = config_from_pytorch_model(model) + config = config_from_pytorch_model(model, input_shape) output_dir = str(test_root_path / f'hls4mlprj_pytorch_api_flatten_backend_{backend}') - hls_model = convert_from_pytorch_model(model, input_shape, hls_config=config, output_dir=output_dir, backend=backend) + hls_model = convert_from_pytorch_model(model, hls_config=config, output_dir=output_dir, backend=backend) hls_model.compile() pred = hls_model.predict(input.detach().numpy()) @@ -719,14 +714,16 @@ def test_skipped_layers(backend, io_type): model.eval() input_shape = (3, 8) - batch_input_shape = (None,) + input_shape config = config_from_pytorch_model( - model, default_precision='ap_fixed<32,16>', channels_last_conversion="full", transpose_outputs=False + model, + input_shape, + default_precision='ap_fixed<32,16>', + channels_last_conversion="full", + transpose_outputs=False, ) output_dir = str(test_root_path / f'hls4mlprj_pytorch_api_skipped_{backend}_{io_type}') hls_model = convert_from_pytorch_model( model, - batch_input_shape, hls_config=config, output_dir=output_dir, io_type=io_type, @@ -782,16 +779,15 @@ def forward(self, x): input_tensor = torch.randn(10, 1, 8, 8) hls_input = np.ascontiguousarray(torch.permute(input_tensor, (0, 2, 3, 1)).detach().numpy()) - batch_input_shape = (None,) + input_shape config = config_from_pytorch_model( model, + input_shape, default_precision='ap_fixed<32,16>', channels_last_conversion="full", # Crucial for testing if the first Transpose was removed ) output_dir = str(test_root_path / f'hls4mlprj_pytorch_api_transpose_nop_{tensor_rank}d_{backend}_{io_type}') hls_model = convert_from_pytorch_model( model, - batch_input_shape, hls_config=config, output_dir=output_dir, io_type=io_type, @@ -847,12 +843,11 @@ def forward(self, x): # X_input is channels last X_input = np.ascontiguousarray(X_input.transpose(0, 2, 1)) - config = config_from_pytorch_model(model, channels_last_conversion="internal", transpose_outputs=False) + config = config_from_pytorch_model(model, (n_in, size_in), channels_last_conversion="internal", transpose_outputs=False) output_dir = str(test_root_path / f'hls4mlprj_pytorch_view_{backend}_{io_type}') hls_model = convert_from_pytorch_model( model, - (None, n_in, size_in), hls_config=config, output_dir=output_dir, backend=backend, diff --git a/test/pytest/test_qkeras.py b/test/pytest/test_qkeras.py index 12bb1940cc..3d66107c85 100644 --- a/test/pytest/test_qkeras.py +++ b/test/pytest/test_qkeras.py @@ -77,7 +77,7 @@ def convert(load_jettagging_model, strategy): ''' model = load_jettagging_model - config = hls4ml.utils.config_from_keras_model(model, granularity='name') + config = hls4ml.utils.config_from_keras_model(model, granularity='name', backend='Vivado') config['Model']['Strategy'] = strategy config['LayerName']['softmax']['exp_table_t'] = 'ap_fixed<18,8>' config['LayerName']['softmax']['inv_table_t'] = 'ap_fixed<18,4>' @@ -156,7 +156,7 @@ def test_single_dense_activation_exact(randX_100_16, bits, alpha, backend, io_ty model.add(QActivation(activation=quantized_relu(bits, 0), name='relu1')) model.compile() - config = hls4ml.utils.config_from_keras_model(model, granularity='name') + config = hls4ml.utils.config_from_keras_model(model, granularity='name', backend=backend) output_dir = str(test_root_path / f'hls4mlprj_qkeras_single_dense_activation_exact_{bits}_{alpha}_{backend}_{io_type}') hls_model = hls4ml.converters.convert_from_keras_model( model, hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type @@ -205,7 +205,7 @@ def test_quantizer_special(randX_1000_1, quantizer, backend, io_type): model.add(QActivation(input_shape=(1,), activation=quantizer, name='quantizer')) model.compile() - config = hls4ml.utils.config_from_keras_model(model, granularity='name') + config = hls4ml.utils.config_from_keras_model(model, granularity='name', backend=backend) output_dir = str( test_root_path / f'hls4mlprj_qkeras_quantizer_{quantizer.__class__.__name__}_{quantizer.bits}_{backend}_{io_type}' ) @@ -289,7 +289,7 @@ def test_quantizer(randX_1000_1, quantizer, backend, io_type): model.add(QActivation(input_shape=(1,), activation=quantizer, name='quantizer')) model.compile() - config = hls4ml.utils.config_from_keras_model(model, granularity='name') + config = hls4ml.utils.config_from_keras_model(model, granularity='name', backend=backend) output_dir = str( test_root_path / 'hls4mlprj_qkeras_quantizer_{}_{}_{}_{}_{}'.format( @@ -328,7 +328,7 @@ def test_relu_negative_slope(randX_1000_1, quantizer, backend, io_type): model.add(QActivation(input_shape=(1,), activation=quantizer, name='quantizer')) model.compile() - config = hls4ml.utils.config_from_keras_model(model, granularity='name') + config = hls4ml.utils.config_from_keras_model(model, granularity='name', backend=backend) output_dir = str( test_root_path / 'hls4mlprj_qkeras_leaky_relu_{}_{}_neg_slope_{}_{}_{}'.format( @@ -373,7 +373,7 @@ def test_qactivation_kwarg(randX_100_10, activation_quantizer, weight_quantizer) )(inputs) model = Model(inputs, outputs) - config = hls4ml.utils.config_from_keras_model(model, granularity='name') + config = hls4ml.utils.config_from_keras_model(model, granularity='name', backend='Vivado') out_dir = str(test_root_path / f'hls4mlprj_qactivation_kwarg_{activation_quantizer}') @@ -418,7 +418,9 @@ def test_quantizer_parsing(randX_100_10, backend, io_type): ) model.compile() - config = hls4ml.utils.config_from_keras_model(model, granularity='name', default_precision='fixed<24,8>') + config = hls4ml.utils.config_from_keras_model( + model, granularity='name', default_precision='fixed<24,8>', backend=backend + ) output_dir = str(test_root_path / f'hls4mlprj_qkeras_quant_parse_{backend}_{io_type}') hls_model = hls4ml.converters.convert_from_keras_model( model, hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type @@ -459,7 +461,9 @@ def test_qconv2dbn(randX_100_8_8_1, backend, io_type): ) model.compile() - config = hls4ml.utils.config_from_keras_model(model, granularity='name', default_precision='fixed<24,8>') + config = hls4ml.utils.config_from_keras_model( + model, granularity='name', default_precision='fixed<24,8>', backend=backend + ) output_dir = str(test_root_path / f'hls4mlprj_qkeras_qconv2dbn_{backend}_{io_type}') hls_model = hls4ml.converters.convert_from_keras_model( model, hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type @@ -500,7 +504,9 @@ def test_qdepthwiseconv2d(randX_10_32_32_3, backend, io_type): ) model.compile() - config = hls4ml.utils.config_from_keras_model(model, granularity='name', default_precision='fixed<24,8>') + config = hls4ml.utils.config_from_keras_model( + model, granularity='name', default_precision='fixed<24,8>', backend=backend + ) output_dir = str(test_root_path / f'hls4mlprj_qkeras_qdepthwiseconv2d_{backend}_{io_type}') hls_model = hls4ml.converters.convert_from_keras_model( model, hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type @@ -538,7 +544,7 @@ def test_quantised_po2_bit_width(backend, io_type, strategy): y_keras = keras_model.predict(X) hls_config = hls4ml.utils.config_from_keras_model( - keras_model, granularity='name', default_precision='ap_fixed<64, 32>', default_reuse_factor=1 + keras_model, granularity='name', default_precision='ap_fixed<64, 32>', default_reuse_factor=1, backend=backend ) hls_config['Model']['Strategy'] = strategy output_dir = str(test_root_path / f'hls4mlprj_qkeras_quantised_po2_{backend}_{io_type}_{strategy}') @@ -573,7 +579,9 @@ def test_qsimplernn(backend): ) model.compile() - config = hls4ml.utils.config_from_keras_model(model, granularity='name', default_precision="ap_fixed<16,1>") + config = hls4ml.utils.config_from_keras_model( + model, granularity='name', default_precision="ap_fixed<16,1>", backend=backend + ) output_dir = str(test_root_path / f'hls4mlprj_qkeras_qsimplernn_{backend}') hls_model = hls4ml.converters.convert_from_keras_model(model, hls_config=config, output_dir=output_dir, backend=backend) hls_model.compile() @@ -607,7 +615,9 @@ def test_qlstm(backend): ) model.compile() - config = hls4ml.utils.config_from_keras_model(model, granularity='name', default_precision="ap_fixed<8,1>") + config = hls4ml.utils.config_from_keras_model( + model, granularity='name', default_precision="ap_fixed<8,1>", backend=backend + ) output_dir = str(test_root_path / f'hls4mlprj_qkeras_qsimplernn_{backend}') hls_model = hls4ml.converters.convert_from_keras_model(model, hls_config=config, output_dir=output_dir, backend=backend) hls_model.compile() @@ -642,7 +652,9 @@ def test_qgru(backend): ) model.compile() - config = hls4ml.utils.config_from_keras_model(model, granularity='name', default_precision="ap_fixed<8,1>") + config = hls4ml.utils.config_from_keras_model( + model, granularity='name', default_precision="ap_fixed<8,1>", backend=backend + ) output_dir = str(test_root_path / f'hls4mlprj_qkeras_qsimplernn_{backend}') hls_model = hls4ml.converters.convert_from_keras_model(model, hls_config=config, output_dir=output_dir, backend=backend) hls_model.compile() diff --git a/test/pytest/test_recurrent_pytorch.py b/test/pytest/test_recurrent_pytorch.py index c1672c73b9..e4737ea675 100644 --- a/test/pytest/test_recurrent_pytorch.py +++ b/test/pytest/test_recurrent_pytorch.py @@ -32,12 +32,12 @@ def test_gru(backend, io_type): pytorch_prediction = model(torch.Tensor(X_input), torch.Tensor(h0)).detach().numpy() - config = config_from_pytorch_model(model, channels_last_conversion="off", transpose_outputs=False) + config = config_from_pytorch_model( + model, [(None, 1, 10), (None, 1, 20)], channels_last_conversion="off", transpose_outputs=False + ) output_dir = str(test_root_path / f'hls4mlprj_pytorch_api_gru_{backend}_{io_type}') - hls_model = convert_from_pytorch_model( - model, [(None, 1, 10), (None, 1, 20)], hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type - ) + hls_model = convert_from_pytorch_model(model, hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type) hls_model.compile() @@ -69,12 +69,13 @@ def test_lstm(backend, io_type): pytorch_prediction = model(torch.Tensor(X_input), torch.Tensor(h0), torch.tensor(c0)).detach().numpy() - config = config_from_pytorch_model(model, channels_last_conversion="off", transpose_outputs=False) + config = config_from_pytorch_model( + model, [(None, 1, 10), (None, 1, 20), (None, 1, 20)], channels_last_conversion="off", transpose_outputs=False + ) output_dir = str(test_root_path / f'hls4mlprj_pytorch_api_lstm_{backend}_{io_type}') hls_model = convert_from_pytorch_model( model, - [(None, 1, 10), (None, 1, 20), (None, 1, 20)], hls_config=config, output_dir=output_dir, backend=backend, @@ -112,11 +113,13 @@ def test_rnn(backend, io_type): pytorch_prediction = model(torch.Tensor(X_input), torch.Tensor(h0)).detach().numpy() - config = config_from_pytorch_model(model, channels_last_conversion="off", transpose_outputs=False) + config = config_from_pytorch_model( + model, [(1, 10), (1, 20)], channels_last_conversion="off", transpose_outputs=False + ) output_dir = str(test_root_path / f'hls4mlprj_pytorch_api_rnn_{backend}_{io_type}') hls_model = convert_from_pytorch_model( - model, [(None, 1, 10), (None, 1, 20)], hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type + model, hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type ) hls_model.compile() diff --git a/test/pytest/test_rnn.py b/test/pytest/test_rnn.py index 509bc5bdb8..d2303669fe 100644 --- a/test/pytest/test_rnn.py +++ b/test/pytest/test_rnn.py @@ -25,7 +25,7 @@ def test_rnn_parsing(rnn_layer, return_sequences): model = Model(model_input, model_output) model.compile(optimizer='adam', loss='mse') - config = hls4ml.utils.config_from_keras_model(model, granularity='name') + config = hls4ml.utils.config_from_keras_model(model, granularity='name', backend='Vivado') prj_name = f'hls4mlprj_rnn_{rnn_layer.__class__.__name__.lower()}_seq_{int(return_sequences)}' output_dir = str(test_root_path / prj_name) hls_model = hls4ml.converters.convert_from_keras_model(model, hls_config=config, output_dir=output_dir) diff --git a/test/pytest/test_sequential_parsing_pytorch.py b/test/pytest/test_sequential_parsing_pytorch.py index 569c6a5b1c..20b273400a 100644 --- a/test/pytest/test_sequential_parsing_pytorch.py +++ b/test/pytest/test_sequential_parsing_pytorch.py @@ -59,12 +59,10 @@ def test_unnamed_seq(backend, io_type, named_layers): model = seq_named else: model = seq_unnamed - config = config_from_pytorch_model(model) + config = config_from_pytorch_model(model, (1, 5, 5)) output_dir = str(test_root_path / f'hls4mlprj_pytorch_seq_unnamed_{backend}_{io_type}_{named_layers}') - convert_from_pytorch_model( - model, (None, 1, 5, 5), hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type - ) + convert_from_pytorch_model(model, hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type) @pytest.mark.parametrize('backend', ['Vivado']) @@ -75,9 +73,7 @@ def test_named_seq(backend, io_type, named_layers): model = SeqModelNamedLayers() else: model = SeqModelUnnamedLayers() - config = config_from_pytorch_model(model) + config = config_from_pytorch_model(model, (1, 5, 5)) output_dir = str(test_root_path / f'hls4mlprj_pytorch_seq_named_{backend}_{io_type}_{named_layers}') - convert_from_pytorch_model( - model, (None, 1, 5, 5), hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type - ) + convert_from_pytorch_model(model, hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type) diff --git a/test/pytest/test_softmax.py b/test/pytest/test_softmax.py index 19c9042465..048b6832ee 100644 --- a/test/pytest/test_softmax.py +++ b/test/pytest/test_softmax.py @@ -41,7 +41,7 @@ def test_softmax(backend, strategy, generate_data, input_bits, input_shape, tabl table_type = f'fixed<{table_bits}, RND, SAT>' - cfg = hls4ml.utils.config_from_keras_model(model, granularity='name') + cfg = hls4ml.utils.config_from_keras_model(model, granularity='name', backend=backend) cfg['LayerName']['softmax']['Strategy'] = strategy cfg['LayerName']['softmax']['inv_table_t'] = table_type cfg['LayerName']['softmax']['exp_table_t'] = table_type @@ -74,7 +74,7 @@ def test_softmax_skipped(backend, io_type): model = tf.keras.models.Sequential([dense, softmax]) model.compile() - cfg = hls4ml.utils.config_from_keras_model(model, granularity='name') + cfg = hls4ml.utils.config_from_keras_model(model, granularity='name', backend=backend) cfg['LayerName']['softmax']['skip'] = True odir = str(test_root_path / 'hls4mlprj_softmax_skipped_{}_{}').format(backend, io_type) diff --git a/test/pytest/test_softsign.py b/test/pytest/test_softsign.py index 31a2a1c2cf..f0089438a4 100644 --- a/test/pytest/test_softsign.py +++ b/test/pytest/test_softsign.py @@ -19,7 +19,7 @@ def test_softsign(backend, input_shape, io_type): model.add(tf.keras.layers.Activation(input_shape=input_shape, activation='softsign', name='softsign')) model.compile() - cfg = hls4ml.utils.config_from_keras_model(model, granularity='name', default_precision='fixed<20,4>') + cfg = hls4ml.utils.config_from_keras_model(model, granularity='name', default_precision='fixed<20,4>', backend=backend) # Since softsign implementation is lookup-based increasing the precision and size of the table helps with accuracy cfg['LayerName']['softsign']['table_t'] = 'fixed<20,4>' cfg['LayerName']['softsign']['table_size'] = 2048 diff --git a/test/pytest/test_trace.py b/test/pytest/test_trace.py index 14e218fd1c..b01cfcd010 100644 --- a/test/pytest/test_trace.py +++ b/test/pytest/test_trace.py @@ -39,11 +39,11 @@ def test_trace(backend, activation): keras_prediction = model.predict(X_input) - config = hls4ml.utils.config_from_keras_model(model, granularity='name') + config = hls4ml.utils.config_from_keras_model(model, granularity='name', backend=backend) for layer in config['LayerName'].keys(): config['LayerName'][layer]['Trace'] = True - output_dir = str(test_root_path / f'hls4mlprj_trace_{backend}') + output_dir = str(test_root_path / f'hls4mlprj_trace_{backend}_{activation}') hls_model = hls4ml.converters.convert_from_keras_model(model, hls_config=config, output_dir=output_dir, backend=backend) diff --git a/test/pytest/test_transpose_concat.py b/test/pytest/test_transpose_concat.py index 9cfd1f288c..884d5859d5 100644 --- a/test/pytest/test_transpose_concat.py +++ b/test/pytest/test_transpose_concat.py @@ -32,7 +32,7 @@ def keras_model(): @pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus', 'oneAPI']) def hls_model(keras_model, backend, io_type): hls_config = hls4ml.utils.config_from_keras_model( - keras_model, default_precision='ap_fixed<16,3,AP_RND_CONV,AP_SAT>', granularity='name' + keras_model, default_precision='ap_fixed<16,3,AP_RND_CONV,AP_SAT>', granularity='name', backend=backend ) hls_config['LayerName']['relu']['Precision'] = 'ap_ufixed<17,3>' output_dir = str(test_root_path / f'hls4mlprj_transpose_{backend}_{io_type}') diff --git a/test/pytest/test_upsampling.py b/test/pytest/test_upsampling.py index cfddd47a50..dc29a53c27 100644 --- a/test/pytest/test_upsampling.py +++ b/test/pytest/test_upsampling.py @@ -56,7 +56,9 @@ def test_upsampling(keras_model_1d, keras_model_2d, data_1d, data_2d, model_type model = keras_model_2d data = data_2d - config = hls4ml.utils.config_from_keras_model(model, default_precision='ap_fixed<32,1>', granularity='name') + config = hls4ml.utils.config_from_keras_model( + model, default_precision='ap_fixed<32,1>', granularity='name', backend=backend + ) odir = str(test_root_path / f'hls4mlprj_upsampling_{model_type}_{backend}_{io_type}') hls_model = hls4ml.converters.convert_from_keras_model( model, hls_config=config, io_type=io_type, output_dir=odir, backend=backend diff --git a/test/pytest/test_upsampling_pytorch.py b/test/pytest/test_upsampling_pytorch.py index e881c39bbf..6e0d8f78ad 100644 --- a/test/pytest/test_upsampling_pytorch.py +++ b/test/pytest/test_upsampling_pytorch.py @@ -55,13 +55,14 @@ def test_pytorch_upsampling1d(data_1d, io_type, backend): config = hls4ml.utils.config_from_pytorch_model( model, + (None, in_feat, in_width), default_precision='ap_fixed<16,6>', channels_last_conversion="internal", transpose_outputs=False, ) odir = str(test_root_path / f'hls4mlprj_pytorch_upsampling_1d_{backend}_{io_type}') hls_model = hls4ml.converters.convert_from_pytorch_model( - model, (None, in_feat, in_width), hls_config=config, io_type=io_type, output_dir=odir, backend=backend + model, hls_config=config, io_type=io_type, output_dir=odir, backend=backend ) hls_model.compile() @@ -84,13 +85,14 @@ def test_pytorch_upsampling2d(data_2d, io_type, backend): config = hls4ml.utils.config_from_pytorch_model( model, + (in_feat, in_height, in_width), default_precision='ap_fixed<16,6>', channels_last_conversion="full", # With conversion to channels_last transpose_outputs=True, ) odir = str(test_root_path / f'hls4mlprj_pytorch_upsampling_2d_{backend}_{io_type}') hls_model = hls4ml.converters.convert_from_pytorch_model( - model, (None, in_feat, in_height, in_width), hls_config=config, io_type=io_type, output_dir=odir, backend=backend + model, hls_config=config, io_type=io_type, output_dir=odir, backend=backend ) hls_model.compile() diff --git a/test/pytest/test_zeropadding.py b/test/pytest/test_zeropadding.py index fbac7b0997..4e65acef2f 100644 --- a/test/pytest/test_zeropadding.py +++ b/test/pytest/test_zeropadding.py @@ -60,7 +60,9 @@ def test_zeropadding(keras_model_1d, keras_model_2d, data_1d, data_2d, model_typ model = keras_model_2d data = data_2d - config = hls4ml.utils.config_from_keras_model(model, default_precision='ap_fixed<32,1>', granularity='name') + config = hls4ml.utils.config_from_keras_model( + model, default_precision='ap_fixed<32,1>', granularity='name', backend=backend + ) odir = str(test_root_path / f'hls4mlprj_zeropadding_{model_type}_{backend}_{io_type}') hls_model = hls4ml.converters.convert_from_keras_model( model, hls_config=config, io_type=io_type, output_dir=odir, backend=backend