Skip to content

Commit

Permalink
support pointwise conv layers
Browse files Browse the repository at this point in the history
  • Loading branch information
calad0i committed Jul 22, 2024
1 parent 3d0d6fd commit d7e6652
Showing 1 changed file with 14 additions and 5 deletions.
19 changes: 14 additions & 5 deletions hls4ml/optimization/fused_dotp/optimizer_pass/vitis.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

class VitisUnrollCodeGen(UnrollCodeGenPass):
def __init__(self):
super().__init__('Dense', 'Conv1D', 'Conv2D')
super().__init__('Dense', 'Conv1D', 'Conv2D', 'PointwiseConv1D', 'PointwiseConv2D')

def get_stream_type_name(self, name: str) -> str:
return f'{name}::value_type'
Expand Down Expand Up @@ -66,7 +66,12 @@ class VitisConvPreTemplate(OptimizerPass):
def match(self, node: Layer):
if node.get_attr('implementation') != 'linebuffer':
return False
return node.get_attr('unrolled_codegen') and node.class_name in ('Conv1D', 'Conv2D')
return node.get_attr('unrolled_codegen') and node.class_name in (
'Conv1D',
'Conv2D',
'PointwiseConv1D',
'PointwiseConv2D',
)

def transform(self, model: ModelGraph, node: Layer):
io_type = model.config.get_config_value("IOType")
Expand Down Expand Up @@ -94,9 +99,13 @@ def latency_transform(self, model: ModelGraph, node: Layer):
node.attributes.attributes['dense_config'] = config_cpp

# override function_cpp
if node.class_name == 'Conv1D':
class_name = node.class_name
if class_name.startswith('Pointwise'):
class_name = class_name[9:]

if class_name == 'Conv1D':
fn_name = f'conv_1d<config{node.index}>'
elif node.class_name == 'Conv2D':
elif class_name == 'Conv2D':
fn_name = f'conv_2d<config{node.index}>'
else:
raise ValueError(f'Unsupported layer type {node.class_name}')
Expand All @@ -107,7 +116,7 @@ def latency_transform(self, model: ModelGraph, node: Layer):
include_headers = [
'nnet_utils/nnet_unrolled.h',
'nnet_utils/nnet_dense_latency.h',
f'nnet_utils/nnet_{node.class_name.lower()}.h',
f'nnet_utils/nnet_{class_name.lower()}.h',
'nnet_utils/nnet_conv_stream.h', # some properties defined in config need this
]
node.attributes.attributes['include_header'] = include_headers
Expand Down

0 comments on commit d7e6652

Please sign in to comment.