Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unrolled implementation for latency dense/conv layers #1014

Closed
wants to merge 15 commits into from
27 changes: 27 additions & 0 deletions hls4ml/backends/vivado/passes/convolution_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
typedef {weight_t.name} weight_t;
template<class x_T, class y_T>
using product = nnet::product::{product_type}<x_T, y_T>;
constexpr static auto unrolled_fn = {unrolled_fn_name};
}};\n"""

# Conv1D templates
Expand Down Expand Up @@ -93,13 +94,18 @@ def format(self, node):

conv_config = self.template.format(**params)

mult_config = node.get_attr('dense_config')
if mult_config is not None:
return mult_config + '\n' + conv_config

mult_params = self._default_config_params(node)
mult_params['n_in'] = node.get_attr('n_chan') * node.get_attr('filt_width')
mult_params['n_out'] = node.get_attr('n_filt')
mult_params['nzeros'] = node.get_weights('weight').nzeros
mult_params['product_type'] = get_backend('vivado').product_type(
node.get_input_variable().type.precision, node.get_weights('weight').type.precision
)
mult_params.setdefault('unrolled_fn_name', 'nullptr')
mult_config = self.mult_template.format(**mult_params)

return mult_config + '\n' + conv_config
Expand All @@ -118,6 +124,14 @@ def format(self, node):

return self.template.format(**params)

def match(self, node):
if node.get_attr('unrolled_codegen') is not None:
io_type = node.model.config.get_config_value("IOType")
if io_type == 'io_parallel':
# Unrolled impl use alternate entry point for
return False
return super().match(node)


class DepthwiseConv1DFunctionTemplate(Conv1DFunctionTemplate):
def __init__(self):
Expand Down Expand Up @@ -206,13 +220,18 @@ def format(self, node):

conv_config = self.template.format(**params)

mult_config = node.get_attr('dense_config')
if mult_config is not None:
return mult_config + '\n' + conv_config

mult_params = self._default_config_params(node)
mult_params['n_in'] = node.get_attr('n_chan') * node.get_attr('filt_height') * node.get_attr('filt_width')
mult_params['n_out'] = node.get_attr('n_filt')
mult_params['nzeros'] = node.get_weights('weight').nzeros
mult_params['product_type'] = get_backend('vivado').product_type(
node.get_input_variable().type.precision, node.get_weights('weight').type.precision
)
mult_params.setdefault('unrolled_fn_name', 'nullptr')
mult_config = self.mult_template.format(**mult_params)

return mult_config + '\n' + conv_config
Expand All @@ -231,6 +250,14 @@ def format(self, node):

return self.template.format(**params)

def match(self, node):
if node.get_attr('unrolled_codegen') is not None:
io_type = node.model.config.get_config_value("IOType")
if io_type == 'io_parallel':
# Unrolled impl use alternate entry point for
return False
return super().match(node)


class DepthwiseConv2DFunctionTemplate(Conv2DFunctionTemplate):
def __init__(self):
Expand Down
14 changes: 14 additions & 0 deletions hls4ml/backends/vivado/passes/core_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
typedef {index_t.name} index_t;
template<class x_T, class y_T>
using product = nnet::product::{product_type}<x_T, y_T>;
constexpr static auto unrolled_fn = {unrolled_fn_name};
}};\n"""

dense_function_template = 'nnet::dense<{input_t}, {output_t}, {config}>({input}, {output}, {w}, {b});'
Expand All @@ -34,12 +35,17 @@ def __init__(self):
self.template = dense_config_template

def format(self, node):
mult_config = node.get_attr('dense_config')
if mult_config is not None:
return mult_config

params = self._default_config_params(node)
params['nzeros'] = node.get_weights('weight').nzeros
params['nonzeros'] = node.get_weights('weight').nonzeros
params['product_type'] = get_backend('vivado').product_type(
node.get_input_variable().type.precision, node.get_weights('weight').type.precision
)
params.setdefault('unrolled_fn_name', 'nullptr')

return self.template.format(**params)

Expand All @@ -56,6 +62,14 @@ def format(self, node):

return self.template.format(**params)

def match(self, node):
if node.get_attr('unrolled_codegen') is not None:
io_type = node.model.config.get_config_value("IOType")
if io_type == 'io_parallel':
# Unrolled impl use alternate entry point for
return False
return super().match(node)


# BatchNormalization templates

Expand Down
111 changes: 4 additions & 107 deletions hls4ml/optimization/__init__.py
Original file line number Diff line number Diff line change
@@ -1,108 +1,5 @@
import numpy as np
from .dsp_aware_pruning import optimize_keras_model_for_hls4ml # noqa: F401
from .dsp_aware_pruning.attributes import get_attributes_from_keras_model_and_hls4ml_config # noqa: F401
from .dsp_aware_pruning.keras import optimize_model # noqa: F401

from hls4ml.optimization.attributes import get_attributes_from_keras_model_and_hls4ml_config
from hls4ml.optimization.keras import optimize_model

