[MLIR][LLVM] Add llvm.experimental.constrained.fptrunc operation (#86260)

Add operation mapping to the LLVM
`llvm.experimental.constrained.fptrunc.*` intrinsic.

The new operation implements the new
`LLVM::FPExceptionBehaviorOpInterface` and
`LLVM::RoundingModeOpInterface` interfaces.

---------

Signed-off-by: Victor Perez <victor.perez@codeplay.com>
This commit is contained in:
Victor Perez
2024-03-26 11:02:50 +01:00
committed by GitHub
parent 256343a0e9
commit 77cbc9bf60
12 changed files with 328 additions and 0 deletions

View File

@@ -705,4 +705,61 @@ def FramePointerKindEnum : LLVM_EnumAttr<
let cppNamespace = "::mlir::LLVM::framePointerKind";
}
//===----------------------------------------------------------------------===//
// RoundingMode
//===----------------------------------------------------------------------===//
// These values must match llvm::RoundingMode ones.
// See llvm/include/llvm/ADT/FloatingPointMode.h.
def RoundTowardZero
: LLVM_EnumAttrCase<"TowardZero", "towardzero", "TowardZero", 0>;
def RoundNearestTiesToEven
: LLVM_EnumAttrCase<"NearestTiesToEven", "tonearest", "NearestTiesToEven", 1>;
def RoundTowardPositive
: LLVM_EnumAttrCase<"TowardPositive", "upward", "TowardPositive", 2>;
def RoundTowardNegative
: LLVM_EnumAttrCase<"TowardNegative", "downward", "TowardNegative", 3>;
def RoundNearestTiesToAway
: LLVM_EnumAttrCase<"NearestTiesToAway", "tonearestaway", "NearestTiesToAway", 4>;
def RoundDynamic
: LLVM_EnumAttrCase<"Dynamic", "dynamic", "Dynamic", 7>;
// Needed as llvm::RoundingMode defines this.
def RoundInvalid
: LLVM_EnumAttrCase<"Invalid", "invalid", "Invalid", -1>;
// RoundingModeAttr should not be used in operations definitions.
// Use ValidRoundingModeAttr instead.
def RoundingModeAttr : LLVM_EnumAttr<
"RoundingMode",
"::llvm::RoundingMode",
"LLVM Rounding Mode",
[RoundTowardZero, RoundNearestTiesToEven, RoundTowardPositive,
RoundTowardNegative, RoundNearestTiesToAway, RoundDynamic, RoundInvalid]> {
let cppNamespace = "::mlir::LLVM";
}
def ValidRoundingModeAttr : ConfinedAttr<RoundingModeAttr, [IntMinValue<0>]>;
//===----------------------------------------------------------------------===//
// FPExceptionBehavior
//===----------------------------------------------------------------------===//
// These values must match llvm::fp::ExceptionBehavior ones.
// See llvm/include/llvm/IR/FPEnv.h.
def FPExceptionBehaviorIgnore
: LLVM_EnumAttrCase<"Ignore", "ignore", "ebIgnore", 0>;
def FPExceptionBehaviorMayTrap
: LLVM_EnumAttrCase<"MayTrap", "maytrap", "ebMayTrap", 1>;
def FPExceptionBehaviorStrict
: LLVM_EnumAttrCase<"Strict", "strict", "ebStrict", 2>;
def FPExceptionBehaviorAttr : LLVM_EnumAttr<
"FPExceptionBehavior",
"::llvm::fp::ExceptionBehavior",
"LLVM Exception Behavior",
[FPExceptionBehaviorIgnore, FPExceptionBehaviorMayTrap,
FPExceptionBehaviorStrict]> {
let cppNamespace = "::mlir::LLVM";
}
#endif // LLVMIR_ENUMS

View File

