From 645576e77907b49ab994379e71667c739aeb5780 Mon Sep 17 00:00:00 2001 From: Chang Sun Date: Mon, 28 Oct 2024 17:33:02 +0000 Subject: [PATCH 01/18] isolate code_gen in namespace change --- example-models | 2 +- hls4ml/backends/fpga/fpga_backend.py | 4 +- hls4ml/backends/template.py | 1 + .../vivado/passes/convolution_templates.py | 46 +++++++++++-------- .../backends/vivado/passes/core_templates.py | 12 +++-- .../vivado/passes/recurrent_templates.py | 20 ++++---- hls4ml/writer/vivado_writer.py | 4 ++ 7 files changed, 53 insertions(+), 36 deletions(-) diff --git a/example-models b/example-models index 3cfbcfd062..ff74f73dbc 160000 --- a/example-models +++ b/example-models @@ -1 +1 @@ -Subproject commit 3cfbcfd062f60492507d21ff0e91559b3bdd6550 +Subproject commit ff74f73dbc253d1aa7de1603ee10ede551919548 diff --git a/hls4ml/backends/fpga/fpga_backend.py b/hls4ml/backends/fpga/fpga_backend.py index 5c85682354..f922e77b4c 100644 --- a/hls4ml/backends/fpga/fpga_backend.py +++ b/hls4ml/backends/fpga/fpga_backend.py @@ -727,7 +727,7 @@ def generate_conv1d_line_buffer_fn(self, layer_idx, n_partitions, in_W, in_C, ke generated_code = ( "template\n" - "class fill_buffer_{index} : public FillConv1DBuffer {{\n" + "class fill_buffer_{index} : public nnet::FillConv1DBuffer {{\n" " public:\n" " static void fill_buffer(\n" " data_T data[CONFIG_T::in_width * CONFIG_T::n_chan],\n" @@ -857,7 +857,7 @@ def generate_conv2d_line_buffer_fn( generated_code = ( "template\n" - "class fill_buffer_{index} : public FillConv2DBuffer {{\n" + "class fill_buffer_{index} : public nnet::FillConv2DBuffer {{\n" " public:\n" " static void fill_buffer(\n" " data_T data[CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_T::n_chan],\n" diff --git a/hls4ml/backends/template.py b/hls4ml/backends/template.py index f7f6fe313a..c3b61d33d7 100644 --- a/hls4ml/backends/template.py +++ b/hls4ml/backends/template.py @@ -62,6 +62,7 @@ def _default_config_params(self, layer): params = self._default_params(layer) params['iotype'] = layer.model.config.get_config_value('IOType') params['reuse'] = layer.get_attr('reuse_factor') + params['namespace'] = layer.model.config.get_writer_config().get('Namespace', 'nnet') return params diff --git a/hls4ml/backends/vivado/passes/convolution_templates.py b/hls4ml/backends/vivado/passes/convolution_templates.py index dd77bee85e..2c07620375 100644 --- a/hls4ml/backends/vivado/passes/convolution_templates.py +++ b/hls4ml/backends/vivado/passes/convolution_templates.py @@ -23,7 +23,7 @@ typedef {bias_t.name} bias_t; typedef {weight_t.name} weight_t; template - using kernel = nnet::{dense_function}; + using kernel = {dense_function}; template using product = nnet::product::{product_type}; }};\n""" @@ -53,7 +53,7 @@ static const unsigned n_partitions = {n_partitions}; static const unsigned n_pixels = out_width / n_partitions; template - using fill_buffer = nnet::{fill_fn}; + using fill_buffer = {fill_fn}; typedef {accum_t.name} accum_t; typedef {bias_t.name} bias_t; typedef {weight_t.name} weight_t; @@ -89,9 +89,10 @@ def format(self, node): params['scale_index_type'] = 'scale_index_regular' if node.model.config.get_config_value('IOType') == 'io_parallel': - params['fill_fn'] = f'fill_buffer_{node.index}' + namespace = params['namespace'] + params['fill_fn'] = f'{namespace}::fill_buffer_{node.index}' else: - params['fill_fn'] = 'FillConv1DBuffer' + params['fill_fn'] = 'nnet::FillConv1DBuffer' conv_config = self.template.format(**params) @@ -103,16 +104,18 @@ def format(self, node): node.get_input_variable().type.precision, node.get_weights('weight').type.precision ) + namespace = params['namespace'] + if node.get_attr('strategy').lower() == 'latency': - mult_params['dense_function'] = 'DenseLatency' + mult_params['dense_function'] = 'nnet::DenseLatency' elif node.get_attr('strategy').lower() == 'resource': if int(mult_params['reuse_factor']) <= int(mult_params['n_in']): - mult_params['dense_function'] = 'DenseResource_rf_leq_nin' + mult_params['dense_function'] = 'nnet::DenseResource_rf_leq_nin' else: - mult_params['dense_function'] = 'DenseResource_rf_gt_nin_rem0' + mult_params['dense_function'] = 'nnet::DenseResource_rf_gt_nin_rem0' # The 3rd case is never used elif node.get_attr('strategy').lower() == 'resource_unrolled': - mult_params['dense_function'] = f'dense_resource_unrolled_{node.index}' + mult_params['dense_function'] = f'{namespace}::dense_resource_unrolled_{node.index}' mult_config = self.mult_template.format(**mult_params) @@ -170,7 +173,7 @@ def __init__(self): static const unsigned n_partitions = {n_partitions}; static const unsigned n_pixels = out_height * out_width / n_partitions; template - using fill_buffer = nnet::{fill_fn}; + using fill_buffer = {fill_fn}; typedef {accum_t.name} accum_t; typedef {bias_t.name} bias_t; typedef {weight_t.name} weight_t; @@ -214,9 +217,10 @@ def format(self, node): params['scale_index_width_type'] = 'scale_index_regular' if node.model.config.get_config_value('IOType') == 'io_parallel': - params['fill_fn'] = f'fill_buffer_{node.index}' + namespace = params['namespace'] + params['fill_fn'] = f'{namespace}::fill_buffer_{node.index}' else: - params['fill_fn'] = 'FillConv2DBuffer' + params['fill_fn'] = 'nnet::FillConv2DBuffer' conv_config = self.template.format(**params) @@ -313,9 +317,10 @@ def format(self, node): params['weight_t'] = node.get_weights('depthwise').type params['bias_t'] = node.get_weights('zero_bias').type if node.model.config.get_config_value('IOType') == 'io_parallel': - params['fill_fn'] = f'fill_buffer_{node.index}_dw' + namespace = params['namespace'] + params['fill_fn'] = f'{namespace}::fill_buffer_{node.index}_dw' else: - params['fill_fn'] = 'FillConv1DBuffer' + params['fill_fn'] = 'nnet::FillConv1DBuffer' if node.get_attr('unscaled'): params['scale_index_type'] = 'scale_index_unscaled' @@ -359,9 +364,10 @@ def format(self, node): params['min_width'] = params['in_width'] params['instructions'] = '0' if node.model.config.get_config_value('IOType') == 'io_parallel': - params['fill_fn'] = f'fill_buffer_{node.index}_pw' + namespace = params['namespace'] + params['fill_fn'] = f'{namespace}::fill_buffer_{node.index}_pw' else: - params['fill_fn'] = 'FillConv1DBuffer' + params['fill_fn'] = 'nnet::FillConv1DBuffer' if node.get_attr('unscaled'): params['scale_index_type'] = 'scale_index_unscaled' @@ -446,9 +452,10 @@ def format(self, node): params['index'] = str(node.index) + '_depthwise' params['weight_t'] = node.get_weights('depthwise').type if node.model.config.get_config_value('IOType') == 'io_parallel': - params['fill_fn'] = f'fill_buffer_{node.index}_dw' + namespace = params['namespace'] + params['fill_fn'] = f'{namespace}::fill_buffer_{node.index}_dw' else: - params['fill_fn'] = 'FillConv2DBuffer' + params['fill_fn'] = 'nnet::FillConv2DBuffer' if node.get_attr('unscaled_h'): params['scale_index_height_type'] = 'scale_index_unscaled' @@ -500,9 +507,10 @@ def format(self, node): params['min_width'] = params['in_width'] params['instructions'] = '0' if node.model.config.get_config_value('IOType') == 'io_parallel': - params['fill_fn'] = f'fill_buffer_{node.index}_pw' + namespace = params['namespace'] + params['fill_fn'] = f'{namespace}::fill_buffer_{node.index}_pw' else: - params['fill_fn'] = 'FillConv2DBuffer' + params['fill_fn'] = 'nnet::FillConv2DBuffer' if node.get_attr('unscaled_h'): params['scale_index_height_type'] = 'scale_index_unscaled' diff --git a/hls4ml/backends/vivado/passes/core_templates.py b/hls4ml/backends/vivado/passes/core_templates.py index 836da6e68a..1393cdfb49 100644 --- a/hls4ml/backends/vivado/passes/core_templates.py +++ b/hls4ml/backends/vivado/passes/core_templates.py @@ -20,7 +20,7 @@ typedef {weight_t.name} weight_t; typedef {index_t.name} index_t; template - using kernel = nnet::{dense_function}; + using kernel = {dense_function}; template using product = nnet::product::{product_type}; }};\n""" @@ -43,16 +43,18 @@ def format(self, node): node.get_input_variable().type.precision, node.get_weights('weight').type.precision ) + namespace = params['namespace'] + if node.get_attr('strategy').lower() == 'latency': - params['dense_function'] = 'DenseLatency' + params['dense_function'] = 'nnet::DenseLatency' elif node.get_attr('strategy').lower() == 'resource': if int(params['reuse_factor']) <= int(params['n_in']): - params['dense_function'] = 'DenseResource_rf_leq_nin' + params['dense_function'] = 'nnet::DenseResource_rf_leq_nin' else: - params['dense_function'] = 'DenseResource_rf_gt_nin_rem0' + params['dense_function'] = 'nnet::DenseResource_rf_gt_nin_rem0' # The 3rd case is never used elif node.get_attr('strategy').lower() == 'resource_unrolled': - params['dense_function'] = f'dense_resource_unrolled_{node.index}' + params['dense_function'] = f'{namespace}::dense_resource_unrolled_{node.index}' return self.template.format(**params) diff --git a/hls4ml/backends/vivado/passes/recurrent_templates.py b/hls4ml/backends/vivado/passes/recurrent_templates.py index 939713af22..91a6c7ef8d 100644 --- a/hls4ml/backends/vivado/passes/recurrent_templates.py +++ b/hls4ml/backends/vivado/passes/recurrent_templates.py @@ -17,7 +17,7 @@ typedef {bias_t.name} bias_t; typedef {weight_t.name} weight_t; template - using kernel = nnet::{dense_function}; + using kernel = {dense_function}; template using product = nnet::product::{product_type}; }};\n""" @@ -141,16 +141,18 @@ def format(self, node): mult_params1['nzeros'] = node.get_weights('weight').nzeros mult_params1['nonzeros'] = node.get_weights('weight').nonzeros + namespace = params['namespace'] + if node.get_attr('strategy').lower() == 'latency': - mult_params1['dense_function'] = 'DenseLatency' + mult_params1['dense_function'] = 'nnet::DenseLatency' elif node.get_attr('strategy').lower() == 'resource': if int(mult_params1['reuse_factor']) <= int(mult_params1['n_in']): - mult_params1['dense_function'] = 'DenseResource_rf_leq_nin' + mult_params1['dense_function'] = 'nnet::DenseResource_rf_leq_nin' else: - mult_params1['dense_function'] = 'DenseResource_rf_gt_nin_rem0' + mult_params1['dense_function'] = 'nnet::DenseResource_rf_gt_nin_rem0' # The 3rd case is never used elif node.get_attr('strategy').lower() == 'resource_unrolled': - mult_params1['dense_function'] = f'dense_resource_unrolled_{node.index}_1' + mult_params1['dense_function'] = f'{namespace}::dense_resource_unrolled_{node.index}_1' if node.get_attr('return_sequences'): mult_params2['n_in'] = node.get_output_variable().shape[1] @@ -167,15 +169,15 @@ def format(self, node): mult_params2['nonzeros'] = node.get_weights('recurrent_weight').nonzeros if node.get_attr('strategy').lower() == 'latency': - mult_params2['dense_function'] = 'DenseLatency' + mult_params2['dense_function'] = 'nnet::DenseLatency' elif node.get_attr('strategy').lower() == 'resource': if int(mult_params2['reuse_factor']) <= int(mult_params2['n_in']): - mult_params2['dense_function'] = 'DenseResource_rf_leq_nin' + mult_params2['dense_function'] = 'nnet::DenseResource_rf_leq_nin' else: - mult_params2['dense_function'] = 'DenseResource_rf_gt_nin_rem0' + mult_params2['dense_function'] = 'nnet::DenseResource_rf_gt_nin_rem0' # The 3rd case is never used elif node.get_attr('strategy').lower() == 'resource_unrolled': - mult_params2['dense_function'] = f'dense_resource_unrolled_{node.index}_2' + mult_params2['dense_function'] = f'{namespace}::dense_resource_unrolled_{node.index}_2' mult_config1 = self.mult1_template.format(**mult_params1) mult_config2 = self.mult2_template.format(**mult_params2) diff --git a/hls4ml/writer/vivado_writer.py b/hls4ml/writer/vivado_writer.py index 5ab13736ec..275e73813e 100644 --- a/hls4ml/writer/vivado_writer.py +++ b/hls4ml/writer/vivado_writer.py @@ -790,6 +790,7 @@ def write_generated_code(self, model): contents = f.readlines() f.close() f = open(path, 'w') + namespace = model.config.get_writer_config().get('Namespace', None) for line in contents: if '// hls4ml insert code' in line: @@ -799,6 +800,9 @@ def write_generated_code(self, model): newline += str(generated_code) else: newline = line + if namespace is not None: + if 'namespace nnet' in newline: + newline = newline.replace('namespace nnet', f'namespace {namespace}') f.write(newline) f.close() From 193c7f52d42c786856acf51ba6afca8712356724 Mon Sep 17 00:00:00 2001 From: Chang Sun Date: Wed, 30 Oct 2024 01:36:10 +0000 Subject: [PATCH 02/18] leftover --- hls4ml/backends/vivado/passes/convolution_templates.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/hls4ml/backends/vivado/passes/convolution_templates.py b/hls4ml/backends/vivado/passes/convolution_templates.py index 2c07620375..97652032d5 100644 --- a/hls4ml/backends/vivado/passes/convolution_templates.py +++ b/hls4ml/backends/vivado/passes/convolution_templates.py @@ -232,16 +232,17 @@ def format(self, node): node.get_input_variable().type.precision, node.get_weights('weight').type.precision ) + namespace = params['namespace'] if node.get_attr('strategy').lower() == 'latency': - mult_params['dense_function'] = 'DenseLatency' + mult_params['dense_function'] = 'nnet::DenseLatency' elif node.get_attr('strategy').lower() == 'resource': if int(mult_params['reuse_factor']) <= int(mult_params['n_in']): - mult_params['dense_function'] = 'DenseResource_rf_leq_nin' + mult_params['dense_function'] = 'nnet::DenseResource_rf_leq_nin' else: - mult_params['dense_function'] = 'DenseResource_rf_gt_nin_rem0' + mult_params['dense_function'] = 'nnet::DenseResource_rf_gt_nin_rem0' # The 3rd case is never used elif node.get_attr('strategy').lower() == 'resource_unrolled': - mult_params['dense_function'] = f'dense_resource_unrolled_{node.index}' + mult_params['dense_function'] = f'{namespace}::dense_resource_unrolled_{node.index}' mult_config = self.mult_template.format(**mult_params) From 958840073d47edae3c9515e10043f9ba02451e85 Mon Sep 17 00:00:00 2001 From: Chang Sun Date: Wed, 30 Oct 2024 05:47:47 +0000 Subject: [PATCH 03/18] leftover for hgq --- hls4ml/backends/fpga/passes/hgq_proxy_model.py | 4 +++- hls4ml/backends/template.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/hls4ml/backends/fpga/passes/hgq_proxy_model.py b/hls4ml/backends/fpga/passes/hgq_proxy_model.py index 5ec1200ac7..81f980ef50 100644 --- a/hls4ml/backends/fpga/passes/hgq_proxy_model.py +++ b/hls4ml/backends/fpga/passes/hgq_proxy_model.py @@ -75,10 +75,12 @@ def transform(self, model, node: FixedPointQuantizer): class ProcessFixedPointQuantizerCall(FunctionCallTemplate): def __init__(self): super().__init__(FixedPointQuantizer, include_header=[]) - self.template = 'nnet::{name}<{input_t}, {output_t}>({input}, {output});' + self.template = '{namespace}::{name}<{input_t}, {output_t}>({input}, {output});' def format(self, node): params = self._default_function_params(node) + namespace = node.model.config.get_writer_config()['Namespace'] or 'nnet' + params['namespace'] = namespace return self.template.format(**params) diff --git a/hls4ml/backends/template.py b/hls4ml/backends/template.py index c3b61d33d7..228077dbbb 100644 --- a/hls4ml/backends/template.py +++ b/hls4ml/backends/template.py @@ -62,7 +62,7 @@ def _default_config_params(self, layer): params = self._default_params(layer) params['iotype'] = layer.model.config.get_config_value('IOType') params['reuse'] = layer.get_attr('reuse_factor') - params['namespace'] = layer.model.config.get_writer_config().get('Namespace', 'nnet') + params['namespace'] = layer.model.config.get_writer_config()['Namespace'] or 'nnet' return params From 6b23f0443c2b3dd722264da66b6892a788739e34 Mon Sep 17 00:00:00 2001 From: Chang Sun Date: Wed, 30 Oct 2024 05:50:21 +0000 Subject: [PATCH 04/18] preliminary jit compiler support --- hls4ml/converters/__init__.py | 7 ++ hls4ml/model/graph.py | 105 ++++++++++++++++- .../templates/catapult/firmware/myproject.cpp | 2 + hls4ml/templates/catapult/myproject_test.cpp | 2 + .../templates/vivado/firmware/myproject.cpp | 11 +- hls4ml/writer/catapult_writer.py | 1 + hls4ml/writer/oneapi_writer.py | 1 + hls4ml/writer/quartus_writer.py | 1 + hls4ml/writer/symbolic_writer.py | 1 + hls4ml/writer/vivado_writer.py | 10 +- hls4ml/writer/writers.py | 107 ++++++++++++++++++ 11 files changed, 233 insertions(+), 15 deletions(-) diff --git a/hls4ml/converters/__init__.py b/hls4ml/converters/__init__.py index 092e53b3d3..7ba2edda38 100644 --- a/hls4ml/converters/__init__.py +++ b/hls4ml/converters/__init__.py @@ -220,6 +220,13 @@ def convert_from_keras_model( ModelGraph: hls4ml model. """ + if os.environ.get('HLS4ML_USE_JIT', '0') == '1': + import random + import string + + rand_namespace = ''.join(random.choices(string.ascii_lowercase + string.ascii_uppercase, k=16)) + kwargs.setdefault('namespace', rand_namespace) + config = create_config(output_dir=output_dir, project_name=project_name, backend=backend, **kwargs) config['KerasModel'] = model diff --git a/hls4ml/model/graph.py b/hls4ml/model/graph.py index 709c3db3ff..1d68d49c9f 100644 --- a/hls4ml/model/graph.py +++ b/hls4ml/model/graph.py @@ -1,7 +1,10 @@ import ctypes import os import platform +import sys from collections import OrderedDict +from pathlib import Path +from typing import Sequence import numpy as np import numpy.ctypeslib as npc @@ -347,6 +350,8 @@ def __init__(self, config, layer_list, inputs=None, outputs=None): for flow in self.config.flows: self.apply_flow(flow) + self._is_jitted = False + def _find_output_variable_names(self, layer_list, layer_names): """Given a list of all layers, and a list input/output names, find the names of their outputs that will be used as the name of the output variables.""" @@ -677,15 +682,53 @@ def write(self): self.config.backend.write(self) - def compile(self): + def compile(self, jit=None): """Compile the generated project and link the library into current environment. Users should call this function if they want to use `predict` functionality for simulation. """ self.write() - self._compile() + self._compile(jit=jit) + + def _compile(self, jit=None): + if jit is None: + jit = os.environ.get('HLS4ML_USE_JIT', '0') == '1' + if jit: + self.jit_compile() + else: + self.compile_shared_lib() + + def jit_compile(self): + if self._is_jitted: + print('INFO: JIT Compilation does not support recompilation at the moment, skipping', file=sys.stderr) + return + import cppyy + + prj_path = Path(self.config.config['OutputDir']).absolute() + prj_name = self.config.config['ProjectName'] + prj_top_path = prj_path / f'firmware/{prj_name}.cpp' + jit_bridge_path = prj_path / f'{prj_name}_jit_bridge.cpp' + + cppyy.cppdef("#define HLS4ML_EXTERNAL_WEIGHT_LOAD") + cppyy.add_include_path(str(prj_path)) + cppyy.add_include_path(str(prj_path / 'firmware/ap_types')) + cppyy.include(str(prj_top_path)) + cppyy.include(str(jit_bridge_path)) - def _compile(self): + cpp_namespace = getattr(cppyy.gbl, self.config.writer_config['Namespace']) + + # Load weights from to c arrays, if not initialized in headers + if self.config.config['WriterConfig']['WriteWeightsTxt']: + for w in self.get_weight_variables(): + cpp_w = getattr(cpp_namespace, w.name, None) + if cpp_w is None: + continue + fill_fn = cpp_namespace.fill_weight[w.type.name] + fill_fn(cpp_w, w.data.ravel().astype(np.float32)) + + self._is_jitted = True + + def compile_shared_lib(self): lib_name = self.config.backend.compile(self) if self._top_function_lib is not None: if platform.system() == "Linux": @@ -757,7 +800,61 @@ def _compute_n_samples(self, x): return int(n_sample) - def predict(self, x): + def predict(self, x, jit=None) -> np.ndarray | list[np.ndarray]: + """Predict with the cpp-compiled model. + The exact prediction method depends on the backend used to compile the model: + + - Vivado/Vitis/Quartus/Catapult: The model can be compiled with g++/clang++, or cling (cppyy). This function will call the compiled model directly, depending on which backend was used. + + - OneAPI: The model must be compiled with icpx. This function will call the compiled model directly. + """ # noqa: E501 + + if jit is None: + jit = self._is_jitted + + if jit: + return self.jit_predict(x) + else: + return self.ctypes_predict(x) + + def jit_predict(self, x: np.ndarray | Sequence[np.ndarray]): + import cppyy + + assert cppyy.gbl, 'cppyy not initialized, please call model.jit_compile() or model.compile(jit=True) first' + namespace: str | None = self.config.writer_config['Namespace'] # type: ignore + assert namespace is not None, 'Namespace must be set in the writer config for jit_predict to work' + assert hasattr( + cppyy.gbl, namespace + ), f'Namespace {namespace} not found in cppyy.gbl. Did you call model.jit_compile()?' + cpp_namespace = getattr(cppyy.gbl, namespace) + top_fn_template = cpp_namespace.batch_inference + + output_shapes = [out.shape for out in self.get_output_variables()] + n_inputs = len(self.get_input_variables()) + n_outputs = len(output_shapes) + + dtype = np.float32 + cpp_dtype = cppyy.gbl.float + if getattr(x[0], 'dtype', None) != np.float32: + dtype = np.float64 + cpp_dtype = cppyy.gbl.double + + top_fn = top_fn_template[cpp_dtype] + + if n_inputs > 1: + assert len(x) == n_inputs, f'Expected {n_inputs} inputs, got {len(x)}.' + args = (np.ascontiguousarray(xi, dtype=dtype) for xi in x) + ret = top_fn(*args) + + else: + ret = top_fn(np.ascontiguousarray(x, dtype=dtype)) + + if n_outputs == 1: + return np.array(ret[0]).reshape(-1, *output_shapes[0]) + else: + return [np.array(r).reshape(-1, *s) for r, s in zip(ret, output_shapes)] + + def ctypes_predict(self, x): top_function, ctype = self._get_top_function(x) n_samples = self._compute_n_samples(x) n_inputs = len(self.get_input_variables()) diff --git a/hls4ml/templates/catapult/firmware/myproject.cpp b/hls4ml/templates/catapult/firmware/myproject.cpp index bdb0570f8b..1912af31a8 100755 --- a/hls4ml/templates/catapult/firmware/myproject.cpp +++ b/hls4ml/templates/catapult/firmware/myproject.cpp @@ -13,12 +13,14 @@ void CCS_BLOCK(myproject)( // hls-fpga-machine-learning insert IO +#ifndef HLS4ML_EXTERNAL_WEIGHT_LOAD #ifndef __SYNTHESIS__ static bool loaded_weights = false; if (!loaded_weights) { // hls-fpga-machine-learning insert load weights loaded_weights = true; } +#endif #endif // **************************************** diff --git a/hls4ml/templates/catapult/myproject_test.cpp b/hls4ml/templates/catapult/myproject_test.cpp index 66b87f6741..d99a3b0dd9 100755 --- a/hls4ml/templates/catapult/myproject_test.cpp +++ b/hls4ml/templates/catapult/myproject_test.cpp @@ -79,12 +79,14 @@ CCS_MAIN(int argc, char *argv[]) { #endif std::ofstream fout(RESULTS_LOG); +#ifndef HLS4ML_EXTERNAL_WEIGHT_LOAD #ifndef __SYNTHESIS__ static bool loaded_weights = false; if (!loaded_weights) { // hls-fpga-machine-learning insert load weights loaded_weights = true; } +#endif #endif std::string iline; std::string pline; diff --git a/hls4ml/templates/vivado/firmware/myproject.cpp b/hls4ml/templates/vivado/firmware/myproject.cpp index 5ba7f118ba..13adc62129 100644 --- a/hls4ml/templates/vivado/firmware/myproject.cpp +++ b/hls4ml/templates/vivado/firmware/myproject.cpp @@ -11,8 +11,15 @@ void myproject( // hls-fpga-machine-learning insert IO - // hls-fpga-machine-learning insert load weights - +#ifndef HLS4ML_EXTERNAL_WEIGHT_LOAD +#ifndef __SYNTHESIS__ + static bool loaded_weights = false; + if (!loaded_weights) { + // hls-fpga-machine-learning insert load weights + loaded_weights = true; + } +#endif +#endif // **************************************** // NETWORK INSTANTIATION // **************************************** diff --git a/hls4ml/writer/catapult_writer.py b/hls4ml/writer/catapult_writer.py index 7db1063206..c634fd565a 100755 --- a/hls4ml/writer/catapult_writer.py +++ b/hls4ml/writer/catapult_writer.py @@ -921,6 +921,7 @@ def write_hls(self, model): self.write_parameters(model) self.write_test_bench(model) self.write_bridge(model) + self.write_jit_bridge(model) self.write_build_script(model) self.write_nnet_utils(model) self.write_generated_code(model) diff --git a/hls4ml/writer/oneapi_writer.py b/hls4ml/writer/oneapi_writer.py index fe633214f6..b2c61d85d8 100644 --- a/hls4ml/writer/oneapi_writer.py +++ b/hls4ml/writer/oneapi_writer.py @@ -961,6 +961,7 @@ def write_hls(self, model): self.write_parameters(model) self.write_test_bench(model) self.write_bridge(model) + self.write_jit_bridge(model) self.write_build_script(model) self.write_nnet_utils(model) self.write_activation_tables(model) diff --git a/hls4ml/writer/quartus_writer.py b/hls4ml/writer/quartus_writer.py index 932a8b6a6d..faf826c3d0 100644 --- a/hls4ml/writer/quartus_writer.py +++ b/hls4ml/writer/quartus_writer.py @@ -1356,6 +1356,7 @@ def write_hls(self, model): self.write_parameters(model) self.write_test_bench(model) self.write_bridge(model) + self.write_jit_bridge(model) self.write_build_script(model) self.write_nnet_utils(model) self.write_activation_tables(model) diff --git a/hls4ml/writer/symbolic_writer.py b/hls4ml/writer/symbolic_writer.py index 76d56b1533..b535f0f934 100644 --- a/hls4ml/writer/symbolic_writer.py +++ b/hls4ml/writer/symbolic_writer.py @@ -111,6 +111,7 @@ def write_hls(self, model): self.write_parameters(model) self.write_test_bench(model) self.write_bridge(model) + self.write_jit_bridge(model) self.write_build_script(model) self.write_nnet_utils(model) self.write_generated_code(model) diff --git a/hls4ml/writer/vivado_writer.py b/hls4ml/writer/vivado_writer.py index 275e73813e..f450e96040 100644 --- a/hls4ml/writer/vivado_writer.py +++ b/hls4ml/writer/vivado_writer.py @@ -165,11 +165,6 @@ def write_project_cpp(self, model): elif '// hls-fpga-machine-learning insert load weights' in line: newline = line if model.config.get_writer_config()['WriteWeightsTxt']: - - newline += '#ifndef __SYNTHESIS__\n' - newline += ' static bool loaded_weights = false;\n' - newline += ' if (!loaded_weights) {\n' - for layer in model.get_layers(): for w in layer.get_weights(): if w.weight_class == 'CompressedWeightVariable': @@ -191,10 +186,6 @@ def write_project_cpp(self, model): w.type.name, w.data_length, w.name, w.name ) - newline += ' loaded_weights = true;' - newline += ' }\n' - newline += '#endif' - # Add input/output type elif '// hls-fpga-machine-learning insert IO' in line: newline = line @@ -853,6 +844,7 @@ def write_hls(self, model): self.write_parameters(model) self.write_test_bench(model) self.write_bridge(model) + self.write_jit_bridge(model) self.write_build_script(model) self.write_nnet_utils(model) self.write_generated_code(model) diff --git a/hls4ml/writer/writers.py b/hls4ml/writer/writers.py index 54caec1d11..2888372ac6 100644 --- a/hls4ml/writer/writers.py +++ b/hls4ml/writer/writers.py @@ -1,3 +1,89 @@ +import typing +from math import prod +from pathlib import Path + +if typing.TYPE_CHECKING: + from hls4ml.model import ModelGraph + + +def create_jit_bridge_fn(model: 'ModelGraph'): + inp_vars = model.get_input_variables() + out_vars = model.get_output_variables() + + inp_shapes = [tuple(v.shape) for v in inp_vars] + out_shapes = [tuple(v.shape) for v in out_vars] + + inp_names = [v.name for v in inp_vars] + out_names = [v.name for v in out_vars] + + inp_sizes = [prod(s) for s in inp_shapes] + out_sizes = [prod(s) for s in out_shapes] + + n_out = len(out_names) + + input_def = '\n '.join(f'std::vector {v.name}, ' for v in inp_vars)[:-2] + + inp_size_def = '\n '.join(f'constexpr size_t {n}_size = {s};' for n, s in zip(inp_names, inp_sizes)) + out_size_def = '\n '.join(f'constexpr size_t {n}_size = {s};' for n, s in zip(out_names, out_sizes)) + + ptr_buf_def = '\n '.join(f'T* {v.name}_ptr = {v.name}.data();' for v in inp_vars + out_vars) + n_samples_def = f'{inp_vars[0].name}.size() / {inp_vars[0].name}_size' + + assertions_def = ' ||\n '.join(f'({n}.size() != {n}_size * n_samples)' for n in inp_names) + + inp_args_def_list = [f'{n}_ptr + i * {n}_size,' for n in inp_names] + out_args_def_list = [f'{n}_ptr + i * {n}_size,' for n in out_names] + args_def = ('\n' + ' ' * 12).join(inp_args_def_list + out_args_def_list)[:-1] + + out_var_def = '\n '.join(f'std::vector {v.name}({v.name}_size * n_samples);' for v in out_vars) + + _ret_template_arg = ('std::vector, ' * n_out)[:-2] + _ret_tuple_arg = ', '.join(out_names) + return_def = f'std::tuple<{_ret_template_arg}>({_ret_tuple_arg})' + + cpp_fn = f""" + template + auto batch_inference( + {input_def} + ){{ + {inp_size_def} + {out_size_def} + + size_t n_samples = {n_samples_def}; + + if ( + {assertions_def} + ) + throw std::runtime_error("Invalid input sizes: number of samples or input sizes do not match"); + + {out_var_def} + + {ptr_buf_def} + + for (int i = 0; i < n_samples; i++) {{ + myproject_float( + {args_def} + ); + }} + + return {return_def}; + }} +""" + return cpp_fn + + +def create_jit_weight_filler(model: 'ModelGraph'): + filler_fn = """ + template + void fill_weight(T weight[], std::vector vec) {{ + for (size_t i = 0; i < vec.size(); i++) {{ + weight[i] = vec[i]; + }} + }} +""" + return filler_fn + + class Writer: def __init__(self): pass @@ -5,6 +91,27 @@ def __init__(self): def write_hls(self, model): raise NotImplementedError + def write_jit_bridge(self, model: 'ModelGraph'): + """Write the Python-C++ bridge for JIT compilation (myproject_bridge_jit.cpp)""" + + prj_name = model.config.get_project_name() + prj_path = Path(model.config.get_output_dir()) + path_c_bridge = prj_path / f'{prj_name}_bridge.cpp' + path_jit_bridge = prj_path / f'{prj_name}_jit_bridge.cpp' + + namespace = model.config.get_writer_config()['Namespace'] or 'nnet' + + cpp_source_bridge = path_c_bridge.read_text() + + jit_bridge_fn = create_jit_bridge_fn(model) + weight_filler_fn = create_jit_weight_filler(model) + + _plugin_code = f'namespace {namespace} {{\n' + jit_bridge_fn + '\n\n' + weight_filler_fn + '\n' + + cpp_source_bridge = cpp_source_bridge.replace('extern "C" {', _plugin_code) + + path_jit_bridge.write_text(cpp_source_bridge) + writer_map = {} From 5be84ced0e3cca74ec471cf4b21100b2d713f6cf Mon Sep 17 00:00:00 2001 From: Chang Sun Date: Wed, 30 Oct 2024 17:14:10 +0000 Subject: [PATCH 05/18] isolate cppyy in subprocess --- hls4ml/model/graph.py | 162 ++++++++++++++++++++++++++++++------------ 1 file changed, 115 insertions(+), 47 deletions(-) diff --git a/hls4ml/model/graph.py b/hls4ml/model/graph.py index 1d68d49c9f..771a479342 100644 --- a/hls4ml/model/graph.py +++ b/hls4ml/model/graph.py @@ -1,8 +1,10 @@ import ctypes +import gc import os import platform import sys from collections import OrderedDict +from multiprocessing import Pool, shared_memory from pathlib import Path from typing import Sequence @@ -16,6 +18,64 @@ from hls4ml.utils.string_utils import convert_to_snake_case +def compile_worker(config, var_types: Sequence[str], var_names: Sequence[str], shm_names: Sequence[str]): + import cppyy + import numpy as np + + prj_path = Path(config['OutputDir']).absolute() + prj_name = config['ProjectName'] + prj_top_path = prj_path / f'firmware/{prj_name}.cpp' + jit_bridge_path = prj_path / f'{prj_name}_jit_bridge.cpp' + + cppyy.cppdef("#define HLS4ML_EXTERNAL_WEIGHT_LOAD") + cppyy.add_include_path(str(prj_path)) + if (prj_path / 'firmware/ap_types').exists(): + cppyy.add_include_path(str(prj_path / 'firmware/ap_types')) + if (prj_path / 'firmware/ac_types').exists(): + cppyy.add_include_path(str(prj_path / 'firmware/ac_types')) + cppyy.include(str(prj_top_path)) + cppyy.include(str(jit_bridge_path)) + + cpp_namespace = getattr(cppyy.gbl, config['Namespace']) + + # Load weights from to c arrays, if not initialized in headers + if config['WriteWeightsTxt']: + for w_type, w_name, shm_name in zip(var_types, var_names, shm_names): + shm = shared_memory.SharedMemory(name=shm_name, create=False) + wval = np.ndarray((shm.size // np.dtype(np.float32).itemsize,), dtype=np.float32, buffer=shm.buf) + cpp_w = getattr(cpp_namespace, w_name, None) + if cpp_w is None: + continue + fill_fn = cpp_namespace.fill_weight[w_type] + fill_fn(cpp_w, wval.ravel().astype(np.float32)) + shm.close() + shm.unlink() + import gc + + gc.collect() + + +def inference_worker(namespace: str, *args: list[np.ndarray]): + import cppyy + + cpp_namespace = getattr(cppyy.gbl, namespace) + top_fn_template = cpp_namespace.batch_inference + + cpp_dtype = cppyy.gbl.float + if getattr(args[0], 'dtype', None) != np.float32: + cpp_dtype = cppyy.gbl.double + + top_fn = top_fn_template[cpp_dtype] + + ret = top_fn(*args) + + import gc + + gc.collect() + + return [np.array(r) for r in ret] + + class HLSConfig: """The configuration class as stored in the ModelGraph. @@ -350,7 +410,7 @@ def __init__(self, config, layer_list, inputs=None, outputs=None): for flow in self.config.flows: self.apply_flow(flow) - self._is_jitted = False + self._jit_process = None def _find_output_variable_names(self, layer_list, layer_names): """Given a list of all layers, and a list input/output names, find the names of their outputs that will be used @@ -693,40 +753,46 @@ def compile(self, jit=None): def _compile(self, jit=None): if jit is None: jit = os.environ.get('HLS4ML_USE_JIT', '0') == '1' + + if self.config.config['Backend'].lower() == 'oneapi': + print('INFO: OneAPI backend does not support JIT compilation, using shared library compilation', file=sys.stderr) + jit = False + if jit: self.jit_compile() else: self.compile_shared_lib() def jit_compile(self): - if self._is_jitted: - print('INFO: JIT Compilation does not support recompilation at the moment, skipping', file=sys.stderr) - return - import cppyy - - prj_path = Path(self.config.config['OutputDir']).absolute() - prj_name = self.config.config['ProjectName'] - prj_top_path = prj_path / f'firmware/{prj_name}.cpp' - jit_bridge_path = prj_path / f'{prj_name}_jit_bridge.cpp' - - cppyy.cppdef("#define HLS4ML_EXTERNAL_WEIGHT_LOAD") - cppyy.add_include_path(str(prj_path)) - cppyy.add_include_path(str(prj_path / 'firmware/ap_types')) - cppyy.include(str(prj_top_path)) - cppyy.include(str(jit_bridge_path)) - - cpp_namespace = getattr(cppyy.gbl, self.config.writer_config['Namespace']) - - # Load weights from to c arrays, if not initialized in headers - if self.config.config['WriterConfig']['WriteWeightsTxt']: - for w in self.get_weight_variables(): - cpp_w = getattr(cpp_namespace, w.name, None) - if cpp_w is None: - continue - fill_fn = cpp_namespace.fill_weight[w.type.name] - fill_fn(cpp_w, w.data.ravel().astype(np.float32)) - - self._is_jitted = True + if self._jit_process is not None: + self._jit_process.close() + self._jit_process.join() + self._jit_process = None + gc.collect() + + self._jit_process = Pool(1) + variables = self.get_weight_variables() + var_types = [v.type.name for v in variables] + var_names = [v.name for v in variables] + var_values = [v.data.astype(np.float32) for v in variables] + + shm_names = [] + for wval in var_values: + shm = shared_memory.SharedMemory(create=True, size=wval.nbytes) + np.ndarray(wval.size, dtype=wval.dtype, buffer=shm.buf)[:] = wval.ravel() + shm_names.append(shm.name) + shm.close() + + config = { + 'OutputDir': self.config.config['OutputDir'], + 'ProjectName': self.config.config['ProjectName'], + 'Namespace': self.config.config['WriterConfig']['Namespace'], + 'WriteWeightsTxt': self.config.config['WriterConfig'].get('WriteWeightsTxt', False), + } + + self._jit_process.apply(compile_worker, (config, var_types, var_names, shm_names)) + + gc.collect() def compile_shared_lib(self): lib_name = self.config.backend.compile(self) @@ -810,7 +876,7 @@ def predict(self, x, jit=None) -> np.ndarray | list[np.ndarray]: """ # noqa: E501 if jit is None: - jit = self._is_jitted + jit = self._jit_process is not None if jit: return self.jit_predict(x) @@ -818,36 +884,31 @@ def predict(self, x, jit=None) -> np.ndarray | list[np.ndarray]: return self.ctypes_predict(x) def jit_predict(self, x: np.ndarray | Sequence[np.ndarray]): - import cppyy - assert cppyy.gbl, 'cppyy not initialized, please call model.jit_compile() or model.compile(jit=True) first' + assert self._jit_process is not None, 'Model not jit compiled' + # assert cppyy.gbl, 'cppyy not initialized, please call model.jit_compile() or model.compile(jit=True) first' namespace: str | None = self.config.writer_config['Namespace'] # type: ignore assert namespace is not None, 'Namespace must be set in the writer config for jit_predict to work' - assert hasattr( - cppyy.gbl, namespace - ), f'Namespace {namespace} not found in cppyy.gbl. Did you call model.jit_compile()?' - cpp_namespace = getattr(cppyy.gbl, namespace) - top_fn_template = cpp_namespace.batch_inference + # assert hasattr( + # cppyy.gbl, namespace + # ), f'Namespace {namespace} not found in cppyy.gbl. Did you call model.jit_compile()?' output_shapes = [out.shape for out in self.get_output_variables()] n_inputs = len(self.get_input_variables()) n_outputs = len(output_shapes) dtype = np.float32 - cpp_dtype = cppyy.gbl.float if getattr(x[0], 'dtype', None) != np.float32: dtype = np.float64 - cpp_dtype = cppyy.gbl.double - top_fn = top_fn_template[cpp_dtype] + if n_inputs == 1: + if isinstance(x, np.ndarray): + x = (x,) - if n_inputs > 1: - assert len(x) == n_inputs, f'Expected {n_inputs} inputs, got {len(x)}.' - args = (np.ascontiguousarray(xi, dtype=dtype) for xi in x) - ret = top_fn(*args) + assert len(x) == n_inputs, f'Expected {n_inputs} inputs, got {len(x)}.' + args = (np.ascontiguousarray(xi, dtype=dtype) for xi in x) - else: - ret = top_fn(np.ascontiguousarray(x, dtype=dtype)) + ret = self._jit_process.apply(inference_worker, (namespace, *args)) if n_outputs == 1: return np.array(ret[0]).reshape(-1, *output_shapes[0]) @@ -891,7 +952,7 @@ def ctypes_predict(self, x): def trace(self, x): print(f'Recompiling {self.config.get_project_name()} with tracing') self.config.trace_output = True - self.compile() + self.compile(jit=False) top_function, ctype = self._get_top_function(x) n_samples = self._compute_n_samples(x) @@ -982,3 +1043,10 @@ def build(self, **kwargs): self.write() return self.config.backend.build(self, **kwargs) + + def __del__(self): + if self._jit_process is not None: + self._jit_process.close() + self._jit_process.join() + self._jit_process = None + gc.collect() From 635f08247d2d585088b21bd60583685036bd119b Mon Sep 17 00:00:00 2001 From: Chang Sun Date: Wed, 30 Oct 2024 17:23:31 +0000 Subject: [PATCH 06/18] pass array directly, no shm --- hls4ml/model/graph.py | 26 +++++--------------------- 1 file changed, 5 insertions(+), 21 deletions(-) diff --git a/hls4ml/model/graph.py b/hls4ml/model/graph.py index 771a479342..89952bd742 100644 --- a/hls4ml/model/graph.py +++ b/hls4ml/model/graph.py @@ -4,7 +4,7 @@ import platform import sys from collections import OrderedDict -from multiprocessing import Pool, shared_memory +from multiprocessing import Pool from pathlib import Path from typing import Sequence @@ -18,9 +18,8 @@ from hls4ml.utils.string_utils import convert_to_snake_case -def compile_worker(config, var_types: Sequence[str], var_names: Sequence[str], shm_names: Sequence[str]): +def compile_worker(config, var_types: Sequence[str], var_names: Sequence[str], var_vals: Sequence[np.ndarray]): import cppyy - import numpy as np prj_path = Path(config['OutputDir']).absolute() prj_name = config['ProjectName'] @@ -40,16 +39,12 @@ def compile_worker(config, var_types: Sequence[str], var_names: Sequence[str], s # Load weights from to c arrays, if not initialized in headers if config['WriteWeightsTxt']: - for w_type, w_name, shm_name in zip(var_types, var_names, shm_names): - shm = shared_memory.SharedMemory(name=shm_name, create=False) - wval = np.ndarray((shm.size // np.dtype(np.float32).itemsize,), dtype=np.float32, buffer=shm.buf) + for w_type, w_name, w_val in zip(var_types, var_names, var_vals): cpp_w = getattr(cpp_namespace, w_name, None) if cpp_w is None: continue fill_fn = cpp_namespace.fill_weight[w_type] - fill_fn(cpp_w, wval.ravel().astype(np.float32)) - shm.close() - shm.unlink() + fill_fn(cpp_w, w_val.ravel().astype(np.float32)) import gc gc.collect() @@ -776,13 +771,6 @@ def jit_compile(self): var_names = [v.name for v in variables] var_values = [v.data.astype(np.float32) for v in variables] - shm_names = [] - for wval in var_values: - shm = shared_memory.SharedMemory(create=True, size=wval.nbytes) - np.ndarray(wval.size, dtype=wval.dtype, buffer=shm.buf)[:] = wval.ravel() - shm_names.append(shm.name) - shm.close() - config = { 'OutputDir': self.config.config['OutputDir'], 'ProjectName': self.config.config['ProjectName'], @@ -790,7 +778,7 @@ def jit_compile(self): 'WriteWeightsTxt': self.config.config['WriterConfig'].get('WriteWeightsTxt', False), } - self._jit_process.apply(compile_worker, (config, var_types, var_names, shm_names)) + self._jit_process.apply(compile_worker, (config, var_types, var_names, var_values)) gc.collect() @@ -886,12 +874,8 @@ def predict(self, x, jit=None) -> np.ndarray | list[np.ndarray]: def jit_predict(self, x: np.ndarray | Sequence[np.ndarray]): assert self._jit_process is not None, 'Model not jit compiled' - # assert cppyy.gbl, 'cppyy not initialized, please call model.jit_compile() or model.compile(jit=True) first' namespace: str | None = self.config.writer_config['Namespace'] # type: ignore assert namespace is not None, 'Namespace must be set in the writer config for jit_predict to work' - # assert hasattr( - # cppyy.gbl, namespace - # ), f'Namespace {namespace} not found in cppyy.gbl. Did you call model.jit_compile()?' output_shapes = [out.shape for out in self.get_output_variables()] n_inputs = len(self.get_input_variables()) From 6a0d80f8324fcdb23461792a0e1187d41455170f Mon Sep 17 00:00:00 2001 From: Chang Sun Date: Wed, 30 Oct 2024 17:57:49 +0000 Subject: [PATCH 07/18] namespace related fix --- hls4ml/backends/fpga/passes/hgq_proxy_model.py | 2 +- hls4ml/backends/template.py | 2 +- hls4ml/converters/__init__.py | 7 ------- hls4ml/model/graph.py | 2 +- 4 files changed, 3 insertions(+), 10 deletions(-) diff --git a/hls4ml/backends/fpga/passes/hgq_proxy_model.py b/hls4ml/backends/fpga/passes/hgq_proxy_model.py index 81f980ef50..928b4dbd26 100644 --- a/hls4ml/backends/fpga/passes/hgq_proxy_model.py +++ b/hls4ml/backends/fpga/passes/hgq_proxy_model.py @@ -79,7 +79,7 @@ def __init__(self): def format(self, node): params = self._default_function_params(node) - namespace = node.model.config.get_writer_config()['Namespace'] or 'nnet' + namespace = node.model.config.get_writer_config.get('Namespace', None) or 'nnet' params['namespace'] = namespace return self.template.format(**params) diff --git a/hls4ml/backends/template.py b/hls4ml/backends/template.py index 228077dbbb..ac9c5e9f43 100644 --- a/hls4ml/backends/template.py +++ b/hls4ml/backends/template.py @@ -62,7 +62,7 @@ def _default_config_params(self, layer): params = self._default_params(layer) params['iotype'] = layer.model.config.get_config_value('IOType') params['reuse'] = layer.get_attr('reuse_factor') - params['namespace'] = layer.model.config.get_writer_config()['Namespace'] or 'nnet' + params['namespace'] = layer.model.config.get_writer_config().get('Namespace', None) or 'nnet' return params diff --git a/hls4ml/converters/__init__.py b/hls4ml/converters/__init__.py index 7ba2edda38..092e53b3d3 100644 --- a/hls4ml/converters/__init__.py +++ b/hls4ml/converters/__init__.py @@ -220,13 +220,6 @@ def convert_from_keras_model( ModelGraph: hls4ml model. """ - if os.environ.get('HLS4ML_USE_JIT', '0') == '1': - import random - import string - - rand_namespace = ''.join(random.choices(string.ascii_lowercase + string.ascii_uppercase, k=16)) - kwargs.setdefault('namespace', rand_namespace) - config = create_config(output_dir=output_dir, project_name=project_name, backend=backend, **kwargs) config['KerasModel'] = model diff --git a/hls4ml/model/graph.py b/hls4ml/model/graph.py index 89952bd742..d2a7b2388c 100644 --- a/hls4ml/model/graph.py +++ b/hls4ml/model/graph.py @@ -774,7 +774,7 @@ def jit_compile(self): config = { 'OutputDir': self.config.config['OutputDir'], 'ProjectName': self.config.config['ProjectName'], - 'Namespace': self.config.config['WriterConfig']['Namespace'], + 'Namespace': self.config.config['WriterConfig'].get('Namespace', None) or 'nnet', 'WriteWeightsTxt': self.config.config['WriterConfig'].get('WriteWeightsTxt', False), } From cd169da49c07f3e0e1b9b67fef01fa5f101a8323 Mon Sep 17 00:00:00 2001 From: Chang Sun Date: Wed, 30 Oct 2024 20:03:17 +0000 Subject: [PATCH 08/18] fix silly typo --- hls4ml/backends/fpga/passes/hgq_proxy_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hls4ml/backends/fpga/passes/hgq_proxy_model.py b/hls4ml/backends/fpga/passes/hgq_proxy_model.py index 928b4dbd26..3c985c88c4 100644 --- a/hls4ml/backends/fpga/passes/hgq_proxy_model.py +++ b/hls4ml/backends/fpga/passes/hgq_proxy_model.py @@ -79,7 +79,7 @@ def __init__(self): def format(self, node): params = self._default_function_params(node) - namespace = node.model.config.get_writer_config.get('Namespace', None) or 'nnet' + namespace = node.model.config.writer_config.get('Namespace', None) or 'nnet' params['namespace'] = namespace return self.template.format(**params) From 59e00ffbb1dcb66cabb86f4e5e3d73113d0453ff Mon Sep 17 00:00:00 2001 From: Chang Sun Date: Wed, 30 Oct 2024 21:00:54 +0000 Subject: [PATCH 09/18] fix quartus jit --- hls4ml/model/graph.py | 33 ++++++++++++++----- hls4ml/writer/writers.py | 71 ++++++++++++++++++++++------------------ 2 files changed, 64 insertions(+), 40 deletions(-) diff --git a/hls4ml/model/graph.py b/hls4ml/model/graph.py index d2a7b2388c..831ba426fa 100644 --- a/hls4ml/model/graph.py +++ b/hls4ml/model/graph.py @@ -25,13 +25,30 @@ def compile_worker(config, var_types: Sequence[str], var_names: Sequence[str], v prj_name = config['ProjectName'] prj_top_path = prj_path / f'firmware/{prj_name}.cpp' jit_bridge_path = prj_path / f'{prj_name}_jit_bridge.cpp' + backend = config['Backend'].lower() cppyy.cppdef("#define HLS4ML_EXTERNAL_WEIGHT_LOAD") cppyy.add_include_path(str(prj_path)) - if (prj_path / 'firmware/ap_types').exists(): - cppyy.add_include_path(str(prj_path / 'firmware/ap_types')) - if (prj_path / 'firmware/ac_types').exists(): - cppyy.add_include_path(str(prj_path / 'firmware/ac_types')) + + print(f'Compiling {prj_top_path} for {backend} backend') + match backend: + case 'vivado' | 'vitis': + print('Adding Vivado paths') + cppyy.add_include_path(str(prj_path / 'firmware/ap_types')) + case 'quartus': + cppyy.add_include_path(str(prj_path / 'firmware/ap_types')) + cppyy.add_include_path(str(prj_path / 'firmware/ac_types')) + # case 'catapult': # Doesn't work, unknown error + # print('Adding MGC paths') + # MGC_HOME = os.environ.get('MGC_HOME', '') + # if MGC_HOME: + # cppyy.add_include_path(f'{MGC_HOME}/shared/include') + # cppyy.add_include_path(f'{MGC_HOME}/shared/include/nnet_utils') + # cppyy.add_include_path(str(prj_path / 'firmware/ac_types/include')) + # cppyy.add_include_path(str(prj_path / 'firmware/ac_math/include')) + # cppyy.add_include_path(str(prj_path / 'firmware/ac_simutils/include')) + # cppyy.add_include_path(str(prj_path / 'firmware/nnet_utils')) + cppyy.include(str(prj_top_path)) cppyy.include(str(jit_bridge_path)) @@ -774,8 +791,9 @@ def jit_compile(self): config = { 'OutputDir': self.config.config['OutputDir'], 'ProjectName': self.config.config['ProjectName'], - 'Namespace': self.config.config['WriterConfig'].get('Namespace', None) or 'nnet', - 'WriteWeightsTxt': self.config.config['WriterConfig'].get('WriteWeightsTxt', False), + 'Namespace': self.config.writer_config.get('Namespace', None) or 'nnet', + 'WriteWeightsTxt': self.config.writer_config.get('WriteWeightsTxt', False), + 'Backend': self.config.config['Backend'], } self._jit_process.apply(compile_worker, (config, var_types, var_names, var_values)) @@ -874,8 +892,7 @@ def predict(self, x, jit=None) -> np.ndarray | list[np.ndarray]: def jit_predict(self, x: np.ndarray | Sequence[np.ndarray]): assert self._jit_process is not None, 'Model not jit compiled' - namespace: str | None = self.config.writer_config['Namespace'] # type: ignore - assert namespace is not None, 'Namespace must be set in the writer config for jit_predict to work' + namespace: str = self.config.writer_config.get('Namespace', None) or 'nnet' # type: ignore output_shapes = [out.shape for out in self.get_output_variables()] n_inputs = len(self.get_input_variables()) diff --git a/hls4ml/writer/writers.py b/hls4ml/writer/writers.py index 2888372ac6..86580b415c 100644 --- a/hls4ml/writer/writers.py +++ b/hls4ml/writer/writers.py @@ -1,3 +1,4 @@ +import re import typing from math import prod from pathlib import Path @@ -23,63 +24,63 @@ def create_jit_bridge_fn(model: 'ModelGraph'): input_def = '\n '.join(f'std::vector {v.name}, ' for v in inp_vars)[:-2] - inp_size_def = '\n '.join(f'constexpr size_t {n}_size = {s};' for n, s in zip(inp_names, inp_sizes)) - out_size_def = '\n '.join(f'constexpr size_t {n}_size = {s};' for n, s in zip(out_names, out_sizes)) + inp_size_def = '\n '.join(f'constexpr size_t {n}_size = {s};' for n, s in zip(inp_names, inp_sizes)) + out_size_def = '\n '.join(f'constexpr size_t {n}_size = {s};' for n, s in zip(out_names, out_sizes)) - ptr_buf_def = '\n '.join(f'T* {v.name}_ptr = {v.name}.data();' for v in inp_vars + out_vars) + ptr_buf_def = '\n '.join(f'T* {v.name}_ptr = {v.name}.data();' for v in inp_vars + out_vars) n_samples_def = f'{inp_vars[0].name}.size() / {inp_vars[0].name}_size' - assertions_def = ' ||\n '.join(f'({n}.size() != {n}_size * n_samples)' for n in inp_names) + assertions_def = ' ||\n '.join(f'({n}.size() != {n}_size * n_samples)' for n in inp_names) inp_args_def_list = [f'{n}_ptr + i * {n}_size,' for n in inp_names] out_args_def_list = [f'{n}_ptr + i * {n}_size,' for n in out_names] args_def = ('\n' + ' ' * 12).join(inp_args_def_list + out_args_def_list)[:-1] - out_var_def = '\n '.join(f'std::vector {v.name}({v.name}_size * n_samples);' for v in out_vars) + out_var_def = '\n '.join(f'std::vector {v.name}({v.name}_size * n_samples);' for v in out_vars) _ret_template_arg = ('std::vector, ' * n_out)[:-2] _ret_tuple_arg = ', '.join(out_names) return_def = f'std::tuple<{_ret_template_arg}>({_ret_tuple_arg})' cpp_fn = f""" - template - auto batch_inference( - {input_def} - ){{ - {inp_size_def} - {out_size_def} +template +auto batch_inference( + {input_def} +){{ + {inp_size_def} + {out_size_def} - size_t n_samples = {n_samples_def}; + size_t n_samples = {n_samples_def}; - if ( - {assertions_def} - ) - throw std::runtime_error("Invalid input sizes: number of samples or input sizes do not match"); + if ( + {assertions_def} + ) + throw std::runtime_error("Invalid input sizes: number of samples or input sizes do not match"); - {out_var_def} + {out_var_def} - {ptr_buf_def} + {ptr_buf_def} - for (int i = 0; i < n_samples; i++) {{ - myproject_float( - {args_def} - ); - }} - - return {return_def}; + for (int i = 0; i < n_samples; i++) {{ + myproject_float( + {args_def} + ); }} + + return {return_def}; +}} """ return cpp_fn def create_jit_weight_filler(model: 'ModelGraph'): filler_fn = """ - template - void fill_weight(T weight[], std::vector vec) {{ - for (size_t i = 0; i < vec.size(); i++) {{ - weight[i] = vec[i]; - }} +template +void fill_weight(T weight[], std::vector vec) {{ + for (size_t i = 0; i < vec.size(); i++) {{ + weight[i] = vec[i]; }} +}} """ return filler_fn @@ -103,12 +104,18 @@ def write_jit_bridge(self, model: 'ModelGraph'): cpp_source_bridge = path_c_bridge.read_text() + # Remove the unsigned short &const_size_... arguments from the function signature + # For Quartus + m = re.compile(r',\s*unsigned short &const_size_\w+', re.MULTILINE) + cpp_source_bridge = m.sub('', cpp_source_bridge) + jit_bridge_fn = create_jit_bridge_fn(model) weight_filler_fn = create_jit_weight_filler(model) - _plugin_code = f'namespace {namespace} {{\n' + jit_bridge_fn + '\n\n' + weight_filler_fn + '\n' + _plugin_code = f'namespace {namespace} {{\n' + jit_bridge_fn + '\n\n' + weight_filler_fn + '\n}\n#endif' - cpp_source_bridge = cpp_source_bridge.replace('extern "C" {', _plugin_code) + cpp_source_bridge = cpp_source_bridge.replace('extern "C" {', f'namespace {namespace} {{') + cpp_source_bridge = cpp_source_bridge.replace('#endif', _plugin_code) path_jit_bridge.write_text(cpp_source_bridge) From fe6de2be95cd43f0f9e813eca5909c435e4370fb Mon Sep 17 00:00:00 2001 From: Chang Sun Date: Wed, 30 Oct 2024 21:03:11 +0000 Subject: [PATCH 10/18] ward unsupported backend jit --- hls4ml/model/graph.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/hls4ml/model/graph.py b/hls4ml/model/graph.py index 831ba426fa..fe07aa7f96 100644 --- a/hls4ml/model/graph.py +++ b/hls4ml/model/graph.py @@ -764,11 +764,17 @@ def compile(self, jit=None): def _compile(self, jit=None): if jit is None: - jit = os.environ.get('HLS4ML_USE_JIT', '0') == '1' - - if self.config.config['Backend'].lower() == 'oneapi': - print('INFO: OneAPI backend does not support JIT compilation, using shared library compilation', file=sys.stderr) jit = False + backend = self.config.config['Backend'].lower() + if backend == 'vivado' or backend == 'vitis': + jit = os.environ.get('HLS4ML_USE_JIT', '0') == '1' + + if self.config.config['Backend'].lower() not in ['vivado', 'vitis', 'quartus']: + if jit: + print( + 'JIT compilation is not supported for this backend, falling back to normal compilation', file=sys.stderr + ) + jit = False if jit: self.jit_compile() From 311507cb62481cc4c8b88382f76b0ba135136652 Mon Sep 17 00:00:00 2001 From: Chang Sun Date: Wed, 30 Oct 2024 21:56:50 +0000 Subject: [PATCH 11/18] fix default namespace param update --- hls4ml/model/graph.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/hls4ml/model/graph.py b/hls4ml/model/graph.py index fe07aa7f96..5c4bbe7bf8 100644 --- a/hls4ml/model/graph.py +++ b/hls4ml/model/graph.py @@ -58,6 +58,8 @@ def compile_worker(config, var_types: Sequence[str], var_names: Sequence[str], v if config['WriteWeightsTxt']: for w_type, w_name, w_val in zip(var_types, var_names, var_vals): cpp_w = getattr(cpp_namespace, w_name, None) + if cpp_w is None: + cpp_w = getattr(cppyy.gbl, w_name, None) if cpp_w is None: continue fill_fn = cpp_namespace.fill_weight[w_type] From f8fb6dfe5a278289a7554476e29cfc4ff73e6197 Mon Sep 17 00:00:00 2001 From: Chang Sun Date: Wed, 30 Oct 2024 15:37:36 -0700 Subject: [PATCH 12/18] disable unsupported backends, fix fp64 entry --- hls4ml/model/graph.py | 17 +++++++++++++---- hls4ml/writer/catapult_writer.py | 2 +- hls4ml/writer/oneapi_writer.py | 2 +- hls4ml/writer/writers.py | 19 ++++++++++++++----- 4 files changed, 29 insertions(+), 11 deletions(-) diff --git a/hls4ml/model/graph.py b/hls4ml/model/graph.py index 5c4bbe7bf8..4977123165 100644 --- a/hls4ml/model/graph.py +++ b/hls4ml/model/graph.py @@ -893,7 +893,14 @@ def predict(self, x, jit=None) -> np.ndarray | list[np.ndarray]: jit = self._jit_process is not None if jit: - return self.jit_predict(x) + r = self.jit_predict(x) + + # Match original predict output shape + if isinstance(r, np.ndarray): + if len(r) == 1: + return r[0] + return r.reshape(r.shape[0], -1) + return [ri.reshape(ri.shape[0], -1) for ri in r] else: return self.ctypes_predict(x) @@ -1055,7 +1062,9 @@ def build(self, **kwargs): def __del__(self): if self._jit_process is not None: - self._jit_process.close() - self._jit_process.join() - self._jit_process = None + try: + self._jit_process.close() + self._jit_process.join() + except Exception: + pass gc.collect() diff --git a/hls4ml/writer/catapult_writer.py b/hls4ml/writer/catapult_writer.py index c634fd565a..72e72ec64f 100755 --- a/hls4ml/writer/catapult_writer.py +++ b/hls4ml/writer/catapult_writer.py @@ -921,7 +921,7 @@ def write_hls(self, model): self.write_parameters(model) self.write_test_bench(model) self.write_bridge(model) - self.write_jit_bridge(model) + # self.write_jit_bridge(model) self.write_build_script(model) self.write_nnet_utils(model) self.write_generated_code(model) diff --git a/hls4ml/writer/oneapi_writer.py b/hls4ml/writer/oneapi_writer.py index b2c61d85d8..949ab76515 100644 --- a/hls4ml/writer/oneapi_writer.py +++ b/hls4ml/writer/oneapi_writer.py @@ -961,7 +961,7 @@ def write_hls(self, model): self.write_parameters(model) self.write_test_bench(model) self.write_bridge(model) - self.write_jit_bridge(model) + # self.write_jit_bridge(model) self.write_build_script(model) self.write_nnet_utils(model) self.write_activation_tables(model) diff --git a/hls4ml/writer/writers.py b/hls4ml/writer/writers.py index 86580b415c..f0c454f2a1 100644 --- a/hls4ml/writer/writers.py +++ b/hls4ml/writer/writers.py @@ -10,6 +10,7 @@ def create_jit_bridge_fn(model: 'ModelGraph'): inp_vars = model.get_input_variables() out_vars = model.get_output_variables() + prj_name = model.config.config['ProjectName'] inp_shapes = [tuple(v.shape) for v in inp_vars] out_shapes = [tuple(v.shape) for v in out_vars] @@ -61,12 +62,20 @@ def create_jit_bridge_fn(model: 'ModelGraph'): {ptr_buf_def} - for (int i = 0; i < n_samples; i++) {{ - myproject_float( - {args_def} - ); + if (std::is_same::value) {{ + for (int i = 0; i < n_samples; i++) + {prj_name}_double( + {args_def} + ); + }} else if (std::is_same::value) {{ + for (int i = 0; i < n_samples; i++) + {prj_name}_float( + {args_def} + ); + }} else {{ + throw std::runtime_error("Unsupported type"); }} - + return {return_def}; }} """ From be23f7f608ba6728853f3d9ecc2893fc19cae232 Mon Sep 17 00:00:00 2001 From: Chang Sun Date: Wed, 30 Oct 2024 16:16:47 -0700 Subject: [PATCH 13/18] allow openmp opt-in --- hls4ml/writer/writers.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/hls4ml/writer/writers.py b/hls4ml/writer/writers.py index f0c454f2a1..2b9f67ffe7 100644 --- a/hls4ml/writer/writers.py +++ b/hls4ml/writer/writers.py @@ -62,18 +62,19 @@ def create_jit_bridge_fn(model: 'ModelGraph'): {ptr_buf_def} - if (std::is_same::value) {{ - for (int i = 0; i < n_samples; i++) + #pragma omp parallel for + for (int i = 0; i < n_samples; i++) {{ + if (std::is_same::value) {{ {prj_name}_double( {args_def} ); - }} else if (std::is_same::value) {{ - for (int i = 0; i < n_samples; i++) + }} else if (std::is_same::value) {{ {prj_name}_float( {args_def} ); - }} else {{ - throw std::runtime_error("Unsupported type"); + }} else {{ + throw std::runtime_error("Unsupported type"); + }} }} return {return_def}; From dfb681640f21ba6ddd39ffacf5db1b9609ff2e8f Mon Sep 17 00:00:00 2001 From: Chang Sun Date: Wed, 30 Oct 2024 16:17:15 -0700 Subject: [PATCH 14/18] wrad unsupported exponent weight load --- hls4ml/model/graph.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/hls4ml/model/graph.py b/hls4ml/model/graph.py index 4977123165..6f911b9f06 100644 --- a/hls4ml/model/graph.py +++ b/hls4ml/model/graph.py @@ -771,13 +771,19 @@ def _compile(self, jit=None): if backend == 'vivado' or backend == 'vitis': jit = os.environ.get('HLS4ML_USE_JIT', '0') == '1' - if self.config.config['Backend'].lower() not in ['vivado', 'vitis', 'quartus']: - if jit: - print( - 'JIT compilation is not supported for this backend, falling back to normal compilation', file=sys.stderr - ) - jit = False + if self.config.config['Backend'].lower() not in ['vivado', 'vitis', 'quartus'] and jit: + print( + 'JIT compilation is not supported for this backend, falling back to normal compilation', file=sys.stderr + ) + jit = False + write_weights_txt = self.config.writer_config.get('WriteWeightsTxt', False) + if write_weights_txt and any(w.type.name.startswith('exponent_') for w in self.get_weight_variables()) and jit: + print( + 'JIT compilation is not supported for models with exponent weights, falling back to normal compilation', + file=sys.stderr, + ) + jit = False if jit: self.jit_compile() else: From ef4341311155c4f3d14633726de5e706fc24807c Mon Sep 17 00:00:00 2001 From: Chang Sun Date: Wed, 30 Oct 2024 16:54:59 -0700 Subject: [PATCH 15/18] rm unused ap_types in quartus build, ward "." in var name --- hls4ml/model/graph.py | 1 - hls4ml/templates/quartus/build_lib.sh | 2 +- hls4ml/writer/writers.py | 12 ++++++------ 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/hls4ml/model/graph.py b/hls4ml/model/graph.py index 6f911b9f06..9d8461bc62 100644 --- a/hls4ml/model/graph.py +++ b/hls4ml/model/graph.py @@ -36,7 +36,6 @@ def compile_worker(config, var_types: Sequence[str], var_names: Sequence[str], v print('Adding Vivado paths') cppyy.add_include_path(str(prj_path / 'firmware/ap_types')) case 'quartus': - cppyy.add_include_path(str(prj_path / 'firmware/ap_types')) cppyy.add_include_path(str(prj_path / 'firmware/ac_types')) # case 'catapult': # Doesn't work, unknown error # print('Adding MGC paths') diff --git a/hls4ml/templates/quartus/build_lib.sh b/hls4ml/templates/quartus/build_lib.sh index 5514a9cc75..59b3160672 100755 --- a/hls4ml/templates/quartus/build_lib.sh +++ b/hls4ml/templates/quartus/build_lib.sh @@ -8,7 +8,7 @@ elif [[ "$OSTYPE" == "darwin"* ]]; then CFLAGS="-O3 -fPIC -std=c++11" fi LDFLAGS= -INCFLAGS="-Ifirmware/ac_types/ -Ifirmware/ap_types/" +INCFLAGS="-Ifirmware/ac_types/" PROJECT=myproject LIB_STAMP=mystamp diff --git a/hls4ml/writer/writers.py b/hls4ml/writer/writers.py index 2b9f67ffe7..a8ff2ba490 100644 --- a/hls4ml/writer/writers.py +++ b/hls4ml/writer/writers.py @@ -15,21 +15,21 @@ def create_jit_bridge_fn(model: 'ModelGraph'): inp_shapes = [tuple(v.shape) for v in inp_vars] out_shapes = [tuple(v.shape) for v in out_vars] - inp_names = [v.name for v in inp_vars] - out_names = [v.name for v in out_vars] + inp_names = [v.name.replace('.', '_') for v in inp_vars] + out_names = [v.name.replace('.', '_') for v in out_vars] inp_sizes = [prod(s) for s in inp_shapes] out_sizes = [prod(s) for s in out_shapes] n_out = len(out_names) - input_def = '\n '.join(f'std::vector {v.name}, ' for v in inp_vars)[:-2] + input_def = '\n '.join(f'std::vector {name}, ' for name in inp_names)[:-2] inp_size_def = '\n '.join(f'constexpr size_t {n}_size = {s};' for n, s in zip(inp_names, inp_sizes)) out_size_def = '\n '.join(f'constexpr size_t {n}_size = {s};' for n, s in zip(out_names, out_sizes)) - ptr_buf_def = '\n '.join(f'T* {v.name}_ptr = {v.name}.data();' for v in inp_vars + out_vars) - n_samples_def = f'{inp_vars[0].name}.size() / {inp_vars[0].name}_size' + ptr_buf_def = '\n '.join(f'T* {name}_ptr = {name}.data();' for name in inp_names + out_names) + n_samples_def = f'{inp_names[0]}.size() / {inp_names[0]}_size' assertions_def = ' ||\n '.join(f'({n}.size() != {n}_size * n_samples)' for n in inp_names) @@ -37,7 +37,7 @@ def create_jit_bridge_fn(model: 'ModelGraph'): out_args_def_list = [f'{n}_ptr + i * {n}_size,' for n in out_names] args_def = ('\n' + ' ' * 12).join(inp_args_def_list + out_args_def_list)[:-1] - out_var_def = '\n '.join(f'std::vector {v.name}({v.name}_size * n_samples);' for v in out_vars) + out_var_def = '\n '.join(f'std::vector {name}({name}_size * n_samples);' for name in out_names) _ret_template_arg = ('std::vector, ' * n_out)[:-2] _ret_tuple_arg = ', '.join(out_names) From d23d31812a12c9eceb8b67e748ffba94f29ff2c2 Mon Sep 17 00:00:00 2001 From: Chang Sun Date: Wed, 30 Oct 2024 16:55:10 -0700 Subject: [PATCH 16/18] add test --- test/pytest/test_jit.py | 48 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) create mode 100644 test/pytest/test_jit.py diff --git a/test/pytest/test_jit.py b/test/pytest/test_jit.py new file mode 100644 index 0000000000..616085cb3e --- /dev/null +++ b/test/pytest/test_jit.py @@ -0,0 +1,48 @@ +from pathlib import Path + +import numpy as np +import pytest +from keras.layers import Dense +from tensorflow import keras + +from hls4ml.converters import convert_from_keras_model + +test_root_path = Path(__file__).parent + + +@pytest.fixture(scope='module') +def model(): + in1 = keras.Input(shape=(10,)) + in2 = keras.Input(shape=(9,)) + x = Dense(8, name='dense1')(in1) + y = Dense(7, name='dense2')(in2) + model = keras.Model([in1, in2], [x, y]) + return model + + +@pytest.fixture(scope='module') +def data(): + IN1 = np.random.normal(0, 1, (1000, 10)).astype(np.float32) + IN2 = np.random.normal(0, 1, (1000, 9)).astype(np.float32) + return IN1, IN2 + + +@pytest.mark.parametrize('backend', ['Vivado', 'Quartus', 'Vitis']) +@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream']) +@pytest.mark.parametrize('strategy', ['Resource', 'Latency']) +def test_jit(model, data, backend: str, io_type: str, strategy: str): + output_dir = str(test_root_path / f'hls4mlprj_jit_{backend}_{io_type}_{strategy}') + + model_hls = convert_from_keras_model( + model, backend=backend, output_dir=output_dir, io_type=io_type, hls_config={'Model': {'Strategy': strategy, 'ReuseFactor': 1, 'Precision': 'ap_fixed<16,6>'}} + ) + + model_hls.write() + model_hls.compile_shared_lib() + model_hls.jit_compile() + + ctypes_pred = model_hls.predict(data, jit=False) + jit_pred = model_hls.predict(data, jit=True) + + assert np.all(ctypes_pred[0] == jit_pred[0]) + assert np.all(ctypes_pred[1] == jit_pred[1]) From 50fd4c8e39d3f759b8bc626690a4cfe1b9ef84a7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 31 Oct 2024 00:02:53 +0000 Subject: [PATCH 17/18] [pre-commit.ci] auto fixes from pre-commit hooks --- hls4ml/model/graph.py | 4 +--- hls4ml/writer/writers.py | 2 +- test/pytest/test_jit.py | 6 +++++- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/hls4ml/model/graph.py b/hls4ml/model/graph.py index 9d8461bc62..4f123b5b2c 100644 --- a/hls4ml/model/graph.py +++ b/hls4ml/model/graph.py @@ -771,9 +771,7 @@ def _compile(self, jit=None): jit = os.environ.get('HLS4ML_USE_JIT', '0') == '1' if self.config.config['Backend'].lower() not in ['vivado', 'vitis', 'quartus'] and jit: - print( - 'JIT compilation is not supported for this backend, falling back to normal compilation', file=sys.stderr - ) + print('JIT compilation is not supported for this backend, falling back to normal compilation', file=sys.stderr) jit = False write_weights_txt = self.config.writer_config.get('WriteWeightsTxt', False) diff --git a/hls4ml/writer/writers.py b/hls4ml/writer/writers.py index a8ff2ba490..c4a9fe1c90 100644 --- a/hls4ml/writer/writers.py +++ b/hls4ml/writer/writers.py @@ -76,7 +76,7 @@ def create_jit_bridge_fn(model: 'ModelGraph'): throw std::runtime_error("Unsupported type"); }} }} - + return {return_def}; }} """ diff --git a/test/pytest/test_jit.py b/test/pytest/test_jit.py index 616085cb3e..71fca21596 100644 --- a/test/pytest/test_jit.py +++ b/test/pytest/test_jit.py @@ -34,7 +34,11 @@ def test_jit(model, data, backend: str, io_type: str, strategy: str): output_dir = str(test_root_path / f'hls4mlprj_jit_{backend}_{io_type}_{strategy}') model_hls = convert_from_keras_model( - model, backend=backend, output_dir=output_dir, io_type=io_type, hls_config={'Model': {'Strategy': strategy, 'ReuseFactor': 1, 'Precision': 'ap_fixed<16,6>'}} + model, + backend=backend, + output_dir=output_dir, + io_type=io_type, + hls_config={'Model': {'Strategy': strategy, 'ReuseFactor': 1, 'Precision': 'ap_fixed<16,6>'}}, ) model_hls.write() From 7d56b5284f4ef9e73d2d2e785a93e79ea4d31dc8 Mon Sep 17 00:00:00 2001 From: Chang Sun Date: Thu, 31 Oct 2024 00:47:33 +0000 Subject: [PATCH 18/18] install cppyy for testing --- setup.cfg | 2 ++ test/pytest/ci-template.yml | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 9b7ef45f8f..447e18843d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -43,6 +43,8 @@ pytest_randomly.random_seeder = [options.extras_require] HGQ = HGQ~=0.2.0 +jit = + cppyy optimization = keras-tuner==1.1.3 ortools==9.4.1874 diff --git a/test/pytest/ci-template.yml b/test/pytest/ci-template.yml index f6aa700415..64a51ff2e6 100644 --- a/test/pytest/ci-template.yml +++ b/test/pytest/ci-template.yml @@ -9,7 +9,7 @@ - git submodule update --init --recursive hls4ml/templates/catapult/ - if [ $EXAMPLEMODEL == 1 ]; then git submodule update --init example-models; fi - conda activate hls4ml-testing - - pip install .[testing,sr,optimization] + - pip install .[testing,sr,optimization,jit] script: - cd test/pytest - pytest $PYTESTFILE -rA --cov-report xml --cov-report term --cov=hls4ml --junitxml=report.xml --randomly-seed=42 --randomly-dont-reorganize --randomly-dont-reset-seed