default_regularization_range = np.logspace(-6, -2, num=16).tolist()


def optimize_keras_model_for_hls4ml(
keras_model,
hls_config,
objective,
scheduler,
X_train,
y_train,
X_val,
y_val,
batch_size,
epochs,
optimizer,
loss_fn,
validation_metric,
increasing,
rtol,
callbacks=None,
ranking_metric='l1',
local=False,
verbose=False,
rewinding_epochs=1,
cutoff_bad_trials=3,
directory='hls4ml-optimization',
tuner='Bayesian',
knapsack_solver='CBC_MIP',
regularization_range=default_regularization_range,
):
'''
Top-level function for optimizing a Keras model, given hls4ml config and a hardware objective(s)

Args:
keras_model (keras.Model): Model to be optimized
hls_config (dict): hls4ml configuration, obtained from hls4ml.utils.config.config_from_keras_model(...)
objective (hls4ml.optimization.objectives.ObjectiveEstimator):
Parameter, hardware or user-defined objective of optimization
scheduler (hls4ml.optimization.scheduler.OptimizationScheduler):
Sparsity scheduler, choose between constant, polynomial and binary
X_train (np.array): Training inputs
y_train (np.array): Training labels
X_val (np.array): Validation inputs
y_val (np.array): Validation labels
batch_size (int): Batch size during training
epochs (int): Maximum number of epochs to fine-tune model, in one iteration of pruning
optimizer (keras.optimizers.Optimizer or equivalent-string description): Optimizer used during training
loss_fn (keras.losses.Loss or equivalent loss description): Loss function used during training
validation_metric (keras.metrics.Metric or equivalent loss description): Validation metric, used as a baseline
increasing (boolean): If the metric improves with increased values;
e.g. accuracy -> increasing = True, MSE -> increasing = False
rtol (float): Relative tolerance;
pruning stops when pruned_validation_metric < (or >) rtol * baseline_validation_metric
callbacks (list of keras.callbacks.Callback) Currently not supported, developed in future versions
ranking_metric (string): Metric used for ranking weights and structures;
currently supported l1, l2, saliency and Oracle
local (boolean): Layer-wise or global pruning
verbose (boolean): Display debug logs during model optimization
rewinding_epochs (int): Number of epochs to retrain model without weight freezing,
allows regrowth of previously pruned weights
cutoff_bad_trials (int): After how many bad trials (performance below threshold),
should model pruning / weight sharing stop
directory (string): Directory to store temporary results
tuner (str): Tuning algorithm, choose between Bayesian, Hyperband and None
knapsack_solver (str): Algorithm to solve Knapsack problem when optimizing;
default usually works well; for very large networks, greedy algorithm might be more suitable
regularization_range (list): List of suitable hyperparameters for weight decay

Returns:
keras.Model: Optimized model
'''

# Extract model attributes
model_attributes = get_attributes_from_keras_model_and_hls4ml_config(keras_model, hls_config)

# Optimize model
return optimize_model(
keras_model,
model_attributes,
objective,
scheduler,
X_train,
y_train,
X_val,
y_val,
batch_size,
epochs,
optimizer,
loss_fn,
validation_metric,
increasing,
rtol,
callbacks=callbacks,
ranking_metric=ranking_metric,
local=local,
verbose=verbose,
rewinding_epochs=rewinding_epochs,
cutoff_bad_trials=cutoff_bad_trials,
directory=directory,
tuner=tuner,
knapsack_solver=knapsack_solver,
regularization_range=regularization_range,
)
# from .fused_dotp.hook import use_unrolled # noqa: F401
108 changes: 108 additions & 0 deletions hls4ml/optimization/dsp_aware_pruning/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import numpy as np

from hls4ml.optimization.dsp_aware_pruning.attributes import get_attributes_from_keras_model_and_hls4ml_config
from hls4ml.optimization.dsp_aware_pruning.keras import optimize_model

default_regularization_range = np.logspace(-6, -2, num=16).tolist()


