From 4eb07466f2248dcd48b54654a9094c41b6b63d03 Mon Sep 17 00:00:00 2001 From: jvreca Date: Tue, 19 Nov 2024 21:19:27 +0000 Subject: [PATCH] Fixes the problem if scale is a tensor. scale[0] does not return a scalar, if scale is a multidimensional array (e.g. qonnx returns array shape 1x18) --- hls4ml/model/optimizer/passes/quant_opt.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/hls4ml/model/optimizer/passes/quant_opt.py b/hls4ml/model/optimizer/passes/quant_opt.py index e2e3d1401..04d539374 100644 --- a/hls4ml/model/optimizer/passes/quant_opt.py +++ b/hls4ml/model/optimizer/passes/quant_opt.py @@ -167,8 +167,8 @@ def match(self, node): scale_unit_or_po2 = (scale == np.ones_like(scale)).all() if not scale_unit_or_po2 and _ALSO_MATCH_PO2: # This optimization only works if all scales are the same - if np.all(scale[0] == scale): - mantissa, _ = np.frexp(scale[0]) + if np.all(scale.item(0) == scale): + mantissa, _ = np.frexp(scale.item(0)) scale_unit_or_po2 = mantissa == 0.5 is_match = scale_unit_or_po2 @@ -187,7 +187,7 @@ def transform(self, model, node): integer = bitwidth scale = node.get_attr('scale') if _ALSO_MATCH_PO2 and not (scale == np.ones_like(scale)).all(): - _, exp = np.frexp(scale[0]) # know that np.all(scale[0] == scale) must be true + _, exp = np.frexp(scale.item(0)) # know that np.all(scale.item(0) == scale) must be true integer = bitwidth + exp - 1 precision, quantizer = _calculate_precision_quantizer(bitwidth, integer, signed, narrow, rounding_mode)