[mlir][tosa] Update tosa.avg_pool2d for bit-exact TOSA behavior

The normalization component of average pool has a very specific
rounding behavior for compouting the division for floating
point values. Updated so that the bit-exact version is implemented.

Also includes a fix for computing the stride part of the average pool
operation.

Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D141339
This commit is contained in:
Rob Suderman
2023-01-25 17:58:03 +00:00
committed by Robert Suderman
parent 6ec446ddce
commit b67b024d58
2 changed files with 218 additions and 179 deletions

View File

@@ -817,12 +817,17 @@ public:
// Normalize the summed value by the number of elements grouped in each
// pool.
auto poolingOpTy = poolingOp.getType().cast<ShapedType>();
auto affineMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
Value iH = rewriter.create<tensor::DimOp>(loc, poolingOp, 1);
Value iW = rewriter.create<tensor::DimOp>(loc, poolingOp, 2);
auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
iH = rewriter.create<arith::SubIOp>(loc, iH, one);
iW = rewriter.create<arith::SubIOp>(loc, iW, one);
Value genericEmptyTensor = rewriter.create<tensor::EmptyOp>(
loc, resultTy.getShape(), resultETy, dynamicDims);
auto affineMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
auto genericOp = rewriter.create<linalg::GenericOp>(
loc, ArrayRef<Type>({resultTy}), ValueRange{poolingOp},
ValueRange{genericEmptyTensor},
@@ -830,60 +835,59 @@ public:
getNParallelLoopsAttrs(resultTy.getRank()),
[&](OpBuilder &b, Location loc, ValueRange args) {
auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
auto iH = rewriter.create<arith::ConstantIndexOp>(
loc, poolingOpTy.getDimSize(1) - 1);
auto iW = rewriter.create<arith::ConstantIndexOp>(
loc, poolingOpTy.getDimSize(2) - 1);
// Compute the indices from either end.
auto y0 = rewriter.create<linalg::IndexOp>(loc, 1);
auto x0 = rewriter.create<linalg::IndexOp>(loc, 2);
auto y1 = rewriter.create<arith::SubIOp>(loc, iH, y0);
auto x1 = rewriter.create<arith::SubIOp>(loc, iW, x0);
// Determines what the portion of valid input is covered by the
// kernel.
auto padFn = [&](Value v, Value x, int64_t pad) -> Value {
auto padFn = [&](Value valid, Value pos, int64_t pad) -> Value {
if (pad == 0)
return v;
return valid;
auto padVal = rewriter.create<arith::ConstantIndexOp>(loc, pad);
Value dx = rewriter.create<arith::SubIOp>(loc, x, padVal);
Value dpos = rewriter.create<arith::SubIOp>(loc, pos, padVal);
Value cmp = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, dx, zero);
Value offset = rewriter.create<arith::SelectOp>(loc, cmp, dx, zero);
return rewriter.create<arith::AddIOp>(loc, v, offset)->getResult(0);
loc, arith::CmpIPredicate::slt, dpos, zero);
Value offset =
rewriter.create<arith::SelectOp>(loc, cmp, dpos, zero);
return rewriter.create<arith::AddIOp>(loc, valid, offset)
->getResult(0);
};
// Compute the vertical component of coverage.
auto kH0 = rewriter.create<arith::ConstantIndexOp>(loc, kernel[0]);
auto kH1 = padFn(kH0, y0, pad[2]);
auto kH2 = padFn(kH1, y1, pad[3]);
auto kHCmp = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, kH2, one);
auto kH3 = rewriter.create<arith::SelectOp>(loc, kHCmp, one, kH2);
auto coverageFn = [&](int64_t i, Value isize) -> Value {
Value strideVal =
rewriter.create<arith::ConstantIndexOp>(loc, stride[i - 1]);
Value val =
rewriter.create<arith::ConstantIndexOp>(loc, kernel[i - 1]);
// compute the horizontal component of coverage.
auto kW0 = rewriter.create<arith::ConstantIndexOp>(loc, kernel[1]);
auto kW1 = padFn(kW0, x0, pad[4]);
auto kW2 = padFn(kW1, x1, pad[5]);
auto kWCmp = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, kW2, one);
auto kW3 = rewriter.create<arith::SelectOp>(loc, kWCmp, one, kW2);
// Find the position relative to the input tensor's ends.
Value left = rewriter.create<linalg::IndexOp>(loc, i);
Value right = rewriter.create<arith::SubIOp>(loc, isize, left);
left = rewriter.create<arith::MulIOp>(loc, left, strideVal);
right = rewriter.create<arith::MulIOp>(loc, right, strideVal);
// Determine how much padding was included.
val = padFn(val, left, pad[i * 2]);
val = padFn(val, right, pad[i * 2 + 1]);
Value cmp = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, val, one);
return rewriter.create<arith::SelectOp>(loc, cmp, one, val);
};
// Compute the indices from either end.
Value kH3 = coverageFn(1, iH);
Value kW3 = coverageFn(2, iW);
// Compute the total number of elements and normalize.
Value count = rewriter.create<arith::MulIOp>(loc, kH3, kW3);
auto countI = rewriter.create<arith::IndexCastOp>(
loc, rewriter.getI32Type(), count);
auto count = rewriter.create<arith::IndexCastOp>(
loc, rewriter.getI32Type(),
rewriter.create<arith::MulIOp>(loc, kH3, kW3));
// Divide by the number of summed values. For floats this is just
// a div however for quantized values input normalization had
// to be applied.
Value poolVal = args[0];
if (accETy.isa<FloatType>()) {
auto countF = rewriter.create<arith::SIToFPOp>(loc, accETy, countI);
auto countF = rewriter.create<arith::SIToFPOp>(loc, accETy, count);
poolVal = rewriter.create<arith::DivFOp>(loc, poolVal, countF)
->getResult(0);
} else {
@@ -895,33 +899,52 @@ public:
auto inputZp = rewriter.create<arith::ConstantOp>(
loc, b.getIntegerAttr(accETy, quantizationInfo.getInputZp()));
Value offset =
rewriter.create<arith::MulIOp>(loc, accETy, countI, inputZp);
rewriter.create<arith::MulIOp>(loc, accETy, count, inputZp);
poolVal =
rewriter.create<arith::SubIOp>(loc, accETy, poolVal, offset);
}
// Compute the multiplier and shift values for the quantization
// normalization. Preferably we would want to compute more bits
// however 32-bits should be enough for compute. Honestly we
// should probably straight divide.
int64_t numerator = ((1 << 30) + 1);
int64_t shift = 30;
// Compute: k = 32 - count_leading_zeros(value - 1)
Value one32 = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI32IntegerAttr(1));
Value thirtyTwo32 = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI32IntegerAttr(32));
Value numeratorVal = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI32IntegerAttr(numerator));
Value multiplierVal =
rewriter
.create<arith::DivUIOp>(loc, rewriter.getI32Type(),
numeratorVal, countI)
.getResult();
Value shiftVal = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI8IntegerAttr(shift));
Value countSubOne =
rewriter.create<arith::SubIOp>(loc, count, one32);
Value leadingZeros =
rewriter.create<math::CountLeadingZerosOp>(loc, countSubOne);
Value k =
rewriter.create<arith::SubIOp>(loc, thirtyTwo32, leadingZeros);
// Compute: numerator = ((1 << 30) + 1) << k
Value k64 =
rewriter.create<arith::ExtUIOp>(loc, rewriter.getI64Type(), k);
Value thirtyShiftPlusOne = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI64IntegerAttr((1 << 30) + 1));
Value numerator =
rewriter.create<arith::ShLIOp>(loc, thirtyShiftPlusOne, k64);
// Compute: scale.multiplier = numerator / value;
Value count64 = rewriter.create<arith::ExtUIOp>(
loc, rewriter.getI64Type(), count);
Value multiplier =
rewriter.create<arith::DivUIOp>(loc, numerator, count64);
multiplier = rewriter.create<arith::TruncIOp>(
loc, rewriter.getI32Type(), multiplier);
// Compute: scale.shift = 30 + k
Value k8 =
rewriter.create<arith::TruncIOp>(loc, rewriter.getI8Type(), k);
Value thirty8 = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI8IntegerAttr(30));
Value shift = rewriter.create<arith::AddIOp>(loc, k8, thirty8);
auto scaled =
rewriter
.create<tosa::ApplyScaleOp>(
loc, rewriter.getI32Type(), poolVal, multiplierVal,
shiftVal, rewriter.getBoolAttr(false))
.create<tosa::ApplyScaleOp>(loc, rewriter.getI32Type(),
poolVal, multiplier, shift,
rewriter.getBoolAttr(false))
.getResult();
// If we have quantization information we need to apply output

