[mlir][llvm] Add branch weight op interface

This revision adds a branch weight op interface for the call / branch
operations that support branch weights. It can be used in the LLVM IR
import and export to simplify the branch weight conversion. An
additional mapping between call operations and instructions ensures
the actual conversion can be done in the module translation itself,
rather than in the dialect translation interface. It also has the
benefit that downstream users can amend custom metadata to the call
operation during the export to LLVM IR.

Reviewed By: zero9178, definelicht

Differential Revision: https://reviews.llvm.org/D155702
This commit is contained in:
Tobias Gysi
2023-07-20 08:13:17 +00:00
parent 7d4e14c76b
commit 10fa27704b
12 changed files with 131 additions and 73 deletions

View File

@@ -30,7 +30,7 @@ def FastmathFlagsInterface : OpInterface<"FastmathFlagsInterface"> {
/*args=*/ (ins),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
ConcreteOp op = cast<ConcreteOp>(this->getOperation());
auto op = cast<ConcreteOp>(this->getOperation());
return op.getFastmathFlagsAttr();
}]
>,
@@ -48,6 +48,42 @@ def FastmathFlagsInterface : OpInterface<"FastmathFlagsInterface"> {
];
}
def BranchWeightOpInterface : OpInterface<"BranchWeightOpInterface"> {
let description = [{
An interface for operations that can carry branch weights metadata. It
provides setters and getters for the operation's branch weights attribute.
The default implementation of the interface methods expect the operation to
have an attribute of type DenseI32ArrayAttr named branch_weights.
}];
let cppNamespace = "::mlir::LLVM";
let methods = [
InterfaceMethod<
/*desc=*/ "Returns the branch weights attribute or nullptr",
/*returnType=*/ "DenseI32ArrayAttr",
/*methodName=*/ "getBranchWeightsOrNull",
/*args=*/ (ins),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
auto op = cast<ConcreteOp>(this->getOperation());
return op.getBranchWeightsAttr();
}]
>,
InterfaceMethod<
/*desc=*/ "Sets the branch weights attribute",
/*returnType=*/ "void",
/*methodName=*/ "setBranchWeights",
/*args=*/ (ins "DenseI32ArrayAttr":$attr),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
auto op = cast<ConcreteOp>(this->getOperation());
op.setBranchWeightsAttr(attr);
}]
>
];
}
def AccessGroupOpInterface : OpInterface<"AccessGroupOpInterface"> {
let description = [{
An interface for memory operations that can carry access groups metadata.
@@ -67,7 +103,7 @@ def AccessGroupOpInterface : OpInterface<"AccessGroupOpInterface"> {
/*args=*/ (ins),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
ConcreteOp op = cast<ConcreteOp>(this->getOperation());
auto op = cast<ConcreteOp>(this->getOperation());
return op.getAccessGroupsAttr();
}]
>,
@@ -78,7 +114,7 @@ def AccessGroupOpInterface : OpInterface<"AccessGroupOpInterface"> {
/*args=*/ (ins "const ArrayAttr":$attr),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
ConcreteOp op = cast<ConcreteOp>(this->getOperation());
auto op = cast<ConcreteOp>(this->getOperation());
op.setAccessGroupsAttr(attr);
}]
>
@@ -105,7 +141,7 @@ def AliasAnalysisOpInterface : OpInterface<"AliasAnalysisOpInterface"> {
/*args=*/ (ins),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
ConcreteOp op = cast<ConcreteOp>(this->getOperation());
auto op = cast<ConcreteOp>(this->getOperation());
return op.getAliasScopesAttr();
}]
>,
@@ -116,7 +152,7 @@ def AliasAnalysisOpInterface : OpInterface<"AliasAnalysisOpInterface"> {
/*args=*/ (ins "const ArrayAttr":$attr),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
ConcreteOp op = cast<ConcreteOp>(this->getOperation());
auto op = cast<ConcreteOp>(this->getOperation());
op.setAliasScopesAttr(attr);
}]
>,
@@ -127,7 +163,7 @@ def AliasAnalysisOpInterface : OpInterface<"AliasAnalysisOpInterface"> {
/*args=*/ (ins),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
ConcreteOp op = cast<ConcreteOp>(this->getOperation());
auto op = cast<ConcreteOp>(this->getOperation());
return op.getNoaliasScopesAttr();
}]
>,
@@ -138,7 +174,7 @@ def AliasAnalysisOpInterface : OpInterface<"AliasAnalysisOpInterface"> {
/*args=*/ (ins "const ArrayAttr":$attr),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
ConcreteOp op = cast<ConcreteOp>(this->getOperation());
auto op = cast<ConcreteOp>(this->getOperation());
op.setNoaliasScopesAttr(attr);
}]
>,
@@ -149,7 +185,7 @@ def AliasAnalysisOpInterface : OpInterface<"AliasAnalysisOpInterface"> {
/*args=*/ (ins),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
ConcreteOp op = cast<ConcreteOp>(this->getOperation());
auto op = cast<ConcreteOp>(this->getOperation());
return op.getTbaaAttr();
}]
>,
@@ -160,7 +196,7 @@ def AliasAnalysisOpInterface : OpInterface<"AliasAnalysisOpInterface"> {
/*args=*/ (ins "const ArrayAttr":$attr),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
ConcreteOp op = cast<ConcreteOp>(this->getOperation());
auto op = cast<ConcreteOp>(this->getOperation());
op.setTbaaAttr(attr);
}]
>

