Skip to content

Commit

Permalink
allow disable full unroll and dsp offload
Browse files Browse the repository at this point in the history
  • Loading branch information
calad0i committed Aug 28, 2024
1 parent 28f51bd commit 3cf2f2f
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 13 deletions.
3 changes: 3 additions & 0 deletions hls4ml/optimization/fused_dotp/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from .codegen_backends import CodegenBackend, VitisCodegenBackend, code_gen # noqa: F401
from .dotp_unroll import compile_dense # noqa: F401
from .dotp_unroll import set_dsp_offload_threshold # noqa: F401
from .optimizer_pass import vitis as _ # noqa: F401
from .optimizer_pass.pixel_unrolled_conv import enable_pixel_unroll # noqa: F401
from .precision import FixedPointPrecision # noqa: F401
from .symbolic_variable import Variable # noqa: F401
26 changes: 23 additions & 3 deletions hls4ml/optimization/fused_dotp/dotp_unroll.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,13 @@
from . import symbolic_variable
from .symbolic_variable import Variable

DSP_OFFLOAD_THRES = 5


def set_dsp_offload_threshold(thres: int):
global DSP_OFFLOAD_THRES
DSP_OFFLOAD_THRES = thres


def const_fp_bits(x):
"Number of fp bits needed to represent x exactly."
Expand Down Expand Up @@ -128,7 +135,7 @@ def bit_reduction(combination_mask: np.ndarray, bit_mask: list[int], bit_loc: in

def to_operations(arr: np.ndarray):
"""For a 2d array as linear operator, decompose it as a series of operations.
y = arr @ v is equivalent to:
y = v @ arr is equivalent to:
Returns:
`shift`, `gather_tos`, `extract_froms`, `bit_extract_order`
```
Expand Down Expand Up @@ -195,10 +202,23 @@ def balanced_reduction(vec: list):


def _compile_dense(kernel: np.ndarray, inp: np.ndarray):
shifts, gather_tos, extract_froms, bit_extract_order = to_operations(kernel)
ch_in, ch_out = kernel.shape
buf0 = inp * 2.0**shifts
r: list[float | Variable | list[Variable]] = np.empty((ch_out, 0), dtype=object).tolist()

_, combination_mask = get_mat_shift_mask(kernel)
bits_k = np.sum(np.abs(combination_mask), axis=(2))
bits_i = np.array([v.b for v in inp])
bops = bits_k * bits_i[:, None]
offload = np.where(bops > DSP_OFFLOAD_THRES)
if len(offload[0]) > 0:
kernel = kernel.copy()
for i, j in zip(*offload):
c = kernel[i, j]
kernel[i, j] = 0
r[j].append(c * inp[i])
shifts, gather_tos, extract_froms, bit_extract_order = to_operations(kernel)

buf0 = inp * 2.0**shifts
with symbolic_variable.fuse_associative_ops(False):
for i, (gather_to, extract_from) in enumerate(zip(gather_tos, extract_froms)):
_buf1 = [[] for _ in range(np.max(gather_to) + 1)]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from hls4ml.model.optimizer.passes.hgq_proxy_model import FixedPointQuantizer
from hls4ml.model.types import FixedPrecisionType, IntegerPrecisionType

ENABLE_PIXEL_UNROLL = False


def _im2col(kernel_size: Sequence[int], arr: np.ndarray, buffer: np.ndarray, axis: int):
w = kernel_size[0]
Expand Down Expand Up @@ -88,6 +90,9 @@ def get_input_KIF_idxs(model: ModelGraph, node: Layer) -> tuple[t_KIF, list[list
*ker_inp_shape, n_out_chan = kernel.shape
pf = node.attributes.attributes.get('parallelization_factor', 1)
n_partition = node.attributes.attributes.get('n_partitions', 1)

if not ENABLE_PIXEL_UNROLL:
pf = 1
if n_partition != 1 and pf != 1:
warnings.warn(
f'Parallelization factor {pf}!= 1 is not fully optimized for n_partition {n_partition}>1. Using one unrolled kernel for all partitions.', # noqa: E501
Expand Down Expand Up @@ -135,3 +140,9 @@ def get_input_KIF_idxs(model: ModelGraph, node: Layer) -> tuple[t_KIF, list[list
KIFs_in = tuple(tuple(x) for x in np.array([K, I, F]).transpose(1, 0, 2))
idx = index.tolist() if index is not None else None
return KIFs_in, idx


def enable_pixel_unroll(enabled=True):
global ENABLE_PIXEL_UNROLL
ENABLE_PIXEL_UNROLL = enabled
return ENABLE_PIXEL_UNROLL
5 changes: 4 additions & 1 deletion hls4ml/optimization/fused_dotp/optimizer_pass/vitis.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from ..codegen_backends import VitisCodegenBackend
from .common import UnrollCodeGenPass
from .pixel_unrolled_conv import ENABLE_PIXEL_UNROLL

conf_template = """struct config{index}{postfix} {{
static const unsigned n_in = {n_in};
Expand All @@ -26,7 +27,8 @@ def __init__(self):
class VitisFullyUnrolledConvToDense(OptimizerPass):
def match(self, node: Layer):
return (
node.get_attr('unrolled_codegen')
ENABLE_PIXEL_UNROLL
and node.get_attr('unrolled_codegen')
and node.class_name in ('Conv1D', 'Conv2D', 'PointwiseConv1D', 'PointwiseConv2D')
and node.get_attr('n_partitions') == 1
and node.model.config.get_config_value("IOType") == 'io_parallel'
Expand All @@ -43,6 +45,7 @@ def transform(self, model: ModelGraph, node: Layer):
'unrolled_codegen': node.get_attr('unrolled_codegen'),
'weight_data': node.get_attr('weight_data'),
'bias_data': node.get_attr('weight_data'),
'r_variables': node.get_attr('r_variables'),
}
new_node = model.make_node(class_name, node.name, attrs, node.inputs.copy())
new_node.attributes[name] = node.attributes[name]
Expand Down
47 changes: 38 additions & 9 deletions hls4ml/optimization/fused_dotp/resoure_surrogate.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from copy import copy

import numpy as np
import pandas as pd

from hls4ml.model import ModelGraph

from .precision import FixedPointPrecision
from .symbolic_variable import Variable
Expand Down Expand Up @@ -36,23 +39,23 @@ def resource_bin_add(p1: FixedPointPrecision, p2: FixedPointPrecision):


def resource_bin_mul(p1: FixedPointPrecision, p2: FixedPointPrecision):
return DictWrap(mul=p1.b * p2.b)
return DictWrap(mul=p1.b * p2.b, n_mul=1)


def resource_bin_shift(p1: FixedPointPrecision, shift: int):
return DictWrap(shift=p1.b * abs(shift))
return DictWrap(shift=p1.b * abs(shift), n_shift=1)


def resource_bin_max(p1: FixedPointPrecision, p2: FixedPointPrecision):
return DictWrap(max=max(p1.b, p2.b))
return DictWrap(max=max(p1.b, p2.b), n_max=1)


def resource_bin_sub(p1: FixedPointPrecision, p2: FixedPointPrecision):
return DictWrap(sub=resource_bin_add(p1, p2)['add'])
return DictWrap(sub=resource_bin_add(p1, p2)['add'], n_sub=1)


def resource_bin_neg(p1: FixedPointPrecision):
return DictWrap(neg=p1.b)
return DictWrap(neg=p1.b, n_neg=1)


class ResourceSurrogate:
Expand Down Expand Up @@ -112,16 +115,42 @@ def _trace(self, v: Variable | int | float, recorded: set) -> int | DictWrap:
recorded.add(v)
return resource

def trace(self, r: list | np.ndarray, name: str):
def trace(self, r: list | np.ndarray, name: str, pf: int = 1):
s = set()
arr = np.array(r).ravel()
params: DictWrap = sum(self._trace(v, s) for v in arr) # type: ignore
zero = DictWrap(add=0, sub=0, mul=0, shift=0, neg=0, max=0, depth=0)
zero = zero + {f'n_{k}': 0 for k in zero.keys()}
params: DictWrap = zero + sum(self._trace(v, s) for v in arr) # type: ignore
if params == 0: # layer outputs const array, no operation performed. skip
return
latency = max(v.depth for v in arr if isinstance(v, Variable))
params['latency'] = latency
depth = max(v.depth for v in arr if isinstance(v, Variable))
params['depth'] = depth
params['pf'] = pf
self.layers[name] = params

def scan(self, model: ModelGraph):
for name, layer in model.graph.items():
r_variables = layer.attributes.attributes.get('r_variables')
if r_variables is None:
continue
pf = layer.attributes.attributes.get('parallelization_factor', 1)
self.trace(r_variables, name, pf)

def _summary(self):
df = pd.DataFrame.from_dict(self.layers, orient='index')
lut = np.round((df['add'] + df['neg'] + 2 * df['sub']) * 0.55).astype(int) * df['pf']
dsp = np.round(df['n_mul']).astype(int) * df['pf']
latency_ns = df['depth'] * 0.86
return df, pd.DataFrame({'LUT': lut, 'DSP': dsp, 'Latency (ns)': latency_ns})

def summary(self):
df, summary = self._summary()
return summary

def full_summary(self):
df, summary = self._summary()
return pd.concat([df, summary], axis=1)


# def resource_addr(v: Variable):
# if len(v.ancestors) == 1:
Expand Down

0 comments on commit 3cf2f2f

Please sign in to comment.