Skip to content

Commit

Permalink
Merge pull request #942 from xmos/feature/int16-sum
Browse files Browse the repository at this point in the history
Feature/int16 sum
  • Loading branch information
panickal-xmos authored Nov 19, 2024
2 parents ded91c9 + d8a11ed commit f60ef11
Show file tree
Hide file tree
Showing 38 changed files with 257 additions and 39 deletions.
1 change: 1 addition & 0 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def dailyDeviceTest = {
runPytestDevice("8x8/test_concatenate", "-n 1 --tc 1", "concat_1")
runPytestDevice("8x8/test_concatenate", "-n 1 --tc 5", "concat_5")
runPytestDevice("8x8/test_mean", "-n 1 --tc 1", "mean_1")
runPytestDevice("16x8/test_mean", "-n 1 --tc 1", "16x8_mean_1")
runPytestDevice("8x8/test_lstm", "-n 1 --tc 1", "lstm_1")
runPytestDevice("8x8/test_lstm", "-n 1", "lstm_5")
runPytestDevice("complex_models/8x8/test_cnn_classifier", "-n 1 --tc 1", "cnn_classifier_1")
Expand Down
1 change: 1 addition & 0 deletions integration_tests/models/16x8/test_mean/params.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
MAX_ABS_ERROR: 1.0
Binary file not shown.
5 changes: 5 additions & 0 deletions integration_tests/models/16x8/test_mean/test_mean_1.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
func.func @main(%arg0: tensor<1x5x8x16x!quant.uniform<i8:f32, 0.0078426999971270561:-1>> {tf_saved_model.index_path = ["input_2"]}) -> (tensor<1x5x1x16x!quant.uniform<i8:f32, 0.0078426999971270561:-1>> {tf_saved_model.index_path = ["tf.mean_1"]}) attributes {tf.entry_function = {inputs = "serving_default_input_2:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} {
%0 = "tfl.pseudo_qconst"() {qtype = tensor<1xi32>, value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
%1 = "tfl.mean"(%arg0, %0) {keep_dims = true} : (tensor<1x5x8x16x!quant.uniform<i8:f32, 0.0078426999971270561:-1>>, tensor<1xi32>) -> tensor<1x5x1x16x!quant.uniform<i8:f32, 0.0078426999971270561:-1>>
return %1 : tensor<1x5x1x16x!quant.uniform<i8:f32, 0.0078426999971270561:-1>>
}
Binary file not shown.
6 changes: 6 additions & 0 deletions integration_tests/models/16x8/test_mean/test_mean_10.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
// This test reduces the 2nd and 3rd axes of a 4D tensor with consecutive axes and keep_dims = true.
func.func @main(%arg0: tensor<8x5x10x12x!quant.uniform<i8:f32, 0.008:2>>) -> (tensor<8x1x1x12x!quant.uniform<i8:f32, 0.008:2>>) {
%0 = "tfl.pseudo_qconst"() {qtype = tensor<2xi32>, value = dense<[1, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
%1 = "tfl.mean"(%arg0, %0) {keep_dims = true} : (tensor<8x5x10x12x!quant.uniform<i8:f32, 0.008:2>>, tensor<2xi32>) -> tensor<8x1x1x12x!quant.uniform<i8:f32, 0.008:2>>
return %1 : tensor<8x1x1x12x!quant.uniform<i8:f32, 0.008:2>>
}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
// This test reduces the 2nd and 3rd axes of a 4D tensor with consecutive axes and keep_dims = true.
func.func @main(%arg0: tensor<8x5x10x12x!quant.uniform<i16:f32, 0.008>>) -> (tensor<8x1x1x12x!quant.uniform<i16:f32, 0.008>>) {
%0 = "tfl.pseudo_qconst"() {qtype = tensor<2xi32>, value = dense<[1, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
%1 = "tfl.mean"(%arg0, %0) {keep_dims = true} : (tensor<8x5x10x12x!quant.uniform<i16:f32, 0.008>>, tensor<2xi32>) -> tensor<8x1x1x12x!quant.uniform<i16:f32, 0.008>>
return %1 : tensor<8x1x1x12x!quant.uniform<i16:f32, 0.008>>
}
6 changes: 6 additions & 0 deletions integration_tests/models/16x8/test_mean/test_mean_11.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
// This test reduces the 2nd and 3rd axes of a 4D tensor with consecutive axes and keep_dims = true.
func.func @main(%arg0: tensor<8x5x10x12x!quant.uniform<i8:f32, 0.008:2>>) -> (tensor<8x1x1x12x!quant.uniform<i8:f32, 0.008:2>>) {
%0 = "tfl.pseudo_qconst"() {qtype = tensor<2xi32>, value = dense<[1, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
%1 = "tfl.mean"(%arg0, %0) {keep_dims = true} : (tensor<8x5x10x12x!quant.uniform<i8:f32, 0.008:2>>, tensor<2xi32>) -> tensor<8x1x1x12x!quant.uniform<i8:f32, 0.008:2>>
return %1 : tensor<8x1x1x12x!quant.uniform<i8:f32, 0.008:2>>
}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
// This test reduces the 2nd and 3rd axes of a 4D tensor with consecutive axes and keep_dims = true.
func.func @main(%arg0: tensor<8x5x10x12x!quant.uniform<i16:f32, 0.008>>) -> (tensor<8x1x1x12x!quant.uniform<i16:f32, 0.008>>) {
%0 = "tfl.pseudo_qconst"() {qtype = tensor<2xi32>, value = dense<[1, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
%1 = "tfl.mean"(%arg0, %0) {keep_dims = true} : (tensor<8x5x10x12x!quant.uniform<i16:f32, 0.008>>, tensor<2xi32>) -> tensor<8x1x1x12x!quant.uniform<i16:f32, 0.008>>
return %1 : tensor<8x1x1x12x!quant.uniform<i16:f32, 0.008>>
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
func.func @main(%arg0: tensor<1x5x8x16x!quant.uniform<i16:f32, 0.0078426999971270561>> {tf_saved_model.index_path = ["input_2"]}) -> (tensor<1x5x1x16x!quant.uniform<i16:f32, 0.0078426999971270561>> {tf_saved_model.index_path = ["tf.mean_1"]}) attributes {tf.entry_function = {inputs = "serving_default_input_2:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} {
%0 = "tfl.pseudo_qconst"() {qtype = tensor<1xi32>, value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
%1 = "tfl.mean"(%arg0, %0) {keep_dims = true} : (tensor<1x5x8x16x!quant.uniform<i16:f32, 0.0078426999971270561>>, tensor<1xi32>) -> tensor<1x5x1x16x!quant.uniform<i16:f32, 0.0078426999971270561>>
return %1 : tensor<1x5x1x16x!quant.uniform<i16:f32, 0.0078426999971270561>>
}
6 changes: 6 additions & 0 deletions integration_tests/models/16x8/test_mean/test_mean_2.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
// This test reduces the 3rd axis of a 4D tensor without keeping dimensions.
func.func @main(%arg0: tensor<2x3x4x5x!quant.uniform<i8:f32, 0.005:-128>>) -> (tensor<2x3x5x!quant.uniform<i8:f32, 0.006:-127>>) {
%0 = "tfl.pseudo_qconst"() {qtype = tensor<1xi32>, value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
%1 = "tfl.mean"(%arg0, %0) {keep_dims = false} : (tensor<2x3x4x5x!quant.uniform<i8:f32, 0.005:-128>>, tensor<1xi32>) -> tensor<2x3x5x!quant.uniform<i8:f32, 0.006:-127>>
return %1 : tensor<2x3x5x!quant.uniform<i8:f32, 0.006:-127>>
}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
// This test reduces the 3rd axis of a 4D tensor without keeping dimensions.
func.func @main(%arg0: tensor<2x3x4x5x!quant.uniform<i16:f32, 0.005>>) -> (tensor<2x3x5x!quant.uniform<i16:f32, 0.006>>) {
%0 = "tfl.pseudo_qconst"() {qtype = tensor<1xi32>, value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
%1 = "tfl.mean"(%arg0, %0) {keep_dims = false} : (tensor<2x3x4x5x!quant.uniform<i16:f32, 0.005>>, tensor<1xi32>) -> tensor<2x3x5x!quant.uniform<i16:f32, 0.006>>
return %1 : tensor<2x3x5x!quant.uniform<i16:f32, 0.006>>
}
6 changes: 6 additions & 0 deletions integration_tests/models/16x8/test_mean/test_mean_3.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
// This test reduces the 2nd and 4th axes of a 5D tensor while keeping dimensions.
func.func @main(%arg0: tensor<4x3x5x7x6x!quant.uniform<i8:f32, 0.0045:0>>) -> (tensor<4x1x5x1x6x!quant.uniform<i8:f32, 0.0045:0>>) {
%0 = "tfl.pseudo_qconst"() {qtype = tensor<2xi32>, value = dense<[1, 3]> : tensor<2xi32>} : () -> tensor<2xi32>
%1 = "tfl.mean"(%arg0, %0) {keep_dims = true} : (tensor<4x3x5x7x6x!quant.uniform<i8:f32, 0.0045:0>>, tensor<2xi32>) -> tensor<4x1x5x1x6x!quant.uniform<i8:f32, 0.0045:0>>
return %1 : tensor<4x1x5x1x6x!quant.uniform<i8:f32, 0.0045:0>>
}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
// This test reduces the 2nd and 4th axes of a 5D tensor while keeping dimensions.
func.func @main(%arg0: tensor<4x3x5x7x6x!quant.uniform<i16:f32, 0.0045>>) -> (tensor<4x1x5x1x6x!quant.uniform<i16:f32, 0.0045>>) {
%0 = "tfl.pseudo_qconst"() {qtype = tensor<2xi32>, value = dense<[1, 3]> : tensor<2xi32>} : () -> tensor<2xi32>
%1 = "tfl.mean"(%arg0, %0) {keep_dims = true} : (tensor<4x3x5x7x6x!quant.uniform<i16:f32, 0.0045>>, tensor<2xi32>) -> tensor<4x1x5x1x6x!quant.uniform<i16:f32, 0.0045>>
return %1 : tensor<4x1x5x1x6x!quant.uniform<i16:f32, 0.0045>>
}
6 changes: 6 additions & 0 deletions integration_tests/models/16x8/test_mean/test_mean_4.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
// This test reduces the 1st axis of a 3D tensor without keeping dimensions.
func.func @main(%arg0: tensor<10x20x30x!quant.uniform<i8:f32, 0.003:-5>>) -> (tensor<20x30x!quant.uniform<i8:f32, 0.003:-5>>) {
%0 = "tfl.pseudo_qconst"() {qtype = tensor<1xi32>, value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
%1 = "tfl.mean"(%arg0, %0) {keep_dims = false} : (tensor<10x20x30x!quant.uniform<i8:f32, 0.003:-5>>, tensor<1xi32>) -> tensor<20x30x!quant.uniform<i8:f32, 0.003:-5>>
return %1 : tensor<20x30x!quant.uniform<i8:f32, 0.003:-5>>
}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
// This test reduces the 1st axis of a 3D tensor without keeping dimensions.
func.func @main(%arg0: tensor<10x20x30x!quant.uniform<i16:f32, 0.003>>) -> (tensor<20x30x!quant.uniform<i16:f32, 0.003>>) {
%0 = "tfl.pseudo_qconst"() {qtype = tensor<1xi32>, value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
%1 = "tfl.mean"(%arg0, %0) {keep_dims = false} : (tensor<10x20x30x!quant.uniform<i16:f32, 0.003>>, tensor<1xi32>) -> tensor<20x30x!quant.uniform<i16:f32, 0.003>>
return %1 : tensor<20x30x!quant.uniform<i16:f32, 0.003>>
}
6 changes: 6 additions & 0 deletions integration_tests/models/16x8/test_mean/test_mean_5.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
// This test reduces all axes of a 2D tensor while keeping dimensions.
func.func @main(%arg0: tensor<5x7x!quant.uniform<i8:f32, 0.002:-3>>) -> (tensor<1x1x!quant.uniform<i8:f32, 0.002:-3>>) {
%0 = "tfl.pseudo_qconst"() {qtype = tensor<2xi32>, value = dense<[0, 1]> : tensor<2xi32>} : () -> tensor<2xi32>
%1 = "tfl.mean"(%arg0, %0) {keep_dims = true} : (tensor<5x7x!quant.uniform<i8:f32, 0.002:-3>>, tensor<2xi32>) -> tensor<1x1x!quant.uniform<i8:f32, 0.002:-3>>
return %1 : tensor<1x1x!quant.uniform<i8:f32, 0.002:-3>>
}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
// This test reduces all axes of a 2D tensor while keeping dimensions.
func.func @main(%arg0: tensor<5x7x!quant.uniform<i16:f32, 0.002>>) -> (tensor<1x1x!quant.uniform<i16:f32, 0.002>>) {
%0 = "tfl.pseudo_qconst"() {qtype = tensor<2xi32>, value = dense<[0, 1]> : tensor<2xi32>} : () -> tensor<2xi32>
%1 = "tfl.mean"(%arg0, %0) {keep_dims = true} : (tensor<5x7x!quant.uniform<i16:f32, 0.002>>, tensor<2xi32>) -> tensor<1x1x!quant.uniform<i16:f32, 0.002>>
return %1 : tensor<1x1x!quant.uniform<i16:f32, 0.002>>
}
6 changes: 6 additions & 0 deletions integration_tests/models/16x8/test_mean/test_mean_6.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
// This test reduces a 1D tensor to a scalar.
func.func @main(%arg0: tensor<15x!quant.uniform<i8:f32, 0.009:0>>) -> (tensor<!quant.uniform<i8:f32, 0.009:0>>) {
%0 = "tfl.pseudo_qconst"() {qtype = tensor<1xi32>, value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
%1 = "tfl.mean"(%arg0, %0) {keep_dims = false} : (tensor<15x!quant.uniform<i8:f32, 0.009:0>>, tensor<1xi32>) -> tensor<!quant.uniform<i8:f32, 0.009:0>>
return %1 : tensor<!quant.uniform<i8:f32, 0.009:0>>
}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
// This test reduces a 1D tensor to a scalar.
func.func @main(%arg0: tensor<15x!quant.uniform<i16:f32, 0.009>>) -> (tensor<!quant.uniform<i16:f32, 0.009>>) {
%0 = "tfl.pseudo_qconst"() {qtype = tensor<1xi32>, value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
%1 = "tfl.mean"(%arg0, %0) {keep_dims = false} : (tensor<15x!quant.uniform<i16:f32, 0.009>>, tensor<1xi32>) -> tensor<!quant.uniform<i16:f32, 0.009>>
return %1 : tensor<!quant.uniform<i16:f32, 0.009>>
}
6 changes: 6 additions & 0 deletions integration_tests/models/16x8/test_mean/test_mean_7.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
// This test reduces the 2nd and 3rd axes of a 3D tensor with different input/output quantization parameters.
func.func @main(%arg0: tensor<5x6x7x!quant.uniform<i8:f32, 0.004:-2>>) -> (tensor<5x!quant.uniform<i8:f32, 0.0035:-1>>) {
%0 = "tfl.pseudo_qconst"() {qtype = tensor<2xi32>, value = dense<[1, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
%1 = "tfl.mean"(%arg0, %0) {keep_dims = false} : (tensor<5x6x7x!quant.uniform<i8:f32, 0.004:-2>>, tensor<2xi32>) -> tensor<5x!quant.uniform<i8:f32, 0.0035:-1>>
return %1 : tensor<5x!quant.uniform<i8:f32, 0.0035:-1>>
}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
// This test reduces the 2nd and 3rd axes of a 3D tensor with different input/output quantization parameters.
func.func @main(%arg0: tensor<5x6x7x!quant.uniform<i16:f32, 0.004>>) -> (tensor<5x!quant.uniform<i16:f32, 0.0035>>) {
%0 = "tfl.pseudo_qconst"() {qtype = tensor<2xi32>, value = dense<[1, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
%1 = "tfl.mean"(%arg0, %0) {keep_dims = false} : (tensor<5x6x7x!quant.uniform<i16:f32, 0.004>>, tensor<2xi32>) -> tensor<5x!quant.uniform<i16:f32, 0.0035>>
return %1 : tensor<5x!quant.uniform<i16:f32, 0.0035>>
}
53 changes: 53 additions & 0 deletions integration_tests/models/16x8/test_mean/translate_mlir.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import os
import re


def translate_int8_to_int16_mlir(mlir_code):
"""
Translates MLIR code from using int8 quantization to int16 quantization, removing any zero-point specification.
Args:
mlir_code (str): The original MLIR code string with int8 quantization.
Returns:
str: Translated MLIR code with int16 quantization.
"""
# Step 1: Replace int8 quantization with int16 quantization in tensor types
mlir_code = re.sub(
r"!quant\.uniform<i8:f32, ([^>]+)>", r"!quant.uniform<i16:f32, \1>", mlir_code
)

# Step 2: Remove any zero-point by eliminating it from the parameter list
mlir_code = re.sub(
r"!quant\.uniform<i16:f32, ([^,:]+):[^>]+>",
r"!quant.uniform<i16:f32, \1>",
mlir_code,
)

return mlir_code


def process_mlir_files_in_directory():
"""
Processes all .mlir files in the current directory by translating int8 quantization to int16,
and saving the output to a new file with the _int16.mlir suffix.
"""
for filename in os.listdir("."):
if not filename.endswith(".mlir") or filename.endswith("_int16.mlir"):
continue
with open(filename, "r") as file:
mlir_code = file.read()

# Translate the MLIR code
translated_mlir_code = translate_int8_to_int16_mlir(mlir_code)

# Save the translated code to a new file
new_filename = f"{os.path.splitext(filename)[0]}_int16.mlir"
with open(new_filename, "w") as new_file:
new_file.write(translated_mlir_code)
print(f"Processed {filename} -> {new_filename}")


# Execute the script
if __name__ == "__main__":
process_mlir_files_in_directory()
2 changes: 1 addition & 1 deletion third_party/lib_nn
17 changes: 17 additions & 0 deletions xformer/IR/XCoreOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,23 @@ def XC_MeanOp : XC_Op<"mean", [Pure]> {
let results = (outs TensorOf<[QI8]> : $output);
}

def XC_MeanI16Op : XC_Op<"meani16", [Pure]> {
let summary = "Mean int16 op";

let description = [{Mean int16 op.}];

let arguments = (ins
TensorOf<[QI16]>:$input,

I32Attr:$start,
I32Attr:$mean,
I32Attr:$end,
F32Attr:$scale_mul
);

let results = (outs TensorOf<[QI16]> : $output);
}

def XC_MulOp : XC_Op<"mul", [Pure, XC_MemoryOverlappable]> {
let summary = "Mul op";

Expand Down
53 changes: 33 additions & 20 deletions xformer/Transforms/ReplaceMean.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
// Copyright 2021 XMOS LIMITED. This Software is subject to the terms of the
// XMOS Public License: Version 1

#include "IR/XCoreOps.h"
#include "Utils/Util.h"

Expand All @@ -16,7 +13,7 @@ extern "C" {
namespace mlir::xcore {

namespace {
// Replace TFL Mean with Mean for XCore.
// Replace TFL Mean with Mean or Mean16 for XCore.
struct ReplaceMean
: public PassWrapper<ReplaceMean, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ReplaceMean)
Expand All @@ -26,7 +23,7 @@ struct ReplaceMean
}
StringRef getArgument() const final { return "xcore-replace-mean"; }
StringRef getDescription() const final {
return "Replace TFL Mean with Mean for XCore.";
return "Replace TFL Mean with Mean or Mean16 for XCore.";
}
void runOnOperation() override;
};
Expand All @@ -41,7 +38,9 @@ struct ReplaceMeanPattern : public OpRewritePattern<TFL::MeanOp> {
auto output = meanOp.getOutput();

DenseElementsAttr axisAttr;
matchPattern(meanOp.getAxis(), m_Constant(&axisAttr));
if (!matchPattern(meanOp.getAxis(), m_Constant(&axisAttr))) {
return failure();
}
auto axisValues = axisAttr.getValues<int32_t>();
std::vector<int32_t> axis(axisValues.begin(), axisValues.end());
int32_t minAxis = *std::min_element(axis.begin(), axis.end());
Expand All @@ -52,14 +51,19 @@ struct ReplaceMeanPattern : public OpRewritePattern<TFL::MeanOp> {

auto inputType = input.getType().cast<ShapedType>();
auto outputType = output.getType().cast<ShapedType>();
if (!utils::isNBitSignedQType<8>(inputType.getElementType()) ||
!utils::isNBitSignedQType<8>(outputType.getElementType())) {

// Check if input and output are either int8 or int16.
bool isInt8 = utils::isNBitSignedQType<8>(inputType.getElementType()) &&
utils::isNBitSignedQType<8>(outputType.getElementType());

bool isInt16 = utils::isNBitSignedQType<16>(inputType.getElementType()) &&
utils::isNBitSignedQType<16>(outputType.getElementType());

if (!(isInt8 || isInt16)) {
return failure();
}

auto inputShape = inputType.getShape();
auto outputShape = outputType.getShape();

int rank = inputShape.size();

int beginDims = 1;
Expand All @@ -80,23 +84,32 @@ struct ReplaceMeanPattern : public OpRewritePattern<TFL::MeanOp> {
auto inputQType = utils::getQType(input);
auto outputQType = utils::getQType(output);

float inZeroPoint = static_cast<float>(inputQType.getZeroPoint());
float outZeroPoint = static_cast<float>(outputQType.getZeroPoint());
float scaleMul = inputQType.getScale() / outputQType.getScale() /
static_cast<float>(meanDims);
auto scaleMulAttr = rewriter.getF32FloatAttr(scaleMul);

auto beginDimsAttr = rewriter.getI32IntegerAttr(beginDims);
auto endDimsAttr = rewriter.getI32IntegerAttr(endDims);
auto meanDimsAttr = rewriter.getI32IntegerAttr(meanDims);
auto inZeroPointAttr = rewriter.getF32FloatAttr(inZeroPoint);
auto outZeroPointAttr = rewriter.getF32FloatAttr(outZeroPoint);
auto scaleMulAttr = rewriter.getF32FloatAttr(scaleMul);

auto xcMeanOp = rewriter.create<MeanOp>(
meanOp.getLoc(), meanOp.getType(), meanOp.getInput(), beginDimsAttr,
meanDimsAttr, endDimsAttr, inZeroPointAttr, outZeroPointAttr,
scaleMulAttr);
rewriter.replaceOp(meanOp, xcMeanOp.getOutput());
if (isInt8) {
float inZeroPoint = static_cast<float>(inputQType.getZeroPoint());
float outZeroPoint = static_cast<float>(outputQType.getZeroPoint());
auto inZeroPointAttr = rewriter.getF32FloatAttr(inZeroPoint);
auto outZeroPointAttr = rewriter.getF32FloatAttr(outZeroPoint);

auto xcMeanOp = rewriter.create<MeanOp>(
meanOp.getLoc(), meanOp.getType(), meanOp.getInput(), beginDimsAttr,
meanDimsAttr, endDimsAttr, inZeroPointAttr, outZeroPointAttr,
scaleMulAttr);
rewriter.replaceOp(meanOp, xcMeanOp.getOutput());
} else { // isInt16
// Zero points are always zero for int16 and are not passed to Mean16Op.
auto xcMeanOp = rewriter.create<MeanI16Op>(
meanOp.getLoc(), meanOp.getType(), meanOp.getInput(), beginDimsAttr,
meanDimsAttr, endDimsAttr, scaleMulAttr);
rewriter.replaceOp(meanOp, xcMeanOp.getOutput());
}

return success();
}
Expand Down
Loading

0 comments on commit f60ef11

Please sign in to comment.