diff --git a/hls4ml/model/optimizer/passes/quant_opt.py b/hls4ml/model/optimizer/passes/quant_opt.py index e2e3d1401..04d539374 100644 --- a/hls4ml/model/optimizer/passes/quant_opt.py +++ b/hls4ml/model/optimizer/passes/quant_opt.py @@ -167,8 +167,8 @@ def match(self, node): scale_unit_or_po2 = (scale == np.ones_like(scale)).all() if not scale_unit_or_po2 and _ALSO_MATCH_PO2: # This optimization only works if all scales are the same - if np.all(scale[0] == scale): - mantissa, _ = np.frexp(scale[0]) + if np.all(scale.item(0) == scale): + mantissa, _ = np.frexp(scale.item(0)) scale_unit_or_po2 = mantissa == 0.5 is_match = scale_unit_or_po2 @@ -187,7 +187,7 @@ def transform(self, model, node): integer = bitwidth scale = node.get_attr('scale') if _ALSO_MATCH_PO2 and not (scale == np.ones_like(scale)).all(): - _, exp = np.frexp(scale[0]) # know that np.all(scale[0] == scale) must be true + _, exp = np.frexp(scale.item(0)) # know that np.all(scale.item(0) == scale) must be true integer = bitwidth + exp - 1 precision, quantizer = _calculate_precision_quantizer(bitwidth, integer, signed, narrow, rounding_mode)