diff --git a/hls4ml/backends/catapult/catapult_backend.py b/hls4ml/backends/catapult/catapult_backend.py index 5c85bf9b7..6d1c17a3c 100644 --- a/hls4ml/backends/catapult/catapult_backend.py +++ b/hls4ml/backends/catapult/catapult_backend.py @@ -88,6 +88,7 @@ def _register_flows(self): init_flow = register_flow('init_layers', initializers, requires=['optimize'], backend=self.name) streaming_passes = [ + 'catapult:inplace_stream_flatten', # Inform downstream changed packsize in case of skipping flatten 'catapult:reshape_stream', 'catapult:clone_output', 'catapult:insert_zero_padding_before_conv1d', diff --git a/hls4ml/backends/fpga/passes/clone.py b/hls4ml/backends/fpga/passes/clone.py index 0c1f7f2e0..834e78e45 100644 --- a/hls4ml/backends/fpga/passes/clone.py +++ b/hls4ml/backends/fpga/passes/clone.py @@ -1,4 +1,4 @@ -import numpy as np +from math import prod from hls4ml.backends.template import FunctionCallTemplate from hls4ml.model.layers import Layer, register_layer @@ -54,41 +54,61 @@ def match(self, node): if isinstance(node, Clone): return False - return True + # Not needed for io_parallel + io_type = node.model.config.get_config_value('IOType') + if io_type != 'io_stream': + return False + + # Check if the output is used more than once + output_map = node.get_output_use_map() + in_output = node.name in node.model.outputs + for output in node.outputs: + if len(output_map[output]) + in_output > 1: + # model output also need a stream + return True + + return False def transform(self, model, node): - if model.config.get_config_value('IOType') != 'io_stream': - return False output_map = node.get_output_use_map() + in_output = node.name in node.model.outputs transformed = False for output in node.outputs: - if len(output_map[output]) > 1: - if len(output_map[output]) > 3: - print( - 'WARNING: Cloning output {} of {} ({}) more than 3 times not currently supported'.format( - output, node.__class__.__name__, node.name - ) - ) - return False - out_var = node.get_output_variable(output) - for i, layer in enumerate(output_map[output], 1): - attrs = {'size': np.prod(out_var.shape)} - idx = layer.inputs.index(output) - layer.inputs[idx] = output + '_cpy' + str(i) - - clone_layer: Clone = model.make_node( - Clone, - 'clone_' + node.name, - attrs, - [output], - [output + '_cpy' + str(i + 1) for i in range(len(output_map[output]))], - ) - for i in range(len(output_map[output])): - key = output + '_cpy' + str(i + 1) - clone_layer.attributes[key].type = node.get_output_variable().type - model.insert_node(clone_layer) - transformed = True + + n_outputs = len(output_map[output]) + in_output + if n_outputs == 1: + continue + if n_outputs > 3: + msg = f'ERROR: Cloning output {output} of {node.class_name}\ + ({node.name}) more than 3 times not currently supported' + raise ValueError(msg) + + out_var = node.get_output_variable(output) + attrs = {'size': prod(out_var.shape)} + + init_stream_idx = 1 + if in_output: + # If the value is used as output, add one extra stream + idx = node.model.outputs.index(node.name) + node.model.outputs[idx] = node.name + '_cpy1' + init_stream_idx = 2 + for i, layer in enumerate(output_map[output], init_stream_idx): + idx = layer.inputs.index(output) + layer.inputs[idx] = output + f'_cpy{i}' + + clone_layer: Clone = model.make_node( + Clone, + 'clone_' + node.name, + attrs, + [output], + [output + '_cpy' + str(i + 1) for i in range(n_outputs)], + ) + for i in range(n_outputs): + key = output + '_cpy' + str(i + 1) + clone_layer.attributes[key].type = node.attributes['result_t'] + model.insert_node(clone_layer) + transformed = True return transformed diff --git a/hls4ml/backends/fpga/passes/inplace_parallel_reshape.py b/hls4ml/backends/fpga/passes/inplace_parallel_reshape.py index 532becc9d..82efe6710 100644 --- a/hls4ml/backends/fpga/passes/inplace_parallel_reshape.py +++ b/hls4ml/backends/fpga/passes/inplace_parallel_reshape.py @@ -11,14 +11,21 @@ class InplaceParallelReshape(OptimizerPass): """ def match(self, node): - return isinstance(node, Reshape) - - def transform(self, model, node): - if model.config.get_config_value('IOType') != 'io_parallel': + if not isinstance(node, Reshape): return False + return node.model.config.get_config_value('IOType') == 'io_parallel' + def transform(self, model, node): outvar = node.get_output_variable() invar = node.get_input_variable() newoutvar = InplaceTensorVariable(outvar, invar) node.set_attr(node.outputs[0], newoutvar) + if node.name in model.outputs: + prev_node = node.get_input_node() + assert ( + prev_node.name not in model.outputs + ), f"Cannot output node {prev_node.name}: reshape is a no-op in io_parallel.\ + As a result, the previous node {prev_node.name}'s output will be used as the\ + output. However, this node is already an output." + model.outputs = [name if name != node.name else prev_node.name for name in model.outputs] return False diff --git a/hls4ml/backends/fpga/passes/inplace_stream_flatten.py b/hls4ml/backends/fpga/passes/inplace_stream_flatten.py index a16ffefc4..be4994e96 100644 --- a/hls4ml/backends/fpga/passes/inplace_stream_flatten.py +++ b/hls4ml/backends/fpga/passes/inplace_stream_flatten.py @@ -11,13 +11,20 @@ class InplaceStreamFlatten(OptimizerPass): """ def match(self, node): - # Reshape acts as a Flatten layer when the result has 1 dimension - return isinstance(node, Reshape) and len(node.get_output_variable().shape) == 1 + # Layers require flatten data can gather it from the stream, no need for repacking. + # Reshape acts as a Flatten layer when the result has 1 dimension. Make it a inplace tensor if it happens. - def transform(self, model, node): - if model.config.get_config_value('IOType') != 'io_stream': + if node.model.config.get_config_value('IOType') != 'io_stream': + return False + if not (isinstance(node, Reshape) and len(node.get_output_variable().shape) == 1): + # If is not flatten return False + if node.name in node.model.outputs: + # If used as model output. Output shape shall be preserved in this case. + return False + return True + def transform(self, model, node): outvar = node.get_output_variable() invar = node.get_input_variable() newoutvar = InplaceTensorVariable(outvar, invar) diff --git a/hls4ml/backends/fpga/passes/repack_stream.py b/hls4ml/backends/fpga/passes/repack_stream.py index 2408ec5eb..9a77dddb2 100644 --- a/hls4ml/backends/fpga/passes/repack_stream.py +++ b/hls4ml/backends/fpga/passes/repack_stream.py @@ -49,7 +49,9 @@ class ReshapeStream(OptimizerPass): def match(self, node): # do not run optimizer pass for a flatten layer (1 output dimension) - return isinstance(node, Reshape) and len(node.get_output_variable().shape) > 1 + if not isinstance(node, Reshape): + return False + return len(node.get_output_variable().shape) > 1 or node.name in node.model.outputs def transform(self, model, node): if model.config.get_config_value('IOType') != 'io_stream': diff --git a/hls4ml/backends/quartus/quartus_backend.py b/hls4ml/backends/quartus/quartus_backend.py index aecad642c..683d3f77b 100644 --- a/hls4ml/backends/quartus/quartus_backend.py +++ b/hls4ml/backends/quartus/quartus_backend.py @@ -48,7 +48,12 @@ def _register_flows(self): initializers = self._get_layer_initializers() init_flow = register_flow('init_layers', initializers, requires=['optimize'], backend=self.name) - streaming_passes = ['quartus:reshape_stream', 'quartus:clone_output'] + streaming_passes = [ + 'quartus:inplace_stream_flatten', # Inform downstream changed packsize in case of skipping flatten + 'quartus:reshape_stream', + 'quartus:clone_output', + ] + streaming_flow = register_flow('streaming', streaming_passes, requires=[init_flow], backend=self.name) quartus_types = [ diff --git a/hls4ml/backends/vivado/vivado_backend.py b/hls4ml/backends/vivado/vivado_backend.py index 9f8a5171d..e88af278f 100644 --- a/hls4ml/backends/vivado/vivado_backend.py +++ b/hls4ml/backends/vivado/vivado_backend.py @@ -79,6 +79,7 @@ def _register_flows(self): init_flow = register_flow('init_layers', initializers, requires=['optimize'], backend=self.name) streaming_passes = [ + 'vivado:inplace_stream_flatten', # Inform downstream changed packsize in case of skipping flatten 'vivado:reshape_stream', 'vivado:clone_output', 'vivado:insert_zero_padding_before_conv1d', diff --git a/hls4ml/model/graph.py b/hls4ml/model/graph.py index cf715fd76..520f96ba5 100644 --- a/hls4ml/model/graph.py +++ b/hls4ml/model/graph.py @@ -506,6 +506,8 @@ def insert_node(self, node, before=None, input_idx=0): if next_node is not None: next_node.inputs[input_idx] = node.outputs[0] + else: + self.outputs = [node.outputs[0] if name == prev_node.outputs[0] else name for name in self.outputs] new_graph = OrderedDict() for k, v in self.graph.items(): @@ -514,47 +516,57 @@ def insert_node(self, node, before=None, input_idx=0): new_graph[node.name] = node self.graph = new_graph - self._update_model_outputs() def remove_node(self, node, rewire=True): - """Remove a node from a graph. + """Removes a node from the graph. - By default, this function can connect the outputs of previous node to the input of next one. - Note that when removing a leaf node `rewire` should be set to `False`. + By default, this function connects the outputs of the previous + node to the inputs of the next node. If the removed node has multiple + input/output tensors, an exception is raised. Args: - node (Layer): The node to remove - rewire (bool, optional): If `True`, connects the outputs of the previous node - to the inputs of the next node + node (Layer): The node to remove. + rewire (bool, optional): Deprecated, has no effect. Raises: - Exception: If an attempt is made to rewire a leaf node or a node with multiple - inputs/outputs. + Exception: If an attempt is made to rewire a node with + multiple inputs/outputs. + Note: + The `rewire` parameter is deprecated and has no effect. """ - if rewire: - inputs = [inp for inp in node.inputs if inp] - outputs = [outp for outp in node.outputs if outp] - if len(inputs) > 1 or len(outputs) > 1: - raise Exception('Cannot rewire a node with multiple inputs/outputs') - prev_node = node.get_input_node(node.inputs[0]) + + inputs = [inp for inp in node.inputs if inp] + outputs = [outp for outp in node.outputs if outp] + + if len(inputs) > 1 or len(outputs) > 1: + raise Exception('Cannot delete a node with multiple inputs/outputs') + + if len(inputs) == 1: + # Connect inputs -> $outputs + if node.name in self.outputs: + msg = f'Remove leaf node {node.name} will connect its input node {inputs[0]} to output, but it already is.' + assert inputs[0] not in self.outputs, msg + self.outputs = [inputs[0] if name == node.name else name for name in self.outputs] + + if len(outputs) == 1 and len(inputs) == 1: + inp_var = node.get_input_variable() + out_var = node.get_output_variable() + + # fmt: off + assert (np.prod(inp_var.shape) == np.prod(out_var.shape)), \ + f'Input and output shapes do not match for {node.name}: {inp_var.shape} -> {out_var.shape}' + # fmt: on + next_nodes = [x for x in self.graph.values() if node.outputs[0] in x.inputs] - if prev_node is not None: - if len(next_nodes) > 0: - for next_node in next_nodes: - for i, _ in enumerate(next_node.inputs): - if node.outputs[0] == next_node.inputs[i]: - next_node.inputs[i] = prev_node.outputs[0] - break - else: - if not node.outputs[0] in self.outputs: - raise Exception('Cannot rewire a node without child') - else: - raise Exception('Cannot rewire a node without a parent') + for next_node in next_nodes: + # Connect inputs -> next + for i, nxt_inp in enumerate(next_node.inputs): + if outputs[0] == nxt_inp: + next_node.inputs[i] = inputs[0] del self.output_vars[node.outputs[0]] del self.graph[node.name] - self._update_model_outputs() def replace_node(self, old_node, new_node): """Replace an existing node in the graph with a new one. @@ -584,7 +596,11 @@ def replace_node(self, old_node, new_node): node.outputs[i] = repl[n] self.graph = OrderedDict((new_node.name, new_node) if k == old_node.name else (k, v) for k, v in self.graph.items()) - self._update_model_outputs() + + old_name = old_node.name + if old_name in self.outputs: + new_name = new_node.name + self.outputs = [new_name if name == old_name else name for name in self.outputs] def split_node(self, old_node, new_node1, new_node2): """Replace an existing node in the graph with two nodes in sequence. @@ -622,17 +638,9 @@ def split_node(self, old_node, new_node1, new_node2): else: new_graph[key] = value self.graph = new_graph - self._update_model_outputs() - - def _update_model_outputs(self): - '''Update the model outputs - All node outputs and inputs are found. The model outputs are set to all node outputs - that are not also node inputs. - ''' - node_outputs = [out for node in self.graph.values() for out in node.outputs] - node_inputs = [inp for node in self.graph.values() for inp in node.inputs] - self.outputs = [out for out in node_outputs if out not in node_inputs] + if old_node.name in self.outputs: + self.outputs = [new_node2.name if name == old_node.name else name for name in self.outputs] def next_layer(self): self.index += 1 diff --git a/hls4ml/templates/catapult/nnet_utils/nnet_stream.h b/hls4ml/templates/catapult/nnet_utils/nnet_stream.h index c76bfba5a..ec2e9bfb1 100644 --- a/hls4ml/templates/catapult/nnet_utils/nnet_stream.h +++ b/hls4ml/templates/catapult/nnet_utils/nnet_stream.h @@ -41,6 +41,26 @@ void clone_stream(ac_channel &data, ac_channel &res1, ac_channel< } } +template +void clone_stream(ac_channel &data, ac_channel &res1, ac_channel &res2, ac_channel &res3) { +#ifndef __SYNTHESIS__ + while (data.available(1)) +#endif + { + data_T in_data = data.read(); + res_T out_data; + + ClonePack: + for (int j = 0; j < data_T::size; j++) { + out_data[j] = in_data[j]; + } + + res1.write(out_data); + res2.write(out_data); + res3.write(out_data); + } +} + template void repack_stream(ac_channel &data, ac_channel &res) { if (data_T::size == res_T::size) { for (int i = 0; i < N / data_T::size; i++) { diff --git a/hls4ml/utils/config.py b/hls4ml/utils/config.py index 303bb9c18..1a25fb9c3 100644 --- a/hls4ml/utils/config.py +++ b/hls4ml/utils/config.py @@ -284,6 +284,7 @@ def config_from_pytorch_model( default_reuse_factor=1, channels_last_conversion='full', transpose_outputs=True, + max_precision=None, ): """Create an HLS conversion config given the PyTorch model. @@ -304,7 +305,8 @@ def config_from_pytorch_model( will generate config keys for every layer separately, allowing for highly specific configuration tweaks. backend(str, optional): Name of the backend to use - default_precision (str, optional): Default precision to use. Defaults to 'fixed<16,6>'. + default_precision (str, optional): Default precision to use. Defaults to 'fixed<16,6>'. Note, this must + be an explicit precision: 'auto' is not allowed. default_reuse_factor (int, optional): Default reuse factor. Defaults to 1. channels_last_conversion (string, optional): Configures the conversion of pytorch layers to 'channels_last' dataformate. Can be set to 'full', 'internal', or 'off'. If 'full', both the inputs @@ -313,6 +315,8 @@ def config_from_pytorch_model( transpose_outputs (bool, optional): Set to 'False' if the output should not be transposed from channels_last into channels_first data format. Defaults to 'False'. If False, outputs needs to be transposed manually. + max_precision (str or None, optional): Maximum width precision to use. Defaults to None, meaning no maximum. + Note: Only integer and fixed precisions are supported Raises: Exception: If PyTorch model has layers not supported by hls4ml. @@ -324,11 +328,16 @@ def config_from_pytorch_model( config = {} model_config = {} - model_config['Precision'] = default_precision + model_config['Precision'] = {} + model_config['Precision']['default'] = default_precision + if max_precision is not None: + model_config['Precision']['maximum'] = max_precision model_config['ReuseFactor'] = default_reuse_factor model_config['ChannelsLastConversion'] = channels_last_conversion model_config['TransposeOutputs'] = transpose_outputs model_config['Strategy'] = 'Latency' + model_config['BramFactor'] = 1_000_000_000 + model_config['TraceOutput'] = False config['Model'] = model_config config['PytorchModel'] = model @@ -372,7 +381,7 @@ def make_layer_config(layer): if name.endswith('_t'): name = name[:-2] if attr.default is None: - precision_cfg[name] = default_precision + precision_cfg[name] = 'auto' else: precision_cfg[name] = str(attr.default) elif attr.name == 'reuse_factor': diff --git a/test/pytest/test_multiout_network.py b/test/pytest/test_multiout_network.py index 15e23ff79..366fac7fb 100644 --- a/test/pytest/test_multiout_network.py +++ b/test/pytest/test_multiout_network.py @@ -19,6 +19,21 @@ def model(): return model +@pytest.fixture(scope='module') +def model_corner_cases(): + in1 = keras.layers.Input(shape=(24, 8)) + in2 = keras.layers.Input(shape=(16)) + out1 = keras.layers.Conv1D(1, 3)(in1) + out1 = keras.layers.Flatten()(out1) + out2 = keras.layers.Dense(16, activation='relu')(out1) + out2 = keras.layers.Add()([out2, in2]) + out3 = keras.layers.Dense(2)(out1) + out4 = keras.layers.Dense(2)(out2) + out4 = keras.layers.Flatten()(out4) + model = keras.models.Model(inputs=[in1, in2], outputs=[out1, out2, out3, out4]) + return model + + @pytest.fixture(scope='module') def data(): X = np.random.normal(0, 1, (1000, 10)) @@ -26,18 +41,20 @@ def data(): return X +@pytest.fixture(scope='module') +def data_corner_cases(): + X1 = np.random.normal(0, 1, (1000, 24, 8)) + X2 = np.random.normal(0, 1, (1000, 16)) + X1 = np.clip(X1, -16, 15) + X2 = np.clip(X2, -16, 15) + return X1, X2 + + @pytest.mark.parametrize('backend', ['Vivado', 'Quartus', 'Vitis']) @pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream']) -def test_multi_clone(model, data, backend: str, io_type: str): +def test_multi_output_nn(model, data, backend: str, io_type: str): output_dir = str(test_root_path / f'hls4mlprj_multiout_network_{backend}_{io_type}') hls_config = {'Model': {'Precision': 'fixed<32,5>', 'ReuseFactor': 1}} - layer_config = { - 'dense1': {'Precision': {'result': 'fixed<35,5>'}}, - 'dense2': {'Precision': {'result': 'fixed<40,5>'}}, - 'dense1_linear': {'Precision': {'result': 'fixed<35,5>'}}, - 'dense2_linear': {'Precision': {'result': 'fixed<40,5>'}}, - } - hls_config['LayerName'] = layer_config model_hls = convert_from_keras_model( model, backend=backend, output_dir=output_dir, hls_config=hls_config, io_type=io_type ) @@ -50,3 +67,32 @@ def test_multi_clone(model, data, backend: str, io_type: str): assert np.allclose(r_hls[0], r_keras[0], atol=1e-5, rtol=0) assert np.allclose(r_hls[1], r_keras[1], atol=1e-5, rtol=0) + + +@pytest.mark.parametrize('backend', ['Vivado', 'Quartus', 'Vitis', 'Catapult', 'OneAPI']) +@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream']) +@pytest.mark.parametrize('strategy', ['latency', 'resource']) +def test_multi_output_nn_corner_cases(model_corner_cases, data_corner_cases, backend: str, io_type: str, strategy: str): + """Cover corner cases, when: + - a layer outputs both to the next layer(s) and to the model output + - when an node removal/insertion is triggered internally + - a reshape in io_parallel, or flatten in io_stream layer's output is used multiple times + - and as layer output + - and by layer taking multiple inputs + - a Flatten layer outputs to the model output in io_stream + """ + output_dir = str(test_root_path / f'hls4mlprj_multiout_network_2_{backend}_{io_type}_{strategy}') + hls_config = {'Model': {'Precision': 'fixed<32,5>', 'ReuseFactor': 1}, 'Strategy': strategy} + + model_hls = convert_from_keras_model( + model_corner_cases, backend=backend, output_dir=output_dir, hls_config=hls_config, io_type=io_type + ) + + model_hls.compile() + r_hls = model_hls.predict(data_corner_cases) + r_keras = model_corner_cases.predict(data_corner_cases, verbose=0, batch_size=1000) + + assert np.allclose(r_hls[0], r_keras[0], atol=1e-5, rtol=0) + assert np.allclose(r_hls[1], r_keras[1], atol=1e-5, rtol=0) + assert np.allclose(r_hls[2], r_keras[2], atol=1e-5, rtol=0) + assert np.allclose(r_hls[3], r_keras[3], atol=1e-5, rtol=0)