[mlir][LLVM] Add exact flag (#115327)

The implementation is mostly based on the one existing for the nsw and
nuw flags.

If the exact flag is present, the corresponding operation returns a
poison value when the result is not exact. (For a division, if rounding
happens; for a right shift, if a non-zero bit is shifted out.)
This commit is contained in:
lfrenot
2024-11-08 12:56:44 +00:00
committed by GitHub
parent 724b432410
commit afa178d360
8 changed files with 101 additions and 4 deletions

View File

@@ -87,6 +87,33 @@ def IntegerOverflowFlagsInterface : OpInterface<"IntegerOverflowFlagsInterface">
];
}
def ExactFlagInterface : OpInterface<"ExactFlagInterface"> {
let description = [{
This interface defines an LLVM operation with an exact flag and
provides a uniform API for accessing it.
}];
let cppNamespace = "::mlir::LLVM";
let methods = [
InterfaceMethod<[{
Get the exact flag for the operation.
}], "bool", "getIsExact", (ins), [{}], [{
return $_op.getProperties().isExact;
}]>,
InterfaceMethod<[{
Set the exact flag for the operation.
}], "void", "setIsExact", (ins "bool":$isExact), [{}], [{
$_op.getProperties().isExact = isExact;
}]>,
StaticInterfaceMethod<[{
Get the attribute name of the isExact property.
}], "StringRef", "getIsExactName", (ins), [{}], [{
return "isExact";
}]>,
];
}
def BranchWeightOpInterface : OpInterface<"BranchWeightOpInterface"> {
let description = [{
An interface for operations that can carry branch weights metadata. It

View File

@@ -76,6 +76,24 @@ class LLVM_IntArithmeticOpWithOverflowFlag<string mnemonic, string instName,
"$res = builder.Create" # instName #
"($lhs, $rhs, /*Name=*/\"\", op.hasNoUnsignedWrap(), op.hasNoSignedWrap());";
}
class LLVM_IntArithmeticOpWithExactFlag<string mnemonic, string instName,
list<Trait> traits = []> :
LLVM_ArithmeticOpBase<AnySignlessInteger, mnemonic, instName,
!listconcat([DeclareOpInterfaceMethods<ExactFlagInterface>], traits)> {
let arguments = !con(commonArgs, (ins UnitAttr:$isExact));
string mlirBuilder = [{
auto op = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs);
moduleImport.setExactFlag(inst, op);
$res = op;
}];
let assemblyFormat = [{
(`exact` $isExact^)? $lhs `,` $rhs custom<LLVMOpAttrs>(attr-dict) `:` type($res)
}];
string llvmBuilder =
"$res = builder.Create" # instName #
"($lhs, $rhs, /*Name=*/\"\", op.getIsExact());";
}
class LLVM_FloatArithmeticOp<string mnemonic, string instName,
list<Trait> traits = []> :
LLVM_ArithmeticOpBase<LLVM_AnyFloat, mnemonic, instName,
@@ -116,8 +134,8 @@ def LLVM_AddOp : LLVM_IntArithmeticOpWithOverflowFlag<"add", "Add",
def LLVM_SubOp : LLVM_IntArithmeticOpWithOverflowFlag<"sub", "Sub", []>;
def LLVM_MulOp : LLVM_IntArithmeticOpWithOverflowFlag<"mul", "Mul",
[Commutative]>;
def LLVM_UDivOp : LLVM_IntArithmeticOp<"udiv", "UDiv">;
def LLVM_SDivOp : LLVM_IntArithmeticOp<"sdiv", "SDiv">;
def LLVM_UDivOp : LLVM_IntArithmeticOpWithExactFlag<"udiv", "UDiv">;
def LLVM_SDivOp : LLVM_IntArithmeticOpWithExactFlag<"sdiv", "SDiv">;
def LLVM_URemOp : LLVM_IntArithmeticOp<"urem", "URem">;
def LLVM_SRemOp : LLVM_IntArithmeticOp<"srem", "SRem">;
def LLVM_AndOp : LLVM_IntArithmeticOp<"and", "And">;
@@ -128,8 +146,8 @@ def LLVM_XOrOp : LLVM_IntArithmeticOp<"xor", "Xor">;
def LLVM_ShlOp : LLVM_IntArithmeticOpWithOverflowFlag<"shl", "Shl", []> {
let hasFolder = 1;
}
def LLVM_LShrOp : LLVM_IntArithmeticOp<"lshr", "LShr">;
def LLVM_AShrOp : LLVM_IntArithmeticOp<"ashr", "AShr">;
def LLVM_LShrOp : LLVM_IntArithmeticOpWithExactFlag<"lshr", "LShr">;
def LLVM_AShrOp : LLVM_IntArithmeticOpWithExactFlag<"ashr", "AShr">;
// Base class for compare operations. A compare operation takes two operands
// of the same type and returns a boolean result. If the operands are

View File

@@ -187,6 +187,11 @@ public:
/// operation does not implement the integer overflow flag interface.
void setIntegerOverflowFlags(llvm::Instruction *inst, Operation *op) const;
/// Sets the exact flag attribute for the imported operation `op` given
/// the original instruction `inst`. Asserts if the operation does not
/// implement the exact flag interface.
void setExactFlag(llvm::Instruction *inst, Operation *op) const;
/// Sets the fastmath flags attribute for the imported operation `op` given
/// the original instruction `inst`. Asserts if the operation does not
/// implement the fastmath interface.

View File

@@ -143,6 +143,9 @@ static void printLLVMOpAttrs(OpAsmPrinter &printer, Operation *op,
if (auto iface = dyn_cast<IntegerOverflowFlagsInterface>(op)) {
printer.printOptionalAttrDict(
filteredAttrs, /*elidedAttrs=*/{iface.getOverflowFlagsAttrName()});
} else if (auto iface = dyn_cast<ExactFlagInterface>(op)) {
printer.printOptionalAttrDict(filteredAttrs,
/*elidedAttrs=*/{iface.getIsExactName()});
} else {
printer.printOptionalAttrDict(filteredAttrs);
}

View File

@@ -683,6 +683,12 @@ void ModuleImport::setIntegerOverflowFlags(llvm::Instruction *inst,
iface.setOverflowFlags(value);
}
void ModuleImport::setExactFlag(llvm::Instruction *inst, Operation *op) const {
auto iface = cast<ExactFlagInterface>(op);
iface.setIsExact(inst->isExact());
}
void ModuleImport::setFastmathFlagsAttr(llvm::Instruction *inst,
Operation *op) const {
auto iface = cast<FastmathFlagsInterface>(op);

View File

@@ -49,6 +49,16 @@ func.func @ops(%arg0: i32, %arg1: f32,
%mul_flag = llvm.mul %arg0, %arg0 overflow<nsw, nuw> : i32
%shl_flag = llvm.shl %arg0, %arg0 overflow<nuw, nsw> : i32
// Integer exact flag.
// CHECK: {{.*}} = llvm.sdiv exact %[[I32]], %[[I32]] : i32
// CHECK: {{.*}} = llvm.udiv exact %[[I32]], %[[I32]] : i32
// CHECK: {{.*}} = llvm.ashr exact %[[I32]], %[[I32]] : i32
// CHECK: {{.*}} = llvm.lshr exact %[[I32]], %[[I32]] : i32
%sdiv_flag = llvm.sdiv exact %arg0, %arg0 : i32
%udiv_flag = llvm.udiv exact %arg0, %arg0 : i32
%ashr_flag = llvm.ashr exact %arg0, %arg0 : i32
%lshr_flag = llvm.lshr exact %arg0, %arg0 : i32
// Floating point binary operations.
//
// CHECK: {{.*}} = llvm.fadd %[[FLOAT]], %[[FLOAT]] : f32

View File

@@ -0,0 +1,14 @@
; RUN: mlir-translate -import-llvm -split-input-file %s | FileCheck %s
; CHECK-LABEL: @exactflag_inst
define void @exactflag_inst(i64 %arg1, i64 %arg2) {
; CHECK: llvm.udiv exact %{{.*}}, %{{.*}} : i64
%1 = udiv exact i64 %arg1, %arg2
; CHECK: llvm.sdiv exact %{{.*}}, %{{.*}} : i64
%2 = sdiv exact i64 %arg1, %arg2
; CHECK: llvm.lshr exact %{{.*}}, %{{.*}} : i64
%3 = lshr exact i64 %arg1, %arg2
; CHECK: llvm.ashr exact %{{.*}}, %{{.*}} : i64
%4 = ashr exact i64 %arg1, %arg2
ret void
}

View File

@@ -0,0 +1,14 @@
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
// CHECK-LABEL: define void @exactflag_func
llvm.func @exactflag_func(%arg0: i64, %arg1: i64) {
// CHECK: %{{.*}} = udiv exact i64 %{{.*}}, %{{.*}}
%0 = llvm.udiv exact %arg0, %arg1 : i64
// CHECK: %{{.*}} = sdiv exact i64 %{{.*}}, %{{.*}}
%1 = llvm.sdiv exact %arg0, %arg1 : i64
// CHECK: %{{.*}} = lshr exact i64 %{{.*}}, %{{.*}}
%2 = llvm.lshr exact %arg0, %arg1 : i64
// CHECK: %{{.*}} = ashr exact i64 %{{.*}}, %{{.*}}
%3 = llvm.ashr exact %arg0, %arg1 : i64
llvm.return
}