From 356bd2c9605761121b49f318a187560ec306718e Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Thu, 8 May 2025 12:33:37 +0200 Subject: [PATCH] [mlir][tosa] Allow unsigned types for rescale ops during validation (#138253) This commit allows unsigned types (ui8/ui16/ui32) when checking for valid element types, only for rescale operators. Signed-off-by: Luke Hutton --- .../Tosa/Transforms/TosaValidation.cpp | 23 +++++++++++++----- mlir/test/Dialect/Tosa/invalid.mlir | 24 +++++++++++++++++++ 2 files changed, 41 insertions(+), 6 deletions(-) diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp index e8b52d48347a..feedc5057bea 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp @@ -562,7 +562,7 @@ private: bool CheckVariable(Operation *op); bool CheckVariableReadOrWrite(Operation *op); - bool isValidElementType(Type type); + bool isValidElementType(Type type, const bool allowUnsigned = false); SmallVector< std::function> @@ -1176,7 +1176,7 @@ LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) { return success(); } -bool TosaValidation::isValidElementType(Type type) { +bool TosaValidation::isValidElementType(Type type, const bool allowUnsigned) { if (isa(type)) { return isa(type); @@ -1191,6 +1191,13 @@ bool TosaValidation::isValidElementType(Type type) { case 48: return true; } + } else if (allowUnsigned && intTy.isUnsigned()) { + switch (intTy.getWidth()) { + case 8: + case 16: + case 32: + return true; + } } } else if (mlir::isa(type)) { return true; @@ -1209,11 +1216,15 @@ void TosaValidation::runOnOperation() { if (op->getDialect() != tosaDialect) return; - // perform valid element type check at the beginning to - // protect rest of code against quantized element types + // validate operator element types: + // - rescale operator is allowed to have ui8/ui16/ui32 + // operands/results + // - perform valid element type check at the beginning to + // protect rest of code against quantized element types + const bool opIsRescale = isa(op); for (Value operand : op->getOperands()) { auto elementTy = getElementTypeOrSelf(operand); - if (!isValidElementType(elementTy)) { + if (!isValidElementType(elementTy, opIsRescale)) { op->emitOpError() << "is not profile-aligned: element type " << elementTy << " is not legal"; return signalPassFailure(); @@ -1221,7 +1232,7 @@ void TosaValidation::runOnOperation() { } for (Type resultTy : op->getResultTypes()) { auto elementTy = getElementTypeOrSelf(resultTy); - if (!isValidElementType(elementTy)) { + if (!isValidElementType(elementTy, opIsRescale)) { op->emitOpError() << "is not profile-aligned: element type " << elementTy << " is not legal"; return signalPassFailure(); diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index 56d76585be71..732c980f3ab9 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -1937,3 +1937,27 @@ func.func @test_clamp_min_larger_than_max_fp32(%arg0: tensor<13x21x3xf32>) -> te %0 = tosa.clamp %arg0 {min_val = 2.0 : f32, max_val = -1.1: f32} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> return %0 : tensor<13x21x3xf32> } + +// ----- + +// CHECK-LABEL: test_rescale_input_unsigned +func.func @test_rescale_input_unsigned(%arg0: tensor<1x1xui8>) -> (tensor<1x1xi8>) { + %0 = "tosa.const"() <{values = dense<1> : tensor<1xi8>}> : () -> tensor<1xi8> + %1 = "tosa.const"() <{values = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32> + %2 = "tosa.const"() <{values = dense<3> : tensor<1xi8>}> : () -> tensor<1xi8> + %3 = "tosa.const"() <{values = dense<-128> : tensor<1xi8>}> : () -> tensor<1xi8> + %r = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = true, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xui8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1xi8> + return %r : tensor<1x1xi8> +} + +// ----- + +// CHECK-LABEL: test_rescale_output_unsigned +func.func @test_rescale_output_unsigned(%arg0: tensor<1x1xi8>) -> (tensor<1x1xui8>) { + %0 = "tosa.const"() <{values = dense<1> : tensor<1xi8>}> : () -> tensor<1xi8> + %1 = "tosa.const"() <{values = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32> + %2 = "tosa.const"() <{values = dense<3> : tensor<1xi8>}> : () -> tensor<1xi8> + %3 = "tosa.const"() <{values = dense<-128> : tensor<1xi8>}> : () -> tensor<1xi8> + %r = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = false, output_unsigned = true, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1xui8> + return %r : tensor<1x1xui8> +}