[mlir][tosa] Switch zero point of avgpool2d to input variable type (#128983)

This commit changes the TOSA operator AvgPool2d's zero point attributes
to inputs to align with TOSA 1.0 spec.

Signed-off-by: Luke Hutton <luke.hutton@arm.com>
Co-authored-by: Luke Hutton <luke.hutton@arm.com>
This commit is contained in:
Tai Ly
2025-03-04 11:34:23 -06:00
committed by GitHub
parent 17bfc00f7c
commit 25a29cef31
18 changed files with 355 additions and 204 deletions

View File

@@ -5,9 +5,11 @@ profileComplianceMap = {
{{{Profile::pro_int}, {{i8T, i32T}}},
{{Profile::pro_fp}, {{fp16T, i32T}, {fp32T, i32T}}}}},
{"tosa.avg_pool2d",
{{{Profile::pro_int}, {{i8T, i32T, i8T}}},
{{{Profile::pro_int}, {{i8T, i8T, i8T, i32T, i8T}}},
{{Profile::pro_fp},
{{fp16T, fp16T, fp16T}, {fp16T, fp32T, fp16T}, {fp32T, fp32T, fp32T}}}}},
{{fp16T, fp16T, fp16T, fp16T, fp16T},
{fp16T, fp16T, fp16T, fp32T, fp16T},
{fp32T, fp32T, fp32T, fp32T, fp32T}}}}},
{"tosa.conv2d",
{{{Profile::pro_int}, {{i8T, i8T, i32T, i32T, i32T}}},
{{Profile::pro_fp},
@@ -243,10 +245,10 @@ extensionComplianceMap = {
{{Extension::fp8e5m2}, {{fp8e5m2T, i32T}}},
{{Extension::bf16}, {{bf16T, i32T}}}}},
{"tosa.avg_pool2d",
{{{Extension::int16}, {{i16T, i32T, i16T}}},
{{Extension::fp8e4m3}, {{fp8e4m3T, fp16T, fp8e4m3T}}},
{{Extension::fp8e5m2}, {{fp8e5m2T, fp16T, fp8e5m2T}}},
{{Extension::bf16}, {{bf16T, fp32T, bf16T}}}}},
{{{Extension::int16}, {{i16T, i16T, i16T, i32T, i16T}}},
{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T}}},
{{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T}}},
{{Extension::bf16}, {{bf16T, bf16T, bf16T, fp32T, bf16T}}}}},
{"tosa.conv2d",
{{{Extension::int4}, {{i8T, i4T, i32T, i32T, i32T}}},
{{Extension::int16}, {{i16T, i8T, i48T, i48T, i48T}}},

View File

@@ -79,12 +79,12 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d"> {
let arguments = (ins
Tosa_Tensor4D:$input,
Tosa_ScalarIntOrFloatTensor:$input_zp,
Tosa_ScalarIntOrFloatTensor:$output_zp,
Tosa_IntArrayAttr2:$kernel,
Tosa_IntArrayAttr2:$stride,
Tosa_IntArrayAttr4:$pad,
TypeAttrOf<Tosa_AccType>:$acc_type,
OptionalAttr<I32Attr>:$input_zp,
OptionalAttr<I32Attr>:$output_zp
TypeAttrOf<Tosa_AccType>:$acc_type
);
let results = (outs
@@ -97,6 +97,14 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d"> {
];
let builders = [Tosa_AvgPool2dOpQuantInfoBuilder];
let extraClassDeclaration = [{
FailureOr<int64_t> getInputZeroPoint();
FailureOr<int64_t> getOutputZeroPoint();
LogicalResult verifyInputZeroPoint(int64_t zp);
LogicalResult verifyOutputZeroPoint(int64_t zp);
}];
let hasVerifier = 1;
}
@@ -116,8 +124,8 @@ def Tosa_Conv2DOp : Tosa_ConvOp<"conv2d"> {
Tosa_Tensor4D:$input,
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
Tosa_Tensor1D:$bias,
Tosa_ScalarTensor:$input_zp,
Tosa_ScalarTensor:$weight_zp,
Tosa_ScalarIntOrFloatTensor:$input_zp,
Tosa_ScalarIntOrFloatTensor:$weight_zp,
Tosa_IntArrayAttr4:$pad,
Tosa_IntArrayAttr2:$stride,
@@ -136,8 +144,8 @@ def Tosa_Conv2DOp : Tosa_ConvOp<"conv2d"> {
];
let extraClassDeclaration = [{
LogicalResult getInputZeroPoint(int64_t &zp);
LogicalResult getWeightZeroPoint(int64_t &zp);
FailureOr<int64_t> getInputZeroPoint();
FailureOr<int64_t> getWeightZeroPoint();
LogicalResult verifyInputZeroPoint(int64_t zp);
LogicalResult verifyWeightZeroPoint(int64_t zp);
}];
@@ -161,8 +169,8 @@ def Tosa_Conv3DOp : Tosa_ConvOp<"conv3d"> {
Tosa_Tensor5D:$input,
TosaTensorRankOf<[Tosa_Weight], [5]>:$weight,
Tosa_Tensor1D:$bias,
Tosa_ScalarTensor:$input_zp,
Tosa_ScalarTensor:$weight_zp,
Tosa_ScalarIntOrFloatTensor:$input_zp,
Tosa_ScalarIntOrFloatTensor:$weight_zp,
Tosa_IntArrayAttr6:$pad,
Tosa_IntArrayAttr3:$stride,
@@ -181,8 +189,8 @@ def Tosa_Conv3DOp : Tosa_ConvOp<"conv3d"> {
];
let extraClassDeclaration = [{
LogicalResult getInputZeroPoint(int64_t &zp);
LogicalResult getWeightZeroPoint(int64_t &zp);
FailureOr<int64_t> getInputZeroPoint();
FailureOr<int64_t> getWeightZeroPoint();
LogicalResult verifyInputZeroPoint(int64_t zp);
LogicalResult verifyWeightZeroPoint(int64_t zp);
}];
@@ -207,8 +215,8 @@ def Tosa_DepthwiseConv2DOp : Tosa_ConvOp<"depthwise_conv2d"> {
Tosa_Tensor4D:$input,
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
Tosa_Tensor1D:$bias,
Tosa_ScalarTensor:$input_zp,
Tosa_ScalarTensor:$weight_zp,
Tosa_ScalarIntOrFloatTensor:$input_zp,
Tosa_ScalarIntOrFloatTensor:$weight_zp,
Tosa_IntArrayAttr4:$pad,
Tosa_IntArrayAttr2:$stride,
@@ -227,8 +235,8 @@ def Tosa_DepthwiseConv2DOp : Tosa_ConvOp<"depthwise_conv2d"> {
];
let extraClassDeclaration = [{
LogicalResult getInputZeroPoint(int64_t &zp);
LogicalResult getWeightZeroPoint(int64_t &zp);
FailureOr<int64_t> getInputZeroPoint();
FailureOr<int64_t> getWeightZeroPoint();
LogicalResult verifyInputZeroPoint(int64_t zp);
LogicalResult verifyWeightZeroPoint(int64_t zp);
}];
@@ -412,8 +420,8 @@ def Tosa_TransposeConv2DOp : Tosa_ConvOp<"transpose_conv2d"> {
Tosa_Tensor4D:$input,
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
Tosa_Tensor1D:$bias,
Tosa_ScalarTensor:$input_zp,
Tosa_ScalarTensor:$weight_zp,
Tosa_ScalarIntOrFloatTensor:$input_zp,
Tosa_ScalarIntOrFloatTensor:$weight_zp,
Tosa_IntArrayAttr4:$out_pad,
Tosa_IntArrayAttr2:$stride,
@@ -431,8 +439,8 @@ def Tosa_TransposeConv2DOp : Tosa_ConvOp<"transpose_conv2d"> {
];
let extraClassDeclaration = [{
LogicalResult getInputZeroPoint(int64_t &zp);
LogicalResult getWeightZeroPoint(int64_t &zp);
FailureOr<int64_t> getInputZeroPoint();
FailureOr<int64_t> getWeightZeroPoint();
LogicalResult verifyInputZeroPoint(int64_t zp);
LogicalResult verifyWeightZeroPoint(int64_t zp);
}];

View File

@@ -149,6 +149,7 @@ def Tosa_Rank0Tensor : TosaTensorRankOf<[Tosa_AnyNumber], [0]>;
def Tosa_ScalarTensor : TosaScalarTensorOf<[Tosa_AnyNumber], [1]>;
def Tosa_ScalarInt8Tensor : TosaScalarTensorOf<[Tosa_Int8], [1]>;
def Tosa_ScalarIntOrFloatTensor : TosaScalarTensorOf<[Tosa_Int, AnyFloat], [1]>;
// We include unranked tensors as a supported type for all possible tosa
// Tensors as unranked does not guarantee invalid. If unranked tensors exist

View File

@@ -260,18 +260,26 @@ public:
DenseI64ArrayAttr dilationTosaAttr = op.getDilationAttr();
// Get and verify zero points.
int64_t inputZpVal;
int64_t weightZpVal;
if (op.getInputZeroPoint(inputZpVal).failed() ||
op.getWeightZeroPoint(weightZpVal).failed())
FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
if (failed(maybeIZp))
return rewriter.notifyMatchFailure(
op, "bail out if zero points cannot statically be determined");
op, "input zero point cannot be statically determined");
if (op.verifyInputZeroPoint(inputZpVal).failed() ||
op.verifyWeightZeroPoint(weightZpVal).failed())
FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
if (failed(maybeWZp))
return rewriter.notifyMatchFailure(
op, "zero point must be zero for non-int8 integer types");
op, "weight zero point cannot be statically determined");
int64_t inputZpVal = *maybeIZp;
int64_t weightZpVal = *maybeWZp;
if (op.verifyInputZeroPoint(inputZpVal).failed())
return rewriter.notifyMatchFailure(
op, "input zero point must be zero for non-int8 integer types");
if (op.verifyWeightZeroPoint(weightZpVal).failed())
return rewriter.notifyMatchFailure(
op, "weight zero point must be zero for non-int8 integer types");
bool hasZp = (inputZpVal != 0) || (weightZpVal != 0);
@@ -448,18 +456,26 @@ public:
/*kernelSizeDims=*/{0, 1}, rewriter);
// Get and verify zero points.
int64_t inputZpVal;
int64_t weightZpVal;
if (op.getInputZeroPoint(inputZpVal).failed() ||
op.getWeightZeroPoint(weightZpVal).failed())
FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
if (failed(maybeIZp))
return rewriter.notifyMatchFailure(
op, "bail out if zero points cannot statically be determined");
op, "input zero point cannot be statically determined");
if (failed(maybeWZp))
return rewriter.notifyMatchFailure(
op, "weight zero point cannot be statically determined");
if (op.verifyInputZeroPoint(inputZpVal).failed() ||
op.verifyWeightZeroPoint(weightZpVal).failed())
int64_t inputZpVal = *maybeIZp;
int64_t weightZpVal = *maybeWZp;
if (op.verifyInputZeroPoint(inputZpVal).failed())
return rewriter.notifyMatchFailure(
op, "zero point must be zero for non-int8 integer types");
op, "input zero point must be zero for non-int8 integer types");
if (op.verifyWeightZeroPoint(weightZpVal).failed())
return rewriter.notifyMatchFailure(
op, "weight zero point must be zero for non-int8 integer types");
bool hasZp = (inputZpVal != 0) || (weightZpVal != 0);
auto weightShape = weightTy.getShape();
@@ -809,6 +825,18 @@ public:
return failure();
SmallVector<Value> dynamicDims = *dynamicDimsOr;
FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
if (failed(maybeIZp))
return rewriter.notifyMatchFailure(
op, "input zero point could not be statically determined");
if (failed(maybeOZp))
return rewriter.notifyMatchFailure(
op, "output zero point could not be statically determined");
int64_t inputZpVal = *maybeIZp;
int64_t outputZpVal = *maybeOZp;
// Apply padding as necessary.
llvm::SmallVector<int64_t> pad;
pad.resize(2, 0);
@@ -928,9 +956,9 @@ public:
// If we have quantization information we need to apply an offset
// for the input zp value.
if (op.getInputZp()) {
auto inputZp =
rewriter.create<arith::ConstantOp>(loc, op.getInputZpAttr());
if (inputZpVal != 0) {
auto inputZp = rewriter.create<arith::ConstantOp>(
loc, b.getIntegerAttr(accETy, inputZpVal));
Value offset =
rewriter.create<arith::MulIOp>(loc, accETy, count, inputZp);
poolVal =
@@ -982,9 +1010,9 @@ public:
// If we have quantization information we need to apply output
// zeropoint.
if (op.getOutputZp()) {
auto outputZp =
rewriter.create<arith::ConstantOp>(loc, op.getOutputZpAttr());
if (outputZpVal != 0) {
auto outputZp = rewriter.create<arith::ConstantOp>(
loc, b.getIntegerAttr(scaled.getType(), outputZpVal));
scaled = rewriter.create<arith::AddIOp>(loc, scaled, outputZp)
.getResult();
}

View File

@@ -321,17 +321,13 @@ static LogicalResult verifyConvOp(T op) {
<< weightEType << " and " << weightZpEType;
}
int64_t inputZpVal;
if (op.getInputZeroPoint(inputZpVal).succeeded() &&
op.verifyInputZeroPoint(inputZpVal).failed())
return op.emitOpError(
"input zero point must be zero for non-int8 integer types");
FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
if (succeeded(maybeIZp) && op.verifyInputZeroPoint(*maybeIZp).failed())
return failure();
int64_t weightZpVal;
if (op.getWeightZeroPoint(weightZpVal).succeeded() &&
op.verifyWeightZeroPoint(weightZpVal).failed())
return op.emitOpError(
"weight zero point must be zero for non-int8 integer types");
FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
if (succeeded(maybeWZp) && op.verifyWeightZeroPoint(*maybeWZp).failed())
return failure();
return success();
}
@@ -455,18 +451,10 @@ LogicalResult tosa::ArgMaxOp::verify() {
}
LogicalResult tosa::AvgPool2dOp::verify() {
auto inputType = llvm::cast<ShapedType>(getInput().getType());
auto inputETy = inputType.getElementType();
auto resultETy = llvm::cast<ShapedType>(getType()).getElementType();
if (auto quantType =
llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy))
inputETy = quantType.getStorageType();
if (auto quantType =
llvm::dyn_cast<mlir::quant::UniformQuantizedType>(resultETy))
resultETy = quantType.getStorageType();
const Type inputETy = getStorageElementTypeOrSelf(getInput().getType());
const Type resultETy = getStorageElementTypeOrSelf(getOutput().getType());
const Type inputZpETy = getStorageElementTypeOrSelf(getInputZp().getType());
const Type outputZpETy = getStorageElementTypeOrSelf(getOutputZp().getType());
auto accType = getAccType();
if (llvm::isa<IntegerType>(inputETy) && !accType.isInteger(32))
@@ -481,6 +469,24 @@ LogicalResult tosa::AvgPool2dOp::verify() {
if (inputETy.isF32() && !accType.isF32())
return emitOpError("accumulator type for f32 tensor is not f32");
if (inputETy != inputZpETy)
return emitOpError("expect both input and its zero point are the same "
"element type, got ")
<< inputETy << " and " << inputZpETy;
if (resultETy != outputZpETy)
return emitOpError("expect both output and its zero point are the same "
"element type, got ")
<< resultETy << " and " << outputZpETy;
FailureOr<int64_t> maybeIZp = getInputZeroPoint();
if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
return failure();
FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
return failure();
if ((inputETy.isF32() && resultETy.isF32()) ||
(inputETy.isF16() && resultETy.isF16()) ||
(inputETy.isBF16() && resultETy.isBF16()) ||
@@ -629,27 +635,48 @@ static void buildMatMulOpWithQuantInfo(OpBuilder &builder,
}
/// Both the tosa.avg_pool2d and unary ops use the same
/// UnaruOpQuantizationAttr but avg_pool operator has its own builder as it
/// UnaryOpQuantizationAttr but avg_pool operator has its own builder as it
/// has additional parameters not part of the unary ops.
static void
buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result,
Type outputType, Value input,
DenseArrayAttr kernel, DenseArrayAttr stride,
DenseArrayAttr pad, TypeAttr accType) {
result.addOperands(input);
const Location loc{result.location};
int64_t inputZp{0};
int64_t outputZp{0};
if (auto quantAttr =
buildUnaryOpQuantizationAttr(builder, input, outputType)) {
inputZp = quantAttr.getInputZp();
outputZp = quantAttr.getOutputZp();
}
const std::optional<Value> inputZpOp =
createZeroPointTensor(builder, loc, input.getType(), inputZp);
if (!inputZpOp) {
(void)emitError(
loc,
"Failed to create input zero point tensor for quantized AVG_POOL2D op");
}
const std::optional<Value> outputZpOp =
createZeroPointTensor(builder, loc, outputType, outputZp);
if (!outputZpOp) {
(void)emitError(loc, "Failed to create output zero point tensor for "
"quantized AVG_POOL2D op");
}
if (inputZpOp && outputZpOp) {
result.addOperands({input, inputZpOp.value(), outputZpOp.value()});
} else {
// failed to create one or more zero points above: just add input as
// operands this will trigger error in building the op because of missing
// zero points
result.addOperands({input});
}
result.addAttribute("kernel", kernel);
result.addAttribute("stride", stride);
result.addAttribute("pad", pad);
result.addAttribute("acc_type", accType);
auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
if (quantAttr) {
result.addAttribute("input_zp",
builder.getI32IntegerAttr(
static_cast<int32_t>(quantAttr.getInputZp())));
result.addAttribute("output_zp",
builder.getI32IntegerAttr(
static_cast<int32_t>(quantAttr.getOutputZp())));
}
result.types.push_back(outputType);
}
@@ -1471,77 +1498,68 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
return mlir::success();
}
// return failure if val is not a constant
// set zp to -1 if val is non-zero float or val is not integer nor float
// otherwise set zp to val's constant value
template <typename T>
static LogicalResult getZeroPoint(T op, Value val, int64_t &zp) {
static FailureOr<int64_t> getZeroPoint(T op, Value val) {
ElementsAttr zpAttr;
if (!matchPattern(val, m_Constant(&zpAttr))) {
return failure();
}
Type zpElemType = zpAttr.getElementType();
if (auto quantType =
llvm::dyn_cast<mlir::quant::UniformQuantizedType>(zpElemType)) {
zp = quantType.getZeroPoint();
return success();
}
if (llvm::isa<FloatType>(zpElemType)) {
if (!zpAttr.getValues<APFloat>()[0].isZero())
return op.emitOpError(
"non-zero zero point is not allowed for float types");
zp = 0;
return success();
if (zpAttr.getValues<APFloat>()[0].isZero()) {
return 0;
}
// return non-zero value to trigger error check
return -1;
}
if (llvm::isa<IntegerType>(zpElemType)) {
zp = zpAttr.getValues<APInt>()[0].getSExtValue();
return success();
return zpAttr.getValues<APInt>()[0].getSExtValue();
}
return op.emitOpError("zero point is not allowed for unsupported types");
// return non-zero value to trigger error check
return -1;
}
template <typename T>
static LogicalResult verifyZeroPoint(T op, Value val, int64_t &zp) {
// TODO clean it up when the entire zero point (attribute -> input tensor
// type) change is done. Remaining Matmul, Rescale, Negate, and AvgPool2D.
if constexpr (!std::is_same_v<T, Conv2DOp> && !std::is_same_v<T, Conv3DOp> &&
!std::is_same_v<T, DepthwiseConv2DOp> &&
!std::is_same_v<T, TransposeConv2DOp>)
return failure();
static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp,
const std::string &operand) {
Type zpElemType = getElementTypeOrSelf(val);
if (!zpElemType.isIntOrFloat())
return op.emitOpError("zero point is not integer or float typss");
if (!zpElemType.isInteger(8) && zp != 0)
return op.emitOpError("zero point must be zero for non-int8 integer types");
if (zp < -128 || zp > 127)
return failure();
if (!zpElemType.isInteger(8) && zp != 0) {
// convert operand to lower case for error message
std::string lower = operand;
std::transform(lower.begin(), lower.end(), lower.begin(), ::tolower);
return op.emitOpError()
<< lower << " zero point must be zero for non-int8 integer types";
}
return success();
}
#define ZERO_POINT_HELPER(OP) \
LogicalResult tosa::OP::getInputZeroPoint(int64_t &zp) { \
return getZeroPoint(*this, getInputZp(), zp); \
#define ZERO_POINT_HELPER(OP, OPERAND_NAME) \
FailureOr<int64_t> tosa::OP::get##OPERAND_NAME##ZeroPoint() { \
return getZeroPoint(*this, get##OPERAND_NAME##Zp()); \
} \
LogicalResult tosa::OP::getWeightZeroPoint(int64_t &zp) { \
return getZeroPoint(*this, getWeightZp(), zp); \
} \
LogicalResult tosa::OP::verifyInputZeroPoint(int64_t zp) { \
return verifyZeroPoint(*this, getInputZp(), zp); \
} \
LogicalResult tosa::OP::verifyWeightZeroPoint(int64_t zp) { \
return verifyZeroPoint(*this, getWeightZp(), zp); \
LogicalResult tosa::OP::verify##OPERAND_NAME##ZeroPoint(int64_t zp) { \
return verifyZeroPoint(*this, get##OPERAND_NAME##Zp(), zp, #OPERAND_NAME); \
}
ZERO_POINT_HELPER(Conv2DOp)
ZERO_POINT_HELPER(Conv3DOp)
ZERO_POINT_HELPER(DepthwiseConv2DOp)
ZERO_POINT_HELPER(TransposeConv2DOp)
ZERO_POINT_HELPER(Conv2DOp, Input)
ZERO_POINT_HELPER(Conv2DOp, Weight)
ZERO_POINT_HELPER(Conv3DOp, Input)
ZERO_POINT_HELPER(Conv3DOp, Weight)
ZERO_POINT_HELPER(DepthwiseConv2DOp, Input)
ZERO_POINT_HELPER(DepthwiseConv2DOp, Weight)
ZERO_POINT_HELPER(TransposeConv2DOp, Input)
ZERO_POINT_HELPER(TransposeConv2DOp, Weight)
ZERO_POINT_HELPER(AvgPool2dOp, Input)
ZERO_POINT_HELPER(AvgPool2dOp, Output)
#undef ZERO_POINT_HELPER
LogicalResult tosa::TransposeOp::inferReturnTypeComponents(

View File

@@ -54,18 +54,24 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
return rewriter.notifyMatchFailure(op, "unsupported type");
// Get and verify zero points.
int64_t iZp;
int64_t wZp;
if (op.getInputZeroPoint(iZp).failed() ||
op.getWeightZeroPoint(wZp).failed())
FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
if (failed(maybeIZp))
return rewriter.notifyMatchFailure(
op, "bail out if zero points cannot statically be determined");
op, "input zero point cannot be statically determined");
if (op.verifyInputZeroPoint(iZp).failed() ||
op.verifyWeightZeroPoint(wZp).failed())
FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
if (failed(maybeWZp))
return rewriter.notifyMatchFailure(
op, "zero point must be zero for non-int8 integer types");
op, "weight zero point cannot be statically determined");
int64_t iZp = *maybeIZp;
int64_t wZp = *maybeWZp;
if (op.verifyInputZeroPoint(iZp).failed())
return rewriter.notifyMatchFailure(
op, "input zero point must be zero for non-int8 integer types");
if (op.verifyWeightZeroPoint(wZp).failed())
return rewriter.notifyMatchFailure(
op, "weight zero point must be zero for non-int8 integer types");
// Reshape input to [N, H, W, C] -> [N, H, W, C, 1].
ArrayRef<int64_t> inputShape = inputType.getShape();

View File

@@ -135,18 +135,26 @@ public:
getTosaConstShape(rewriter, op->getLoc(), weightPadding);
// Get and verify zero points.
int64_t inputZpVal;
int64_t weightZpVal;
if (op.getInputZeroPoint(inputZpVal).failed() ||
op.getWeightZeroPoint(weightZpVal).failed())
FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
if (failed(maybeIZp))
return rewriter.notifyMatchFailure(
op, "bail out if zero points cannot statically be determined");
op, "input zero point cannot be statically determined");
if (op.verifyInputZeroPoint(inputZpVal).failed() ||
op.verifyWeightZeroPoint(weightZpVal).failed())
FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
if (failed(maybeWZp))
return rewriter.notifyMatchFailure(
op, "zero point must be zero for non-int8 integer types");
op, "weight zero point cannot be statically determined");
int64_t inputZpVal = *maybeIZp;
int64_t weightZpVal = *maybeWZp;
if (op.verifyInputZeroPoint(inputZpVal).failed())
return rewriter.notifyMatchFailure(
op, "input zero point must be zero for non-int8 integer types");
if (op.verifyWeightZeroPoint(weightZpVal).failed())
return rewriter.notifyMatchFailure(
op, "weight zero point must be zero for non-int8 integer types");
if (weightZpVal != 0) {
weight = CreateOpAndInferShape<tosa::PadOp>(

View File

@@ -58,6 +58,8 @@ void ProfileInfoDepot::populateProfileInfo(tosa::ConcatOp op) {
template <>
void ProfileInfoDepot::populateProfileInfo(tosa::AvgPool2dOp op) {
addValue(op.getInput());
addValue(op.getInputZp());
addValue(op.getOutputZp());
addType(op.getAccType());
addValue(op.getOutput());
}

View File

@@ -1,9 +1,9 @@
// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg))" %s -verify-diagnostics
// CHECK-LABEL: @avg_pool2d_with_unsupported_quant_type
func.func @avg_pool2d_with_unsupported_quant_type(%arg0: tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>) -> tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>> {
func.func @avg_pool2d_with_unsupported_quant_type(%arg0: tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>, %arg1: tensor<1xi8>, %arg2: tensor<1xi8>) -> tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>> {
// expected-error@+1 {{failed to legalize operation 'tosa.avg_pool2d'}}
%0 = "tosa.avg_pool2d"(%arg0) {acc_type = i32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>) -> tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>
%0 = "tosa.avg_pool2d"(%arg0, %arg1, %arg2) {acc_type = i32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>
return %0 : tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>
}

View File

@@ -290,7 +290,9 @@ func.func @avg_pool_f32(%arg0: tensor<1x6x34x62xf32>) -> (tensor<1x5x33x62xf32>)
// CHECK: %[[FLT:.+]] = arith.sitofp %[[CAST]]
// CHECK: %[[DIV:.+]] = arith.divf %[[IN]], %[[FLT]]
// CHECK: linalg.yield %[[DIV]]
%0 = tosa.avg_pool2d %arg0 {acc_type = f32, pad = array<i64: 1, 1, 1, 1>, kernel = array<i64: 4, 4>, stride = array<i64: 1, 1>} : (tensor<1x6x34x62xf32>) -> tensor<1x5x33x62xf32>
%input_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
%output_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
%0 = tosa.avg_pool2d %arg0, %input_zp, %output_zp {acc_type = f32, pad = array<i64: 1, 1, 1, 1>, kernel = array<i64: 4, 4>, stride = array<i64: 1, 1>} : (tensor<1x6x34x62xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x5x33x62xf32>
return %0 : tensor<1x5x33x62xf32>
}
@@ -375,7 +377,9 @@ func.func @avg_pool_f16_f32acc(%arg0: tensor<1x6x34x62xf16>) -> (tensor<1x5x33x6
// CHECK: %[[DIV:.+]] = arith.divf %[[IN]], %[[FLT]]
// CHECK: %[[TRUNC:.+]] = arith.truncf %[[DIV]]
// CHECK: linalg.yield %[[TRUNC]]
%0 = tosa.avg_pool2d %arg0 {acc_type = f32, pad = array<i64: 1, 1, 1, 1>, kernel = array<i64: 4, 4>, stride = array<i64: 1, 1>} : (tensor<1x6x34x62xf16>) -> tensor<1x5x33x62xf16>
%input_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
%output_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
%0 = tosa.avg_pool2d %arg0, %input_zp, %output_zp {acc_type = f32, pad = array<i64: 1, 1, 1, 1>, kernel = array<i64: 4, 4>, stride = array<i64: 1, 1>} : (tensor<1x6x34x62xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x5x33x62xf16>
return %0 : tensor<1x5x33x62xf16>
}
@@ -416,7 +420,9 @@ func.func @avg_pool_i8(%arg0: tensor<1x6x34x62xi8>) -> (tensor<1x5x33x62xi8>) {
// CHECK: %[[CLAMP:.+]] = arith.minsi %[[CMAX]], %[[LOW]]
// CHECK: %[[TRUNC:.+]] = arith.trunci %[[CLAMP]]
// CHECK: linalg.yield %[[TRUNC]]
%0 = tosa.avg_pool2d %arg0 {acc_type = i32, pad = array<i64: 1, 1, 1, 1>, kernel = array<i64: 4, 4>, stride = array<i64: 1, 1>} : (tensor<1x6x34x62xi8>) -> tensor<1x5x33x62xi8>
%input_zp = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
%output_zp = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
%0 = tosa.avg_pool2d %arg0, %input_zp, %output_zp {acc_type = i32, pad = array<i64: 1, 1, 1, 1>, kernel = array<i64: 4, 4>, stride = array<i64: 1, 1>} : (tensor<1x6x34x62xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x5x33x62xi8>
return %0 : tensor<1x5x33x62xi8>
}
@@ -439,7 +445,9 @@ func.func @avg_pool_dyn(%arg0: tensor<?x6x34x62xf32>) -> (tensor<?x5x33x62xf32>)
// CHECK-SAME: outs(%[[FILL]] : tensor<?x5x33x62xf32>) -> tensor<?x5x33x62xf32>
// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[BATCH]]) : tensor<?x5x33x62xf32>
// CHECK: %[[GENERIC:.+]] = linalg.generic
%0 = tosa.avg_pool2d %arg0 {acc_type = f32, pad = array<i64: 1, 1, 1, 1>, kernel = array<i64: 4, 4>, stride = array<i64: 1, 1>} : (tensor<?x6x34x62xf32>) -> tensor<?x5x33x62xf32>
%input_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
%output_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
%0 = tosa.avg_pool2d %arg0, %input_zp, %output_zp {acc_type = f32, pad = array<i64: 1, 1, 1, 1>, kernel = array<i64: 4, 4>, stride = array<i64: 1, 1>} : (tensor<?x6x34x62xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x5x33x62xf32>
return %0 : tensor<?x5x33x62xf32>
}

View File

@@ -23,18 +23,18 @@ func.func @tensor_with_unknown_rank(%arg0: tensor<*xi8>) -> tensor<*xi8> {
// -----
// check that tosa verify kick in
func.func @test_avg_pool2d_zero_dim_input(%arg0: tensor<1x0x?x9xf32>) -> tensor<1x7x7x9xf32> {
func.func @test_avg_pool2d_zero_dim_input(%arg0: tensor<1x0x?x9xf32>, %arg1: tensor<1xf32>, %arg2: tensor<1xf32>) -> tensor<1x7x7x9xf32> {
// expected-error@+1 {{'tosa.avg_pool2d' op operand #0 must be 4-d tosa-conformant tensor, but got 'tensor<1x0x?x9xf32>'}}
%0 = "tosa.avg_pool2d"(%arg0) {acc_type = f32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>}
: (tensor<1x0x?x9xf32>) -> tensor<1x7x7x9xf32>
%0 = "tosa.avg_pool2d"(%arg0, %arg1, %arg2) {acc_type = f32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>}
: (tensor<1x0x?x9xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x7x7x9xf32>
return %0 : tensor<1x7x7x9xf32>
}
// -----
// check that --tosa-to-linalg kick in
func.func @avg_pool2d_with_unsupported_quant_type(%arg0: tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>) -> tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>> {
func.func @avg_pool2d_with_unsupported_quant_type(%arg0: tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>, %arg1: tensor<1xi8>, %arg2: tensor<1xi8>) -> tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>> {
// expected-error@+1 {{failed to legalize operation 'tosa.avg_pool2d'}}
%0 = "tosa.avg_pool2d"(%arg0) {acc_type = i32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>) -> tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>
%0 = "tosa.avg_pool2d"(%arg0, %arg1, %arg2) {acc_type = i32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>
return %0 : tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>
}

View File

@@ -19,7 +19,9 @@ func.func @test_argmax(%arg0: tensor<14x19xf32>) -> tensor<14xi32> {
func.func @test_avg_pool2d(%arg0: tensor<1x7x7x9xf32>) -> tensor<1x7x7x9xf32> {
// CHECK: profiles: [ [pro_int, pro_fp] ]
// CHECK: extensions: [ [int16, fp8e4m3, fp8e5m2, bf16] ]
%0 = tosa.avg_pool2d %arg0 {acc_type = f32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9xf32>) -> tensor<1x7x7x9xf32>
%input_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
%output_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
%0 = tosa.avg_pool2d %arg0, %input_zp, %output_zp {acc_type = f32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x7x7x9xf32>
return %0 : tensor<1x7x7x9xf32>
}

View File

@@ -56,7 +56,7 @@ func.func @test_conv2d(%arg0: tensor<1x29x29x4xi8>, %arg1: tensor<*xi8>, %arg2:
func.func @test_conv2d_input_zp(%arg0: tensor<1x29x29x4xf16>, %arg1: tensor<16x3x3x4xf16>, %arg2: tensor<16xf16>) -> tensor<1x27x27x16xf16> {
%input_zp = "tosa.const"() <{value = dense<-1.0> : tensor<1xf16>}> : () -> tensor<1xf16>
%weight_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
// expected-error@+1 {{'tosa.conv2d' op non-zero zero point is not allowed for float types}}
// expected-error@+1 {{'tosa.conv2d' op input zero point must be zero for non-int8 integer types}}
%0 = tosa.conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f16, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
: (tensor<1x29x29x4xf16>, tensor<16x3x3x4xf16>, tensor<16xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x27x27x16xf16>
return %0 : tensor<1x27x27x16xf16>
@@ -67,7 +67,7 @@ func.func @test_conv2d_input_zp(%arg0: tensor<1x29x29x4xf16>, %arg1: tensor<16x3
func.func @test_conv2d_weight_zp(%arg0: tensor<1x29x29x4xf16>, %arg1: tensor<16x3x3x4xf16>, %arg2: tensor<16xf16>) -> tensor<1x27x27x16xf16> {
%input_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
%weight_zp = "tosa.const"() <{value = dense<-1.0> : tensor<1xf16>}> : () -> tensor<1xf16>
// expected-error@+1 {{'tosa.conv2d' op non-zero zero point is not allowed for float types}}
// expected-error@+1 {{'tosa.conv2d' op weight zero point must be zero for non-int8 integer types}}
%0 = tosa.conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f16, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
: (tensor<1x29x29x4xf16>, tensor<16x3x3x4xf16>, tensor<16xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x27x27x16xf16>
return %0 : tensor<1x27x27x16xf16>
@@ -567,19 +567,19 @@ func.func @test_conv2d_zero_dim_input(%arg0: tensor<1x?x0x4xf32>, %arg1: tensor<
// -----
func.func @test_avg_pool2d_static_zero_dim_input(%arg0: tensor<1x0x7x9xf32>) -> tensor<1x7x7x9xf32> {
func.func @test_avg_pool2d_static_zero_dim_input(%arg0: tensor<1x0x7x9xf32>, %arg1: tensor<1xf32>, %arg2: tensor<1xf32>) -> tensor<1x7x7x9xf32> {
// expected-error@+1 {{'tosa.avg_pool2d' op operand #0 must be 4-d tosa-conformant tensor, but got 'tensor<1x0x7x9xf32>'}}
%0 = "tosa.avg_pool2d"(%arg0) {acc_type = f32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>}
: (tensor<1x0x7x9xf32>) -> tensor<1x7x7x9xf32>
%0 = "tosa.avg_pool2d"(%arg0, %arg1, %arg2) {acc_type = f32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>}
: (tensor<1x0x7x9xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x7x7x9xf32>
return %0 : tensor<1x7x7x9xf32>
}
// -----
func.func @test_avg_pool2d_zero_dim_input(%arg0: tensor<1x0x?x9xf32>) -> tensor<1x7x7x9xf32> {
func.func @test_avg_pool2d_zero_dim_input(%arg0: tensor<1x0x?x9xf32>, %arg1: tensor<1xf32>, %arg2: tensor<1xf32>) -> tensor<1x7x7x9xf32> {
// expected-error@+1 {{'tosa.avg_pool2d' op operand #0 must be 4-d tosa-conformant tensor, but got 'tensor<1x0x?x9xf32>'}}
%0 = "tosa.avg_pool2d"(%arg0) {acc_type = f32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>}
: (tensor<1x0x?x9xf32>) -> tensor<1x7x7x9xf32>
%0 = "tosa.avg_pool2d"(%arg0, %arg1, %arg2) {acc_type = f32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>}
: (tensor<1x0x?x9xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x7x7x9xf32>
return %0 : tensor<1x7x7x9xf32>
}
@@ -1271,6 +1271,50 @@ func.func @test_conv2d_invalid_bias_size(%arg0: tensor<1x4x4x4xf32>, %arg1: tens
// -----
// CHECK-LABEL: test_avg_pool_input_zp_same_element_type
func.func @test_avg_pool_input_zp_same_element_type(%arg0: tensor<1x16x16x8xf16>, %arg1: tensor<1xi8>, %arg2: tensor<1xf16>) -> tensor<1x16x16x8xf16> {
// expected-error@+1 {{'tosa.avg_pool2d' op expect both input and its zero point are the same element type, got 'f16' and 'i8'}}
%0 = "tosa.avg_pool2d"(%arg0, %arg1, %arg2) {acc_type = f32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>}
: (tensor<1x16x16x8xf16>, tensor<1xi8>, tensor<1xf16>) -> tensor<1x16x16x8xf16>
return %0 : tensor<1x16x16x8xf16>
}
// -----
// CHECK-LABEL: test_avg_pool_output_zp_same_element_type
func.func @test_avg_pool_output_zp_same_element_type(%arg0: tensor<1x16x16x8xi8>, %arg1: tensor<1xi8>, %arg2: tensor<1xf16>) -> tensor<1x16x16x8xi8> {
// expected-error@+1 {{'tosa.avg_pool2d' op expect both output and its zero point are the same element type, got 'i8' and 'f16'}}
%0 = "tosa.avg_pool2d"(%arg0, %arg1, %arg2) {acc_type = i32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>}
: (tensor<1x16x16x8xi8>, tensor<1xi8>, tensor<1xf16>) -> tensor<1x16x16x8xi8>
return %0 : tensor<1x16x16x8xi8>
}
// -----
// CHECK-LABEL: test_avg_pool_input_zp_non_zero
func.func @test_avg_pool_input_zp_non_zero(%arg0: tensor<1x16x16x8xf32>) -> tensor<1x16x16x8xf32> {
%input_zp = "tosa.const"() {value = dense<-1.0> : tensor<1xf32>} : () -> tensor<1xf32>
%output_zp = "tosa.const"() {value = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32>
// expected-error@+1 {{'tosa.avg_pool2d' op input zero point must be zero for non-int8 integer types}}
%0 = "tosa.avg_pool2d"(%arg0, %input_zp, %output_zp) {acc_type = f32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>}
: (tensor<1x16x16x8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x16x16x8xf32>
return %0 : tensor<1x16x16x8xf32>
}
// -----
// CHECK-LABEL: test_avg_pool_output_zp_non_zero
func.func @test_avg_pool_output_zp_non_zero(%arg0: tensor<1x16x16x8xf32>) -> tensor<1x16x16x8xf32> {
%input_zp = "tosa.const"() {value = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32>
%output_zp = "tosa.const"() {value = dense<-1.0> : tensor<1xf32>} : () -> tensor<1xf32>
// expected-error@+1 {{'tosa.avg_pool2d' op output zero point must be zero for non-int8 integer types}}
%0 = "tosa.avg_pool2d"(%arg0, %input_zp, %output_zp) {acc_type = f32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>}
: (tensor<1x16x16x8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x16x16x8xf32>
return %0 : tensor<1x16x16x8xf32>
}
// -----
func.func @test_fft2d_same_operands_and_result_element_type(%arg0: tensor<1x4x8xf32>, %arg1: tensor<1x4x8xf32>) -> (tensor<1x4x8xf16>, tensor<1x4x8xf16>) {
// expected-error@+1 {{'tosa.fft2d' op requires the same element type for all operands and results}}
%0, %1 = tosa.fft2d %arg0, %arg1 {inverse = false} : (tensor<1x4x8xf32>, tensor<1x4x8xf32>) -> (tensor<1x4x8xf16>, tensor<1x4x8xf16>)

View File

@@ -506,74 +506,74 @@ func.func @test_identity_rank_valid(%arg0: tensor<i32>) -> tensor<i32> {
// -----
func.func @test_avgpool2d_kernel_y(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
func.func @test_avgpool2d_kernel_y(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<1xf32>, %arg2: tensor<1xf32>) -> tensor<1x32x32x8xf32> {
// expected-error@+1 {{'tosa.avg_pool2d' op failed level check: kernel <= MAX_KERNEL}}
%0 = "tosa.avg_pool2d"(%arg0) {kernel = array<i64: 8193, 1>, pad = array<i64: 4, 4, 4, 4>, stride = array<i64: 1, 1>, acc_type = f32} :
(tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32>
%0 = "tosa.avg_pool2d"(%arg0, %arg1, %arg2) {kernel = array<i64: 8193, 1>, pad = array<i64: 4, 4, 4, 4>, stride = array<i64: 1, 1>, acc_type = f32} :
(tensor<1x32x32x8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x32x32x8xf32>
return %0 : tensor<1x32x32x8xf32>
}
// -----
func.func @test_avgpool2d_kernel_x(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
func.func @test_avgpool2d_kernel_x(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<1xf32>, %arg2: tensor<1xf32>) -> tensor<1x32x32x8xf32> {
// expected-error@+1 {{'tosa.avg_pool2d' op failed level check: kernel <= MAX_KERNEL}}
%0 = "tosa.avg_pool2d"(%arg0) {kernel = array<i64: 1, 8193>, pad = array<i64: 4, 4, 4, 4>, stride = array<i64: 1, 1>, acc_type = f32} :
(tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32>
%0 = "tosa.avg_pool2d"(%arg0, %arg1, %arg2) {kernel = array<i64: 1, 8193>, pad = array<i64: 4, 4, 4, 4>, stride = array<i64: 1, 1>, acc_type = f32} :
(tensor<1x32x32x8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x32x32x8xf32>
return %0 : tensor<1x32x32x8xf32>
}
// -----
func.func @test_avgpool2d_stride_y(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
func.func @test_avgpool2d_stride_y(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<1xf32>, %arg2: tensor<1xf32>) -> tensor<1x32x32x8xf32> {
// expected-error@+1 {{'tosa.avg_pool2d' op failed level check: stride <= MAX_STRIDE}}
%0 = "tosa.avg_pool2d"(%arg0) {kernel = array<i64: 1, 1>, pad = array<i64: 4, 4, 4, 4>, stride = array<i64: 8193, 1>, acc_type = f32} :
(tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32>
%0 = "tosa.avg_pool2d"(%arg0, %arg1, %arg2) {kernel = array<i64: 1, 1>, pad = array<i64: 4, 4, 4, 4>, stride = array<i64: 8193, 1>, acc_type = f32} :
(tensor<1x32x32x8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x32x32x8xf32>
return %0 : tensor<1x32x32x8xf32>
}
// -----
func.func @test_avgpool2d_stride_x(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
func.func @test_avgpool2d_stride_x(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<1xf32>, %arg2: tensor<1xf32>) -> tensor<1x32x32x8xf32> {
// expected-error@+1 {{'tosa.avg_pool2d' op failed level check: stride <= MAX_STRIDE}}
%0 = "tosa.avg_pool2d"(%arg0) {kernel = array<i64: 1, 1>, pad = array<i64: 4, 4, 4, 4>, stride = array<i64: 1, 8193>, acc_type = f32} :
(tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32>
%0 = "tosa.avg_pool2d"(%arg0, %arg1, %arg2) {kernel = array<i64: 1, 1>, pad = array<i64: 4, 4, 4, 4>, stride = array<i64: 1, 8193>, acc_type = f32} :
(tensor<1x32x32x8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x32x32x8xf32>
return %0 : tensor<1x32x32x8xf32>
}
// -----
func.func @test_avgpool2d_pad_top(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
func.func @test_avgpool2d_pad_top(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<1xf32>, %arg2: tensor<1xf32>) -> tensor<1x32x32x8xf32> {
// expected-error@+1 {{'tosa.avg_pool2d' op failed level check: pad <= MAX_KERNEL}}
%0 = "tosa.avg_pool2d"(%arg0) {kernel = array<i64: 1, 1>, pad = array<i64: 8193, 4, 4, 4>, stride = array<i64: 1, 1>, acc_type = f32} :
(tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32>
%0 = "tosa.avg_pool2d"(%arg0, %arg1, %arg2) {kernel = array<i64: 1, 1>, pad = array<i64: 8193, 4, 4, 4>, stride = array<i64: 1, 1>, acc_type = f32} :
(tensor<1x32x32x8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x32x32x8xf32>
return %0 : tensor<1x32x32x8xf32>
}
// -----
func.func @test_avgpool2d_pad_bottom(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
func.func @test_avgpool2d_pad_bottom(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<1xf32>, %arg2: tensor<1xf32>) -> tensor<1x32x32x8xf32> {
// expected-error@+1 {{'tosa.avg_pool2d' op failed level check: pad <= MAX_KERNEL}}
%0 = "tosa.avg_pool2d"(%arg0) {kernel = array<i64: 1, 1>, pad = array<i64: 4, 8193, 4, 4>, stride = array<i64: 1, 1>, acc_type = f32} :
(tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32>
%0 = "tosa.avg_pool2d"(%arg0, %arg1, %arg2) {kernel = array<i64: 1, 1>, pad = array<i64: 4, 8193, 4, 4>, stride = array<i64: 1, 1>, acc_type = f32} :
(tensor<1x32x32x8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x32x32x8xf32>
return %0 : tensor<1x32x32x8xf32>
}
// -----
func.func @test_avgpool2d_pad_left(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
func.func @test_avgpool2d_pad_left(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<1xf32>, %arg2: tensor<1xf32>) -> tensor<1x32x32x8xf32> {
// expected-error@+1 {{'tosa.avg_pool2d' op failed level check: pad <= MAX_KERNEL}}
%0 = "tosa.avg_pool2d"(%arg0) {kernel = array<i64: 1, 1>, pad = array<i64: 4, 4, 8193, 4>, stride = array<i64: 1, 1>, acc_type = f32} :
(tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32>
%0 = "tosa.avg_pool2d"(%arg0, %arg1, %arg2) {kernel = array<i64: 1, 1>, pad = array<i64: 4, 4, 8193, 4>, stride = array<i64: 1, 1>, acc_type = f32} :
(tensor<1x32x32x8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x32x32x8xf32>
return %0 : tensor<1x32x32x8xf32>
}
// -----
func.func @test_avgpool2d_pad_right(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
func.func @test_avgpool2d_pad_right(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<1xf32>, %arg2: tensor<1xf32>) -> tensor<1x32x32x8xf32> {
// expected-error@+1 {{'tosa.avg_pool2d' op failed level check: pad <= MAX_KERNEL}}
%0 = "tosa.avg_pool2d"(%arg0) {kernel = array<i64: 1, 1>, pad = array<i64: 4, 4, 4, 8193>, stride = array<i64: 1, 1>, acc_type = f32} :
(tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32>
%0 = "tosa.avg_pool2d"(%arg0, %arg1, %arg2) {kernel = array<i64: 1, 1>, pad = array<i64: 4, 4, 4, 8193>, stride = array<i64: 1, 1>, acc_type = f32} :
(tensor<1x32x32x8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x32x32x8xf32>
return %0 : tensor<1x32x32x8xf32>
}
@@ -1074,9 +1074,9 @@ func.func @test_resize_tensor_size_invalid(%arg0: tensor<1x23178x23178x1xf32>) {
// -----
func.func @test_avg_pool2d_tensor_size_invalid(%arg0: tensor<1x23178x23178x9xf32>) -> tensor<1x23178x23178x9xf32> {
func.func @test_avg_pool2d_tensor_size_invalid(%arg0: tensor<1x23178x23178x9xf32>, %arg1: tensor<1xf32>, %arg2: tensor<1xf32>) -> tensor<1x23178x23178x9xf32> {
// expected-error@+1 {{'tosa.avg_pool2d' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
%0 = tosa.avg_pool2d %arg0 {acc_type = f32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x23178x23178x9xf32>) -> tensor<1x23178x23178x9xf32>
%0 = tosa.avg_pool2d %arg0, %arg1, %arg2 {acc_type = f32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x23178x23178x9xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x23178x23178x9xf32>
return %0 : tensor<1x23178x23178x9xf32>
}

View File

@@ -12,42 +12,54 @@ func.func @test_argmax(%arg0: tensor<14x19xf32>) -> tensor<14xi32> {
// -----
// CHECK-LABEL: avg_pool2d_f32
func.func @test_avg_pool2d_f32(%arg0: tensor<1x7x7x9xf32>) -> tensor<1x7x7x9xf32> {
%0 = tosa.avg_pool2d %arg0 {acc_type = f32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9xf32>) -> tensor<1x7x7x9xf32>
%input_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
%output_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
%0 = tosa.avg_pool2d %arg0, %input_zp, %output_zp {acc_type = f32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x7x7x9xf32>
return %0 : tensor<1x7x7x9xf32>
}
// -----
// CHECK-LABEL: avg_pool2d_f16
func.func @test_avg_pool2d_f16(%arg0: tensor<1x7x7x9xf16>) -> tensor<1x7x7x9xf16> {
%0 = tosa.avg_pool2d %arg0 {acc_type = f16, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9xf16>) -> tensor<1x7x7x9xf16>
%input_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
%output_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
%0 = tosa.avg_pool2d %arg0, %input_zp, %output_zp {acc_type = f16, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x7x7x9xf16>
return %0 : tensor<1x7x7x9xf16>
}
// -----
// CHECK-LABEL: avg_pool2d_f16_accumf32
func.func @test_avg_pool2d_f16_accumf32(%arg0: tensor<1x7x7x9xf16>) -> tensor<1x7x7x9xf16> {
%0 = tosa.avg_pool2d %arg0 {acc_type = f32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9xf16>) -> tensor<1x7x7x9xf16>
%input_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
%output_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
%0 = tosa.avg_pool2d %arg0, %input_zp, %output_zp {acc_type = f32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x7x7x9xf16>
return %0 : tensor<1x7x7x9xf16>
}
// -----
// CHECK-LABEL: avg_pool2d_i8
func.func @test_avg_pool2d_i8(%arg0: tensor<1x7x7x9xi8>) -> tensor<1x7x7x9xi8> {
%0 = tosa.avg_pool2d %arg0 {acc_type = i32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9xi8>) -> tensor<1x7x7x9xi8>
%input_zp = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
%output_zp = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
%0 = tosa.avg_pool2d %arg0, %input_zp, %output_zp {acc_type = i32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x7x7x9xi8>
return %0 : tensor<1x7x7x9xi8>
}
// -----
// CHECK-LABEL: avg_pool2d_i16
func.func @test_avg_pool2d_i16(%arg0: tensor<1x7x7x9xi16>) -> tensor<1x7x7x9xi16> {
%0 = tosa.avg_pool2d %arg0 {acc_type = i32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9xi16>) -> tensor<1x7x7x9xi16>
%input_zp = "tosa.const"() <{value = dense<0> : tensor<1xi16>}> : () -> tensor<1xi16>
%output_zp = "tosa.const"() <{value = dense<0> : tensor<1xi16>}> : () -> tensor<1xi16>
%0 = tosa.avg_pool2d %arg0, %input_zp, %output_zp {acc_type = i32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9xi16>, tensor<1xi16>, tensor<1xi16>) -> tensor<1x7x7x9xi16>
return %0 : tensor<1x7x7x9xi16>
}
// -----
// CHECK-LABEL: avg_pool2d_q8
func.func @test_avg_pool2d_q8(%arg0: tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>) -> tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>> {
%0 = tosa.avg_pool2d %arg0 {acc_type = i32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>) -> tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>
%input_zp = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
%output_zp = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
%0 = tosa.avg_pool2d %arg0, %input_zp, %output_zp {acc_type = i32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>
return %0 : tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>
}

View File

@@ -19,9 +19,9 @@ func.func @test_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, %
}
// -----
func.func @test_avg_pool2d_f32(%arg0: tensor<1x7x7x9xf32>) -> tensor<1x7x7x9xf32> {
func.func @test_avg_pool2d(%arg0: tensor<1x7x7x9xf32>, %arg1: tensor<1xf32>, %arg2: tensor<1xf32>) -> tensor<1x7x7x9xf32> {
// expected-error@+1 {{'tosa.avg_pool2d' op illegal: requires [pro_fp] but not enabled in target}}
%0 = tosa.avg_pool2d %arg0 {acc_type = f32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9xf32>) -> tensor<1x7x7x9xf32>
%0 = tosa.avg_pool2d %arg0, %arg1, %arg2 {acc_type = f32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x7x7x9xf32>
return %0 : tensor<1x7x7x9xf32>
}

View File

@@ -12,9 +12,9 @@ func.func @test_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, %
}
// -----
func.func @test_avg_pool2d_f32(%arg0: tensor<1x7x7x9xf32>) -> tensor<1x7x7x9xf32> {
func.func @test_avg_pool2d(%arg0: tensor<1x7x7x9xf32>, %arg1: tensor<1xf32>, %arg2: tensor<1xf32>) -> tensor<1x7x7x9xf32> {
// expected-error@+1 {{'tosa.avg_pool2d' op illegal: requires [pro_fp] but not enabled in target}}
%0 = tosa.avg_pool2d %arg0 {acc_type = f32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9xf32>) -> tensor<1x7x7x9xf32>
%0 = tosa.avg_pool2d %arg0, %arg1, %arg2 {acc_type = f32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x7x7x9xf32>
return %0 : tensor<1x7x7x9xf32>
}

View File

@@ -669,8 +669,11 @@ func.func @scatter_minimum_static(%arg0 : tensor<?x4x?xi32>, %arg1 : tensor<3x?x
// CHECK-LABEL: @test_pool_static
func.func @test_pool_static(%arg0: tensor<3x5x6x7xf32>) {
%input_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
%output_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
// CHECK: -> tensor<3x2x4x7xf32>
%0 = tosa.avg_pool2d %arg0 {acc_type = f32, kernel = array<i64: 4, 3>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<3x5x6x7xf32>) -> tensor<?x?x?x?xf32>
%0 = tosa.avg_pool2d %arg0, %input_zp, %output_zp {acc_type = f32, kernel = array<i64: 4, 3>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<3x5x6x7xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x?x?x?xf32>
// CHECK: -> tensor<3x2x4x7xf32>
%1 = tosa.max_pool2d %arg0 {kernel = array<i64: 4, 3>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<3x5x6x7xf32>) -> tensor<?x?x?x?xf32>
@@ -699,8 +702,11 @@ func.func @conv2d_dynamic_input(%input: tensor<?x?x?x?xf32>, %weights: tensor<5x
// CHECK-LABEL: @test_pool_dynamic_input
func.func @test_pool_dynamic_input(%arg0: tensor<?x?x?x?xf32>) {
%input_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
%output_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
// CHECK: -> tensor<?x?x?x?xf32>
%0 = tosa.avg_pool2d %arg0 {acc_type = f32, kernel = array<i64: 4, 3>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
%0 = tosa.avg_pool2d %arg0, %input_zp, %output_zp {acc_type = f32, kernel = array<i64: 4, 3>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<?x?x?x?xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x?x?x?xf32>
// CHECK: -> tensor<?x?x?x?xf32>
%1 = tosa.max_pool2d %arg0 {kernel = array<i64: 4, 3>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
@@ -711,8 +717,11 @@ func.func @test_pool_dynamic_input(%arg0: tensor<?x?x?x?xf32>) {
// CHECK-LABEL: @test_pool_padded
func.func @test_pool_padded(%arg0: tensor<3x5x6x7xf32>) {
%input_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
%output_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
// CHECK: -> tensor<3x5x11x7xf32>
%0 = tosa.avg_pool2d %arg0 {acc_type = f32, kernel = array<i64: 4, 3>, pad = array<i64: 1, 2, 3, 4>, stride = array<i64: 1, 1>} : (tensor<3x5x6x7xf32>) -> tensor<?x?x?x?xf32>
%0 = tosa.avg_pool2d %arg0, %input_zp, %output_zp {acc_type = f32, kernel = array<i64: 4, 3>, pad = array<i64: 1, 2, 3, 4>, stride = array<i64: 1, 1>} : (tensor<3x5x6x7xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x?x?x?xf32>
// CHECK: -> tensor<3x5x11x7xf32>
%1 = tosa.max_pool2d %arg0 {kernel = array<i64: 4, 3>, pad = array<i64: 1, 2, 3, 4>, stride = array<i64: 1, 1>} : (tensor<3x5x6x7xf32>) -> tensor<?x?x?x?xf32>
@@ -741,8 +750,11 @@ func.func @conv2d_dynamic_bias(%input: tensor<2x8x9x3xf32>, %weights: tensor<5x3
// CHECK-LABEL: @test_pool_stride
func.func @test_pool_stride(%arg0: tensor<3x11x12x7xf32>) {
%input_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
%output_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
// CHECK: -> tensor<3x4x4x7xf32>
%0 = tosa.avg_pool2d %arg0 {acc_type = f32, kernel = array<i64: 4, 3>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 2, 3>} : (tensor<3x11x12x7xf32>) -> tensor<?x?x?x?xf32>
%0 = tosa.avg_pool2d %arg0, %input_zp, %output_zp {acc_type = f32, kernel = array<i64: 4, 3>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 2, 3>} : (tensor<3x11x12x7xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x?x?x?xf32>
// CHECK: -> tensor<3x4x4x7xf32>
%1 = tosa.max_pool2d %arg0 {kernel = array<i64: 4, 3>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 2, 3>} : (tensor<3x11x12x7xf32>) -> tensor<?x?x?x?xf32>