@@ -290,6 +290,73 @@ def GetResultPtrElementType : OpInterface<"GetResultPtrElementType"> {
];
}
def FPExceptionBehaviorOpInterface : OpInterface<"FPExceptionBehaviorOpInterface"> {
let description = [{
An interface for operations receiving an exception behavior attribute
controlling FP exception behavior.
}];
let cppNamespace = "::mlir::LLVM";
let methods = [
InterfaceMethod<
/*desc=*/ "Returns a FPExceptionBehavior attribute for the operation",
/*returnType=*/ "FPExceptionBehaviorAttr",
/*methodName=*/ "getFPExceptionBehaviorAttr",
/*args=*/ (ins),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
auto op = cast<ConcreteOp>(this->getOperation());
return op.getFpExceptionBehaviorAttr();
}]
>,
StaticInterfaceMethod<
/*desc=*/ [{Returns the name of the FPExceptionBehaviorAttr
attribute for the operation}],
/*returnType=*/ "StringRef",
/*methodName=*/ "getFPExceptionBehaviorAttrName",
/*args=*/ (ins),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
return "fpExceptionBehavior";
}]
>
];
}
def RoundingModeOpInterface : OpInterface<"RoundingModeOpInterface"> {
let description = [{
An interface for operations receiving a rounding mode attribute
controlling FP rounding mode.
}];
let cppNamespace = "::mlir::LLVM";
let methods = [
InterfaceMethod<
/*desc=*/ "Returns a RoundingMode attribute for the operation",
/*returnType=*/ "RoundingModeAttr",
/*methodName=*/ "getRoundingModeAttr",
/*args=*/ (ins),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
auto op = cast<ConcreteOp>(this->getOperation());
return op.getRoundingmodeAttr();
}]
>,
StaticInterfaceMethod<
/*desc=*/ [{Returns the name of the RoundingModeAttr attribute
for the operation}],
/*returnType=*/ "StringRef",
/*methodName=*/ "getRoundingModeAttrName",
/*args=*/ (ins),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
return "roundingmode";
}]
>,
];
}
//===----------------------------------------------------------------------===//
// LLVM dialect type interfaces.

View File

@@ -311,6 +311,91 @@ def LLVM_InvariantEndOp : LLVM_ZeroResultIntrOp<"invariant.end", [2],
"qualified(type($ptr))";
}
// Constrained Floating-Point Intrinsics.
class LLVM_ConstrainedIntr<string mnem, int numArgs,
bit overloadedResult, list<int> overloadedOperands,
bit hasRoundingMode>
: LLVM_OneResultIntrOp<"experimental.constrained." # mnem,
/*overloadedResults=*/
!cond(!gt(overloadedResult, 0) : [0],
true : []),
overloadedOperands,
/*traits=*/[Pure, DeclareOpInterfaceMethods<FPExceptionBehaviorOpInterface>]
# !cond(
!gt(hasRoundingMode, 0) : [DeclareOpInterfaceMethods<RoundingModeOpInterface>],
true : []),
/*requiresFastmath=*/0,
/*immArgPositions=*/[],
/*immArgAttrNames=*/[]> {
dag regularArgs = !dag(ins, !listsplat(LLVM_Type, numArgs), !foreach(i, !range(numArgs), "arg_" #i));
dag attrArgs = !con(!cond(!gt(hasRoundingMode, 0) : (ins ValidRoundingModeAttr:$roundingmode),
true : (ins)),
(ins FPExceptionBehaviorAttr:$fpExceptionBehavior));
let arguments = !con(regularArgs, attrArgs);
let llvmBuilder = [{
SmallVector<llvm::Value *> args =
moduleTranslation.lookupValues(opInst.getOperands());
SmallVector<llvm::Type *> overloadedTypes; }] #
!cond(!gt(overloadedResult, 0) : [{
// Take into account overloaded result type.
overloadedTypes.push_back($_resultType); }],
// No overloaded result type.
true : "") # [{
llvm::transform(ArrayRef<unsigned>}] # overloadedOperandsCpp # [{,
std::back_inserter(overloadedTypes),
[&args](unsigned index) { return args[index]->getType(); });
llvm::Module *module = builder.GetInsertBlock()->getModule();
llvm::Function *callee =
llvm::Intrinsic::getDeclaration(module,
llvm::Intrinsic::experimental_constrained_}] #
mnem # [{, overloadedTypes); }] #
!cond(!gt(hasRoundingMode, 0) : [{
// Get rounding mode using interface.
llvm::RoundingMode rounding =
moduleTranslation.translateRoundingMode($roundingmode); }],
true : [{
// No rounding mode.
std::optional<llvm::RoundingMode> rounding; }]) # [{
llvm::fp::ExceptionBehavior except =
moduleTranslation.translateFPExceptionBehavior($fpExceptionBehavior);
$res = builder.CreateConstrainedFPCall(callee, args, "", rounding, except);
}];
let mlirBuilder = [{
SmallVector<Value> mlirOperands;
SmallVector<NamedAttribute> mlirAttrs;
if (failed(moduleImport.convertIntrinsicArguments(
llvmOperands.take_front( }] # numArgs # [{),
{}, {}, mlirOperands, mlirAttrs))) {
return failure();
}
FPExceptionBehaviorAttr fpExceptionBehaviorAttr =
$_fpExceptionBehavior_attr($fpExceptionBehavior);
mlirAttrs.push_back(
$_builder.getNamedAttr(
$_qualCppClassName::getFPExceptionBehaviorAttrName(),
fpExceptionBehaviorAttr)); }] #
!cond(!gt(hasRoundingMode, 0) : [{
RoundingModeAttr roundingModeAttr = $_roundingMode_attr($roundingmode);
mlirAttrs.push_back(
$_builder.getNamedAttr($_qualCppClassName::getRoundingModeAttrName(),
roundingModeAttr));
}], true : "") # [{
$res = $_builder.create<$_qualCppClassName>($_location,
$_resultType, mlirOperands, mlirAttrs);
}];
}
def LLVM_ConstrainedFPTruncIntr
: LLVM_ConstrainedIntr<"fptrunc", /*numArgs=*/1,
/*overloadedResult=*/1, /*overloadedOperands=*/[0],
/*hasRoundingMode=*/1> {
let assemblyFormat = [{
$arg_0 $roundingmode $fpExceptionBehavior attr-dict `:` type($arg_0) `to` type(results)
}];
}
// Intrinsics with multiple returns.
class LLVM_ArithWithOverflowOp<string mnem>

