diff --git a/hls4ml/optimization/fused_dotp/codegen_backends/vitis.py b/hls4ml/optimization/fused_dotp/codegen_backends/vitis.py index de3bcc1df3..727278c150 100644 --- a/hls4ml/optimization/fused_dotp/codegen_backends/vitis.py +++ b/hls4ml/optimization/fused_dotp/codegen_backends/vitis.py @@ -9,3 +9,6 @@ def type(precision: FixedPointPrecision): precision = precision.make_proper() k, b, I = precision.k, precision.b, precision.I # noqa: E741 return f'ap_{"" if k else "u"}fixed<{b},{I}>' + + def get_stream_type_name(self, name: str) -> str: + return f'{name}::value_type' diff --git a/hls4ml/optimization/fused_dotp/optimizer_pass/common.py b/hls4ml/optimization/fused_dotp/optimizer_pass/common.py index 6532c07e20..244b071e23 100644 --- a/hls4ml/optimization/fused_dotp/optimizer_pass/common.py +++ b/hls4ml/optimization/fused_dotp/optimizer_pass/common.py @@ -1,4 +1,5 @@ from functools import lru_cache +from typing import Sequence import numpy as np @@ -14,15 +15,17 @@ from ..precision import FixedPointPrecision from ..symbolic_variable import Variable from ..utils import Singleton +from .pixel_unrolled_conv import get_input_KIF_idxs def nn_codegen( kernel: np.ndarray, bias: np.ndarray | None, - KIF_in: tuple[np.ndarray, ...], + KIFs_in: Sequence[Sequence[np.ndarray]] | np.ndarray, inp_name: str, out_name: str, backend='vivado', + index: Sequence[Sequence[int]] | None = None, ): """ Codegen for conv/dense layer @@ -41,15 +44,31 @@ def nn_codegen( _backend = VitisCodegenBackend() else: raise ValueError(f'Backend {backend} not supported') - precisions = [FixedPointPrecision.from_kif(k, i, f) for k, i, f in zip(*KIF_in)] - inp = np.array( - [Variable(p, id=f'{inp_name}[{_i}]') if isinstance(p, FixedPointPrecision) else p for _i, p in enumerate(precisions)] - ) - r = compile_conv(kernel, inp) - if bias is not None: - r = r + bias - r = list(r) - return code_gen(r, _backend, out_name), r + if index is None: + assert len(KIFs_in) == 1 + index = [range(len(KIFs_in[0][0]))] + + R = [] + + def to_symbol(p, i): + if i >= 0: + if isinstance(p, FixedPointPrecision): + return Variable(p, id=f'{inp_name}[{i}]') + else: + return p + else: + return 0 + + for KIF_in, idxs in zip(KIFs_in, index): + assert len(KIF_in) == 3 + + precisions = [FixedPointPrecision.from_kif(k, i, f) for k, i, f in zip(*KIF_in)] + inp = np.array([to_symbol(p, _i) for _i, p in zip(idxs, precisions)]) + r = compile_conv(kernel, inp) + if bias is not None: + r = r + bias + R.extend(r) + return code_gen(R, _backend, out_name), R def get_input_KIF(model: ModelGraph, node: Layer): @@ -131,9 +150,13 @@ def latency_mat_vec_mul_fn_gen(model: ModelGraph, node: Layer): # Get input precision per-element - KIF_in = get_input_KIF(model, node) - backend = node.model.config.config['Backend'] - oprs, r_variables = nn_codegen(kernel, node.attributes['bias_data'], KIF_in, 'inp', 'out', backend) + if 'Conv' in node.class_name and np.prod(node.attributes.attributes.get('dilation', 1)) > 1: + KIFs_in = [get_input_KIF(model, node)] + index = None + else: + KIFs_in, index = get_input_KIF_idxs(model, node) + backend = model.config.config['Backend'] + oprs, r_variables = nn_codegen(kernel, node.attributes['bias_data'], KIFs_in, 'inp', 'out', backend, index=index) opr_code = '\n '.join(oprs) return opr_code, r_variables @@ -145,12 +168,14 @@ class UnrollCodeGenPass(OptimizerPass, metaclass=Singleton): def __init__(self, *targets: str): self.target = targets + self.backend = None def match(self, node: Layer): return any(node.class_name == target for target in self.target) def get_stream_type_name(self, name: str) -> str: - raise NotImplementedError + assert self.backend is not None, 'Backend not set' + return self.backend.get_stream_type_name(name) def transform(self, model: ModelGraph, node: Layer): @@ -164,7 +189,7 @@ def transform(self, model: ModelGraph, node: Layer): input_t_name: str = input_named_t.name output_t_name: str = output_named_t.name - io_type: str = node.model.config.get_config_value('IOType') + io_type: str = model.config.get_config_value('IOType') assert io_type in ('io_stream', 'io_parallel'), f'io_type {io_type} is unknown.' if io_type == 'io_stream': input_t_name = self.get_stream_type_name(input_t_name) diff --git a/hls4ml/optimization/fused_dotp/optimizer_pass/pixel_unrolled_conv.py b/hls4ml/optimization/fused_dotp/optimizer_pass/pixel_unrolled_conv.py new file mode 100644 index 0000000000..8d31da7de8 --- /dev/null +++ b/hls4ml/optimization/fused_dotp/optimizer_pass/pixel_unrolled_conv.py @@ -0,0 +1,137 @@ +import warnings +from typing import Sequence + +import numpy as np + +from hls4ml.backends.fpga.fpga_types import NamedType +from hls4ml.model.graph import ModelGraph +from hls4ml.model.layers import Layer +from hls4ml.model.optimizer.passes.hgq_proxy_model import FixedPointQuantizer +from hls4ml.model.types import FixedPrecisionType, IntegerPrecisionType + + +def _im2col(kernel_size: Sequence[int], arr: np.ndarray, buffer: np.ndarray, axis: int): + w = kernel_size[0] + if len(kernel_size) == 3: + for i in range(arr.shape[axis] - w + 1): + patch = np.take(arr, range(i, i + w), axis=axis) + buffer[i] = patch.flatten() + else: + for i in range(arr.shape[axis] - w + 1): + patch = arr[i : i + w] + _im2col(kernel_size[1:], patch, buffer[i], axis + 1) + + +def im2col(kernel_size: Sequence[int], arr: np.ndarray): + if len(kernel_size) < 3: + return arr + shape = [inp_d - ker_d + 1 for inp_d, ker_d in zip(arr.shape, kernel_size[:-2])] + shape.append(np.prod(kernel_size[:-1])) # type: ignore + buf = np.empty(shape, dtype=arr.dtype) + _im2col(kernel_size, arr, buf, 0) + return buf + + +def ims2cols(kernel_size: Sequence[int], *arrs: np.ndarray): + return [im2col(kernel_size, arr) for arr in arrs] + + +def pad_and_stride_inp_arr(node: Layer, arr: np.ndarray, pad_val=0): + if node.class_name.endswith('Conv2D'): + pad_top = node.attributes.attributes['pad_top'] + pad_bottom = node.attributes.attributes['pad_bottom'] + pad_left = node.attributes.attributes['pad_left'] + pad_right = node.attributes.attributes['pad_right'] + st_h = node.attributes.attributes['stride_height'] + st_w = node.attributes.attributes['stride_width'] + return np.pad(arr, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), constant_values=pad_val)[::st_h, ::st_w] + if node.class_name.endswith('Conv1D'): + pad_left = node.attributes.attributes['pad_left'] + pad_right = node.attributes.attributes['pad_right'] + st_w = node.attributes.attributes['stride_width'] + return np.pad(arr, ((pad_left, pad_right), (0, 0)), constant_values=pad_val)[::st_w] + return arr + + +def pad_and_stride_inp_arrs(node: Layer, *arrs: np.ndarray, pad_val=0): + return [pad_and_stride_inp_arr(node, arr, pad_val) for arr in arrs] + + +def get_inp_shape(node: Layer): + if node.class_name.endswith('Conv1D'): + in_width = node.attributes.attributes['in_width'] + n_chan = node.attributes.attributes['n_chan'] + return (in_width, n_chan) + if node.class_name.endswith('Conv2D'): + in_height = node.attributes.attributes['in_height'] + in_width = node.attributes.attributes['in_width'] + n_chan = node.attributes.attributes['n_chan'] + return (in_height, in_width, n_chan) + if node.class_name == 'Dense': + n_in = node.attributes.attributes['n_in'] + return (n_in,) + raise ValueError(f'Unsupported node type {node.class_name}') + + +t_KIF = tuple[tuple[np.ndarray, ...], ...] + + +def get_input_KIF_idxs(model: ModelGraph, node: Layer) -> tuple[t_KIF, list[list[int]] | None]: + """Get input precision per-channel, in the form of (k, i, f) each of shape (in_channels,)""" + + assert 'weight_data' in node.attributes, 'No weight data found' + kernel = node.attributes['weight_data'] + inp_node: Layer = model.graph[node.inputs[0]] + input_named_t: NamedType = inp_node.attributes['result_t'] + + # Get input precision per-element + *ker_inp_shape, n_out_chan = kernel.shape + pf = node.attributes.attributes.get('parallelization_factor', 1) + n_partition = node.attributes.attributes.get('n_partitions', 1) + if n_partition != 1 and pf != 1: + warnings.warn( + f'Parallelization factor {pf}!= 1 is not fully optimized for n_partition {n_partition}>1. Using one unrolled kernel for all partitions.', # noqa: E501 + stacklevel=2, + ) + pf = 1 + if model.config.get_config_value('IOType') == 'io_stream': + if pf != 1: + warnings.warn( + f'Parallelization factor {pf} is not supported for io_stream. Ignoring.', stacklevel=2 # noqa: E501 + ) + pf = 1 + + index = None + + inp_shape = get_inp_shape(node) + if pf > 1: + index = np.arange(np.prod(inp_shape)).reshape(inp_shape) + index = pad_and_stride_inp_arr(node, index, -1) + index = im2col(kernel.shape, index) + index = index.reshape(pf, index.shape[-1]) + + if isinstance(inp_node, FixedPointQuantizer): + assert inp_node.mask_kbi is not None + K, B, I = inp_node.mask_kbi # noqa: E741 + K, B, I = K.squeeze(0), B.squeeze(0), I.squeeze(0) # noqa: E741 + K, I, F = K, I - K, B - I # noqa: E741 + K, I, F = np.broadcast_to(K, inp_shape), np.broadcast_to(I, inp_shape), np.broadcast_to(F, inp_shape) # noqa: E741 + K, I, F = pad_and_stride_inp_arrs(node, K, I, F) # noqa: E741 + K, I, F = ims2cols(kernel.shape, K, I, F) # noqa: E741 + K, I, F = (x.reshape(-1, K.shape[-1]) for x in (K, I, F)) # noqa: E741 + assert K.shape == I.shape == F.shape # noqa: E741 + assert ( + len(K) % pf == 0 + ), f'Number of kernel operands ({len(K)}) must be divisible by n_partitions ({pf})' # noqa: E741 + K, I, F = np.split(K, pf, axis=0), np.split(I, pf, axis=0), np.split(F, pf, axis=0) + K, I, F = np.max(K, axis=1), np.max(I, axis=1), np.max(F, axis=1) + else: + assert isinstance(input_named_t.precision, (FixedPrecisionType, IntegerPrecisionType)) + input_t = input_named_t.precision + k, i, f = input_t.signed, input_t.integer, input_t.fractional + i -= k + dim = np.prod(ker_inp_shape) + K, I, F = np.full((pf, dim), k), np.full((pf, dim), i), np.full((pf, dim), f) + KIFs_in = tuple(tuple(x) for x in np.array([K, I, F]).transpose(1, 0, 2)) + idx = index.tolist() if index is not None else None + return KIFs_in, idx diff --git a/hls4ml/optimization/fused_dotp/optimizer_pass/vitis.py b/hls4ml/optimization/fused_dotp/optimizer_pass/vitis.py index 4a30873ab9..b826b160c1 100644 --- a/hls4ml/optimization/fused_dotp/optimizer_pass/vitis.py +++ b/hls4ml/optimization/fused_dotp/optimizer_pass/vitis.py @@ -1,9 +1,12 @@ +import numpy as np + from hls4ml.backends import get_backend from hls4ml.backends.fpga.fpga_backend import FPGABackend from hls4ml.model.graph import ModelGraph from hls4ml.model.layers import Layer from hls4ml.model.optimizer import OptimizerPass +from ..codegen_backends import VitisCodegenBackend from .common import UnrollCodeGenPass conf_template = """struct config{index}{postfix} {{ @@ -17,9 +20,39 @@ class VitisUnrollCodeGen(UnrollCodeGenPass): def __init__(self): super().__init__('Dense', 'Conv1D', 'Conv2D', 'PointwiseConv1D', 'PointwiseConv2D') + self.backend = VitisCodegenBackend() + + +class VitisFullyUnrolledConvToDense(OptimizerPass): + def match(self, node: Layer): + return ( + node.get_attr('unrolled_codegen') + and node.class_name in ('Conv1D', 'Conv2D', 'PointwiseConv1D', 'PointwiseConv2D') + and node.get_attr('n_partitions') == 1 + and node.model.config.get_config_value("IOType") == 'io_parallel' + ) - def get_stream_type_name(self, name: str) -> str: - return f'{name}::value_type' + def transform(self, model: ModelGraph, node: Layer): + class_name = 'Dense' + + name = node.name + attrs = { + 'n_out': node.get_attr('out_width', 1) * node.get_attr('out_height', 1) * node.get_attr('n_filt'), # type: ignore # noqa: E501 + 'n_in': np.prod(node.get_input_variable().shape), + 'result_t': node.get_attr('result_t'), + 'unrolled_codegen': node.get_attr('unrolled_codegen'), + 'weight_data': node.get_attr('weight_data'), + 'bias_data': node.get_attr('weight_data'), + } + new_node = model.make_node(class_name, node.name, attrs, node.inputs.copy()) + new_node.attributes[name] = node.attributes[name] + new_node.attributes['result_t'] = node.attributes['result_t'] + new_node.attributes['index'] = node.attributes['index'] + new_node.index = node.index + del new_node.attributes.attributes['accum_t'] + del new_node.attributes.attributes['weight_t'] + del new_node.attributes.attributes['bias_t'] + model.replace_node(node, new_node) class VitisDensePreTemplate(OptimizerPass): @@ -60,6 +93,8 @@ def latency_transform(self, model: ModelGraph, node: Layer): # avoid output weights and bias; alternatie entry point does not use them del node.attributes.attributes['weight_data'] del node.attributes.attributes['bias_data'] + del node.attributes.attributes['weight'] + del node.attributes.attributes['bias'] class VitisConvPreTemplate(OptimizerPass): @@ -124,6 +159,8 @@ def latency_transform(self, model: ModelGraph, node: Layer): # avoid output weights and bias; alternatie entry point does not use them del node.attributes.attributes['weight_data'] del node.attributes.attributes['bias_data'] + # del node.attributes.attributes['weight'] + # del node.attributes.attributes['bias'] unrolled_codegen = VitisUnrollCodeGen() @@ -133,5 +170,6 @@ def latency_transform(self, model: ModelGraph, node: Layer): vitis_backend: FPGABackend = get_backend('vitis') # Optimizer flow is shared vitis_backend.register_pass('unrolled_codegen', unrolled_codegen, flow='vivado:specific_types') +vitis_backend.register_pass('fully_unrolled_conv_to_dense', VitisFullyUnrolledConvToDense(), flow='vivado:specific_types') vitis_backend.register_pass('dense_pre_template', vitis_dense_pre_template, flow='vivado:specific_types') vitis_backend.register_pass('conv_pre_template', vitis_conv_pre_template, flow='vivado:specific_types')