From d7e665253d81fa41597b2a3333c88c2c6e129fec Mon Sep 17 00:00:00 2001 From: Chang Sun Date: Sun, 21 Jul 2024 23:15:22 -0700 Subject: [PATCH] support pointwise conv layers --- .../fused_dotp/optimizer_pass/vitis.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/hls4ml/optimization/fused_dotp/optimizer_pass/vitis.py b/hls4ml/optimization/fused_dotp/optimizer_pass/vitis.py index ccc1260da3..4a30873ab9 100644 --- a/hls4ml/optimization/fused_dotp/optimizer_pass/vitis.py +++ b/hls4ml/optimization/fused_dotp/optimizer_pass/vitis.py @@ -16,7 +16,7 @@ class VitisUnrollCodeGen(UnrollCodeGenPass): def __init__(self): - super().__init__('Dense', 'Conv1D', 'Conv2D') + super().__init__('Dense', 'Conv1D', 'Conv2D', 'PointwiseConv1D', 'PointwiseConv2D') def get_stream_type_name(self, name: str) -> str: return f'{name}::value_type' @@ -66,7 +66,12 @@ class VitisConvPreTemplate(OptimizerPass): def match(self, node: Layer): if node.get_attr('implementation') != 'linebuffer': return False - return node.get_attr('unrolled_codegen') and node.class_name in ('Conv1D', 'Conv2D') + return node.get_attr('unrolled_codegen') and node.class_name in ( + 'Conv1D', + 'Conv2D', + 'PointwiseConv1D', + 'PointwiseConv2D', + ) def transform(self, model: ModelGraph, node: Layer): io_type = model.config.get_config_value("IOType") @@ -94,9 +99,13 @@ def latency_transform(self, model: ModelGraph, node: Layer): node.attributes.attributes['dense_config'] = config_cpp # override function_cpp - if node.class_name == 'Conv1D': + class_name = node.class_name + if class_name.startswith('Pointwise'): + class_name = class_name[9:] + + if class_name == 'Conv1D': fn_name = f'conv_1d' - elif node.class_name == 'Conv2D': + elif class_name == 'Conv2D': fn_name = f'conv_2d' else: raise ValueError(f'Unsupported layer type {node.class_name}') @@ -107,7 +116,7 @@ def latency_transform(self, model: ModelGraph, node: Layer): include_headers = [ 'nnet_utils/nnet_unrolled.h', 'nnet_utils/nnet_dense_latency.h', - f'nnet_utils/nnet_{node.class_name.lower()}.h', + f'nnet_utils/nnet_{class_name.lower()}.h', 'nnet_utils/nnet_conv_stream.h', # some properties defined in config need this ] node.attributes.attributes['include_header'] = include_headers