View File

@@ -170,6 +170,10 @@ class LLVM_OpBase<Dialect dialect, string mnemonic, list<Trait> traits = []> :
// - $_float_attr - substituted by a call to a float attribute matcher;
// - $_var_attr - substituted by a call to a variable attribute matcher;
// - $_label_attr - substituted by a call to a label attribute matcher;
// - $_roundingMode_attr - substituted by a call to a rounding mode
// attribute matcher;
// - $_fpExceptionBehavior_attr - substituted by a call to a FP exception
// behavior attribute matcher;
// - $_resultType - substituted with the MLIR result type;
// - $_location - substituted with the MLIR location;
// - $_builder - substituted with the MLIR builder;

View File

@@ -152,6 +152,14 @@ public:
/// Converts `value` to a label attribute. Asserts if the matching fails.
DILabelAttr matchLabelAttr(llvm::Value *value);
/// Converts `value` to a FP exception behavior attribute. Asserts if the
/// matching fails.
FPExceptionBehaviorAttr matchFPExceptionBehaviorAttr(llvm::Value *value);
/// Converts `value` to a rounding mode attribute. Asserts if the matching
/// fails.
RoundingModeAttr matchRoundingModeAttr(llvm::Value *value);
/// Converts `value` to an array of alias scopes or returns failure if the
/// conversion fails.
FailureOr<SmallVector<AliasScopeAttr>>

View File

@@ -201,6 +201,13 @@ public:
/// Translates the given LLVM debug info metadata.
llvm::Metadata *translateDebugInfo(LLVM::DINodeAttr attr);
/// Translates the given LLVM rounding mode metadata.
llvm::RoundingMode translateRoundingMode(LLVM::RoundingMode rounding);
/// Translates the given LLVM FP exception behavior metadata.
llvm::fp::ExceptionBehavior
translateFPExceptionBehavior(LLVM::FPExceptionBehavior exceptionBehavior);
/// Translates the contents of the given block to LLVM IR using this
/// translator. The LLVM IR basic block corresponding to the given block is
/// expected to exist in the mapping of this translator. Uses `builder` to

View File

