Skip to content

Commit

Permalink
Merge branch 'main' into resize_pr
Browse files Browse the repository at this point in the history
  • Loading branch information
nghielme authored Nov 15, 2024
2 parents c7f6983 + 51cb83c commit a5e32c5
Show file tree
Hide file tree
Showing 11 changed files with 216 additions and 90 deletions.
1 change: 1 addition & 0 deletions hls4ml/backends/catapult/catapult_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def _register_flows(self):
init_flow = register_flow('init_layers', initializers, requires=['optimize'], backend=self.name)

streaming_passes = [
'catapult:inplace_stream_flatten', # Inform downstream changed packsize in case of skipping flatten
'catapult:reshape_stream',
'catapult:clone_output',
'catapult:insert_zero_padding_before_conv1d',
Expand Down
80 changes: 50 additions & 30 deletions hls4ml/backends/fpga/passes/clone.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import numpy as np
from math import prod

from hls4ml.backends.template import FunctionCallTemplate
from hls4ml.model.layers import Layer, register_layer
Expand Down Expand Up @@ -54,41 +54,61 @@ def match(self, node):
if isinstance(node, Clone):
return False

return True
# Not needed for io_parallel
io_type = node.model.config.get_config_value('IOType')
if io_type != 'io_stream':
return False

# Check if the output is used more than once
output_map = node.get_output_use_map()
in_output = node.name in node.model.outputs
for output in node.outputs:
if len(output_map[output]) + in_output > 1:
# model output also need a stream
return True

return False

def transform(self, model, node):
if model.config.get_config_value('IOType') != 'io_stream':
return False

output_map = node.get_output_use_map()
in_output = node.name in node.model.outputs

transformed = False
for output in node.outputs:
if len(output_map[output]) > 1:
if len(output_map[output]) > 3:
print(
'WARNING: Cloning output {} of {} ({}) more than 3 times not currently supported'.format(
output, node.__class__.__name__, node.name
)
)
return False
out_var = node.get_output_variable(output)
for i, layer in enumerate(output_map[output], 1):
attrs = {'size': np.prod(out_var.shape)}
idx = layer.inputs.index(output)
layer.inputs[idx] = output + '_cpy' + str(i)

clone_layer: Clone = model.make_node(
Clone,
'clone_' + node.name,
attrs,
[output],
[output + '_cpy' + str(i + 1) for i in range(len(output_map[output]))],
)
for i in range(len(output_map[output])):
key = output + '_cpy' + str(i + 1)
clone_layer.attributes[key].type = node.get_output_variable().type
model.insert_node(clone_layer)
transformed = True

n_outputs = len(output_map[output]) + in_output
if n_outputs == 1:
continue
if n_outputs > 3:
msg = f'ERROR: Cloning output {output} of {node.class_name}\
({node.name}) more than 3 times not currently supported'
raise ValueError(msg)

out_var = node.get_output_variable(output)
attrs = {'size': prod(out_var.shape)}

init_stream_idx = 1
if in_output:
# If the value is used as output, add one extra stream
idx = node.model.outputs.index(node.name)
node.model.outputs[idx] = node.name + '_cpy1'
init_stream_idx = 2
for i, layer in enumerate(output_map[output], init_stream_idx):
idx = layer.inputs.index(output)
layer.inputs[idx] = output + f'_cpy{i}'

clone_layer: Clone = model.make_node(
Clone,
'clone_' + node.name,
attrs,
[output],
[output + '_cpy' + str(i + 1) for i in range(n_outputs)],
)
for i in range(n_outputs):
key = output + '_cpy' + str(i + 1)
clone_layer.attributes[key].type = node.attributes['result_t']
model.insert_node(clone_layer)
transformed = True

return transformed
15 changes: 11 additions & 4 deletions hls4ml/backends/fpga/passes/inplace_parallel_reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,21 @@ class InplaceParallelReshape(OptimizerPass):
"""

def match(self, node):
return isinstance(node, Reshape)

def transform(self, model, node):
if model.config.get_config_value('IOType') != 'io_parallel':
if not isinstance(node, Reshape):
return False
return node.model.config.get_config_value('IOType') == 'io_parallel'

def transform(self, model, node):
outvar = node.get_output_variable()
invar = node.get_input_variable()
newoutvar = InplaceTensorVariable(outvar, invar)
node.set_attr(node.outputs[0], newoutvar)
if node.name in model.outputs:
prev_node = node.get_input_node()
assert (
prev_node.name not in model.outputs
), f"Cannot output node {prev_node.name}: reshape is a no-op in io_parallel.\
As a result, the previous node {prev_node.name}'s output will be used as the\
output. However, this node is already an output."
model.outputs = [name if name != node.name else prev_node.name for name in model.outputs]
return False
15 changes: 11 additions & 4 deletions hls4ml/backends/fpga/passes/inplace_stream_flatten.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,20 @@ class InplaceStreamFlatten(OptimizerPass):
"""

def match(self, node):
# Reshape acts as a Flatten layer when the result has 1 dimension
return isinstance(node, Reshape) and len(node.get_output_variable().shape) == 1
# Layers require flatten data can gather it from the stream, no need for repacking.
# Reshape acts as a Flatten layer when the result has 1 dimension. Make it a inplace tensor if it happens.

def transform(self, model, node):
if model.config.get_config_value('IOType') != 'io_stream':
if node.model.config.get_config_value('IOType') != 'io_stream':
return False
if not (isinstance(node, Reshape) and len(node.get_output_variable().shape) == 1):
# If is not flatten
return False
if node.name in node.model.outputs:
# If used as model output. Output shape shall be preserved in this case.
return False
return True

def transform(self, model, node):
outvar = node.get_output_variable()
invar = node.get_input_variable()
newoutvar = InplaceTensorVariable(outvar, invar)
Expand Down
4 changes: 3 additions & 1 deletion hls4ml/backends/fpga/passes/repack_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ class ReshapeStream(OptimizerPass):

def match(self, node):
# do not run optimizer pass for a flatten layer (1 output dimension)
return isinstance(node, Reshape) and len(node.get_output_variable().shape) > 1
if not isinstance(node, Reshape):
return False
return len(node.get_output_variable().shape) > 1 or node.name in node.model.outputs

def transform(self, model, node):
if model.config.get_config_value('IOType') != 'io_stream':
Expand Down
7 changes: 6 additions & 1 deletion hls4ml/backends/quartus/quartus_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,12 @@ def _register_flows(self):
initializers = self._get_layer_initializers()
init_flow = register_flow('init_layers', initializers, requires=['optimize'], backend=self.name)

streaming_passes = ['quartus:reshape_stream', 'quartus:clone_output']
streaming_passes = [
'quartus:inplace_stream_flatten', # Inform downstream changed packsize in case of skipping flatten
'quartus:reshape_stream',
'quartus:clone_output',
]

streaming_flow = register_flow('streaming', streaming_passes, requires=[init_flow], backend=self.name)

quartus_types = [
Expand Down
1 change: 1 addition & 0 deletions hls4ml/backends/vivado/vivado_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def _register_flows(self):
init_flow = register_flow('init_layers', initializers, requires=['optimize'], backend=self.name)

streaming_passes = [
'vivado:inplace_stream_flatten', # Inform downstream changed packsize in case of skipping flatten
'vivado:reshape_stream',
'vivado:clone_output',
'vivado:insert_zero_padding_before_conv1d',
Expand Down
86 changes: 47 additions & 39 deletions hls4ml/model/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,8 @@ def insert_node(self, node, before=None, input_idx=0):

if next_node is not None:
next_node.inputs[input_idx] = node.outputs[0]
else:
self.outputs = [node.outputs[0] if name == prev_node.outputs[0] else name for name in self.outputs]

new_graph = OrderedDict()
for k, v in self.graph.items():
Expand All @@ -514,47 +516,57 @@ def insert_node(self, node, before=None, input_idx=0):
new_graph[node.name] = node

self.graph = new_graph
self._update_model_outputs()

def remove_node(self, node, rewire=True):
"""Remove a node from a graph.
"""Removes a node from the graph.
By default, this function can connect the outputs of previous node to the input of next one.
Note that when removing a leaf node `rewire` should be set to `False`.
By default, this function connects the outputs of the previous
node to the inputs of the next node. If the removed node has multiple
input/output tensors, an exception is raised.
Args:
node (Layer): The node to remove
rewire (bool, optional): If `True`, connects the outputs of the previous node
to the inputs of the next node
node (Layer): The node to remove.
rewire (bool, optional): Deprecated, has no effect.
Raises:
Exception: If an attempt is made to rewire a leaf node or a node with multiple
inputs/outputs.
Exception: If an attempt is made to rewire a node with
multiple inputs/outputs.
Note:
The `rewire` parameter is deprecated and has no effect.
"""
if rewire:
inputs = [inp for inp in node.inputs if inp]
outputs = [outp for outp in node.outputs if outp]
if len(inputs) > 1 or len(outputs) > 1:
raise Exception('Cannot rewire a node with multiple inputs/outputs')
prev_node = node.get_input_node(node.inputs[0])

inputs = [inp for inp in node.inputs if inp]
outputs = [outp for outp in node.outputs if outp]

if len(inputs) > 1 or len(outputs) > 1:
raise Exception('Cannot delete a node with multiple inputs/outputs')

if len(inputs) == 1:
# Connect inputs -> $outputs
if node.name in self.outputs:
msg = f'Remove leaf node {node.name} will connect its input node {inputs[0]} to output, but it already is.'
assert inputs[0] not in self.outputs, msg
self.outputs = [inputs[0] if name == node.name else name for name in self.outputs]

if len(outputs) == 1 and len(inputs) == 1:
inp_var = node.get_input_variable()
out_var = node.get_output_variable()

# fmt: off
assert (np.prod(inp_var.shape) == np.prod(out_var.shape)), \
f'Input and output shapes do not match for {node.name}: {inp_var.shape} -> {out_var.shape}'
# fmt: on

next_nodes = [x for x in self.graph.values() if node.outputs[0] in x.inputs]
if prev_node is not None:
if len(next_nodes) > 0:
for next_node in next_nodes:
for i, _ in enumerate(next_node.inputs):
if node.outputs[0] == next_node.inputs[i]:
next_node.inputs[i] = prev_node.outputs[0]
break
else:
if not node.outputs[0] in self.outputs:
raise Exception('Cannot rewire a node without child')
else:
raise Exception('Cannot rewire a node without a parent')
for next_node in next_nodes:
# Connect inputs -> next
for i, nxt_inp in enumerate(next_node.inputs):
if outputs[0] == nxt_inp:
next_node.inputs[i] = inputs[0]

del self.output_vars[node.outputs[0]]
del self.graph[node.name]
self._update_model_outputs()

def replace_node(self, old_node, new_node):
"""Replace an existing node in the graph with a new one.
Expand Down Expand Up @@ -584,7 +596,11 @@ def replace_node(self, old_node, new_node):
node.outputs[i] = repl[n]

self.graph = OrderedDict((new_node.name, new_node) if k == old_node.name else (k, v) for k, v in self.graph.items())
self._update_model_outputs()

old_name = old_node.name
if old_name in self.outputs:
new_name = new_node.name
self.outputs = [new_name if name == old_name else name for name in self.outputs]

def split_node(self, old_node, new_node1, new_node2):
"""Replace an existing node in the graph with two nodes in sequence.
Expand Down Expand Up @@ -622,17 +638,9 @@ def split_node(self, old_node, new_node1, new_node2):
else:
new_graph[key] = value
self.graph = new_graph
self._update_model_outputs()

def _update_model_outputs(self):
'''Update the model outputs

All node outputs and inputs are found. The model outputs are set to all node outputs
that are not also node inputs.
'''
node_outputs = [out for node in self.graph.values() for out in node.outputs]
node_inputs = [inp for node in self.graph.values() for inp in node.inputs]
self.outputs = [out for out in node_outputs if out not in node_inputs]
if old_node.name in self.outputs:
self.outputs = [new_node2.name if name == old_node.name else name for name in self.outputs]

def next_layer(self):
self.index += 1
Expand Down
20 changes: 20 additions & 0 deletions hls4ml/templates/catapult/nnet_utils/nnet_stream.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,26 @@ void clone_stream(ac_channel<data_T> &data, ac_channel<res_T> &res1, ac_channel<
}
}

template <class data_T, class res_T, int N>
void clone_stream(ac_channel<data_T> &data, ac_channel<res_T> &res1, ac_channel<res_T> &res2, ac_channel<res_T> &res3) {
#ifndef __SYNTHESIS__
while (data.available(1))
#endif
{
data_T in_data = data.read();
res_T out_data;

ClonePack:
for (int j = 0; j < data_T::size; j++) {
out_data[j] = in_data[j];
}

res1.write(out_data);
res2.write(out_data);
res3.write(out_data);
}
}

template <class data_T, class res_T, int N> void repack_stream(ac_channel<data_T> &data, ac_channel<res_T> &res) {
if (data_T::size == res_T::size) {
for (int i = 0; i < N / data_T::size; i++) {
Expand Down
15 changes: 12 additions & 3 deletions hls4ml/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ def config_from_pytorch_model(
default_reuse_factor=1,
channels_last_conversion='full',
transpose_outputs=True,
max_precision=None,
):
"""Create an HLS conversion config given the PyTorch model.
Expand All @@ -304,7 +305,8 @@ def config_from_pytorch_model(
will generate config keys for every layer separately, allowing for highly specific
configuration tweaks.
backend(str, optional): Name of the backend to use
default_precision (str, optional): Default precision to use. Defaults to 'fixed<16,6>'.
default_precision (str, optional): Default precision to use. Defaults to 'fixed<16,6>'. Note, this must
be an explicit precision: 'auto' is not allowed.
default_reuse_factor (int, optional): Default reuse factor. Defaults to 1.
channels_last_conversion (string, optional): Configures the conversion of pytorch layers to
'channels_last' dataformate. Can be set to 'full', 'internal', or 'off'. If 'full', both the inputs
Expand All @@ -313,6 +315,8 @@ def config_from_pytorch_model(
transpose_outputs (bool, optional): Set to 'False' if the output should not be transposed from
channels_last into channels_first data format. Defaults to 'False'. If False, outputs needs
to be transposed manually.
max_precision (str or None, optional): Maximum width precision to use. Defaults to None, meaning no maximum.
Note: Only integer and fixed precisions are supported
Raises:
Exception: If PyTorch model has layers not supported by hls4ml.
Expand All @@ -324,11 +328,16 @@ def config_from_pytorch_model(
config = {}

model_config = {}
model_config['Precision'] = default_precision
model_config['Precision'] = {}
model_config['Precision']['default'] = default_precision
if max_precision is not None:
model_config['Precision']['maximum'] = max_precision
model_config['ReuseFactor'] = default_reuse_factor
model_config['ChannelsLastConversion'] = channels_last_conversion
model_config['TransposeOutputs'] = transpose_outputs
model_config['Strategy'] = 'Latency'
model_config['BramFactor'] = 1_000_000_000
model_config['TraceOutput'] = False

config['Model'] = model_config
config['PytorchModel'] = model
Expand Down Expand Up @@ -372,7 +381,7 @@ def make_layer_config(layer):
if name.endswith('_t'):
name = name[:-2]
if attr.default is None:
precision_cfg[name] = default_precision
precision_cfg[name] = 'auto'
else:
precision_cfg[name] = str(attr.default)
elif attr.name == 'reuse_factor':
Expand Down
Loading

0 comments on commit a5e32c5

Please sign in to comment.