View File

@@ -536,12 +536,14 @@ def LLVM_FPTruncOp : LLVM_CastOp<"fptrunc", "FPTrunc",
def LLVM_InvokeOp : LLVM_Op<"invoke", [
AttrSizedOperandSegments,
DeclareOpInterfaceMethods<BranchOpInterface>,
DeclareOpInterfaceMethods<CallOpInterface>, Terminator]> {
DeclareOpInterfaceMethods<CallOpInterface>,
DeclareOpInterfaceMethods<BranchWeightOpInterface>,
Terminator]> {
let arguments = (ins OptionalAttr<FlatSymbolRefAttr>:$callee,
Variadic<LLVM_Type>:$callee_operands,
Variadic<LLVM_Type>:$normalDestOperands,
Variadic<LLVM_Type>:$unwindDestOperands,
OptionalAttr<ElementsAttr>:$branch_weights);
OptionalAttr<DenseI32ArrayAttr>:$branch_weights);
let results = (outs Variadic<LLVM_Type>);
let successors = (successor AnySuccessor:$normalDest,
AnySuccessor:$unwindDest);
@@ -582,7 +584,8 @@ def LLVM_LandingpadOp : LLVM_Op<"landingpad"> {
def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
[DeclareOpInterfaceMethods<FastmathFlagsInterface>,
DeclareOpInterfaceMethods<CallOpInterface>,
DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
DeclareOpInterfaceMethods<BranchWeightOpInterface>]> {
let summary = "Call to an LLVM function.";
let description = [{
In LLVM IR, functions may return either 0 or 1 value. LLVM IR dialect
@@ -616,7 +619,7 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
Variadic<LLVM_Type>,
DefaultValuedAttr<LLVM_FastmathFlagsAttr,
"{}">:$fastmathFlags,
OptionalAttr<ElementsAttr>:$branch_weights);
OptionalAttr<DenseI32ArrayAttr>:$branch_weights);
// Append the aliasing related attributes defined in LLVM_MemAccessOpBase.
let arguments = !con(args, aliasAttrs);
let results = (outs Optional<LLVM_Type>:$result);
@@ -847,12 +850,14 @@ def LLVM_BrOp : LLVM_TerminatorOp<"br",
];
}
def LLVM_CondBrOp : LLVM_TerminatorOp<"cond_br",
[AttrSizedOperandSegments, DeclareOpInterfaceMethods<BranchOpInterface>,
[AttrSizedOperandSegments,
DeclareOpInterfaceMethods<BranchOpInterface>,
DeclareOpInterfaceMethods<BranchWeightOpInterface>,
Pure]> {
let arguments = (ins I1:$condition,
Variadic<LLVM_Type>:$trueDestOperands,
Variadic<LLVM_Type>:$falseDestOperands,
OptionalAttr<ElementsAttr>:$branch_weights,
OptionalAttr<DenseI32ArrayAttr>:$branch_weights,
OptionalAttr<LoopAnnotationAttr>:$loop_annotation);
let successors = (successor AnySuccessor:$trueDest, AnySuccessor:$falseDest);
let assemblyFormat = [{
@@ -874,7 +879,7 @@ def LLVM_CondBrOp : LLVM_TerminatorOp<"cond_br",
falseOperands);
}]>,
OpBuilder<(ins "Value":$condition, "ValueRange":$trueOperands, "ValueRange":$falseOperands,
"ElementsAttr":$branchWeights, "Block *":$trueDest, "Block *":$falseDest),
"DenseI32ArrayAttr":$branchWeights, "Block *":$trueDest, "Block *":$falseDest),
[{
build($_builder, $_state, condition, trueOperands, falseOperands, branchWeights,
{}, trueDest, falseDest);
@@ -934,7 +939,9 @@ def LLVM_UnreachableOp : LLVM_TerminatorOp<"unreachable"> {
}
def LLVM_SwitchOp : LLVM_TerminatorOp<"switch",
[AttrSizedOperandSegments, DeclareOpInterfaceMethods<BranchOpInterface>,
[AttrSizedOperandSegments,
DeclareOpInterfaceMethods<BranchOpInterface>,
DeclareOpInterfaceMethods<BranchWeightOpInterface>,
Pure]> {
let arguments = (ins
AnyInteger:$value,
@@ -942,7 +949,7 @@ def LLVM_SwitchOp : LLVM_TerminatorOp<"switch",
VariadicOfVariadic<AnyType, "case_operand_segments">:$caseOperands,
OptionalAttr<AnyIntElementsAttr>:$case_values,
DenseI32ArrayAttr:$case_operand_segments,
OptionalAttr<ElementsAttr>:$branch_weights
OptionalAttr<DenseI32ArrayAttr>:$branch_weights
);
let successors = (successor
AnySuccessor:$defaultDestination,

View File

@@ -118,6 +118,20 @@ public:
return branchMapping.lookup(op);
}
/// Stores a mapping between an MLIR call operation and a corresponding LLVM
/// call instruction.
void mapCall(Operation *mlir, llvm::CallInst *llvm) {
auto result = callMapping.try_emplace(mlir, llvm);
(void)result;
assert(result.second && "attempting to map a call that is already mapped");
}
/// Finds an LLVM call instruction that corresponds to the given MLIR call
/// operation.
llvm::CallInst *lookupCall(Operation *op) const {
return callMapping.lookup(op);
}
/// Removes the mapping for blocks contained in the region and values defined
/// in these blocks.
void forgetMapping(Region &region);
@@ -141,6 +155,9 @@ public:
/// Sets LLVM TBAA metadata for memory operations that have TBAA attributes.
void setTBAAMetadata(AliasAnalysisOpInterface op, llvm::Instruction *inst);
/// Sets LLVM profiling metadata for operations that have branch weights.
void setBranchWeightsMetadata(BranchWeightOpInterface op);
/// Sets LLVM loop metadata for branch operations that have a loop annotation
/// attribute.
void setLoopMetadata(Operation *op, llvm::Instruction *inst);
@@ -328,6 +345,11 @@ private:
/// values after all operations are converted.
DenseMap<Operation *, llvm::Instruction *> branchMapping;
/// A mapping between MLIR LLVM dialect call operations and LLVM IR call
/// instructions. This allows for adding branch weights after the operations
/// have been converted.
DenseMap<Operation *, llvm::CallInst *> callMapping;
/// Mapping from an alias scope metadata operation to its LLVM metadata.
/// This map is populated on module entry.
DenseMap<Attribute, llvm::MDNode *> aliasScopeMetadataMapping;

View File

@@ -553,10 +553,12 @@ public:
matchAndRewrite(spirv::BranchConditionalOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// If branch weights exist, map them to 32-bit integer vector.
ElementsAttr branchWeights = nullptr;
DenseI32ArrayAttr branchWeights = nullptr;
if (auto weights = op.getBranchWeights()) {
VectorType weightType = VectorType::get(2, rewriter.getI32Type());
branchWeights = DenseElementsAttr::get(weightType, weights->getValue());
SmallVector<int32_t> weightValues;
for (auto weight : weights->getAsRange<IntegerAttr>())
weightValues.push_back(weight.getInt());
branchWeights = DenseI32ArrayAttr::get(getContext(), weightValues);
}
rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(

View File

@@ -310,11 +310,11 @@ void CondBrOp::build(OpBuilder &builder, OperationState &result,
Value condition, Block *trueDest, ValueRange trueOperands,
Block *falseDest, ValueRange falseOperands,
std::optional<std::pair<uint32_t, uint32_t>> weights) {
ElementsAttr weightsAttr;
DenseI32ArrayAttr weightsAttr;
if (weights)
weightsAttr =
builder.getI32VectorAttr({static_cast<int32_t>(weights->first),
static_cast<int32_t>(weights->second)});
builder.getDenseI32ArrayAttr({static_cast<int32_t>(weights->first),
static_cast<int32_t>(weights->second)});
build(builder, result, condition, trueOperands, falseOperands, weightsAttr,
/*loop_annotation=*/{}, trueDest, falseDest);
@@ -330,9 +330,9 @@ void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
BlockRange caseDestinations,
ArrayRef<ValueRange> caseOperands,
ArrayRef<int32_t> branchWeights) {
ElementsAttr weightsAttr;
DenseI32ArrayAttr weightsAttr;
if (!branchWeights.empty())
weightsAttr = builder.getI32VectorAttr(llvm::to_vector<4>(branchWeights));
weightsAttr = builder.getDenseI32ArrayAttr(branchWeights);
build(builder, result, value, defaultOperands, caseOperands, caseValues,
weightsAttr, defaultDestination, caseDestinations);

View File

@@ -125,13 +125,11 @@ static LogicalResult setProfilingAttr(OpBuilder &builder, llvm::MDNode *node,
branchWeights.push_back(branchWeight->getZExtValue());
}
return TypeSwitch<Operation *, LogicalResult>(op)
.Case<CondBrOp, SwitchOp, CallOp, InvokeOp>([&](auto branchWeightOp) {
branchWeightOp.setBranchWeightsAttr(
builder.getI32VectorAttr(branchWeights));
return success();
})
.Default([](auto) { return failure(); });
if (auto iface = dyn_cast<BranchWeightOpInterface>(op)) {
iface.setBranchWeights(builder.getDenseI32ArrayAttr(branchWeights));
return success();
}
return failure();
}
/// Searches for the attribute that maps to the given TBAA metadata `node` and

View File

@@ -124,21 +124,6 @@ convertCallLLVMIntrinsicOp(CallIntrinsicOp op, llvm::IRBuilderBase &builder,
return success();
}
/// Constructs branch weights metadata if the provided `weights` hold a value,
/// otherwise returns nullptr.
static llvm::MDNode *
convertBranchWeights(std::optional<ElementsAttr> weights,
LLVM::ModuleTranslation &moduleTranslation) {
if (!weights)
return nullptr;
SmallVector<uint32_t> weightValues;
weightValues.reserve(weights->size());
for (APInt weight : llvm::cast<DenseIntElementsAttr>(*weights))
weightValues.push_back(weight.getLimitedValue());
return llvm::MDBuilder(moduleTranslation.getLLVMContext())
.createBranchWeights(weightValues);
}
static LogicalResult
convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
@@ -182,10 +167,6 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
callOp.getArgOperands()),
operandsRef.front(), operandsRef.drop_front());
}
llvm::MDNode *branchWeights =
convertBranchWeights(callOp.getBranchWeights(), moduleTranslation);
if (branchWeights)
call->setMetadata(llvm::LLVMContext::MD_prof, branchWeights);
moduleTranslation.setAccessGroupsMetadata(callOp, call);
moduleTranslation.setAliasScopeMetadata(callOp, call);
moduleTranslation.setTBAAMetadata(callOp, call);
@@ -196,7 +177,10 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
return success();
}
// Check that LLVM call returns void for 0-result functions.
return success(call->getType()->isVoidTy());
if (!call->getType()->isVoidTy())
return failure();
moduleTranslation.mapCall(callOp, call);
return success();
}
if (auto inlineAsmOp = dyn_cast<LLVM::InlineAsmOp>(opInst)) {
@@ -274,10 +258,6 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
moduleTranslation.lookupBlock(invOp.getSuccessor(1)),
operandsRef.drop_front());
}
llvm::MDNode *branchWeights =
convertBranchWeights(invOp.getBranchWeights(), moduleTranslation);
if (branchWeights)
result->setMetadata(llvm::LLVMContext::MD_prof, branchWeights);
moduleTranslation.mapBranch(invOp, result);
// InvokeOp can only have 0 or 1 result
if (invOp->getNumResults() != 0) {
@@ -314,23 +294,19 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
return success();
}
if (auto condbrOp = dyn_cast<LLVM::CondBrOp>(opInst)) {
llvm::MDNode *branchWeights =
convertBranchWeights(condbrOp.getBranchWeights(), moduleTranslation);
llvm::BranchInst *branch = builder.CreateCondBr(
moduleTranslation.lookupValue(condbrOp.getOperand(0)),
moduleTranslation.lookupBlock(condbrOp.getSuccessor(0)),
moduleTranslation.lookupBlock(condbrOp.getSuccessor(1)), branchWeights);
moduleTranslation.lookupBlock(condbrOp.getSuccessor(1)));
moduleTranslation.mapBranch(&opInst, branch);
moduleTranslation.setLoopMetadata(&opInst, branch);
return success();
}
if (auto switchOp = dyn_cast<LLVM::SwitchOp>(opInst)) {
llvm::MDNode *branchWeights =
convertBranchWeights(switchOp.getBranchWeights(), moduleTranslation);
llvm::SwitchInst *switchInst = builder.CreateSwitch(
moduleTranslation.lookupValue(switchOp.getValue()),
moduleTranslation.lookupBlock(switchOp.getDefaultDestination()),
switchOp.getCaseDestinations().size(), branchWeights);
switchOp.getCaseDestinations().size());
auto *ty = llvm::cast<llvm::IntegerType>(
moduleTranslation.convertType(switchOp.getValue().getType()));