@@ -1290,6 +1290,27 @@ DILabelAttr ModuleImport::matchLabelAttr(llvm::Value *value) {
return debugImporter->translate(node);
}
FPExceptionBehaviorAttr
ModuleImport::matchFPExceptionBehaviorAttr(llvm::Value *value) {
auto *metadata = cast<llvm::MetadataAsValue>(value);
auto *mdstr = cast<llvm::MDString>(metadata->getMetadata());
std::optional<llvm::fp::ExceptionBehavior> optLLVM =
llvm::convertStrToExceptionBehavior(mdstr->getString());
assert(optLLVM && "Expecting FP exception behavior");
return builder.getAttr<FPExceptionBehaviorAttr>(
convertFPExceptionBehaviorFromLLVM(*optLLVM));
}
RoundingModeAttr ModuleImport::matchRoundingModeAttr(llvm::Value *value) {
auto *metadata = cast<llvm::MetadataAsValue>(value);
auto *mdstr = cast<llvm::MDString>(metadata->getMetadata());
std::optional<llvm::RoundingMode> optLLVM =
llvm::convertStrToRoundingMode(mdstr->getString());
assert(optLLVM && "Expecting rounding mode");
return builder.getAttr<RoundingModeAttr>(
convertRoundingModeFromLLVM(*optLLVM));
}
FailureOr<SmallVector<AliasScopeAttr>>
ModuleImport::matchAliasScopeAttrs(llvm::Value *value) {
auto *nodeAsVal = cast<llvm::MetadataAsValue>(value);

View File

@@ -1721,6 +1721,16 @@ llvm::Metadata *ModuleTranslation::translateDebugInfo(LLVM::DINodeAttr attr) {
return debugTranslation->translate(attr);
}
llvm::RoundingMode
ModuleTranslation::translateRoundingMode(LLVM::RoundingMode rounding) {
return convertRoundingModeToLLVM(rounding);
}
llvm::fp::ExceptionBehavior ModuleTranslation::translateFPExceptionBehavior(
LLVM::FPExceptionBehavior exceptionBehavior) {
return convertFPExceptionBehaviorToLLVM(exceptionBehavior);
}
llvm::NamedMDNode *
ModuleTranslation::getOrInsertNamedModuleMetadata(StringRef name) {
return llvmModule->getOrInsertNamedMetadata(name);

View File

@@ -647,3 +647,18 @@ llvm.func @experimental_noalias_scope_decl() {
llvm.intr.experimental.noalias.scope.decl #alias_scope
llvm.return
}
// CHECK-LABEL: @experimental_constrained_fptrunc
llvm.func @experimental_constrained_fptrunc(%in: f64) {
// CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} towardzero ignore : f64 to f32
%0 = llvm.intr.experimental.constrained.fptrunc %in towardzero ignore : f64 to f32
// CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} tonearest maytrap : f64 to f32
%1 = llvm.intr.experimental.constrained.fptrunc %in tonearest maytrap : f64 to f32
// CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} upward strict : f64 to f32
%2 = llvm.intr.experimental.constrained.fptrunc %in upward strict : f64 to f32
// CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} downward ignore : f64 to f32
%3 = llvm.intr.experimental.constrained.fptrunc %in downward ignore : f64 to f32
// CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} tonearestaway ignore : f64 to f32
%4 = llvm.intr.experimental.constrained.fptrunc %in tonearestaway ignore : f64 to f32
llvm.return
}

View File

@@ -894,6 +894,23 @@ define float @ssa_copy(float %0) {
ret float %2
}
; CHECK-LABEL: experimental_constrained_fptrunc
define void @experimental_constrained_fptrunc(double %s, <4 x double> %v) {
; CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} towardzero ignore : f64 to f32
%1 = call float @llvm.experimental.constrained.fptrunc.f32.f64(double %s, metadata !"round.towardzero", metadata !"fpexcept.ignore")
; CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} tonearest maytrap : f64 to f32
%2 = call float @llvm.experimental.constrained.fptrunc.f32.f64(double %s, metadata !"round.tonearest", metadata !"fpexcept.maytrap")
; CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} upward strict : f64 to f32
%3 = call float @llvm.experimental.constrained.fptrunc.f32.f64(double %s, metadata !"round.upward", metadata !"fpexcept.strict")
; CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} downward ignore : f64 to f32
%4 = call float @llvm.experimental.constrained.fptrunc.f32.f64(double %s, metadata !"round.downward", metadata !"fpexcept.ignore")
; CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} tonearestaway ignore : f64 to f32
%5 = call float @llvm.experimental.constrained.fptrunc.f32.f64(double %s, metadata !"round.tonearestaway", metadata !"fpexcept.ignore")
; CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} tonearestaway ignore : vector<4xf64> to vector<4xf16>
%6 = call <4 x half> @llvm.experimental.constrained.fptrunc.v4f16.v4f64(<4 x double> %v, metadata !"round.tonearestaway", metadata !"fpexcept.ignore")
ret void
}
declare float @llvm.fmuladd.f32(float, float, float)
declare <8 x float> @llvm.fmuladd.v8f32(<8 x float>, <8 x float>, <8 x float>)
declare float @llvm.fma.f32(float, float, float)
@@ -1120,3 +1137,5 @@ declare void @llvm.assume(i1)
declare float @llvm.ssa.copy.f32(float returned)
declare <vscale x 4 x float> @llvm.vector.insert.nxv4f32.v4f32(<vscale x 4 x float>, <4 x float>, i64)
declare <4 x float> @llvm.vector.extract.v4f32.nxv4f32(<vscale x 4 x float>, i64)
declare <4 x half> @llvm.experimental.constrained.fptrunc.v4f16.v4f64(<4 x double>, metadata, metadata)
declare float @llvm.experimental.constrained.fptrunc.f32.f64(double, metadata, metadata)