View File

@@ -200,148 +200,164 @@ func.func @max_pool_i32(%arg0: tensor<1x6x34x62xi32>) -> () {
%0 = "tosa.max_pool2d"(%arg0) {pad = array<i64: 0, 0, 0, 0>, kernel = array<i64: 3, 3>, stride = array<i64: 1, 1>} : (tensor<1x6x34x62xi32>) -> (tensor<1x4x32x62xi32>)
return
}
// -----
// CHECK-LABEL: @avg_pool
func.func @avg_pool(%arg0: tensor<1x6x34x62xf32>) -> (tensor<1x5x33x62xf32>) {
// Initial piece computes the sum of the pooling region, with appropriate padding.
// CHECK: [[CONST:%.+]] = arith.constant 0
// CHECK: [[PAD:%.+]] = tensor.pad %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0]
// CHECK: [[CONST:%.+]] = arith.constant 0
// CHECK: [[POOLINIT:%.+]] = tensor.empty()
// CHECK: [[FILL:%.+]] = linalg.fill ins([[CONST]]{{.*}}outs([[POOLINIT]]
// CHECK: [[KERNEL:%.+]] = tensor.empty()
// CHECK: [[POOL:%.+]] = linalg.pooling_nhwc_sum {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins([[PAD]], [[KERNEL]] : tensor<1x8x36x62xf32>, tensor<4x4xf32>) outs([[FILL]] : tensor<1x5x33x62xf32>)
// CHECK: [[INIT:%.+]] = tensor.empty()
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins([[POOL]] : tensor<1x5x33x62xf32>) outs([[INIT]] : tensor<1x5x33x62xf32>)
// CHECK: ^bb0(%[[BBARG1:[a-zA-Z0-9_]+]]: f32,
// CHECK: [[ZERO:%.0]] = arith.constant 0
// CHECK: [[ONE:%.+]] = arith.constant 1
// CHECK: [[HEIGHT:%.+]] = arith.constant 4
// CHECK: [[WIDTH:%.+]] = arith.constant 32
// CHECK: [[IDX1:%.+]] = linalg.index 1
// CHECK: [[IDX2:%.+]] = linalg.index 2
// CHECK-LABEL: @avg_pool_f32
func.func @avg_pool_f32(%arg0: tensor<1x6x34x62xf32>) -> (tensor<1x5x33x62xf32>) {
// Apply padding to the input:
// CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[PAD:.+]] = tensor.pad %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0]
// CHECK: tensor.yield %[[F0]] : f32
// The large block below computes what portion of the kernel is within non-padded input.
// CHECK: [[NY:%.+]] = arith.subi [[HEIGHT]], [[IDX1]]
// CHECK: [[NX:%.+]] = arith.subi [[WIDTH]], [[IDX2]]
// CHECK: [[KH:%.+]] = arith.constant 4
// CHECK: [[PAD0:%.+]] = arith.constant 1
// CHECK: [[SUBP0:%.+]] = arith.subi [[IDX1]], [[PAD0]]
// CHECK: [[P0CMP:%.+]] = arith.cmpi slt, [[SUBP0]], [[ZERO]]
// CHECK: [[SELP0:%.+]] = arith.select [[P0CMP]], [[SUBP0]], [[ZERO]]
// CHECK: [[ADDP0:%.+]] = arith.addi [[KH]], [[SELP0]]
// CHECK: [[PAD1:%.+]] = arith.constant 1
// CHECK: [[SUBP1:%.+]] = arith.subi [[NY]], [[PAD1]]
// CHECK: [[P1CMP:%.+]] = arith.cmpi slt, [[SUBP1]], [[ZERO]]
// CHECK: [[SELP1:%.+]] = arith.select [[P1CMP]], [[SUBP1]], [[ZERO]]
// CHECK: [[ADDP1:%.+]] = arith.addi [[ADDP0]], [[SELP1]]
// CHECK: [[YCMP:%.+]] = arith.cmpi slt, [[ADDP1]], [[ONE]]
// CHECK: [[YSEL:%.+]] = arith.select [[YCMP]], [[ONE]], [[ADDP1]]
// CHECK: [[KW:%.+]] = arith.constant 4 : index
// CHECK: [[PAD2:%.+]] = arith.constant 1 : index
// CHECK: [[SUBP2:%.+]] = arith.subi [[IDX2]], [[PAD2]]
// CHECK: [[P2CMP:%.+]] = arith.cmpi slt, [[SUBP2]], [[ZERO]]
// CHECK: [[SELP2:%.+]] = arith.select [[P2CMP]], [[SUBP2]], [[ZERO]]
// CHECK: [[ADDP2:%.+]] = arith.addi [[KW]], [[SELP2]]
// CHECK: [[PAD3:%.+]] = arith.constant 1 : index
// CHECK: [[SUBP3:%.+]] = arith.subi [[NX]], [[PAD3]]
// CHECK: [[P3CMP:%.+]] = arith.cmpi slt, [[SUBP3]], [[ZERO]]
// CHECK: [[SELP3:%.+]] = arith.select [[P3CMP]], [[SUBP3]], [[ZERO]]
// CHECK: [[ADDP3:%.+]] = arith.addi [[ADDP2]], [[SELP3]]
// CHECK: [[XCMP:%.+]] = arith.cmpi slt, [[ADDP3]], [[ONE]]
// CHECK: [[XSEL:%.+]] = arith.select [[XCMP]], [[ONE]], [[ADDP3]]
// Fill the pooling target:
// CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x5x33x62xf32>
// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[F0]] : f32) outs(%[[EMPTY]] : tensor<1x5x33x62xf32>)
// Given the valid coverage of the pooling region, normalize the summation.
// CHECK: [[C:%.+]] = arith.muli [[YSEL]], [[XSEL]]
// CHECK: [[CI:%.+]] = arith.index_cast [[C]]
// CHECK: [[CF:%.+]] = arith.sitofp [[CI]]
// CHECK: [[RESULT:%.+]] = arith.divf %[[BBARG1]], [[CF]]
// CHECK: linalg.yield [[RESULT]]
// Compute the sum padding:
// CHECK: %[[KERNEL:.+]] = tensor.empty() : tensor<4x4xf32>
// CHECK: %[[POOL:.+]] = linalg.pooling_nhwc_sum
// CHECK-SAME: dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>}
// CHECK-SAME: ins(%[[PAD]], %[[KERNEL]] : tensor<1x8x36x62xf32>, tensor<4x4xf32>)
// CHECK-SAME: outs(%[[FILL]] : tensor<1x5x33x62xf32>)
// Compute dimension based constants:
// CHECK: %[[I1:.+]] = arith.constant 1 : index
// CHECK: %[[DIM1:.+]] = tensor.dim %[[POOL]], %[[I1]]
// CHECK: %[[I2:.+]] = arith.constant 2 : index
// CHECK: %[[DIM2:.+]] = tensor.dim %[[POOL]], %[[I2]]
// CHECK: %[[ONE:.+]] = arith.constant 1 : index
// CHECK: %[[HEIGHT:.+]] = arith.subi %[[DIM1]], %[[ONE]] : index
// CHECK: %[[WIDTH:.+]] = arith.subi %[[DIM2]], %[[ONE]] : index
// Divide the sum pooling by the number of summed values.
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x5x33x62xf32>
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
// CHECK-SAME: ins(%[[POOL]] : tensor<1x5x33x62xf32>)
// CHECK-SAME: outs(%[[EMPTY]] : tensor<1x5x33x62xf32>)
// CHECK: ^bb0(%[[IN:.+]]: f32, %{{.+}}: f32)
// CHECK: %[[ZERO:.+]] = arith.constant 0
// Compute how much of the height does not include padding:
// CHECK: %[[STRIDE:.+]] = arith.constant 1
// CHECK: %[[KSIZE:.+]] = arith.constant 4
// CHECK: %[[START:.+]] = linalg.index 1
// CHECK: %[[END:.+]] = arith.subi %[[HEIGHT]], %[[START]]
// CHECK: %[[SRC_START:.+]] = arith.muli %[[START]], %[[STRIDE]]
// CHECK: %[[SRC_END:.+]] = arith.muli %[[END]], %[[STRIDE]]
// CHECK: %[[PAD_START:.+]] = arith.constant 1
// CHECK: %[[START_SUB:.+]] = arith.subi %[[SRC_START]], %[[PAD_START]]
// CHECK: %[[CMP:.+]] = arith.cmpi slt, %[[START_SUB]], %[[ZERO]]
// CHECK: %[[OFFSET:.+]] = arith.select %[[CMP]], %[[START_SUB]], %[[ZERO]]
// CHECK: %[[START_OFFSET:.+]] = arith.addi %[[KSIZE]], %[[OFFSET]]
// CHECK: %[[PAD_END:.+]] = arith.constant 1
// CHECK: %[[END_SUB:.+]] = arith.subi %[[SRC_END]], %[[PAD_END]]
// CHECK: %[[CMP:.+]] = arith.cmpi slt, %[[END_SUB]], %[[ZERO]]
// CHECK: %[[OFFSET:.+]] = arith.select %[[CMP]], %[[END_SUB]], %[[ZERO]]
// CHECK: %[[END_OFFSET:.+]] = arith.addi %[[START_OFFSET]], %[[OFFSET]]
// CHECK: %[[CMP:.+]] = arith.cmpi slt, %[[END_OFFSET]], %[[ONE]]
// CHECK: %[[KHEIGHT:.+]] = arith.select %[[CMP]], %[[ONE]], %[[END_OFFSET]]
// Compute how much of the width does not include padding:
// CHECK: %[[STRIDE:.+]] = arith.constant 1
// CHECK: %[[KSIZE:.+]] = arith.constant 4
// CHECK: %[[START:.+]] = linalg.index 2
// CHECK: %[[END:.+]] = arith.subi %[[WIDTH]], %[[START]]
// CHECK: %[[SRC_START:.+]] = arith.muli %[[START]], %[[STRIDE]]
// CHECK: %[[SRC_END:.+]] = arith.muli %[[END]], %[[STRIDE]]
// CHECK: %[[PAD_START:.+]] = arith.constant 1
// CHECK: %[[START_SUB:.+]] = arith.subi %[[SRC_START]], %[[PAD_START]]
// CHECK: %[[CMP:.+]] = arith.cmpi slt, %[[START_SUB]], %[[ZERO]]
// CHECK: %[[OFFSET:.+]] = arith.select %[[CMP]], %[[START_SUB]], %[[ZERO]]
// CHECK: %[[START_OFFSET:.+]] = arith.addi %[[KSIZE]], %[[OFFSET]]
// CHECK: %[[PAD_END:.+]] = arith.constant 1
// CHECK: %[[END_SUB:.+]] = arith.subi %[[SRC_END]], %[[PAD_END]]
// CHECK: %[[CMP:.+]] = arith.cmpi slt, %[[END_SUB]], %[[ZERO]]
// CHECK: %[[OFFSET:.+]] = arith.select %[[CMP]], %[[END_SUB]], %[[ZERO]]
// CHECK: %[[END_OFFSET:.+]] = arith.addi %[[START_OFFSET]], %[[OFFSET]]
// CHECK: %[[CMP:.+]] = arith.cmpi slt, %[[END_OFFSET]], %[[ONE]]
// CHECK: %[[KWIDTH:.+]] = arith.select %[[CMP]], %[[ONE]], %[[END_OFFSET]]
// Divide the summed value by the number of values summed.
// CHECK: %[[COUNT:.+]] = arith.muli %[[KHEIGHT]], %[[KWIDTH]]
// CHECK: %[[CAST:.+]] = arith.index_cast %[[COUNT]]
// CHECK: %[[FLT:.+]] = arith.sitofp %[[CAST]]
// CHECK: %[[DIV:.+]] = arith.divf %[[IN]], %[[FLT]]
// CHECK: linalg.yield %[[DIV]]
%0 = "tosa.avg_pool2d"(%arg0) {pad = array<i64: 1, 1, 1, 1>, kernel = array<i64: 4, 4>, stride = array<i64: 1, 1>} : (tensor<1x6x34x62xf32>) -> (tensor<1x5x33x62xf32>)
return %0 : tensor<1x5x33x62xf32>
}
// -----
// CHECK-LABLE: @avg_pool_i8
func.func @avg_pool_i8(%arg0: tensor<1x6x34x62xi8>) -> (tensor<1x5x33x62xi8>) {
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
// CHECK-SAME: ins(%[[POOL]] : tensor<1x5x33x62xi32>)
// CHECK-SAME: outs(%[[EMPTY]] : tensor<1x5x33x62xi8>)
// CHECK: ^bb0(%[[IN:.+]]: i32, %{{.+}}: i8)
// Only different behavior is how the division is performed.
// First we compute the mul and shift values for average pool:
// CHECK: %[[COUNT:.+]] = arith.muli %21, %35
// CHECK: %[[ICAST:.+]] = arith.index_cast %[[COUNT]]
// CHECK: %[[C1:.+]] = arith.constant 1
// CHECK: %[[C32:.+]] = arith.constant 32
// CHECK: %[[ISUB:.+]] = arith.subi %[[ICAST]], %[[C1]]
// CHECK: %[[CTLZ:.+]] = math.ctlz %[[ISUB]]
// CHECK: %[[SUB:.+]] = arith.subi %[[C32]], %[[CTLZ]]
// CHECK: %[[EXT:.+]] = arith.extui %[[SUB]]
// CHECK: %[[CBIG:.+]] = arith.constant 1073741825
// CHECK: %[[SHL:.+]] = arith.shli %[[CBIG]], %[[EXT]]
// CHECK: %[[IEXT:.+]] = arith.extui %[[ICAST]]
// CHECK: %[[DIV:.+]] = arith.divui %[[SHL]], %[[IEXT]]
// CHECK: %[[TRUNC_MUL:.+]] = arith.trunci %[[DIV]]
// CHECK: %[[TRUNC_SHIFT:.+]] = arith.trunci %[[SUB]]
// CHECK: %[[C30:.+]] = arith.constant 30
// CHECK: %[[SHIFT:.+]] = arith.addi %[[TRUNC_SHIFT]], %[[C30]] : i8
// CHECK: %[[SCALED:.+]] = "tosa.apply_scale"(%[[IN]], %[[TRUNC_MUL]], %[[SHIFT]]) {double_round = false}
// Perform the normalization.
// CHECK: %[[CMIN:.+]] = arith.constant -128
// CHECK: %[[CMAX:.+]] = arith.constant 127
// CHECK: %[[CMP:.+]] = arith.cmpi slt, %[[SCALED]], %[[CMIN]]
// CHECK: %[[SEL:.+]] = arith.select %[[CMP]], %[[CMIN]], %[[SCALED]]
// CHECK: %[[CMP:.+]] = arith.cmpi slt, %[[CMAX]], %[[SCALED]]
// CHECK: %[[CLAMP:.+]] = arith.select %[[CMP]], %[[CMAX]], %[[SEL]]
// CHECK: %[[TRUNC:.+]] = arith.trunci %[[CLAMP]]
// CHECK: linalg.yield %[[TRUNC]]
%0 = "tosa.avg_pool2d"(%arg0) {pad = array<i64: 1, 1, 1, 1>, kernel = array<i64: 4, 4>, stride = array<i64: 1, 1>} : (tensor<1x6x34x62xi8>) -> (tensor<1x5x33x62xi8>)
return %0 : tensor<1x5x33x62xi8>
}
// -----
// CHECK-LABEL: @avg_pool_dyn
func.func @avg_pool_dyn(%arg0: tensor<?x6x34x62xf32>) -> (tensor<?x5x33x62xf32>) {
// The calculations remain the same as above, only testing for dyn behavior
// CHECK: %[[C0:.+]] = arith.constant 0
// CHECK: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[BATCH:.+]] = tensor.dim %arg0, %[[C0]]
// CHECK: %[[PAD:.+]] = tensor.pad %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0]
// CHECK: %[[POOLINIT:.+]] = tensor.empty(%[[BATCH]])
// CHECK: %[[FILL:.+]] = linalg.fill
// CHECK: %[[KERNEL:.+]] = tensor.empty()
// CHECK: %[[POOL:.+]] = linalg.pooling_nhwc_sum {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%[[PAD]], %[[KERNEL]] : tensor<?x8x36x62xf32>, tensor<4x4xf32>) outs(%[[FILL]] : tensor<?x5x33x62xf32>)
// CHECK: %[[INIT:.+]] = tensor.empty(%[[BATCH]])
// CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[POOL]] : tensor<?x5x33x62xf32>) outs(%[[INIT]] : tensor<?x5x33x62xf32>)
// CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[PADDED:.+]] = tensor.pad %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0]
// CHECK: tensor.yield %[[F0]]
// CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[BATCH]]) : tensor<?x5x33x62xf32>
// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[F0]] : f32) outs(%[[EMPTY]] : tensor<?x5x33x62xf32>)
// CHECK: %[[KERNEL:.+]] = tensor.empty() : tensor<4x4xf32>
// CHECK: %[[POOL:.+]] = linalg.pooling_nhwc_sum
// CHECK-SAME: dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>
// CHECK-SAME: ins(%[[PADDED]], %[[KERNEL]] : tensor<?x8x36x62xf32>, tensor<4x4xf32>)
// CHECK-SAME: outs(%[[FILL]] : tensor<?x5x33x62xf32>) -> tensor<?x5x33x62xf32>
// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[BATCH]]) : tensor<?x5x33x62xf32>
// CHECK: %[[GENERIC:.+]] = linalg.generic
%0 = "tosa.avg_pool2d"(%arg0) {pad = array<i64: 1, 1, 1, 1>, kernel = array<i64: 4, 4>, stride = array<i64: 1, 1>} : (tensor<?x6x34x62xf32>) -> (tensor<?x5x33x62xf32>)
return %0 : tensor<?x5x33x62xf32>
}
// -----
// CHECK-LABEL: @avg_pool_i8
func.func @avg_pool_i8(%arg0 : tensor<1x128x128x2xi8>) -> () {
// CHECK: linalg.pooling_nhwc_sum
// CHECK: linalg.generic
// CHECK: ^bb0(%[[BBARG1:[a-zA-Z0-9_]+]]: i32,
// CHECK: %[[INZP:.+]] = arith.constant -128
// CHECK: %[[INZP_OFF:.+]] = arith.muli %{{.+}}, %[[INZP]]
// CHECK: %[[OFFSETED:.+]] = arith.subi %[[BBARG1]], %[[INZP_OFF]]
// CHECK: %[[NUMERATOR:.+]] = arith.constant 1073741825
// CHECK: %[[MULTIPLIER:.+]] = arith.divui %[[NUMERATOR]], %{{.+}}
// CHECK: %[[SHIFT:.+]] = arith.constant 30
// CHECK: %[[SCALE:.+]] = "tosa.apply_scale"(%{{.+}}, %[[MULTIPLIER]], %[[SHIFT]]) {double_round = false}
// CHECK: %[[OUTZP:.+]] = arith.constant -128
// CHECK: %[[OUT:.+]] = arith.addi %[[SCALE]], %[[OUTZP]]
// CHECK: %[[MIN:.+]] = arith.constant -128
// CHECK: %[[MAX:.+]] = arith.constant 127
// CHECK: %[[CMP_MIN:.+]] = arith.cmpi slt, %[[OUT]], %[[MIN]]
// CHECK: %[[CLMP_MIN:.+]] = arith.select %[[CMP_MIN]], %[[MIN]], %[[OUT]]
// CHECK: %[[CMP_MAX:.+]] = arith.cmpi slt, %[[MAX]], %[[OUT]]
// CHECK: %[[CLMP_MAX:.+]] = arith.select %[[CMP_MAX]], %[[MAX]], %[[CLMP_MIN]]
// CHECK: %[[TRUNC:.+]] = arith.trunci %[[CLMP_MAX]]
// CHECK: linalg.yield %[[TRUNC]]
%0 = "tosa.avg_pool2d"(%arg0) {kernel = array<i64: 4, 4>, pad = array<i64: 0, 0, 0, 0>, quantization_info = #tosa.unary_quant<input_zp = -128, output_zp = -128>, stride = array<i64: 4, 4>} : (tensor<1x128x128x2xi8>) -> tensor<1x32x32x2xi8>
return
}
// -----
// CHECK-LABEL: @avg_pool_i16
func.func @avg_pool_i16(%arg0 : tensor<1x128x128x2xi16>) -> () {
// CHECK: linalg.pooling_nhwc_sum
// CHECK: linalg.generic
// CHECK: ^bb0(%[[BBARG1:[a-zA-Z0-9_]+]]: i32,
// CHECK: %[[INZP:.+]] = arith.constant -128
// CHECK: %[[INZP_OFF:.+]] = arith.muli %{{.+}}, %[[INZP]]
// CHECK: %[[OFFSETED:.+]] = arith.subi %[[BBARG1]], %[[INZP_OFF]]
// CHECK: %[[NUMERATOR:.+]] = arith.constant 1073741825
// CHECK: %[[MULTIPLIER:.+]] = arith.divui %[[NUMERATOR]], %{{.+}}
// CHECK: %[[SHIFT:.+]] = arith.constant 30
// CHECK: %[[SCALE:.+]] = "tosa.apply_scale"(%{{.+}}, %[[MULTIPLIER]], %[[SHIFT]]) {double_round = false}
// CHECK: %[[OUTZP:.+]] = arith.constant -128
// CHECK: %[[OUT:.+]] = arith.addi %[[SCALE]], %[[OUTZP]]
// CHECK: %[[MIN:.+]] = arith.constant -32768
// CHECK: %[[MAX:.+]] = arith.constant 32767
// CHECK: %[[CMP_MIN:.+]] = arith.cmpi slt, %[[OUT]], %[[MIN]]
// CHECK: %[[CLMP_MIN:.+]] = arith.select %[[CMP_MIN]], %[[MIN]], %[[OUT]]
// CHECK: %[[CMP_MAX:.+]] = arith.cmpi slt, %[[MAX]], %[[OUT]]
// CHECK: %[[CLMP_MAX:.+]] = arith.select %[[CMP_MAX]], %[[MAX]], %[[CLMP_MIN]]
// CHECK: %[[TRUNC:.+]] = arith.trunci %[[CLMP_MAX]]
// CHECK: linalg.yield %[[TRUNC]]
%0 = "tosa.avg_pool2d"(%arg0) {kernel = array<i64: 4, 4>, pad = array<i64: 0, 0, 0, 0>, quantization_info = #tosa.unary_quant<input_zp = -128, output_zp = -128>, stride = array<i64: 4, 4>} : (tensor<1x128x128x2xi16>) -> tensor<1x32x32x2xi16>
return
}
// -----
// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>