View File

@@ -664,6 +664,10 @@ LogicalResult ModuleTranslation::convertBlock(Block &bb, bool ignoreArguments,
if (failed(convertOperation(op, builder)))
return failure();
// Set the branch weight metadata on the translated instruction.
if (auto iface = dyn_cast<BranchWeightOpInterface>(op))
setBranchWeightsMetadata(iface);
}
return success();
@@ -1183,6 +1187,19 @@ void ModuleTranslation::setTBAAMetadata(AliasAnalysisOpInterface op,
inst->setMetadata(llvm::LLVMContext::MD_tbaa, node);
}
void ModuleTranslation::setBranchWeightsMetadata(BranchWeightOpInterface op) {
DenseI32ArrayAttr weightsAttr = op.getBranchWeightsOrNull();
if (!weightsAttr)
return;
llvm::Instruction *inst = isa<CallOp>(op) ? lookupCall(op) : lookupBranch(op);
assert(inst && "expected the operation to have a mapping to an instruction");
SmallVector<uint32_t> weights(weightsAttr.asArrayRef());
inst->setMetadata(
llvm::LLVMContext::MD_prof,
llvm::MDBuilder(getLLVMContext()).createBranchWeights(weights));
}
LogicalResult ModuleTranslation::createTBAAMetadata() {
llvm::LLVMContext &ctx = llvmModule->getContext();
llvm::IntegerType *offsetTy = llvm::IntegerType::get(ctx, 64);

View File

@@ -68,7 +68,7 @@ spirv.module Logical GLSL450 {
}
spirv.func @cond_branch_with_weights(%cond: i1) -> () "None" {
// CHECK: llvm.cond_br %{{.*}} weights(dense<[1, 2]> : vector<2xi32>), ^bb1, ^bb2
// CHECK: llvm.cond_br %{{.*}} weights([1, 2]), ^bb1, ^bb2
spirv.BranchConditional %cond [1, 2], ^true, ^false
// CHECK: ^bb1:
^true:

View File

@@ -874,7 +874,7 @@ func.func @switch_wrong_number_of_weights(%arg0 : i32) {
// expected-error@+1 {{expects number of branch weights to match number of successors: 3 vs 2}}
llvm.switch %arg0 : i32, ^bb1 [
42: ^bb2(%arg0, %arg0 : i32, i32)
] {branch_weights = dense<[13, 17, 19]> : vector<3xi32>}
] {branch_weights = array<i32: 13, 17, 19>}
^bb1: // pred: ^bb0
llvm.return

View File

@@ -4,7 +4,7 @@
define i64 @cond_br(i1 %arg1, i64 %arg2) {
entry:
; CHECK: llvm.cond_br
; CHECK-SAME: weights(dense<[0, 3]> : vector<2xi32>)
; CHECK-SAME: weights([0, 3])
br i1 %arg1, label %bb1, label %bb2, !prof !0
bb1:
ret i64 %arg2
@@ -19,7 +19,7 @@ bb2:
; CHECK-LABEL: @simple_switch(
define i32 @simple_switch(i32 %arg1) {
; CHECK: llvm.switch
; CHECK: {branch_weights = dense<[42, 3, 5]> : vector<3xi32>}
; CHECK: {branch_weights = array<i32: 42, 3, 5>}
switch i32 %arg1, label %bbd [
i32 0, label %bb1
i32 9, label %bb2
@@ -41,7 +41,7 @@ declare void @fn()
; CHECK-LABEL: @call_branch_weights
define void @call_branch_weights() {
; CHECK: llvm.call @fn() {branch_weights = dense<42> : vector<1xi32>}
; CHECK: llvm.call @fn() {branch_weights = array<i32: 42>}
call void @fn(), !prof !0
ret void
}
@@ -55,7 +55,7 @@ declare i32 @__gxx_personality_v0(...)
; CHECK-LABEL: @invoke_branch_weights
define i32 @invoke_branch_weights() personality ptr @__gxx_personality_v0 {
; CHECK: llvm.invoke @foo() to ^bb2 unwind ^bb1 {branch_weights = dense<[42, 99]> : vector<2xi32>} : () -> ()
; CHECK: llvm.invoke @foo() to ^bb2 unwind ^bb1 {branch_weights = array<i32: 42, 99>} : () -> ()
invoke void @foo() to label %bb2 unwind label %bb1, !prof !0
bb1:
%1 = landingpad { ptr, i32 } cleanup

View File

@@ -1802,7 +1802,7 @@ llvm.func @foo() {
// Check that branch weight attributes are exported properly as metadata.
llvm.func @cond_br_weights(%cond : i1, %arg0 : i32, %arg1 : i32) -> i32 {
// CHECK: !prof ![[NODE:[0-9]+]]
llvm.cond_br %cond weights(dense<[5, 10]> : vector<2xi32>), ^bb1, ^bb2
llvm.cond_br %cond weights([5, 10]), ^bb1, ^bb2
^bb1: // pred: ^bb0
llvm.return %arg0 : i32
^bb2: // pred: ^bb0
@@ -1818,7 +1818,7 @@ llvm.func @fn()
// CHECK-LABEL: @call_branch_weights
llvm.func @call_branch_weights() {
// CHECK: !prof ![[NODE:[0-9]+]]
llvm.call @fn() {branch_weights = dense<42> : vector<1xi32>} : () -> ()
llvm.call @fn() {branch_weights = array<i32 : 42>} : () -> ()
llvm.return
}
@@ -1833,7 +1833,7 @@ llvm.func @__gxx_personality_v0(...) -> i32
llvm.func @invoke_branch_weights() -> i32 attributes {personality = @__gxx_personality_v0} {
%0 = llvm.mlir.constant(1 : i32) : i32
// CHECK: !prof ![[NODE:[0-9]+]]
llvm.invoke @foo() to ^bb2 unwind ^bb1 {branch_weights = dense<[42, 99]> : vector<2xi32>} : () -> ()
llvm.invoke @foo() to ^bb2 unwind ^bb1 {branch_weights = array<i32 : 42, 99>} : () -> ()
^bb1: // pred: ^bb0
%1 = llvm.landingpad cleanup : !llvm.struct<(ptr<i8>, i32)>
llvm.br ^bb2
@@ -2062,7 +2062,7 @@ llvm.func @switch_weights(%arg0: i32) -> i32 {
llvm.switch %arg0 : i32, ^bb1(%0 : i32) [
9: ^bb2(%1, %2 : i32, i32),
99: ^bb3
] {branch_weights = dense<[13, 17, 19]> : vector<3xi32>}
] {branch_weights = array<i32 : 13, 17, 19>}
^bb1(%3: i32): // pred: ^bb0
llvm.return %3 : i32