From afa178d36017ab565c33a8639be16355a054b95b Mon Sep 17 00:00:00 2001 From: lfrenot Date: Fri, 8 Nov 2024 12:56:44 +0000 Subject: [PATCH] [mlir][LLVM] Add exact flag (#115327) The implementation is mostly based on the one existing for the nsw and nuw flags. If the exact flag is present, the corresponding operation returns a poison value when the result is not exact. (For a division, if rounding happens; for a right shift, if a non-zero bit is shifted out.) --- .../mlir/Dialect/LLVMIR/LLVMInterfaces.td | 27 +++++++++++++++++++ mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td | 26 +++++++++++++++--- .../include/mlir/Target/LLVMIR/ModuleImport.h | 5 ++++ mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 3 +++ mlir/lib/Target/LLVMIR/ModuleImport.cpp | 6 +++++ mlir/test/Dialect/LLVMIR/roundtrip.mlir | 10 +++++++ mlir/test/Target/LLVMIR/Import/exact.ll | 14 ++++++++++ mlir/test/Target/LLVMIR/exact.mlir | 14 ++++++++++ 8 files changed, 101 insertions(+), 4 deletions(-) create mode 100644 mlir/test/Target/LLVMIR/Import/exact.ll create mode 100644 mlir/test/Target/LLVMIR/exact.mlir diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td index 7e38e0b27fd9..12c430df2089 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td @@ -87,6 +87,33 @@ def IntegerOverflowFlagsInterface : OpInterface<"IntegerOverflowFlagsInterface"> ]; } +def ExactFlagInterface : OpInterface<"ExactFlagInterface"> { + let description = [{ + This interface defines an LLVM operation with an exact flag and + provides a uniform API for accessing it. + }]; + + let cppNamespace = "::mlir::LLVM"; + + let methods = [ + InterfaceMethod<[{ + Get the exact flag for the operation. + }], "bool", "getIsExact", (ins), [{}], [{ + return $_op.getProperties().isExact; + }]>, + InterfaceMethod<[{ + Set the exact flag for the operation. + }], "void", "setIsExact", (ins "bool":$isExact), [{}], [{ + $_op.getProperties().isExact = isExact; + }]>, + StaticInterfaceMethod<[{ + Get the attribute name of the isExact property. + }], "StringRef", "getIsExactName", (ins), [{}], [{ + return "isExact"; + }]>, + ]; +} + def BranchWeightOpInterface : OpInterface<"BranchWeightOpInterface"> { let description = [{ An interface for operations that can carry branch weights metadata. It diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index d5def510a904..315af2594047 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -76,6 +76,24 @@ class LLVM_IntArithmeticOpWithOverflowFlag traits = []> : + LLVM_ArithmeticOpBase], traits)> { + let arguments = !con(commonArgs, (ins UnitAttr:$isExact)); + + string mlirBuilder = [{ + auto op = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs); + moduleImport.setExactFlag(inst, op); + $res = op; + }]; + let assemblyFormat = [{ + (`exact` $isExact^)? $lhs `,` $rhs custom(attr-dict) `:` type($res) + }]; + string llvmBuilder = + "$res = builder.Create" # instName # + "($lhs, $rhs, /*Name=*/\"\", op.getIsExact());"; +} class LLVM_FloatArithmeticOp traits = []> : LLVM_ArithmeticOpBase; def LLVM_MulOp : LLVM_IntArithmeticOpWithOverflowFlag<"mul", "Mul", [Commutative]>; -def LLVM_UDivOp : LLVM_IntArithmeticOp<"udiv", "UDiv">; -def LLVM_SDivOp : LLVM_IntArithmeticOp<"sdiv", "SDiv">; +def LLVM_UDivOp : LLVM_IntArithmeticOpWithExactFlag<"udiv", "UDiv">; +def LLVM_SDivOp : LLVM_IntArithmeticOpWithExactFlag<"sdiv", "SDiv">; def LLVM_URemOp : LLVM_IntArithmeticOp<"urem", "URem">; def LLVM_SRemOp : LLVM_IntArithmeticOp<"srem", "SRem">; def LLVM_AndOp : LLVM_IntArithmeticOp<"and", "And">; @@ -128,8 +146,8 @@ def LLVM_XOrOp : LLVM_IntArithmeticOp<"xor", "Xor">; def LLVM_ShlOp : LLVM_IntArithmeticOpWithOverflowFlag<"shl", "Shl", []> { let hasFolder = 1; } -def LLVM_LShrOp : LLVM_IntArithmeticOp<"lshr", "LShr">; -def LLVM_AShrOp : LLVM_IntArithmeticOp<"ashr", "AShr">; +def LLVM_LShrOp : LLVM_IntArithmeticOpWithExactFlag<"lshr", "LShr">; +def LLVM_AShrOp : LLVM_IntArithmeticOpWithExactFlag<"ashr", "AShr">; // Base class for compare operations. A compare operation takes two operands // of the same type and returns a boolean result. If the operands are diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h index bbb7af58d273..6c3a500f20e3 100644 --- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h @@ -187,6 +187,11 @@ public: /// operation does not implement the integer overflow flag interface. void setIntegerOverflowFlags(llvm::Instruction *inst, Operation *op) const; + /// Sets the exact flag attribute for the imported operation `op` given + /// the original instruction `inst`. Asserts if the operation does not + /// implement the exact flag interface. + void setExactFlag(llvm::Instruction *inst, Operation *op) const; + /// Sets the fastmath flags attribute for the imported operation `op` given /// the original instruction `inst`. Asserts if the operation does not /// implement the fastmath interface. diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index c9bc9533ca2a..6b2d8943bf48 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -143,6 +143,9 @@ static void printLLVMOpAttrs(OpAsmPrinter &printer, Operation *op, if (auto iface = dyn_cast(op)) { printer.printOptionalAttrDict( filteredAttrs, /*elidedAttrs=*/{iface.getOverflowFlagsAttrName()}); + } else if (auto iface = dyn_cast(op)) { + printer.printOptionalAttrDict(filteredAttrs, + /*elidedAttrs=*/{iface.getIsExactName()}); } else { printer.printOptionalAttrDict(filteredAttrs); } diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp index 3a61a373ecae..12145f7a2217 100644 --- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp @@ -683,6 +683,12 @@ void ModuleImport::setIntegerOverflowFlags(llvm::Instruction *inst, iface.setOverflowFlags(value); } +void ModuleImport::setExactFlag(llvm::Instruction *inst, Operation *op) const { + auto iface = cast(op); + + iface.setIsExact(inst->isExact()); +} + void ModuleImport::setFastmathFlagsAttr(llvm::Instruction *inst, Operation *op) const { auto iface = cast(op); diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir index b8ce7db795a1..682780c5f0a7 100644 --- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir +++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir @@ -49,6 +49,16 @@ func.func @ops(%arg0: i32, %arg1: f32, %mul_flag = llvm.mul %arg0, %arg0 overflow : i32 %shl_flag = llvm.shl %arg0, %arg0 overflow : i32 +// Integer exact flag. +// CHECK: {{.*}} = llvm.sdiv exact %[[I32]], %[[I32]] : i32 +// CHECK: {{.*}} = llvm.udiv exact %[[I32]], %[[I32]] : i32 +// CHECK: {{.*}} = llvm.ashr exact %[[I32]], %[[I32]] : i32 +// CHECK: {{.*}} = llvm.lshr exact %[[I32]], %[[I32]] : i32 + %sdiv_flag = llvm.sdiv exact %arg0, %arg0 : i32 + %udiv_flag = llvm.udiv exact %arg0, %arg0 : i32 + %ashr_flag = llvm.ashr exact %arg0, %arg0 : i32 + %lshr_flag = llvm.lshr exact %arg0, %arg0 : i32 + // Floating point binary operations. // // CHECK: {{.*}} = llvm.fadd %[[FLOAT]], %[[FLOAT]] : f32 diff --git a/mlir/test/Target/LLVMIR/Import/exact.ll b/mlir/test/Target/LLVMIR/Import/exact.ll new file mode 100644 index 000000000000..528fee5091d2 --- /dev/null +++ b/mlir/test/Target/LLVMIR/Import/exact.ll @@ -0,0 +1,14 @@ +; RUN: mlir-translate -import-llvm -split-input-file %s | FileCheck %s + +; CHECK-LABEL: @exactflag_inst +define void @exactflag_inst(i64 %arg1, i64 %arg2) { + ; CHECK: llvm.udiv exact %{{.*}}, %{{.*}} : i64 + %1 = udiv exact i64 %arg1, %arg2 + ; CHECK: llvm.sdiv exact %{{.*}}, %{{.*}} : i64 + %2 = sdiv exact i64 %arg1, %arg2 + ; CHECK: llvm.lshr exact %{{.*}}, %{{.*}} : i64 + %3 = lshr exact i64 %arg1, %arg2 + ; CHECK: llvm.ashr exact %{{.*}}, %{{.*}} : i64 + %4 = ashr exact i64 %arg1, %arg2 + ret void +} diff --git a/mlir/test/Target/LLVMIR/exact.mlir b/mlir/test/Target/LLVMIR/exact.mlir new file mode 100644 index 000000000000..b6c378c2fdcc --- /dev/null +++ b/mlir/test/Target/LLVMIR/exact.mlir @@ -0,0 +1,14 @@ +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s + +// CHECK-LABEL: define void @exactflag_func +llvm.func @exactflag_func(%arg0: i64, %arg1: i64) { + // CHECK: %{{.*}} = udiv exact i64 %{{.*}}, %{{.*}} + %0 = llvm.udiv exact %arg0, %arg1 : i64 + // CHECK: %{{.*}} = sdiv exact i64 %{{.*}}, %{{.*}} + %1 = llvm.sdiv exact %arg0, %arg1 : i64 + // CHECK: %{{.*}} = lshr exact i64 %{{.*}}, %{{.*}} + %2 = llvm.lshr exact %arg0, %arg1 : i64 + // CHECK: %{{.*}} = ashr exact i64 %{{.*}}, %{{.*}} + %3 = llvm.ashr exact %arg0, %arg1 : i64 + llvm.return +}