[mlir][emitc] Add EmitC lowering for arith.trunci, arith.extsi, arith.extui

This commit adds conversion to EmitC for arith dialect casts between integer types (trunc, extsi, extui), excluding indexes for now.
This commit is contained in:
Corentin Ferry
2024-05-22 16:33:37 +02:00
committed by GitHub
parent 267de8543c
commit 7630379156
3 changed files with 162 additions and 0 deletions

View File

@@ -15,6 +15,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Tools/PDLL/AST/Types.h"
#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
@@ -112,6 +113,93 @@ public:
}
};
template <typename ArithOp, bool castToUnsigned>
class CastConversion : public OpConversionPattern<ArithOp> {
public:
using OpConversionPattern<ArithOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type opReturnType = this->getTypeConverter()->convertType(op.getType());
if (!isa_and_nonnull<IntegerType>(opReturnType))
return rewriter.notifyMatchFailure(op, "expected integer result type");
if (adaptor.getOperands().size() != 1) {
return rewriter.notifyMatchFailure(
op, "CastConversion only supports unary ops");
}
Type operandType = adaptor.getIn().getType();
if (!isa_and_nonnull<IntegerType>(operandType))
return rewriter.notifyMatchFailure(op, "expected integer operand type");
// Signed (sign-extending) casts from i1 are not supported.
if (operandType.isInteger(1) && !castToUnsigned)
return rewriter.notifyMatchFailure(op,
"operation not supported on i1 type");
// to-i1 conversions: arith semantics want truncation, whereas (bool)(v) is
// equivalent to (v != 0). Implementing as (bool)(v & 0x01) gives
// truncation.
if (opReturnType.isInteger(1)) {
auto constOne = rewriter.create<emitc::ConstantOp>(
op.getLoc(), operandType, rewriter.getIntegerAttr(operandType, 1));
auto oneAndOperand = rewriter.create<emitc::BitwiseAndOp>(
op.getLoc(), operandType, adaptor.getIn(), constOne);
rewriter.replaceOpWithNewOp<emitc::CastOp>(op, opReturnType,
oneAndOperand);
return success();
}
bool isTruncation = operandType.getIntOrFloatBitWidth() >
opReturnType.getIntOrFloatBitWidth();
bool doUnsigned = castToUnsigned || isTruncation;
Type castType = opReturnType;
// If the op is a ui variant and the type wanted as
// return type isn't unsigned, we need to issue an unsigned type to do
// the conversion.
if (castType.isUnsignedInteger() != doUnsigned) {
castType = rewriter.getIntegerType(opReturnType.getIntOrFloatBitWidth(),
/*isSigned=*/!doUnsigned);
}
Value actualOp = adaptor.getIn();
// Adapt the signedness of the operand if necessary
if (operandType.isUnsignedInteger() != doUnsigned) {
Type correctSignednessType =
rewriter.getIntegerType(operandType.getIntOrFloatBitWidth(),
/*isSigned=*/!doUnsigned);
actualOp = rewriter.template create<emitc::CastOp>(
op.getLoc(), correctSignednessType, actualOp);
}
auto result = rewriter.template create<emitc::CastOp>(op.getLoc(), castType,
actualOp);
// Cast to the expected output type
if (castType != opReturnType) {
result = rewriter.template create<emitc::CastOp>(op.getLoc(),
opReturnType, result);
}
rewriter.replaceOp(op, result);
return success();
}
};
template <typename ArithOp>
class UnsignedCastConversion : public CastConversion<ArithOp, true> {
using CastConversion<ArithOp, true>::CastConversion;
};
template <typename ArithOp>
class SignedCastConversion : public CastConversion<ArithOp, false> {
using CastConversion<ArithOp, false>::CastConversion;
};
template <typename ArithOp, typename EmitCOp>
class ArithOpConversion final : public OpConversionPattern<ArithOp> {
public:
@@ -313,6 +401,10 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
IntegerOpConversion<arith::SubIOp, emitc::SubOp>,
CmpIOpConversion,
SelectOpConversion,
// Truncation is guaranteed for unsigned types.
UnsignedCastConversion<arith::TruncIOp>,
SignedCastConversion<arith::ExtSIOp>,
UnsignedCastConversion<arith::ExtUIOp>,
ItoFCastOpConversion<arith::SIToFPOp>,
ItoFCastOpConversion<arith::UIToFPOp>,
FtoICastOpConversion<arith::FPToSIOp>,

View File

@@ -63,3 +63,10 @@ func.func @arith_cast_fptoui_i1(%arg0: f32) -> i1 {
return %t: i1
}
// -----
func.func @arith_extsi_i1_to_i32(%arg0: i1) {
// expected-error @+1 {{failed to legalize operation 'arith.extsi'}}
%idx = arith.extsi %arg0 : i1 to i32
return
}

View File

@@ -177,3 +177,66 @@ func.func @arith_int_to_float_cast_ops(%arg0: i8, %arg1: i64) {
return
}
// -----
func.func @arith_trunci(%arg0: i32) -> i8 {
// CHECK-LABEL: arith_trunci
// CHECK-SAME: (%[[Arg0:[^ ]*]]: i32)
// CHECK: %[[CastUI:.*]] = emitc.cast %[[Arg0]] : i32 to ui32
// CHECK: %[[Trunc:.*]] = emitc.cast %[[CastUI]] : ui32 to ui8
// CHECK: emitc.cast %[[Trunc]] : ui8 to i8
%truncd = arith.trunci %arg0 : i32 to i8
return %truncd : i8
}
// -----
func.func @arith_trunci_to_i1(%arg0: i32) -> i1 {
// CHECK-LABEL: arith_trunci_to_i1
// CHECK-SAME: (%[[Arg0:[^ ]*]]: i32)
// CHECK: %[[Const:.*]] = "emitc.constant"
// CHECK-SAME: value = 1
// CHECK: %[[And:.*]] = emitc.bitwise_and %[[Arg0]], %[[Const]] : (i32, i32) -> i32
// CHECK: emitc.cast %[[And]] : i32 to i1
%truncd = arith.trunci %arg0 : i32 to i1
return %truncd : i1
}
// -----
func.func @arith_extsi(%arg0: i32) {
// CHECK-LABEL: arith_extsi
// CHECK-SAME: ([[Arg0:[^ ]*]]: i32)
// CHECK: emitc.cast [[Arg0]] : i32 to i64
%extd = arith.extsi %arg0 : i32 to i64
return
}
// -----
func.func @arith_extui(%arg0: i32) {
// CHECK-LABEL: arith_extui
// CHECK-SAME: (%[[Arg0:[^ ]*]]: i32)
// CHECK: %[[Conv0:.*]] = emitc.cast %[[Arg0]] : i32 to ui32
// CHECK: %[[Conv1:.*]] = emitc.cast %[[Conv0]] : ui32 to ui64
// CHECK: emitc.cast %[[Conv1]] : ui64 to i64
%extd = arith.extui %arg0 : i32 to i64
return
}
// -----
func.func @arith_extui_i1_to_i32(%arg0: i1) {
// CHECK-LABEL: arith_extui_i1_to_i32
// CHECK-SAME: (%[[Arg0:[^ ]*]]: i1)
// CHECK: %[[Conv0:.*]] = emitc.cast %[[Arg0]] : i1 to ui1
// CHECK: %[[Conv1:.*]] = emitc.cast %[[Conv0]] : ui1 to ui32
// CHECK: emitc.cast %[[Conv1]] : ui32 to i32
%idx = arith.extui %arg0 : i1 to i32
return
}