diff --git a/example-models b/example-models index d40894b03..6a82da23a 160000 --- a/example-models +++ b/example-models @@ -1 +1 @@ -Subproject commit d40894b03f840a32da43a5adea0531ffc1db216e +Subproject commit 6a82da23ad24c238fe156ed4d0aa907db547dbcf diff --git a/hls4ml/converters/onnx/reshape.py b/hls4ml/converters/onnx/reshape.py index 9ef20f03d..f11796b6d 100644 --- a/hls4ml/converters/onnx/reshape.py +++ b/hls4ml/converters/onnx/reshape.py @@ -1,4 +1,4 @@ -from hls4ml.converters.onnx_to_hls import onnx_handler +from hls4ml.converters.onnx_to_hls import get_onnx_attribute, onnx_handler @onnx_handler('Transpose') @@ -36,3 +36,25 @@ def parse_flatten_layer(node, input_names, input_shapes, graph): layer['target_shape'] = [-1] # does not contain batch dimension return layer + + +@onnx_handler('Resize') +def parse_resize_layer(node, input_names, input_shapes, graph): + layer = {} + layer['name'] = node.name + layer['class_name'] = 'Resize' + layer['inputs'] = input_names + layer['outputs'] = list(node.output) + layer['in_height'] = input_shapes[0][2] + layer['in_width'] = input_shapes[0][1] + layer['out_width'] = input_shapes[0][1] + layer['out_height'] = input_shapes[0][2] + layer['n_chan'] = input_shapes[0][3] + layer['algorithm'] = get_onnx_attribute(node, 'mode') + # The following is used in initialize() method. + # Probably a better solution would be to have a channels last parameter at QONNX level + layer['data_format'] = ( + 'channels_last' if any(node.domain == 'qonnx.custom_op.channels_last' for node in graph.node) else 'channels_first' + ) + + return layer diff --git a/hls4ml/model/layers.py b/hls4ml/model/layers.py index 7dbeb4567..891f187ea 100644 --- a/hls4ml/model/layers.py +++ b/hls4ml/model/layers.py @@ -1147,20 +1147,67 @@ class Resize(Layer): def initialize(self): inp = self.get_input_variable() - if self.get_attr('data_format') == 'channels_last': - if len(inp.shape) == 2: # 1D -> width + chan - shape = [self.get_attr('out_width'), self.get_attr('n_chan')] - dims = [f'OUT_WIDTH_{self.index}', f'N_CHAN_{self.index}'] - elif len(inp.shape) == 3: # 2D -> height + width + chan - shape = [self.get_attr('out_height'), self.get_attr('out_width'), self.get_attr('n_chan')] - dims = [f'OUT_HEIGHT_{self.index}', f'OUT_WIDTH_{self.index}', f'N_CHAN_{self.index}'] + if len(self.inputs) > 1: + # In order to be correctly ingested by hls4ml the QONNX resize node should have 3 inputs set with RoI left empty + if len(self.inputs) == 2: + raise Exception( + 'The number of inputs to Resize node is equal to 2. ' + 'In this case, either one is trying to use a version 10 node ' + 'or one is using the RoI parameter only to perform the resize operation, ' + 'both not supported in hls4ml' + ) + if len(self.inputs) == 4: + raise Exception('Sizes parameter is not supported by hls4ml. Use scales instead') + # get the scales of Resize node from QONNX frontend + # see doc here: https://onnx.ai/onnx/operators/onnx__Resize.html + scales_idx = 2 if len(self.inputs) == 3 or len(self.inputs) == 4 else 1 + scales = self.get_input_node(self.inputs[scales_idx]).get_attr('value') + if len(scales) == 4: # Resize 2D + self.set_attr('out_width', int(self.get_attr('in_width') * scales[1])) + self.set_attr('out_height', int(self.get_attr('in_height') * scales[2])) + self.set_attr('n_chan', int(self.get_attr('n_chan') * scales[3])) + elif len(scales) == 3: # Resize 1D + self.set_attr('out_width', int(self.get_attr('in_width') * scales[1])) + self.set_attr('n_chan', int(self.get_attr('n_chan') * scales[2])) + else: + raise Exception('Resize 1D and Resize 2D are the ones supported in hls4ml') + if self.get_attr('data_format') == 'channels_last': + if len(inp.shape) == 2: # 1D -> width + chan + shape = [int(self.get_attr('out_width')), int(self.get_attr('n_chan'))] + dims = [f'OUT_WIDTH_{self.index}', f'N_CHAN_{self.index}'] + elif len(inp.shape) == 3: # 2D -> height + width + chan + shape = [ + int(self.get_attr('out_height')), + int(self.get_attr('out_width')), + int(self.get_attr('n_chan')), + ] + dims = [f'OUT_HEIGHT_{self.index}', f'OUT_WIDTH_{self.index}', f'N_CHAN_{self.index}'] + else: + if len(inp.shape) == 2: # 1D -> width + chan + shape = [int(self.get_attr('n_chan')), int(self.get_attr('out_width'))] + dims = [f'N_CHAN_{self.index}', f'OUT_WIDTH_{self.index}'] + elif len(inp.shape) == 3: # 2D -> height + width + chan + shape = [ + int(self.get_attr('n_chan')), + int(self.get_attr('out_height')), + int(self.get_attr('out_width')), + ] + dims = [f'N_CHAN_{self.index}', f'OUT_HEIGHT_{self.index}', f'OUT_WIDTH_{self.index}'] else: - if len(inp.shape) == 2: # 1D -> width + chan - shape = [self.get_attr('n_chan'), self.get_attr('out_width')] - dims = [f'N_CHAN_{self.index}', f'OUT_WIDTH_{self.index}'] - elif len(inp.shape) == 3: # 2D -> height + width + chan - shape = [self.get_attr('n_chan'), self.get_attr('out_height'), self.get_attr('out_width')] - dims = [f'N_CHAN_{self.index}', f'OUT_HEIGHT_{self.index}', f'OUT_WIDTH_{self.index}'] + if self.get_attr('data_format') == 'channels_last': + if len(inp.shape) == 2: # 1D -> width + chan + shape = [self.get_attr('out_width'), self.get_attr('n_chan')] + dims = [f'OUT_WIDTH_{self.index}', f'N_CHAN_{self.index}'] + elif len(inp.shape) == 3: # 2D -> height + width + chan + shape = [self.get_attr('out_height'), self.get_attr('out_width'), self.get_attr('n_chan')] + dims = [f'OUT_HEIGHT_{self.index}', f'OUT_WIDTH_{self.index}', f'N_CHAN_{self.index}'] + else: + if len(inp.shape) == 2: # 1D -> width + chan + shape = [self.get_attr('n_chan'), self.get_attr('out_width')] + dims = [f'N_CHAN_{self.index}', f'OUT_WIDTH_{self.index}'] + elif len(inp.shape) == 3: # 2D -> height + width + chan + shape = [self.get_attr('n_chan'), self.get_attr('out_height'), self.get_attr('out_width')] + dims = [f'N_CHAN_{self.index}', f'OUT_HEIGHT_{self.index}', f'OUT_WIDTH_{self.index}'] self.add_output_variable(shape, dims, precision=inp.type.precision) diff --git a/hls4ml/model/optimizer/__init__.py b/hls4ml/model/optimizer/__init__.py index 0edd549b2..3302e3c69 100644 --- a/hls4ml/model/optimizer/__init__.py +++ b/hls4ml/model/optimizer/__init__.py @@ -34,6 +34,7 @@ 'parse_qonnx', [ 'reshape_constant', + 'resize_remove_constants', 'quant_constant_parameters', 'quant_to_activation', 'fuse_quant_with_constant', diff --git a/hls4ml/model/optimizer/passes/resize_remove_constants.py b/hls4ml/model/optimizer/passes/resize_remove_constants.py new file mode 100644 index 000000000..69039c60a --- /dev/null +++ b/hls4ml/model/optimizer/passes/resize_remove_constants.py @@ -0,0 +1,38 @@ +from warnings import warn + +from hls4ml.model.layers import Constant, Resize +from hls4ml.model.optimizer import OptimizerPass + + +class ResizeRemoveConstants(OptimizerPass): + """ + This optimizer is intended to clean the Resize node from RoI and Scales parameters that if left cause issues in hls4ml. + """ + + def match(self, node): + is_match = isinstance(node, Resize) and len(node.inputs) > 1 + return is_match + + def transform(self, model, node): + """ + Remove RoI and Scale Constant from new shape input. + """ + # see doc here: https://onnx.ai/onnx/operators/onnx__Resize.html + roi_index = 1 + scales_idx = 2 + scales_node = node.get_input_node(node.inputs[scales_idx]) + node.inputs[scales_idx] = '' + if not isinstance(scales_node, Constant): + raise RuntimeError("Non-constant shape inputs are not supported") + model.remove_node(scales_node, rewire=False) + # RoI position is always 1 when present + roi_node = node.get_input_node(node.inputs[roi_index]) + if roi_node.get_attr('value'): + warn('RoI value vector is not empty. Consider that RoI is not supported in hls4ml', stacklevel=2) + node.inputs[roi_index] = '' + if not isinstance(roi_node, Constant): + raise RuntimeError("Non-constant RoI inputs are not supported") + model.remove_node(roi_node, rewire=False) + # Clean all the '' inputs + node.inputs = list(filter(None, node.inputs)) + return True diff --git a/test/pytest/test_qonnx.py b/test/pytest/test_qonnx.py index f822c591a..f48f26862 100644 --- a/test/pytest/test_qonnx.py +++ b/test/pytest/test_qonnx.py @@ -101,6 +101,32 @@ def sep_conv_model(): return model +@pytest.fixture(scope='module') +def branched_model(): + """ + Load branched model using separable convs, already channels-last and cleaned + """ + dl_file = str(example_model_path / "onnx/branched_model_ch_last.onnx") + assert os.path.isfile(dl_file) + + model = ModelWrapper(dl_file) + + return model + + +@pytest.fixture(scope='module') +def tiny_unet_model(): + """ + Load tiny unet model, already channels-last and cleaned + """ + dl_file = str(example_model_path / "onnx/tiny_unet_ch_last.onnx") + assert os.path.isfile(dl_file) + + model = ModelWrapper(dl_file) + + return model + + @pytest.fixture(scope='module') def two_layer_keras_model(): """ @@ -309,6 +335,58 @@ def test_sep_conv(sep_conv_model, backend): np.testing.assert_allclose(y_qonnx.ravel(), y_hls4ml.ravel(), atol=1e-2, rtol=1) +@pytest.mark.parametrize('backend', ['Vitis']) +def test_branched_model(branched_model, backend): + model = branched_model + ishape = tuple(model.get_tensor_shape(model.graph.input[0].name)) + X = np.random.uniform(low=0, high=1, size=np.prod(ishape)).reshape(ishape) + X = (np.round(X * 2**16) * 2**-16).astype(np.float32) + idict = {model.graph.input[0].name: X} + y_qonnx = oxe.execute_onnx(model, idict)[model.graph.output[0].name] + + config = hls4ml.utils.config.config_from_onnx_model( + model, granularity='name', backend=backend, default_precision='fixed<32,16>' + ) + hls_model = hls4ml.converters.convert_from_onnx_model( + model, + output_dir=str(test_root_path / f'hls4mlprj_qonnx_branched_model_{backend}'), + io_type='io_stream', + backend=backend, + hls_config=config, + ) + hls_model.compile() + y_hls4ml = hls_model.predict(np.ascontiguousarray(X)) + + np.testing.assert_array_equal(y_qonnx.ravel(), y_hls4ml.ravel()) + + +@pytest.mark.parametrize('backend', ['Vitis']) +def test_tiny_unet_model(tiny_unet_model, backend): + + model = tiny_unet_model + ishape = tuple(model.get_tensor_shape(model.graph.input[0].name)) + X = np.random.uniform(low=0, high=1, size=np.prod(ishape)).reshape(ishape) + X = (np.round(X * 2**16) * 2**-16).astype(np.float32) + idict = {model.graph.input[0].name: X} + y_qonnx = oxe.execute_onnx(model, idict)[model.graph.output[0].name] + + config = hls4ml.utils.config.config_from_onnx_model( + model, granularity='name', backend=backend, default_precision='fixed<32,16>' + ) + + hls_model = hls4ml.converters.convert_from_onnx_model( + model, + output_dir=str(test_root_path / f'hls4mlprj_qonnx_tiny_unet_model_{backend}'), + io_type='io_stream', + backend=backend, + hls_config=config, + ) + hls_model.compile() + y_hls4ml = hls_model.predict(np.ascontiguousarray(X)) + + np.testing.assert_array_equal(y_qonnx.ravel(), y_hls4ml.ravel()) + + @pytest.mark.parametrize( 'model_name', [