From 54657f9f8a0da5a5128eb500bf6a32e251cafb24 Mon Sep 17 00:00:00 2001 From: Jan-Frederik Schulte Date: Wed, 6 Nov 2024 09:24:51 -0500 Subject: [PATCH 01/13] make auto default precision for pytorch parser --- hls4ml/utils/config.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/hls4ml/utils/config.py b/hls4ml/utils/config.py index 1bd9ff25ef..14b8cf0b83 100644 --- a/hls4ml/utils/config.py +++ b/hls4ml/utils/config.py @@ -284,6 +284,7 @@ def config_from_pytorch_model( default_reuse_factor=1, channels_last_conversion='full', transpose_outputs=True, + max_precision=None, ): """Create an HLS conversion config given the PyTorch model. @@ -304,7 +305,8 @@ def config_from_pytorch_model( will generate config keys for every layer separately, allowing for highly specific configuration tweaks. backend(str, optional): Name of the backend to use - default_precision (str, optional): Default precision to use. Defaults to 'fixed<16,6>'. + default_precision (str, optional): Default precision to use. Defaults to 'fixed<16,6>'. Note, this must + be an explicit precision: 'auto' is not allowed. default_reuse_factor (int, optional): Default reuse factor. Defaults to 1. channels_last_conversion (string, optional): Configures the conversion of pytorch layers to 'channels_last' dataformate. Can be set to 'full', 'internal', or 'off'. If 'full', both the inputs @@ -313,6 +315,8 @@ def config_from_pytorch_model( transpose_outputs (bool, optional): Set to 'False' if the output should not be transposed from channels_last into channels_first data format. Defaults to 'False'. If False, outputs needs to be transposed manually. + max_precision (str or None, optional): Maximum width precision to use. Defaults to None, meaning no maximum. + Note: Only integer and fixed precisions are supported Raises: Exception: If PyTorch model has layers not supported by hls4ml. @@ -324,7 +328,10 @@ def config_from_pytorch_model( config = {} model_config = {} - model_config['Precision'] = default_precision + model_config['Precision'] = {} + model_config['Precision']['default'] = default_precision + if max_precision is not None: + model_config['Precision']['maximum'] = max_precision model_config['ReuseFactor'] = default_reuse_factor model_config['ChannelsLastConversion'] = channels_last_conversion model_config['TransposeOutputs'] = transpose_outputs @@ -372,7 +379,7 @@ def make_layer_config(layer): if name.endswith('_t'): name = name[:-2] if attr.default is None: - precision_cfg[name] = default_precision + precision_cfg[name] = 'auto' else: precision_cfg[name] = str(attr.default) elif attr.name == 'reuse_factor': From 01d4f793b4b62a09e8b930753063b12ea3cebe4d Mon Sep 17 00:00:00 2001 From: Jan-Frederik Schulte Date: Mon, 11 Nov 2024 12:40:03 -0500 Subject: [PATCH 02/13] more default settings suggested by Jovan --- hls4ml/utils/config.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/hls4ml/utils/config.py b/hls4ml/utils/config.py index 14b8cf0b83..6938cf180a 100644 --- a/hls4ml/utils/config.py +++ b/hls4ml/utils/config.py @@ -336,6 +336,8 @@ def config_from_pytorch_model( model_config['ChannelsLastConversion'] = channels_last_conversion model_config['TransposeOutputs'] = transpose_outputs model_config['Strategy'] = 'Latency' + model_config['BramFactor'] = 1_000_000_000 + model_config['TraceOutput'] = False config['Model'] = model_config config['PytorchModel'] = model From 677c738bfba40a700b78efbae2a47fa2838ea457 Mon Sep 17 00:00:00 2001 From: Chang Sun Date: Tue, 29 Oct 2024 21:52:50 +0000 Subject: [PATCH 03/13] multi output and flatten@streaming fix --- hls4ml/backends/fpga/passes/clone.py | 76 ++++++++++++------- .../fpga/passes/inplace_parallel_reshape.py | 6 ++ .../fpga/passes/inplace_stream_flatten.py | 15 +++- hls4ml/model/graph.py | 73 +++++++++--------- test/pytest/test_multiout_network.py | 52 +++++++++++-- 5 files changed, 143 insertions(+), 79 deletions(-) diff --git a/hls4ml/backends/fpga/passes/clone.py b/hls4ml/backends/fpga/passes/clone.py index 306e839900..0e1ee363c8 100644 --- a/hls4ml/backends/fpga/passes/clone.py +++ b/hls4ml/backends/fpga/passes/clone.py @@ -54,41 +54,59 @@ def match(self, node): if isinstance(node, Clone): return False - return True + # Not needed for io_parallel + io_type = node.model.config.get_config_value('IOType') + if io_type != 'io_stream': + return False + + # Check if the output is used more than once + output_map = node.get_output_use_map() + in_output = node.name in node.model.outputs + for output in node.outputs: + if len(output_map[output]) + in_output > 1: + # model output also need a stream + return True + + return False def transform(self, model, node): - if model.config.get_config_value('IOType') != 'io_stream': - return False output_map = node.get_output_use_map() + in_output = node.name in node.model.outputs transformed = False for output in node.outputs: - if len(output_map[output]) > 1: - if len(output_map[output]) > 3: - print( - 'WARNING: Cloning output {} of {} ({}) more than 3 times not currently supported'.format( - output, node.__class__.__name__, node.name - ) - ) - return False - out_var = node.get_output_variable(output) - for i, layer in enumerate(output_map[output], 1): - attrs = {'size': np.prod(out_var.shape)} - idx = layer.inputs.index(output) - layer.inputs[idx] = output + '_cpy' + str(i) - - clone_layer: Clone = model.make_node( - Clone, - 'clone_' + node.name, - attrs, - [output], - [output + '_cpy' + str(i + 1) for i in range(len(output_map[output]))], - ) - for i in range(len(output_map[output])): - key = output + '_cpy' + str(i + 1) - clone_layer.attributes[key].type = node.attributes['result_t'] - model.insert_node(clone_layer) - transformed = True + n_outputs = len(output_map[output]) + in_output + if n_outputs == 1: + continue + if n_outputs > 3: + msg = f'ERROR: Cloning output {output} of {node.__class__.__name__} ({node.name}) more than 3 times not currently supported' # noqa: E501 + raise ValueError(msg) + + out_var = node.get_output_variable(output) + attrs = {'size': np.prod(out_var.shape)} + + i0 = 1 + if in_output: + # If the value is used as output, add one extra stream + idx = node.model.outputs.index(node.name) + node.model.outputs[idx] = node.name + '_cpy1' + i0 = 2 + for i, layer in enumerate(output_map[output], i0): + idx = layer.inputs.index(output) + layer.inputs[idx] = output + f'_cpy{i}' + + clone_layer: Clone = model.make_node( + Clone, + 'clone_' + node.name, + attrs, + [output], + [output + '_cpy' + str(i + 1) for i in range(n_outputs)], + ) + for i in range(n_outputs): + key = output + '_cpy' + str(i + 1) + clone_layer.attributes[key].type = node.attributes['result_t'] + model.insert_node(clone_layer) + transformed = True return transformed diff --git a/hls4ml/backends/fpga/passes/inplace_parallel_reshape.py b/hls4ml/backends/fpga/passes/inplace_parallel_reshape.py index 532becc9db..e0580e946e 100644 --- a/hls4ml/backends/fpga/passes/inplace_parallel_reshape.py +++ b/hls4ml/backends/fpga/passes/inplace_parallel_reshape.py @@ -21,4 +21,10 @@ def transform(self, model, node): invar = node.get_input_variable() newoutvar = InplaceTensorVariable(outvar, invar) node.set_attr(node.outputs[0], newoutvar) + if node.name in model.outputs: + prev_node = node.get_input_node() + assert ( + prev_node.name not in model.outputs + ), f"Cannot output node {prev_node.name}: reshape is a no-op in io_parallel. As a result, the previous node {prev_node.name}'s output will be used as the output. However, this node is already an output." # noqa: E501 + model.outputs = [name if name != node.name else prev_node.name for name in model.outputs] return False diff --git a/hls4ml/backends/fpga/passes/inplace_stream_flatten.py b/hls4ml/backends/fpga/passes/inplace_stream_flatten.py index a16ffefc4a..ed54ad9ace 100644 --- a/hls4ml/backends/fpga/passes/inplace_stream_flatten.py +++ b/hls4ml/backends/fpga/passes/inplace_stream_flatten.py @@ -12,14 +12,21 @@ class InplaceStreamFlatten(OptimizerPass): def match(self, node): # Reshape acts as a Flatten layer when the result has 1 dimension - return isinstance(node, Reshape) and len(node.get_output_variable().shape) == 1 - - def transform(self, model, node): - if model.config.get_config_value('IOType') != 'io_stream': + if not (isinstance(node, Reshape) and len(node.get_output_variable().shape) == 1): + # Reshape with multiple outputs will be kept as is, or repack cannot handle different shapes return False + io_type = node.model.config.get_config_value('IOType') + return io_type == 'io_stream' + def transform(self, model, node): outvar = node.get_output_variable() invar = node.get_input_variable() newoutvar = InplaceTensorVariable(outvar, invar) node.set_attr(node.outputs[0], newoutvar) + if node.name in model.outputs: + prev_node = node.get_input_node() + assert ( + prev_node.name not in model.outputs + ), f"Cannot output node {prev_node.name}: In io_stream, flatten with a single output is a no-op. As a result, the previous node {prev_node.name}'s output will be used as the output. However, this node is already an output." # noqa: E501 + model.outputs = [name if name != node.name else prev_node.name for name in model.outputs] return False diff --git a/hls4ml/model/graph.py b/hls4ml/model/graph.py index cf715fd767..6dbcde90d5 100644 --- a/hls4ml/model/graph.py +++ b/hls4ml/model/graph.py @@ -506,6 +506,8 @@ def insert_node(self, node, before=None, input_idx=0): if next_node is not None: next_node.inputs[input_idx] = node.outputs[0] + else: + self.outputs = [node.outputs[0] if name == prev_node.outputs[0] else name for name in self.outputs] new_graph = OrderedDict() for k, v in self.graph.items(): @@ -514,47 +516,46 @@ def insert_node(self, node, before=None, input_idx=0): new_graph[node.name] = node self.graph = new_graph - self._update_model_outputs() def remove_node(self, node, rewire=True): """Remove a node from a graph. - By default, this function can connect the outputs of previous node to the input of next one. - Note that when removing a leaf node `rewire` should be set to `False`. + By default, this function can connect the outputs of previous node to the input of next one. If the removed node has multiple inputs/outputs tensors, an exception is raised. Args: node (Layer): The node to remove - rewire (bool, optional): If `True`, connects the outputs of the previous node - to the inputs of the next node + rewire (bool, optional): Deprecated, no effect Raises: - Exception: If an attempt is made to rewire a leaf node or a node with multiple - inputs/outputs. + Exception: If an attempt is made to rewire a node with multiple inputs/outputs. + """ # noqa: E501 - """ - if rewire: - inputs = [inp for inp in node.inputs if inp] - outputs = [outp for outp in node.outputs if outp] - if len(inputs) > 1 or len(outputs) > 1: - raise Exception('Cannot rewire a node with multiple inputs/outputs') - prev_node = node.get_input_node(node.inputs[0]) + inputs = [inp for inp in node.inputs if inp] + outputs = [outp for outp in node.outputs if outp] + + inp_var = node.get_input_variable() + out_var = node.get_output_variable() + + assert inp_var.shape == out_var.shape, f'Input and output shapes do not match for {node.name}' + + if len(inputs) > 1 or len(outputs) > 1: + raise Exception('Cannot delete a node with multiple inputs/outputs') + + if len(inputs) == 1: + # Connect inputs -> $outputs + if node.name in self.outputs: + self.outputs = [inputs[0] if name == node.name else name for name in self.outputs] + + if len(outputs) == 1 and len(inputs) == 1: next_nodes = [x for x in self.graph.values() if node.outputs[0] in x.inputs] - if prev_node is not None: - if len(next_nodes) > 0: - for next_node in next_nodes: - for i, _ in enumerate(next_node.inputs): - if node.outputs[0] == next_node.inputs[i]: - next_node.inputs[i] = prev_node.outputs[0] - break - else: - if not node.outputs[0] in self.outputs: - raise Exception('Cannot rewire a node without child') - else: - raise Exception('Cannot rewire a node without a parent') + for next_node in next_nodes: + # Connect inputs -> next + for i, nxt_inp in enumerate(next_node.inputs): + if outputs[0] == nxt_inp: + next_node.inputs[i] = inputs[0] del self.output_vars[node.outputs[0]] del self.graph[node.name] - self._update_model_outputs() def replace_node(self, old_node, new_node): """Replace an existing node in the graph with a new one. @@ -584,7 +585,11 @@ def replace_node(self, old_node, new_node): node.outputs[i] = repl[n] self.graph = OrderedDict((new_node.name, new_node) if k == old_node.name else (k, v) for k, v in self.graph.items()) - self._update_model_outputs() + + old_name = old_node.name + if old_name in self.outputs: + new_name = new_node.name + self.outputs = [new_name if name == old_name else name for name in self.outputs] def split_node(self, old_node, new_node1, new_node2): """Replace an existing node in the graph with two nodes in sequence. @@ -622,17 +627,9 @@ def split_node(self, old_node, new_node1, new_node2): else: new_graph[key] = value self.graph = new_graph - self._update_model_outputs() - - def _update_model_outputs(self): - '''Update the model outputs - All node outputs and inputs are found. The model outputs are set to all node outputs - that are not also node inputs. - ''' - node_outputs = [out for node in self.graph.values() for out in node.outputs] - node_inputs = [inp for node in self.graph.values() for inp in node.inputs] - self.outputs = [out for out in node_outputs if out not in node_inputs] + if old_node.name in self.outputs: + self.outputs = [new_node2.name if name == old_node.name else name for name in self.outputs] def next_layer(self): self.index += 1 diff --git a/test/pytest/test_multiout_network.py b/test/pytest/test_multiout_network.py index 15e23ff79a..2e780c8930 100644 --- a/test/pytest/test_multiout_network.py +++ b/test/pytest/test_multiout_network.py @@ -19,6 +19,19 @@ def model(): return model +@pytest.fixture(scope='module') +def model2(): + in1 = keras.layers.Input(shape=(24, 8)) + in2 = keras.layers.Input(shape=(16)) + out1 = keras.layers.Conv1D(1, 3)(in1) + out1 = keras.layers.Flatten()(out1) + out2 = keras.layers.Dense(16, activation='relu')(out1) + out2 = keras.layers.Add()([out2, in2]) + out3 = keras.layers.Dense(2)(out1) + model = keras.models.Model(inputs=[in1, in2], outputs=[out1, out2, out3]) + return model + + @pytest.fixture(scope='module') def data(): X = np.random.normal(0, 1, (1000, 10)) @@ -26,18 +39,20 @@ def data(): return X +@pytest.fixture(scope='module') +def data2(): + X1 = np.random.normal(0, 1, (1000, 24, 8)) + X2 = np.random.normal(0, 1, (1000, 16)) + X1 = np.clip(X1, -16, 15) + X2 = np.clip(X2, -16, 15) + return X1, X2 + + @pytest.mark.parametrize('backend', ['Vivado', 'Quartus', 'Vitis']) @pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream']) -def test_multi_clone(model, data, backend: str, io_type: str): +def test_multi_output_nn(model, data, backend: str, io_type: str): output_dir = str(test_root_path / f'hls4mlprj_multiout_network_{backend}_{io_type}') hls_config = {'Model': {'Precision': 'fixed<32,5>', 'ReuseFactor': 1}} - layer_config = { - 'dense1': {'Precision': {'result': 'fixed<35,5>'}}, - 'dense2': {'Precision': {'result': 'fixed<40,5>'}}, - 'dense1_linear': {'Precision': {'result': 'fixed<35,5>'}}, - 'dense2_linear': {'Precision': {'result': 'fixed<40,5>'}}, - } - hls_config['LayerName'] = layer_config model_hls = convert_from_keras_model( model, backend=backend, output_dir=output_dir, hls_config=hls_config, io_type=io_type ) @@ -50,3 +65,24 @@ def test_multi_clone(model, data, backend: str, io_type: str): assert np.allclose(r_hls[0], r_keras[0], atol=1e-5, rtol=0) assert np.allclose(r_hls[1], r_keras[1], atol=1e-5, rtol=0) + + +@pytest.mark.parametrize('backend', ['Vivado', 'Quartus', 'Vitis']) +@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream']) +@pytest.mark.parametrize('strategy', ['latency', 'resource']) +def test_multi_output_nn_2(model2, data2, backend: str, io_type: str, strategy: str): + """Cover corner case where a flatten layer is cloned multiple times, and used as model output""" + output_dir = str(test_root_path / f'hls4mlprj_multiout_network_2_{backend}_{io_type}_{strategy}') + hls_config = {'Model': {'Precision': 'fixed<32,5>', 'ReuseFactor': 1}, 'Strategy': strategy} + + model_hls = convert_from_keras_model( + model2, backend=backend, output_dir=output_dir, hls_config=hls_config, io_type=io_type + ) + + model_hls.compile() + r_hls = model_hls.predict(data2) + r_keras = model2.predict(data2, verbose=0, batch_size=1000) + + assert np.allclose(r_hls[0], r_keras[0], atol=1e-5, rtol=0) + assert np.allclose(r_hls[1], r_keras[1], atol=1e-5, rtol=0) + assert np.allclose(r_hls[2], r_keras[2], atol=1e-5, rtol=0) From 04cbe83762b213d47f141ea84859c5a67fb1be0d Mon Sep 17 00:00:00 2001 From: Chang Sun Date: Tue, 29 Oct 2024 21:44:33 +0000 Subject: [PATCH 04/13] relaxing remove node shape check cond --- hls4ml/model/graph.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hls4ml/model/graph.py b/hls4ml/model/graph.py index 6dbcde90d5..a641e38d26 100644 --- a/hls4ml/model/graph.py +++ b/hls4ml/model/graph.py @@ -536,7 +536,7 @@ def remove_node(self, node, rewire=True): inp_var = node.get_input_variable() out_var = node.get_output_variable() - assert inp_var.shape == out_var.shape, f'Input and output shapes do not match for {node.name}' + assert np.prod(inp_var.shape) == np.prod(out_var.shape), f'Input and output shapes do not match for {node.name}' if len(inputs) > 1 or len(outputs) > 1: raise Exception('Cannot delete a node with multiple inputs/outputs') From e73b3d30e9dfdb3926db48518f109dad475849a4 Mon Sep 17 00:00:00 2001 From: Chang Sun Date: Tue, 29 Oct 2024 22:57:59 +0000 Subject: [PATCH 05/13] fix regression errors --- hls4ml/backends/catapult/catapult_backend.py | 1 + .../backends/fpga/passes/inplace_parallel_reshape.py | 7 +++---- hls4ml/backends/fpga/passes/inplace_stream_flatten.py | 10 +++------- hls4ml/backends/quartus/quartus_backend.py | 7 ++++++- hls4ml/backends/vivado/vivado_backend.py | 1 + 5 files changed, 14 insertions(+), 12 deletions(-) diff --git a/hls4ml/backends/catapult/catapult_backend.py b/hls4ml/backends/catapult/catapult_backend.py index 5c85bf9b7e..6d1c17a3c6 100644 --- a/hls4ml/backends/catapult/catapult_backend.py +++ b/hls4ml/backends/catapult/catapult_backend.py @@ -88,6 +88,7 @@ def _register_flows(self): init_flow = register_flow('init_layers', initializers, requires=['optimize'], backend=self.name) streaming_passes = [ + 'catapult:inplace_stream_flatten', # Inform downstream changed packsize in case of skipping flatten 'catapult:reshape_stream', 'catapult:clone_output', 'catapult:insert_zero_padding_before_conv1d', diff --git a/hls4ml/backends/fpga/passes/inplace_parallel_reshape.py b/hls4ml/backends/fpga/passes/inplace_parallel_reshape.py index e0580e946e..a56531a310 100644 --- a/hls4ml/backends/fpga/passes/inplace_parallel_reshape.py +++ b/hls4ml/backends/fpga/passes/inplace_parallel_reshape.py @@ -11,12 +11,11 @@ class InplaceParallelReshape(OptimizerPass): """ def match(self, node): - return isinstance(node, Reshape) + if not isinstance(node, Reshape): + return + return node.model.config.get_config_value('IOType') == 'io_parallel' def transform(self, model, node): - if model.config.get_config_value('IOType') != 'io_parallel': - return False - outvar = node.get_output_variable() invar = node.get_input_variable() newoutvar = InplaceTensorVariable(outvar, invar) diff --git a/hls4ml/backends/fpga/passes/inplace_stream_flatten.py b/hls4ml/backends/fpga/passes/inplace_stream_flatten.py index ed54ad9ace..a209b02cb8 100644 --- a/hls4ml/backends/fpga/passes/inplace_stream_flatten.py +++ b/hls4ml/backends/fpga/passes/inplace_stream_flatten.py @@ -12,9 +12,11 @@ class InplaceStreamFlatten(OptimizerPass): def match(self, node): # Reshape acts as a Flatten layer when the result has 1 dimension - if not (isinstance(node, Reshape) and len(node.get_output_variable().shape) == 1): + if not (isinstance(node, Reshape)): # Reshape with multiple outputs will be kept as is, or repack cannot handle different shapes return False + if len(node.get_output_variable().shape) + node.name in node.model.outputs != 1: + return False io_type = node.model.config.get_config_value('IOType') return io_type == 'io_stream' @@ -23,10 +25,4 @@ def transform(self, model, node): invar = node.get_input_variable() newoutvar = InplaceTensorVariable(outvar, invar) node.set_attr(node.outputs[0], newoutvar) - if node.name in model.outputs: - prev_node = node.get_input_node() - assert ( - prev_node.name not in model.outputs - ), f"Cannot output node {prev_node.name}: In io_stream, flatten with a single output is a no-op. As a result, the previous node {prev_node.name}'s output will be used as the output. However, this node is already an output." # noqa: E501 - model.outputs = [name if name != node.name else prev_node.name for name in model.outputs] return False diff --git a/hls4ml/backends/quartus/quartus_backend.py b/hls4ml/backends/quartus/quartus_backend.py index aecad642c6..683d3f77b1 100644 --- a/hls4ml/backends/quartus/quartus_backend.py +++ b/hls4ml/backends/quartus/quartus_backend.py @@ -48,7 +48,12 @@ def _register_flows(self): initializers = self._get_layer_initializers() init_flow = register_flow('init_layers', initializers, requires=['optimize'], backend=self.name) - streaming_passes = ['quartus:reshape_stream', 'quartus:clone_output'] + streaming_passes = [ + 'quartus:inplace_stream_flatten', # Inform downstream changed packsize in case of skipping flatten + 'quartus:reshape_stream', + 'quartus:clone_output', + ] + streaming_flow = register_flow('streaming', streaming_passes, requires=[init_flow], backend=self.name) quartus_types = [ diff --git a/hls4ml/backends/vivado/vivado_backend.py b/hls4ml/backends/vivado/vivado_backend.py index 9f8a5171d3..e88af278f0 100644 --- a/hls4ml/backends/vivado/vivado_backend.py +++ b/hls4ml/backends/vivado/vivado_backend.py @@ -79,6 +79,7 @@ def _register_flows(self): init_flow = register_flow('init_layers', initializers, requires=['optimize'], backend=self.name) streaming_passes = [ + 'vivado:inplace_stream_flatten', # Inform downstream changed packsize in case of skipping flatten 'vivado:reshape_stream', 'vivado:clone_output', 'vivado:insert_zero_padding_before_conv1d', From b055233b1c07622f34d52b75ab81d4648829ba28 Mon Sep 17 00:00:00 2001 From: Chang Sun Date: Tue, 29 Oct 2024 23:06:41 +0000 Subject: [PATCH 06/13] catapult and oneapi tests --- hls4ml/backends/fpga/passes/inplace_stream_flatten.py | 2 +- test/pytest/test_multiout_network.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/hls4ml/backends/fpga/passes/inplace_stream_flatten.py b/hls4ml/backends/fpga/passes/inplace_stream_flatten.py index a209b02cb8..b2c9710e71 100644 --- a/hls4ml/backends/fpga/passes/inplace_stream_flatten.py +++ b/hls4ml/backends/fpga/passes/inplace_stream_flatten.py @@ -15,7 +15,7 @@ def match(self, node): if not (isinstance(node, Reshape)): # Reshape with multiple outputs will be kept as is, or repack cannot handle different shapes return False - if len(node.get_output_variable().shape) + node.name in node.model.outputs != 1: + if len(node.get_output_variable().shape) + (node.name in node.model.outputs) != 1: return False io_type = node.model.config.get_config_value('IOType') return io_type == 'io_stream' diff --git a/test/pytest/test_multiout_network.py b/test/pytest/test_multiout_network.py index 2e780c8930..88e884f5ce 100644 --- a/test/pytest/test_multiout_network.py +++ b/test/pytest/test_multiout_network.py @@ -67,7 +67,7 @@ def test_multi_output_nn(model, data, backend: str, io_type: str): assert np.allclose(r_hls[1], r_keras[1], atol=1e-5, rtol=0) -@pytest.mark.parametrize('backend', ['Vivado', 'Quartus', 'Vitis']) +@pytest.mark.parametrize('backend', ['Vivado', 'Quartus', 'Vitis', 'Catapult', 'OneAPI']) @pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream']) @pytest.mark.parametrize('strategy', ['latency', 'resource']) def test_multi_output_nn_2(model2, data2, backend: str, io_type: str, strategy: str): From 26b4c54b03d307cead6adc1a1b4c4e048731c80a Mon Sep 17 00:00:00 2001 From: Chang Sun Date: Tue, 29 Oct 2024 23:17:59 +0000 Subject: [PATCH 07/13] add catapult 3-clone --- .../catapult/nnet_utils/nnet_stream.h | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/hls4ml/templates/catapult/nnet_utils/nnet_stream.h b/hls4ml/templates/catapult/nnet_utils/nnet_stream.h index c76bfba5a6..ec2e9bfb1a 100644 --- a/hls4ml/templates/catapult/nnet_utils/nnet_stream.h +++ b/hls4ml/templates/catapult/nnet_utils/nnet_stream.h @@ -41,6 +41,26 @@ void clone_stream(ac_channel &data, ac_channel &res1, ac_channel< } } +template +void clone_stream(ac_channel &data, ac_channel &res1, ac_channel &res2, ac_channel &res3) { +#ifndef __SYNTHESIS__ + while (data.available(1)) +#endif + { + data_T in_data = data.read(); + res_T out_data; + + ClonePack: + for (int j = 0; j < data_T::size; j++) { + out_data[j] = in_data[j]; + } + + res1.write(out_data); + res2.write(out_data); + res3.write(out_data); + } +} + template void repack_stream(ac_channel &data, ac_channel &res) { if (data_T::size == res_T::size) { for (int i = 0; i < N / data_T::size; i++) { From 702e4eb654e4b92027649540e8ec63f6cf864fb1 Mon Sep 17 00:00:00 2001 From: Chang Sun Date: Tue, 29 Oct 2024 23:31:02 +0000 Subject: [PATCH 08/13] rm ill-condition _ --- hls4ml/backends/fpga/passes/inplace_stream_flatten.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/hls4ml/backends/fpga/passes/inplace_stream_flatten.py b/hls4ml/backends/fpga/passes/inplace_stream_flatten.py index b2c9710e71..a41efe6fd6 100644 --- a/hls4ml/backends/fpga/passes/inplace_stream_flatten.py +++ b/hls4ml/backends/fpga/passes/inplace_stream_flatten.py @@ -12,11 +12,9 @@ class InplaceStreamFlatten(OptimizerPass): def match(self, node): # Reshape acts as a Flatten layer when the result has 1 dimension - if not (isinstance(node, Reshape)): + if not (isinstance(node, Reshape) and len(node.get_output_variable().shape)) == 1: # Reshape with multiple outputs will be kept as is, or repack cannot handle different shapes return False - if len(node.get_output_variable().shape) + (node.name in node.model.outputs) != 1: - return False io_type = node.model.config.get_config_value('IOType') return io_type == 'io_stream' From d016612f5355a9b4ef65073510ba63a1b1f974ab Mon Sep 17 00:00:00 2001 From: Chang Sun Date: Fri, 8 Nov 2024 18:12:59 +0000 Subject: [PATCH 09/13] allow removing nodes w/o i or o --- hls4ml/model/graph.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/hls4ml/model/graph.py b/hls4ml/model/graph.py index a641e38d26..550296602c 100644 --- a/hls4ml/model/graph.py +++ b/hls4ml/model/graph.py @@ -533,11 +533,6 @@ def remove_node(self, node, rewire=True): inputs = [inp for inp in node.inputs if inp] outputs = [outp for outp in node.outputs if outp] - inp_var = node.get_input_variable() - out_var = node.get_output_variable() - - assert np.prod(inp_var.shape) == np.prod(out_var.shape), f'Input and output shapes do not match for {node.name}' - if len(inputs) > 1 or len(outputs) > 1: raise Exception('Cannot delete a node with multiple inputs/outputs') @@ -547,6 +542,14 @@ def remove_node(self, node, rewire=True): self.outputs = [inputs[0] if name == node.name else name for name in self.outputs] if len(outputs) == 1 and len(inputs) == 1: + inp_var = node.get_input_variable() + out_var = node.get_output_variable() + + # fmt: off + assert (np.prod(inp_var.shape) == np.prod(out_var.shape)), \ + f'Input and output shapes do not match for {node.name}: {inp_var.shape} -> {out_var.shape}' + # fmt: on + next_nodes = [x for x in self.graph.values() if node.outputs[0] in x.inputs] for next_node in next_nodes: # Connect inputs -> next From d1a3b7533e5cf90ce0dbbdf64fac15b2f2b49599 Mon Sep 17 00:00:00 2001 From: Chang Sun Date: Sat, 9 Nov 2024 03:31:00 +0000 Subject: [PATCH 10/13] chore --- hls4ml/backends/fpga/passes/clone.py | 6 +++--- hls4ml/model/graph.py | 13 ++++++++----- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/hls4ml/backends/fpga/passes/clone.py b/hls4ml/backends/fpga/passes/clone.py index 0e1ee363c8..c640d6f37e 100644 --- a/hls4ml/backends/fpga/passes/clone.py +++ b/hls4ml/backends/fpga/passes/clone.py @@ -86,13 +86,13 @@ def transform(self, model, node): out_var = node.get_output_variable(output) attrs = {'size': np.prod(out_var.shape)} - i0 = 1 + init_stream_idx = 1 if in_output: # If the value is used as output, add one extra stream idx = node.model.outputs.index(node.name) node.model.outputs[idx] = node.name + '_cpy1' - i0 = 2 - for i, layer in enumerate(output_map[output], i0): + init_stream_idx = 2 + for i, layer in enumerate(output_map[output], init_stream_idx): idx = layer.inputs.index(output) layer.inputs[idx] = output + f'_cpy{i}' diff --git a/hls4ml/model/graph.py b/hls4ml/model/graph.py index 550296602c..ffd173b8f2 100644 --- a/hls4ml/model/graph.py +++ b/hls4ml/model/graph.py @@ -520,15 +520,18 @@ def insert_node(self, node, before=None, input_idx=0): def remove_node(self, node, rewire=True): """Remove a node from a graph. - By default, this function can connect the outputs of previous node to the input of next one. If the removed node has multiple inputs/outputs tensors, an exception is raised. + By default, this function can connect the outputs of previous + node to the input of next one. If the removed node has multiple + inputs/outputs tensors, an exception is raised. Args: - node (Layer): The node to remove - rewire (bool, optional): Deprecated, no effect + node (Layer): The node to remove rewire (bool, optional): + Deprecated, no effect Raises: - Exception: If an attempt is made to rewire a node with multiple inputs/outputs. - """ # noqa: E501 + Exception: If an attempt is made to rewire a node with + multiple inputs/outputs. + """ inputs = [inp for inp in node.inputs if inp] outputs = [outp for outp in node.outputs if outp] From bf6fe7a567996e6b3f752a05763f3fc4ff9b44b2 Mon Sep 17 00:00:00 2001 From: Chang Sun Date: Sun, 10 Nov 2024 21:49:57 +0000 Subject: [PATCH 11/13] cosmatic --- hls4ml/backends/fpga/passes/clone.py | 7 ++++--- .../fpga/passes/inplace_parallel_reshape.py | 6 ++++-- .../fpga/passes/inplace_stream_flatten.py | 3 +-- hls4ml/model/graph.py | 16 ++++++++------- test/pytest/test_multiout_network.py | 20 ++++++++++++------- 5 files changed, 31 insertions(+), 21 deletions(-) diff --git a/hls4ml/backends/fpga/passes/clone.py b/hls4ml/backends/fpga/passes/clone.py index c640d6f37e..a36d96dfa8 100644 --- a/hls4ml/backends/fpga/passes/clone.py +++ b/hls4ml/backends/fpga/passes/clone.py @@ -1,4 +1,4 @@ -import numpy as np +from math import prod from hls4ml.backends.template import FunctionCallTemplate from hls4ml.model.layers import Layer, register_layer @@ -80,11 +80,12 @@ def transform(self, model, node): if n_outputs == 1: continue if n_outputs > 3: - msg = f'ERROR: Cloning output {output} of {node.__class__.__name__} ({node.name}) more than 3 times not currently supported' # noqa: E501 + msg = f'ERROR: Cloning output {output} of {node.class_name}\ + ({node.name}) more than 3 times not currently supported' raise ValueError(msg) out_var = node.get_output_variable(output) - attrs = {'size': np.prod(out_var.shape)} + attrs = {'size': prod(out_var.shape)} init_stream_idx = 1 if in_output: diff --git a/hls4ml/backends/fpga/passes/inplace_parallel_reshape.py b/hls4ml/backends/fpga/passes/inplace_parallel_reshape.py index a56531a310..82efe67100 100644 --- a/hls4ml/backends/fpga/passes/inplace_parallel_reshape.py +++ b/hls4ml/backends/fpga/passes/inplace_parallel_reshape.py @@ -12,7 +12,7 @@ class InplaceParallelReshape(OptimizerPass): def match(self, node): if not isinstance(node, Reshape): - return + return False return node.model.config.get_config_value('IOType') == 'io_parallel' def transform(self, model, node): @@ -24,6 +24,8 @@ def transform(self, model, node): prev_node = node.get_input_node() assert ( prev_node.name not in model.outputs - ), f"Cannot output node {prev_node.name}: reshape is a no-op in io_parallel. As a result, the previous node {prev_node.name}'s output will be used as the output. However, this node is already an output." # noqa: E501 + ), f"Cannot output node {prev_node.name}: reshape is a no-op in io_parallel.\ + As a result, the previous node {prev_node.name}'s output will be used as the\ + output. However, this node is already an output." model.outputs = [name if name != node.name else prev_node.name for name in model.outputs] return False diff --git a/hls4ml/backends/fpga/passes/inplace_stream_flatten.py b/hls4ml/backends/fpga/passes/inplace_stream_flatten.py index a41efe6fd6..69720632b3 100644 --- a/hls4ml/backends/fpga/passes/inplace_stream_flatten.py +++ b/hls4ml/backends/fpga/passes/inplace_stream_flatten.py @@ -15,8 +15,7 @@ def match(self, node): if not (isinstance(node, Reshape) and len(node.get_output_variable().shape)) == 1: # Reshape with multiple outputs will be kept as is, or repack cannot handle different shapes return False - io_type = node.model.config.get_config_value('IOType') - return io_type == 'io_stream' + return node.model.config.get_config_value('IOType') == 'io_stream' def transform(self, model, node): outvar = node.get_output_variable() diff --git a/hls4ml/model/graph.py b/hls4ml/model/graph.py index ffd173b8f2..fff970052e 100644 --- a/hls4ml/model/graph.py +++ b/hls4ml/model/graph.py @@ -524,14 +524,16 @@ def remove_node(self, node, rewire=True): node to the input of next one. If the removed node has multiple inputs/outputs tensors, an exception is raised. - Args: - node (Layer): The node to remove rewire (bool, optional): - Deprecated, no effect + :param node: The node to remove. + :type node: Layer + :param rewire: Deprecated, no effect. + :type rewire: bool, optional - Raises: - Exception: If an attempt is made to rewire a node with - multiple inputs/outputs. - """ + :raises Exception: If an attempt is made to rewire a node with + multiple inputs/outputs. + + .. deprecated:: 1.0 + The `rewire` parameter is deprecated and has no effect.""" inputs = [inp for inp in node.inputs if inp] outputs = [outp for outp in node.outputs if outp] diff --git a/test/pytest/test_multiout_network.py b/test/pytest/test_multiout_network.py index 88e884f5ce..1cf90d6141 100644 --- a/test/pytest/test_multiout_network.py +++ b/test/pytest/test_multiout_network.py @@ -20,7 +20,7 @@ def model(): @pytest.fixture(scope='module') -def model2(): +def model_corner_cases(): in1 = keras.layers.Input(shape=(24, 8)) in2 = keras.layers.Input(shape=(16)) out1 = keras.layers.Conv1D(1, 3)(in1) @@ -40,7 +40,7 @@ def data(): @pytest.fixture(scope='module') -def data2(): +def data_corner_cases(): X1 = np.random.normal(0, 1, (1000, 24, 8)) X2 = np.random.normal(0, 1, (1000, 16)) X1 = np.clip(X1, -16, 15) @@ -70,18 +70,24 @@ def test_multi_output_nn(model, data, backend: str, io_type: str): @pytest.mark.parametrize('backend', ['Vivado', 'Quartus', 'Vitis', 'Catapult', 'OneAPI']) @pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream']) @pytest.mark.parametrize('strategy', ['latency', 'resource']) -def test_multi_output_nn_2(model2, data2, backend: str, io_type: str, strategy: str): - """Cover corner case where a flatten layer is cloned multiple times, and used as model output""" +def test_multi_output_nn_corner_cases(model_corner_cases, data_corner_cases, backend: str, io_type: str, strategy: str): + """Cover corner cases, when: + - a layer outputs both to the next layer(s) and to the model output + - when an node removal/insertion is triggered internally + - a reshape in io_parallel, or flatten in io_stream layer's output is used multiple times + - and as layer output + - and by layer taking multiple inputs + """ output_dir = str(test_root_path / f'hls4mlprj_multiout_network_2_{backend}_{io_type}_{strategy}') hls_config = {'Model': {'Precision': 'fixed<32,5>', 'ReuseFactor': 1}, 'Strategy': strategy} model_hls = convert_from_keras_model( - model2, backend=backend, output_dir=output_dir, hls_config=hls_config, io_type=io_type + model_corner_cases, backend=backend, output_dir=output_dir, hls_config=hls_config, io_type=io_type ) model_hls.compile() - r_hls = model_hls.predict(data2) - r_keras = model2.predict(data2, verbose=0, batch_size=1000) + r_hls = model_hls.predict(data_corner_cases) + r_keras = model_corner_cases.predict(data_corner_cases, verbose=0, batch_size=1000) assert np.allclose(r_hls[0], r_keras[0], atol=1e-5, rtol=0) assert np.allclose(r_hls[1], r_keras[1], atol=1e-5, rtol=0) From 7b58c1d0f4ed5108455fab7123d1469f9f207219 Mon Sep 17 00:00:00 2001 From: Chang Sun Date: Wed, 13 Nov 2024 05:20:37 +0000 Subject: [PATCH 12/13] typo and docstring --- .../fpga/passes/inplace_stream_flatten.py | 2 +- hls4ml/model/graph.py | 25 ++++++++++--------- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/hls4ml/backends/fpga/passes/inplace_stream_flatten.py b/hls4ml/backends/fpga/passes/inplace_stream_flatten.py index 69720632b3..5eb649f568 100644 --- a/hls4ml/backends/fpga/passes/inplace_stream_flatten.py +++ b/hls4ml/backends/fpga/passes/inplace_stream_flatten.py @@ -12,7 +12,7 @@ class InplaceStreamFlatten(OptimizerPass): def match(self, node): # Reshape acts as a Flatten layer when the result has 1 dimension - if not (isinstance(node, Reshape) and len(node.get_output_variable().shape)) == 1: + if not (isinstance(node, Reshape) and len(node.get_output_variable().shape) == 1): # Reshape with multiple outputs will be kept as is, or repack cannot handle different shapes return False return node.model.config.get_config_value('IOType') == 'io_stream' diff --git a/hls4ml/model/graph.py b/hls4ml/model/graph.py index fff970052e..76962c6372 100644 --- a/hls4ml/model/graph.py +++ b/hls4ml/model/graph.py @@ -518,22 +518,23 @@ def insert_node(self, node, before=None, input_idx=0): self.graph = new_graph def remove_node(self, node, rewire=True): - """Remove a node from a graph. + """Removes a node from the graph. - By default, this function can connect the outputs of previous - node to the input of next one. If the removed node has multiple - inputs/outputs tensors, an exception is raised. + By default, this function connects the outputs of the previous + node to the inputs of the next node. If the removed node has multiple + input/output tensors, an exception is raised. - :param node: The node to remove. - :type node: Layer - :param rewire: Deprecated, no effect. - :type rewire: bool, optional + Args: + node (Layer): The node to remove. + rewire (bool, optional): Deprecated, has no effect. - :raises Exception: If an attempt is made to rewire a node with - multiple inputs/outputs. + Raises: + Exception: If an attempt is made to rewire a node with + multiple inputs/outputs. - .. deprecated:: 1.0 - The `rewire` parameter is deprecated and has no effect.""" + Note: + The `rewire` parameter is deprecated and has no effect. + """ inputs = [inp for inp in node.inputs if inp] outputs = [outp for outp in node.outputs if outp] From ef2e8f4727a2701b22a1ec68e79e8a1f39e3b5ae Mon Sep 17 00:00:00 2001 From: Chang Sun Date: Wed, 13 Nov 2024 06:04:07 +0000 Subject: [PATCH 13/13] allow io_stream if used as model output --- .../backends/fpga/passes/inplace_stream_flatten.py | 13 ++++++++++--- hls4ml/backends/fpga/passes/repack_stream.py | 4 +++- hls4ml/model/graph.py | 2 ++ test/pytest/test_multiout_network.py | 8 ++++++-- 4 files changed, 21 insertions(+), 6 deletions(-) diff --git a/hls4ml/backends/fpga/passes/inplace_stream_flatten.py b/hls4ml/backends/fpga/passes/inplace_stream_flatten.py index 5eb649f568..be4994e96e 100644 --- a/hls4ml/backends/fpga/passes/inplace_stream_flatten.py +++ b/hls4ml/backends/fpga/passes/inplace_stream_flatten.py @@ -11,11 +11,18 @@ class InplaceStreamFlatten(OptimizerPass): """ def match(self, node): - # Reshape acts as a Flatten layer when the result has 1 dimension + # Layers require flatten data can gather it from the stream, no need for repacking. + # Reshape acts as a Flatten layer when the result has 1 dimension. Make it a inplace tensor if it happens. + + if node.model.config.get_config_value('IOType') != 'io_stream': + return False if not (isinstance(node, Reshape) and len(node.get_output_variable().shape) == 1): - # Reshape with multiple outputs will be kept as is, or repack cannot handle different shapes + # If is not flatten + return False + if node.name in node.model.outputs: + # If used as model output. Output shape shall be preserved in this case. return False - return node.model.config.get_config_value('IOType') == 'io_stream' + return True def transform(self, model, node): outvar = node.get_output_variable() diff --git a/hls4ml/backends/fpga/passes/repack_stream.py b/hls4ml/backends/fpga/passes/repack_stream.py index 2408ec5ebe..9a77dddb29 100644 --- a/hls4ml/backends/fpga/passes/repack_stream.py +++ b/hls4ml/backends/fpga/passes/repack_stream.py @@ -49,7 +49,9 @@ class ReshapeStream(OptimizerPass): def match(self, node): # do not run optimizer pass for a flatten layer (1 output dimension) - return isinstance(node, Reshape) and len(node.get_output_variable().shape) > 1 + if not isinstance(node, Reshape): + return False + return len(node.get_output_variable().shape) > 1 or node.name in node.model.outputs def transform(self, model, node): if model.config.get_config_value('IOType') != 'io_stream': diff --git a/hls4ml/model/graph.py b/hls4ml/model/graph.py index 76962c6372..520f96ba5f 100644 --- a/hls4ml/model/graph.py +++ b/hls4ml/model/graph.py @@ -545,6 +545,8 @@ def remove_node(self, node, rewire=True): if len(inputs) == 1: # Connect inputs -> $outputs if node.name in self.outputs: + msg = f'Remove leaf node {node.name} will connect its input node {inputs[0]} to output, but it already is.' + assert inputs[0] not in self.outputs, msg self.outputs = [inputs[0] if name == node.name else name for name in self.outputs] if len(outputs) == 1 and len(inputs) == 1: diff --git a/test/pytest/test_multiout_network.py b/test/pytest/test_multiout_network.py index 1cf90d6141..366fac7fb5 100644 --- a/test/pytest/test_multiout_network.py +++ b/test/pytest/test_multiout_network.py @@ -28,7 +28,9 @@ def model_corner_cases(): out2 = keras.layers.Dense(16, activation='relu')(out1) out2 = keras.layers.Add()([out2, in2]) out3 = keras.layers.Dense(2)(out1) - model = keras.models.Model(inputs=[in1, in2], outputs=[out1, out2, out3]) + out4 = keras.layers.Dense(2)(out2) + out4 = keras.layers.Flatten()(out4) + model = keras.models.Model(inputs=[in1, in2], outputs=[out1, out2, out3, out4]) return model @@ -76,7 +78,8 @@ def test_multi_output_nn_corner_cases(model_corner_cases, data_corner_cases, bac - when an node removal/insertion is triggered internally - a reshape in io_parallel, or flatten in io_stream layer's output is used multiple times - and as layer output - - and by layer taking multiple inputs + - and by layer taking multiple inputs + - a Flatten layer outputs to the model output in io_stream """ output_dir = str(test_root_path / f'hls4mlprj_multiout_network_2_{backend}_{io_type}_{strategy}') hls_config = {'Model': {'Precision': 'fixed<32,5>', 'ReuseFactor': 1}, 'Strategy': strategy} @@ -92,3 +95,4 @@ def test_multi_output_nn_corner_cases(model_corner_cases, data_corner_cases, bac assert np.allclose(r_hls[0], r_keras[0], atol=1e-5, rtol=0) assert np.allclose(r_hls[1], r_keras[1], atol=1e-5, rtol=0) assert np.allclose(r_hls[2], r_keras[2], atol=1e-5, rtol=0) + assert np.allclose(r_hls[3], r_keras[3], atol=1e-5, rtol=0)