def optimize_keras_model_for_hls4ml(
keras_model,
hls_config,
objective,
scheduler,
X_train,
y_train,
X_val,
y_val,
batch_size,
epochs,
optimizer,
loss_fn,
validation_metric,
increasing,
rtol,
callbacks=None,
ranking_metric='l1',
local=False,
verbose=False,
rewinding_epochs=1,
cutoff_bad_trials=3,
directory='hls4ml-optimization',
tuner='Bayesian',
knapsack_solver='CBC_MIP',
regularization_range=default_regularization_range,
):
'''
Top-level function for optimizing a Keras model, given hls4ml config and a hardware objective(s)

Args:
keras_model (keras.Model): Model to be optimized
hls_config (dict): hls4ml configuration, obtained from hls4ml.utils.config.config_from_keras_model(...)
objective (hls4ml.optimization.objectives.ObjectiveEstimator):
Parameter, hardware or user-defined objective of optimization
scheduler (hls4ml.optimization.scheduler.OptimizationScheduler):
Sparsity scheduler, choose between constant, polynomial and binary
X_train (np.array): Training inputs
y_train (np.array): Training labels
X_val (np.array): Validation inputs
y_val (np.array): Validation labels
batch_size (int): Batch size during training
epochs (int): Maximum number of epochs to fine-tune model, in one iteration of pruning
optimizer (keras.optimizers.Optimizer or equivalent-string description): Optimizer used during training
loss_fn (keras.losses.Loss or equivalent loss description): Loss function used during training
validation_metric (keras.metrics.Metric or equivalent loss description): Validation metric, used as a baseline
increasing (boolean): If the metric improves with increased values;
e.g. accuracy -> increasing = True, MSE -> increasing = False
rtol (float): Relative tolerance;
pruning stops when pruned_validation_metric < (or >) rtol * baseline_validation_metric
callbacks (list of keras.callbacks.Callback) Currently not supported, developed in future versions
ranking_metric (string): Metric used for ranking weights and structures;
currently supported l1, l2, saliency and Oracle
local (boolean): Layer-wise or global pruning
verbose (boolean): Display debug logs during model optimization
rewinding_epochs (int): Number of epochs to retrain model without weight freezing,
allows regrowth of previously pruned weights
cutoff_bad_trials (int): After how many bad trials (performance below threshold),
should model pruning / weight sharing stop
directory (string): Directory to store temporary results
tuner (str): Tuning algorithm, choose between Bayesian, Hyperband and None
knapsack_solver (str): Algorithm to solve Knapsack problem when optimizing;
default usually works well; for very large networks, greedy algorithm might be more suitable
regularization_range (list): List of suitable hyperparameters for weight decay

Returns:
keras.Model: Optimized model
'''

# Extract model attributes
model_attributes = get_attributes_from_keras_model_and_hls4ml_config(keras_model, hls_config)

# Optimize model
return optimize_model(
keras_model,
model_attributes,
objective,
scheduler,
X_train,
y_train,
X_val,
y_val,
batch_size,
epochs,
optimizer,
loss_fn,
validation_metric,
increasing,
rtol,
callbacks=callbacks,
ranking_metric=ranking_metric,
local=local,
verbose=verbose,
rewinding_epochs=rewinding_epochs,
cutoff_bad_trials=cutoff_bad_trials,
directory=directory,
tuner=tuner,
knapsack_solver=knapsack_solver,
regularization_range=regularization_range,
)
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import hls4ml
from hls4ml.model.types import FixedPrecisionType, IntegerPrecisionType
from hls4ml.optimization.config import SUPPORTED_STRUCTURES
from hls4ml.optimization.keras.config import SUPPORTED_LAYERS
from hls4ml.optimization.dsp_aware_pruning.config import SUPPORTED_STRUCTURES
from hls4ml.optimization.dsp_aware_pruning.keras.config import SUPPORTED_LAYERS


class hls4mlAttributes:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,14 @@
import numpy as np
import tensorflow as tf

# Enables printing of loss tensors during custom training loop
from tensorflow.python.ops.numpy_ops import np_config

import hls4ml.optimization.keras.utils as utils
from hls4ml.optimization.config import SUPPORTED_STRUCTURES
from hls4ml.optimization.keras.builder import build_optimizable_model, remove_custom_regularizers
from hls4ml.optimization.keras.config import SUPPORTED_LAYERS, SUPPORTED_METRICS, TMP_DIRECTORY
from hls4ml.optimization.keras.masking import get_model_masks
from hls4ml.optimization.keras.reduction import reduce_model
from hls4ml.optimization.scheduler import OptimizationScheduler

np_config.enable_numpy_behavior()
import hls4ml.optimization.dsp_aware_pruning.keras.utils as utils
from hls4ml.optimization.dsp_aware_pruning.config import SUPPORTED_STRUCTURES
from hls4ml.optimization.dsp_aware_pruning.keras.builder import build_optimizable_model, remove_custom_regularizers
from hls4ml.optimization.dsp_aware_pruning.keras.config import SUPPORTED_LAYERS, SUPPORTED_METRICS, TMP_DIRECTORY
from hls4ml.optimization.dsp_aware_pruning.keras.masking import get_model_masks
from hls4ml.optimization.dsp_aware_pruning.keras.reduction import reduce_model
from hls4ml.optimization.dsp_aware_pruning.scheduler import OptimizationScheduler

default_regularization_range = np.logspace(-6, -2, num=16).tolist()


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.layers import Conv2D, Dense

from hls4ml.optimization.keras.config import SUPPORTED_LAYERS, TMP_DIRECTORY
from hls4ml.optimization.keras.regularizers import Conv2DRegularizer, DenseRegularizer
from hls4ml.optimization.dsp_aware_pruning.keras.config import SUPPORTED_LAYERS, TMP_DIRECTORY
from hls4ml.optimization.dsp_aware_pruning.keras.regularizers import Conv2DRegularizer, DenseRegularizer

co = {}
_add_supported_quantized_objects(co)
Expand Down
Loading
Loading