View File

@@ -964,6 +964,35 @@ llvm.func @ssa_copy(%arg: f32) -> f32 {
llvm.return %0 : f32
}
// CHECK-LABEL: @experimental_constrained_fptrunc
llvm.func @experimental_constrained_fptrunc(%s: f64, %v: vector<4xf32>) {
// CHECK: call float @llvm.experimental.constrained.fptrunc.f32.f64(
// CHECK: metadata !"round.towardzero"
// CHECK: metadata !"fpexcept.ignore"
%0 = llvm.intr.experimental.constrained.fptrunc %s towardzero ignore : f64 to f32
// CHECK: call float @llvm.experimental.constrained.fptrunc.f32.f64(
// CHECK: metadata !"round.tonearest"
// CHECK: metadata !"fpexcept.maytrap"
%1 = llvm.intr.experimental.constrained.fptrunc %s tonearest maytrap : f64 to f32
// CHECK: call float @llvm.experimental.constrained.fptrunc.f32.f64(
// CHECK: metadata !"round.upward"
// CHECK: metadata !"fpexcept.strict"
%2 = llvm.intr.experimental.constrained.fptrunc %s upward strict : f64 to f32
// CHECK: call float @llvm.experimental.constrained.fptrunc.f32.f64(
// CHECK: metadata !"round.downward"
// CHECK: metadata !"fpexcept.ignore"
%3 = llvm.intr.experimental.constrained.fptrunc %s downward ignore : f64 to f32
// CHECK: call float @llvm.experimental.constrained.fptrunc.f32.f64(
// CHECK: metadata !"round.tonearestaway"
// CHECK: metadata !"fpexcept.ignore"
%4 = llvm.intr.experimental.constrained.fptrunc %s tonearestaway ignore : f64 to f32
// CHECK: call <4 x half> @llvm.experimental.constrained.fptrunc.v4f16.v4f32(
// CHECK: metadata !"round.upward"
// CHECK: metadata !"fpexcept.strict"
%5 = llvm.intr.experimental.constrained.fptrunc %v upward strict : vector<4xf32> to vector<4xf16>
llvm.return
}
// Check that intrinsics are declared with appropriate types.
// CHECK-DAG: declare float @llvm.fma.f32(float, float, float)
// CHECK-DAG: declare <8 x float> @llvm.fma.v8f32(<8 x float>, <8 x float>, <8 x float>) #0
@@ -1126,3 +1155,5 @@ llvm.func @ssa_copy(%arg: f32) -> f32 {
// CHECK-DAG: declare ptr addrspace(1) @llvm.stacksave.p1()
// CHECK-DAG: declare void @llvm.stackrestore.p0(ptr)
// CHECK-DAG: declare void @llvm.stackrestore.p1(ptr addrspace(1))
// CHECK-DAG: declare float @llvm.experimental.constrained.fptrunc.f32.f64(double, metadata, metadata)
// CHECK-DAG: declare <4 x half> @llvm.experimental.constrained.fptrunc.v4f16.v4f32(<4 x float>, metadata, metadata)

View File

@@ -272,6 +272,10 @@ static LogicalResult emitOneMLIRBuilder(const Record &record, raw_ostream &os,
bs << "moduleImport.matchLocalVariableAttr";
} else if (name == "_label_attr") {
bs << "moduleImport.matchLabelAttr";
} else if (name == "_fpExceptionBehavior_attr") {
bs << "moduleImport.matchFPExceptionBehaviorAttr";
} else if (name == "_roundingMode_attr") {
bs << "moduleImport.matchRoundingModeAttr";
} else if (name == "_resultType") {
bs << "moduleImport.convertType(inst->getType())";
} else if (name == "_location") {