mirror of
https://github.com/intel/llvm.git
synced 2026-01-16 05:32:28 +08:00
[mlir][tosa] Allow unsigned types for rescale ops during validation (#138253)
This commit allows unsigned types (ui8/ui16/ui32) when checking for valid element types, only for rescale operators. Signed-off-by: Luke Hutton <luke.hutton@arm.com>
This commit is contained in:
@@ -562,7 +562,7 @@ private:
|
||||
|
||||
bool CheckVariable(Operation *op);
|
||||
bool CheckVariableReadOrWrite(Operation *op);
|
||||
bool isValidElementType(Type type);
|
||||
bool isValidElementType(Type type, const bool allowUnsigned = false);
|
||||
|
||||
SmallVector<
|
||||
std::function<LogicalResult(Operation *, const tosa::TargetEnv &)>>
|
||||
@@ -1176,7 +1176,7 @@ LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
bool TosaValidation::isValidElementType(Type type) {
|
||||
bool TosaValidation::isValidElementType(Type type, const bool allowUnsigned) {
|
||||
if (isa<FloatType>(type)) {
|
||||
return isa<Float32Type, Float16Type, BFloat16Type, Float8E4M3FNType,
|
||||
Float8E5M2Type>(type);
|
||||
@@ -1191,6 +1191,13 @@ bool TosaValidation::isValidElementType(Type type) {
|
||||
case 48:
|
||||
return true;
|
||||
}
|
||||
} else if (allowUnsigned && intTy.isUnsigned()) {
|
||||
switch (intTy.getWidth()) {
|
||||
case 8:
|
||||
case 16:
|
||||
case 32:
|
||||
return true;
|
||||
}
|
||||
}
|
||||
} else if (mlir::isa<tosa::shapeType>(type)) {
|
||||
return true;
|
||||
@@ -1209,11 +1216,15 @@ void TosaValidation::runOnOperation() {
|
||||
if (op->getDialect() != tosaDialect)
|
||||
return;
|
||||
|
||||
// perform valid element type check at the beginning to
|
||||
// protect rest of code against quantized element types
|
||||
// validate operator element types:
|
||||
// - rescale operator is allowed to have ui8/ui16/ui32
|
||||
// operands/results
|
||||
// - perform valid element type check at the beginning to
|
||||
// protect rest of code against quantized element types
|
||||
const bool opIsRescale = isa<tosa::RescaleOp>(op);
|
||||
for (Value operand : op->getOperands()) {
|
||||
auto elementTy = getElementTypeOrSelf(operand);
|
||||
if (!isValidElementType(elementTy)) {
|
||||
if (!isValidElementType(elementTy, opIsRescale)) {
|
||||
op->emitOpError() << "is not profile-aligned: element type "
|
||||
<< elementTy << " is not legal";
|
||||
return signalPassFailure();
|
||||
@@ -1221,7 +1232,7 @@ void TosaValidation::runOnOperation() {
|
||||
}
|
||||
for (Type resultTy : op->getResultTypes()) {
|
||||
auto elementTy = getElementTypeOrSelf(resultTy);
|
||||
if (!isValidElementType(elementTy)) {
|
||||
if (!isValidElementType(elementTy, opIsRescale)) {
|
||||
op->emitOpError() << "is not profile-aligned: element type "
|
||||
<< elementTy << " is not legal";
|
||||
return signalPassFailure();
|
||||
|
||||
@@ -1937,3 +1937,27 @@ func.func @test_clamp_min_larger_than_max_fp32(%arg0: tensor<13x21x3xf32>) -> te
|
||||
%0 = tosa.clamp %arg0 {min_val = 2.0 : f32, max_val = -1.1: f32} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
|
||||
return %0 : tensor<13x21x3xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: test_rescale_input_unsigned
|
||||
func.func @test_rescale_input_unsigned(%arg0: tensor<1x1xui8>) -> (tensor<1x1xi8>) {
|
||||
%0 = "tosa.const"() <{values = dense<1> : tensor<1xi8>}> : () -> tensor<1xi8>
|
||||
%1 = "tosa.const"() <{values = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32>
|
||||
%2 = "tosa.const"() <{values = dense<3> : tensor<1xi8>}> : () -> tensor<1xi8>
|
||||
%3 = "tosa.const"() <{values = dense<-128> : tensor<1xi8>}> : () -> tensor<1xi8>
|
||||
%r = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = true, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xui8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1xi8>
|
||||
return %r : tensor<1x1xi8>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: test_rescale_output_unsigned
|
||||
func.func @test_rescale_output_unsigned(%arg0: tensor<1x1xi8>) -> (tensor<1x1xui8>) {
|
||||
%0 = "tosa.const"() <{values = dense<1> : tensor<1xi8>}> : () -> tensor<1xi8>
|
||||
%1 = "tosa.const"() <{values = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32>
|
||||
%2 = "tosa.const"() <{values = dense<3> : tensor<1xi8>}> : () -> tensor<1xi8>
|
||||
%3 = "tosa.const"() <{values = dense<-128> : tensor<1xi8>}> : () -> tensor<1xi8>
|
||||
%r = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = false, output_unsigned = true, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1xui8>
|
||||
return %r : tensor<1x1xui8>
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user