mirror of
https://github.com/intel/llvm.git
synced 2026-01-25 10:55:58 +08:00
[mlir][LLVM] Add exact flag (#115327)
The implementation is mostly based on the one existing for the nsw and nuw flags. If the exact flag is present, the corresponding operation returns a poison value when the result is not exact. (For a division, if rounding happens; for a right shift, if a non-zero bit is shifted out.)
This commit is contained in:
@@ -87,6 +87,33 @@ def IntegerOverflowFlagsInterface : OpInterface<"IntegerOverflowFlagsInterface">
|
||||
];
|
||||
}
|
||||
|
||||
def ExactFlagInterface : OpInterface<"ExactFlagInterface"> {
|
||||
let description = [{
|
||||
This interface defines an LLVM operation with an exact flag and
|
||||
provides a uniform API for accessing it.
|
||||
}];
|
||||
|
||||
let cppNamespace = "::mlir::LLVM";
|
||||
|
||||
let methods = [
|
||||
InterfaceMethod<[{
|
||||
Get the exact flag for the operation.
|
||||
}], "bool", "getIsExact", (ins), [{}], [{
|
||||
return $_op.getProperties().isExact;
|
||||
}]>,
|
||||
InterfaceMethod<[{
|
||||
Set the exact flag for the operation.
|
||||
}], "void", "setIsExact", (ins "bool":$isExact), [{}], [{
|
||||
$_op.getProperties().isExact = isExact;
|
||||
}]>,
|
||||
StaticInterfaceMethod<[{
|
||||
Get the attribute name of the isExact property.
|
||||
}], "StringRef", "getIsExactName", (ins), [{}], [{
|
||||
return "isExact";
|
||||
}]>,
|
||||
];
|
||||
}
|
||||
|
||||
def BranchWeightOpInterface : OpInterface<"BranchWeightOpInterface"> {
|
||||
let description = [{
|
||||
An interface for operations that can carry branch weights metadata. It
|
||||
|
||||
@@ -76,6 +76,24 @@ class LLVM_IntArithmeticOpWithOverflowFlag<string mnemonic, string instName,
|
||||
"$res = builder.Create" # instName #
|
||||
"($lhs, $rhs, /*Name=*/\"\", op.hasNoUnsignedWrap(), op.hasNoSignedWrap());";
|
||||
}
|
||||
class LLVM_IntArithmeticOpWithExactFlag<string mnemonic, string instName,
|
||||
list<Trait> traits = []> :
|
||||
LLVM_ArithmeticOpBase<AnySignlessInteger, mnemonic, instName,
|
||||
!listconcat([DeclareOpInterfaceMethods<ExactFlagInterface>], traits)> {
|
||||
let arguments = !con(commonArgs, (ins UnitAttr:$isExact));
|
||||
|
||||
string mlirBuilder = [{
|
||||
auto op = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs);
|
||||
moduleImport.setExactFlag(inst, op);
|
||||
$res = op;
|
||||
}];
|
||||
let assemblyFormat = [{
|
||||
(`exact` $isExact^)? $lhs `,` $rhs custom<LLVMOpAttrs>(attr-dict) `:` type($res)
|
||||
}];
|
||||
string llvmBuilder =
|
||||
"$res = builder.Create" # instName #
|
||||
"($lhs, $rhs, /*Name=*/\"\", op.getIsExact());";
|
||||
}
|
||||
class LLVM_FloatArithmeticOp<string mnemonic, string instName,
|
||||
list<Trait> traits = []> :
|
||||
LLVM_ArithmeticOpBase<LLVM_AnyFloat, mnemonic, instName,
|
||||
@@ -116,8 +134,8 @@ def LLVM_AddOp : LLVM_IntArithmeticOpWithOverflowFlag<"add", "Add",
|
||||
def LLVM_SubOp : LLVM_IntArithmeticOpWithOverflowFlag<"sub", "Sub", []>;
|
||||
def LLVM_MulOp : LLVM_IntArithmeticOpWithOverflowFlag<"mul", "Mul",
|
||||
[Commutative]>;
|
||||
def LLVM_UDivOp : LLVM_IntArithmeticOp<"udiv", "UDiv">;
|
||||
def LLVM_SDivOp : LLVM_IntArithmeticOp<"sdiv", "SDiv">;
|
||||
def LLVM_UDivOp : LLVM_IntArithmeticOpWithExactFlag<"udiv", "UDiv">;
|
||||
def LLVM_SDivOp : LLVM_IntArithmeticOpWithExactFlag<"sdiv", "SDiv">;
|
||||
def LLVM_URemOp : LLVM_IntArithmeticOp<"urem", "URem">;
|
||||
def LLVM_SRemOp : LLVM_IntArithmeticOp<"srem", "SRem">;
|
||||
def LLVM_AndOp : LLVM_IntArithmeticOp<"and", "And">;
|
||||
@@ -128,8 +146,8 @@ def LLVM_XOrOp : LLVM_IntArithmeticOp<"xor", "Xor">;
|
||||
def LLVM_ShlOp : LLVM_IntArithmeticOpWithOverflowFlag<"shl", "Shl", []> {
|
||||
let hasFolder = 1;
|
||||
}
|
||||
def LLVM_LShrOp : LLVM_IntArithmeticOp<"lshr", "LShr">;
|
||||
def LLVM_AShrOp : LLVM_IntArithmeticOp<"ashr", "AShr">;
|
||||
def LLVM_LShrOp : LLVM_IntArithmeticOpWithExactFlag<"lshr", "LShr">;
|
||||
def LLVM_AShrOp : LLVM_IntArithmeticOpWithExactFlag<"ashr", "AShr">;
|
||||
|
||||
// Base class for compare operations. A compare operation takes two operands
|
||||
// of the same type and returns a boolean result. If the operands are
|
||||
|
||||
@@ -187,6 +187,11 @@ public:
|
||||
/// operation does not implement the integer overflow flag interface.
|
||||
void setIntegerOverflowFlags(llvm::Instruction *inst, Operation *op) const;
|
||||
|
||||
/// Sets the exact flag attribute for the imported operation `op` given
|
||||
/// the original instruction `inst`. Asserts if the operation does not
|
||||
/// implement the exact flag interface.
|
||||
void setExactFlag(llvm::Instruction *inst, Operation *op) const;
|
||||
|
||||
/// Sets the fastmath flags attribute for the imported operation `op` given
|
||||
/// the original instruction `inst`. Asserts if the operation does not
|
||||
/// implement the fastmath interface.
|
||||
|
||||
@@ -143,6 +143,9 @@ static void printLLVMOpAttrs(OpAsmPrinter &printer, Operation *op,
|
||||
if (auto iface = dyn_cast<IntegerOverflowFlagsInterface>(op)) {
|
||||
printer.printOptionalAttrDict(
|
||||
filteredAttrs, /*elidedAttrs=*/{iface.getOverflowFlagsAttrName()});
|
||||
} else if (auto iface = dyn_cast<ExactFlagInterface>(op)) {
|
||||
printer.printOptionalAttrDict(filteredAttrs,
|
||||
/*elidedAttrs=*/{iface.getIsExactName()});
|
||||
} else {
|
||||
printer.printOptionalAttrDict(filteredAttrs);
|
||||
}
|
||||
|
||||
@@ -683,6 +683,12 @@ void ModuleImport::setIntegerOverflowFlags(llvm::Instruction *inst,
|
||||
iface.setOverflowFlags(value);
|
||||
}
|
||||
|
||||
void ModuleImport::setExactFlag(llvm::Instruction *inst, Operation *op) const {
|
||||
auto iface = cast<ExactFlagInterface>(op);
|
||||
|
||||
iface.setIsExact(inst->isExact());
|
||||
}
|
||||
|
||||
void ModuleImport::setFastmathFlagsAttr(llvm::Instruction *inst,
|
||||
Operation *op) const {
|
||||
auto iface = cast<FastmathFlagsInterface>(op);
|
||||
|
||||
@@ -49,6 +49,16 @@ func.func @ops(%arg0: i32, %arg1: f32,
|
||||
%mul_flag = llvm.mul %arg0, %arg0 overflow<nsw, nuw> : i32
|
||||
%shl_flag = llvm.shl %arg0, %arg0 overflow<nuw, nsw> : i32
|
||||
|
||||
// Integer exact flag.
|
||||
// CHECK: {{.*}} = llvm.sdiv exact %[[I32]], %[[I32]] : i32
|
||||
// CHECK: {{.*}} = llvm.udiv exact %[[I32]], %[[I32]] : i32
|
||||
// CHECK: {{.*}} = llvm.ashr exact %[[I32]], %[[I32]] : i32
|
||||
// CHECK: {{.*}} = llvm.lshr exact %[[I32]], %[[I32]] : i32
|
||||
%sdiv_flag = llvm.sdiv exact %arg0, %arg0 : i32
|
||||
%udiv_flag = llvm.udiv exact %arg0, %arg0 : i32
|
||||
%ashr_flag = llvm.ashr exact %arg0, %arg0 : i32
|
||||
%lshr_flag = llvm.lshr exact %arg0, %arg0 : i32
|
||||
|
||||
// Floating point binary operations.
|
||||
//
|
||||
// CHECK: {{.*}} = llvm.fadd %[[FLOAT]], %[[FLOAT]] : f32
|
||||
|
||||
14
mlir/test/Target/LLVMIR/Import/exact.ll
Normal file
14
mlir/test/Target/LLVMIR/Import/exact.ll
Normal file
@@ -0,0 +1,14 @@
|
||||
; RUN: mlir-translate -import-llvm -split-input-file %s | FileCheck %s
|
||||
|
||||
; CHECK-LABEL: @exactflag_inst
|
||||
define void @exactflag_inst(i64 %arg1, i64 %arg2) {
|
||||
; CHECK: llvm.udiv exact %{{.*}}, %{{.*}} : i64
|
||||
%1 = udiv exact i64 %arg1, %arg2
|
||||
; CHECK: llvm.sdiv exact %{{.*}}, %{{.*}} : i64
|
||||
%2 = sdiv exact i64 %arg1, %arg2
|
||||
; CHECK: llvm.lshr exact %{{.*}}, %{{.*}} : i64
|
||||
%3 = lshr exact i64 %arg1, %arg2
|
||||
; CHECK: llvm.ashr exact %{{.*}}, %{{.*}} : i64
|
||||
%4 = ashr exact i64 %arg1, %arg2
|
||||
ret void
|
||||
}
|
||||
14
mlir/test/Target/LLVMIR/exact.mlir
Normal file
14
mlir/test/Target/LLVMIR/exact.mlir
Normal file
@@ -0,0 +1,14 @@
|
||||
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: define void @exactflag_func
|
||||
llvm.func @exactflag_func(%arg0: i64, %arg1: i64) {
|
||||
// CHECK: %{{.*}} = udiv exact i64 %{{.*}}, %{{.*}}
|
||||
%0 = llvm.udiv exact %arg0, %arg1 : i64
|
||||
// CHECK: %{{.*}} = sdiv exact i64 %{{.*}}, %{{.*}}
|
||||
%1 = llvm.sdiv exact %arg0, %arg1 : i64
|
||||
// CHECK: %{{.*}} = lshr exact i64 %{{.*}}, %{{.*}}
|
||||
%2 = llvm.lshr exact %arg0, %arg1 : i64
|
||||
// CHECK: %{{.*}} = ashr exact i64 %{{.*}}, %{{.*}}
|
||||
%3 = llvm.ashr exact %arg0, %arg1 : i64
|
||||
llvm.return
|
||||
}
|
||||
Reference in New Issue
Block a user