From 10fa27704b3165ddc4efbcf7964042b137e7fa7e Mon Sep 17 00:00:00 2001 From: Tobias Gysi Date: Thu, 20 Jul 2023 08:13:17 +0000 Subject: [PATCH] [mlir][llvm] Add branch weight op interface This revision adds a branch weight op interface for the call / branch operations that support branch weights. It can be used in the LLVM IR import and export to simplify the branch weight conversion. An additional mapping between call operations and instructions ensures the actual conversion can be done in the module translation itself, rather than in the dialect translation interface. It also has the benefit that downstream users can amend custom metadata to the call operation during the export to LLVM IR. Reviewed By: zero9178, definelicht Differential Revision: https://reviews.llvm.org/D155702 --- .../mlir/Dialect/LLVMIR/LLVMInterfaces.td | 54 +++++++++++++++---- mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td | 25 +++++---- .../mlir/Target/LLVMIR/ModuleTranslation.h | 22 ++++++++ .../Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp | 8 +-- mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 10 ++-- .../LLVMIR/LLVMIRToLLVMTranslation.cpp | 12 ++--- .../LLVMIR/LLVMToLLVMIRTranslation.cpp | 36 +++---------- mlir/lib/Target/LLVMIR/ModuleTranslation.cpp | 17 ++++++ .../SPIRVToLLVM/control-flow-ops-to-llvm.mlir | 2 +- mlir/test/Dialect/LLVMIR/invalid.mlir | 2 +- .../LLVMIR/Import/metadata-profiling.ll | 8 +-- mlir/test/Target/LLVMIR/llvmir.mlir | 8 +-- 12 files changed, 131 insertions(+), 73 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td index 9f230bf0be87..7b33ec8bb0c3 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td @@ -30,7 +30,7 @@ def FastmathFlagsInterface : OpInterface<"FastmathFlagsInterface"> { /*args=*/ (ins), /*methodBody=*/ [{}], /*defaultImpl=*/ [{ - ConcreteOp op = cast(this->getOperation()); + auto op = cast(this->getOperation()); return op.getFastmathFlagsAttr(); }] >, @@ -48,6 +48,42 @@ def FastmathFlagsInterface : OpInterface<"FastmathFlagsInterface"> { ]; } +def BranchWeightOpInterface : OpInterface<"BranchWeightOpInterface"> { + let description = [{ + An interface for operations that can carry branch weights metadata. It + provides setters and getters for the operation's branch weights attribute. + The default implementation of the interface methods expect the operation to + have an attribute of type DenseI32ArrayAttr named branch_weights. + }]; + + let cppNamespace = "::mlir::LLVM"; + + let methods = [ + InterfaceMethod< + /*desc=*/ "Returns the branch weights attribute or nullptr", + /*returnType=*/ "DenseI32ArrayAttr", + /*methodName=*/ "getBranchWeightsOrNull", + /*args=*/ (ins), + /*methodBody=*/ [{}], + /*defaultImpl=*/ [{ + auto op = cast(this->getOperation()); + return op.getBranchWeightsAttr(); + }] + >, + InterfaceMethod< + /*desc=*/ "Sets the branch weights attribute", + /*returnType=*/ "void", + /*methodName=*/ "setBranchWeights", + /*args=*/ (ins "DenseI32ArrayAttr":$attr), + /*methodBody=*/ [{}], + /*defaultImpl=*/ [{ + auto op = cast(this->getOperation()); + op.setBranchWeightsAttr(attr); + }] + > + ]; +} + def AccessGroupOpInterface : OpInterface<"AccessGroupOpInterface"> { let description = [{ An interface for memory operations that can carry access groups metadata. @@ -67,7 +103,7 @@ def AccessGroupOpInterface : OpInterface<"AccessGroupOpInterface"> { /*args=*/ (ins), /*methodBody=*/ [{}], /*defaultImpl=*/ [{ - ConcreteOp op = cast(this->getOperation()); + auto op = cast(this->getOperation()); return op.getAccessGroupsAttr(); }] >, @@ -78,7 +114,7 @@ def AccessGroupOpInterface : OpInterface<"AccessGroupOpInterface"> { /*args=*/ (ins "const ArrayAttr":$attr), /*methodBody=*/ [{}], /*defaultImpl=*/ [{ - ConcreteOp op = cast(this->getOperation()); + auto op = cast(this->getOperation()); op.setAccessGroupsAttr(attr); }] > @@ -105,7 +141,7 @@ def AliasAnalysisOpInterface : OpInterface<"AliasAnalysisOpInterface"> { /*args=*/ (ins), /*methodBody=*/ [{}], /*defaultImpl=*/ [{ - ConcreteOp op = cast(this->getOperation()); + auto op = cast(this->getOperation()); return op.getAliasScopesAttr(); }] >, @@ -116,7 +152,7 @@ def AliasAnalysisOpInterface : OpInterface<"AliasAnalysisOpInterface"> { /*args=*/ (ins "const ArrayAttr":$attr), /*methodBody=*/ [{}], /*defaultImpl=*/ [{ - ConcreteOp op = cast(this->getOperation()); + auto op = cast(this->getOperation()); op.setAliasScopesAttr(attr); }] >, @@ -127,7 +163,7 @@ def AliasAnalysisOpInterface : OpInterface<"AliasAnalysisOpInterface"> { /*args=*/ (ins), /*methodBody=*/ [{}], /*defaultImpl=*/ [{ - ConcreteOp op = cast(this->getOperation()); + auto op = cast(this->getOperation()); return op.getNoaliasScopesAttr(); }] >, @@ -138,7 +174,7 @@ def AliasAnalysisOpInterface : OpInterface<"AliasAnalysisOpInterface"> { /*args=*/ (ins "const ArrayAttr":$attr), /*methodBody=*/ [{}], /*defaultImpl=*/ [{ - ConcreteOp op = cast(this->getOperation()); + auto op = cast(this->getOperation()); op.setNoaliasScopesAttr(attr); }] >, @@ -149,7 +185,7 @@ def AliasAnalysisOpInterface : OpInterface<"AliasAnalysisOpInterface"> { /*args=*/ (ins), /*methodBody=*/ [{}], /*defaultImpl=*/ [{ - ConcreteOp op = cast(this->getOperation()); + auto op = cast(this->getOperation()); return op.getTbaaAttr(); }] >, @@ -160,7 +196,7 @@ def AliasAnalysisOpInterface : OpInterface<"AliasAnalysisOpInterface"> { /*args=*/ (ins "const ArrayAttr":$attr), /*methodBody=*/ [{}], /*defaultImpl=*/ [{ - ConcreteOp op = cast(this->getOperation()); + auto op = cast(this->getOperation()); op.setTbaaAttr(attr); }] > diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index 01070fe74bff..2d0ca913e813 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -536,12 +536,14 @@ def LLVM_FPTruncOp : LLVM_CastOp<"fptrunc", "FPTrunc", def LLVM_InvokeOp : LLVM_Op<"invoke", [ AttrSizedOperandSegments, DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods, Terminator]> { + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + Terminator]> { let arguments = (ins OptionalAttr:$callee, Variadic:$callee_operands, Variadic:$normalDestOperands, Variadic:$unwindDestOperands, - OptionalAttr:$branch_weights); + OptionalAttr:$branch_weights); let results = (outs Variadic); let successors = (successor AnySuccessor:$normalDest, AnySuccessor:$unwindDest); @@ -582,7 +584,8 @@ def LLVM_LandingpadOp : LLVM_Op<"landingpad"> { def LLVM_CallOp : LLVM_MemAccessOpBase<"call", [DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods]> { + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { let summary = "Call to an LLVM function."; let description = [{ In LLVM IR, functions may return either 0 or 1 value. LLVM IR dialect @@ -616,7 +619,7 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call", Variadic, DefaultValuedAttr:$fastmathFlags, - OptionalAttr:$branch_weights); + OptionalAttr:$branch_weights); // Append the aliasing related attributes defined in LLVM_MemAccessOpBase. let arguments = !con(args, aliasAttrs); let results = (outs Optional:$result); @@ -847,12 +850,14 @@ def LLVM_BrOp : LLVM_TerminatorOp<"br", ]; } def LLVM_CondBrOp : LLVM_TerminatorOp<"cond_br", - [AttrSizedOperandSegments, DeclareOpInterfaceMethods, + [AttrSizedOperandSegments, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, Pure]> { let arguments = (ins I1:$condition, Variadic:$trueDestOperands, Variadic:$falseDestOperands, - OptionalAttr:$branch_weights, + OptionalAttr:$branch_weights, OptionalAttr:$loop_annotation); let successors = (successor AnySuccessor:$trueDest, AnySuccessor:$falseDest); let assemblyFormat = [{ @@ -874,7 +879,7 @@ def LLVM_CondBrOp : LLVM_TerminatorOp<"cond_br", falseOperands); }]>, OpBuilder<(ins "Value":$condition, "ValueRange":$trueOperands, "ValueRange":$falseOperands, - "ElementsAttr":$branchWeights, "Block *":$trueDest, "Block *":$falseDest), + "DenseI32ArrayAttr":$branchWeights, "Block *":$trueDest, "Block *":$falseDest), [{ build($_builder, $_state, condition, trueOperands, falseOperands, branchWeights, {}, trueDest, falseDest); @@ -934,7 +939,9 @@ def LLVM_UnreachableOp : LLVM_TerminatorOp<"unreachable"> { } def LLVM_SwitchOp : LLVM_TerminatorOp<"switch", - [AttrSizedOperandSegments, DeclareOpInterfaceMethods, + [AttrSizedOperandSegments, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, Pure]> { let arguments = (ins AnyInteger:$value, @@ -942,7 +949,7 @@ def LLVM_SwitchOp : LLVM_TerminatorOp<"switch", VariadicOfVariadic:$caseOperands, OptionalAttr:$case_values, DenseI32ArrayAttr:$case_operand_segments, - OptionalAttr:$branch_weights + OptionalAttr:$branch_weights ); let successors = (successor AnySuccessor:$defaultDestination, diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h index da4d43ac9ac8..0d296aac0559 100644 --- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h @@ -118,6 +118,20 @@ public: return branchMapping.lookup(op); } + /// Stores a mapping between an MLIR call operation and a corresponding LLVM + /// call instruction. + void mapCall(Operation *mlir, llvm::CallInst *llvm) { + auto result = callMapping.try_emplace(mlir, llvm); + (void)result; + assert(result.second && "attempting to map a call that is already mapped"); + } + + /// Finds an LLVM call instruction that corresponds to the given MLIR call + /// operation. + llvm::CallInst *lookupCall(Operation *op) const { + return callMapping.lookup(op); + } + /// Removes the mapping for blocks contained in the region and values defined /// in these blocks. void forgetMapping(Region ®ion); @@ -141,6 +155,9 @@ public: /// Sets LLVM TBAA metadata for memory operations that have TBAA attributes. void setTBAAMetadata(AliasAnalysisOpInterface op, llvm::Instruction *inst); + /// Sets LLVM profiling metadata for operations that have branch weights. + void setBranchWeightsMetadata(BranchWeightOpInterface op); + /// Sets LLVM loop metadata for branch operations that have a loop annotation /// attribute. void setLoopMetadata(Operation *op, llvm::Instruction *inst); @@ -328,6 +345,11 @@ private: /// values after all operations are converted. DenseMap branchMapping; + /// A mapping between MLIR LLVM dialect call operations and LLVM IR call + /// instructions. This allows for adding branch weights after the operations + /// have been converted. + DenseMap callMapping; + /// Mapping from an alias scope metadata operation to its LLVM metadata. /// This map is populated on module entry. DenseMap aliasScopeMetadataMapping; diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp index 28e587a066e4..1d32e6e55f6a 100644 --- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp @@ -553,10 +553,12 @@ public: matchAndRewrite(spirv::BranchConditionalOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // If branch weights exist, map them to 32-bit integer vector. - ElementsAttr branchWeights = nullptr; + DenseI32ArrayAttr branchWeights = nullptr; if (auto weights = op.getBranchWeights()) { - VectorType weightType = VectorType::get(2, rewriter.getI32Type()); - branchWeights = DenseElementsAttr::get(weightType, weights->getValue()); + SmallVector weightValues; + for (auto weight : weights->getAsRange()) + weightValues.push_back(weight.getInt()); + branchWeights = DenseI32ArrayAttr::get(getContext(), weightValues); } rewriter.replaceOpWithNewOp( diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 8eee3b2afc84..f4d9c95e4179 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -310,11 +310,11 @@ void CondBrOp::build(OpBuilder &builder, OperationState &result, Value condition, Block *trueDest, ValueRange trueOperands, Block *falseDest, ValueRange falseOperands, std::optional> weights) { - ElementsAttr weightsAttr; + DenseI32ArrayAttr weightsAttr; if (weights) weightsAttr = - builder.getI32VectorAttr({static_cast(weights->first), - static_cast(weights->second)}); + builder.getDenseI32ArrayAttr({static_cast(weights->first), + static_cast(weights->second)}); build(builder, result, condition, trueOperands, falseOperands, weightsAttr, /*loop_annotation=*/{}, trueDest, falseDest); @@ -330,9 +330,9 @@ void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value, BlockRange caseDestinations, ArrayRef caseOperands, ArrayRef branchWeights) { - ElementsAttr weightsAttr; + DenseI32ArrayAttr weightsAttr; if (!branchWeights.empty()) - weightsAttr = builder.getI32VectorAttr(llvm::to_vector<4>(branchWeights)); + weightsAttr = builder.getDenseI32ArrayAttr(branchWeights); build(builder, result, value, defaultOperands, caseOperands, caseValues, weightsAttr, defaultDestination, caseDestinations); diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp index a6f0ebe54aac..40d8253d822f 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp @@ -125,13 +125,11 @@ static LogicalResult setProfilingAttr(OpBuilder &builder, llvm::MDNode *node, branchWeights.push_back(branchWeight->getZExtValue()); } - return TypeSwitch(op) - .Case([&](auto branchWeightOp) { - branchWeightOp.setBranchWeightsAttr( - builder.getI32VectorAttr(branchWeights)); - return success(); - }) - .Default([](auto) { return failure(); }); + if (auto iface = dyn_cast(op)) { + iface.setBranchWeights(builder.getDenseI32ArrayAttr(branchWeights)); + return success(); + } + return failure(); } /// Searches for the attribute that maps to the given TBAA metadata `node` and diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp index a044930a0cf8..8f7c5d8b799e 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp @@ -124,21 +124,6 @@ convertCallLLVMIntrinsicOp(CallIntrinsicOp op, llvm::IRBuilderBase &builder, return success(); } -/// Constructs branch weights metadata if the provided `weights` hold a value, -/// otherwise returns nullptr. -static llvm::MDNode * -convertBranchWeights(std::optional weights, - LLVM::ModuleTranslation &moduleTranslation) { - if (!weights) - return nullptr; - SmallVector weightValues; - weightValues.reserve(weights->size()); - for (APInt weight : llvm::cast(*weights)) - weightValues.push_back(weight.getLimitedValue()); - return llvm::MDBuilder(moduleTranslation.getLLVMContext()) - .createBranchWeights(weightValues); -} - static LogicalResult convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation) { @@ -182,10 +167,6 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder, callOp.getArgOperands()), operandsRef.front(), operandsRef.drop_front()); } - llvm::MDNode *branchWeights = - convertBranchWeights(callOp.getBranchWeights(), moduleTranslation); - if (branchWeights) - call->setMetadata(llvm::LLVMContext::MD_prof, branchWeights); moduleTranslation.setAccessGroupsMetadata(callOp, call); moduleTranslation.setAliasScopeMetadata(callOp, call); moduleTranslation.setTBAAMetadata(callOp, call); @@ -196,7 +177,10 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder, return success(); } // Check that LLVM call returns void for 0-result functions. - return success(call->getType()->isVoidTy()); + if (!call->getType()->isVoidTy()) + return failure(); + moduleTranslation.mapCall(callOp, call); + return success(); } if (auto inlineAsmOp = dyn_cast(opInst)) { @@ -274,10 +258,6 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder, moduleTranslation.lookupBlock(invOp.getSuccessor(1)), operandsRef.drop_front()); } - llvm::MDNode *branchWeights = - convertBranchWeights(invOp.getBranchWeights(), moduleTranslation); - if (branchWeights) - result->setMetadata(llvm::LLVMContext::MD_prof, branchWeights); moduleTranslation.mapBranch(invOp, result); // InvokeOp can only have 0 or 1 result if (invOp->getNumResults() != 0) { @@ -314,23 +294,19 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder, return success(); } if (auto condbrOp = dyn_cast(opInst)) { - llvm::MDNode *branchWeights = - convertBranchWeights(condbrOp.getBranchWeights(), moduleTranslation); llvm::BranchInst *branch = builder.CreateCondBr( moduleTranslation.lookupValue(condbrOp.getOperand(0)), moduleTranslation.lookupBlock(condbrOp.getSuccessor(0)), - moduleTranslation.lookupBlock(condbrOp.getSuccessor(1)), branchWeights); + moduleTranslation.lookupBlock(condbrOp.getSuccessor(1))); moduleTranslation.mapBranch(&opInst, branch); moduleTranslation.setLoopMetadata(&opInst, branch); return success(); } if (auto switchOp = dyn_cast(opInst)) { - llvm::MDNode *branchWeights = - convertBranchWeights(switchOp.getBranchWeights(), moduleTranslation); llvm::SwitchInst *switchInst = builder.CreateSwitch( moduleTranslation.lookupValue(switchOp.getValue()), moduleTranslation.lookupBlock(switchOp.getDefaultDestination()), - switchOp.getCaseDestinations().size(), branchWeights); + switchOp.getCaseDestinations().size()); auto *ty = llvm::cast( moduleTranslation.convertType(switchOp.getValue().getType())); diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index d363fb8d9186..cd3a645a18c6 100644 --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -664,6 +664,10 @@ LogicalResult ModuleTranslation::convertBlock(Block &bb, bool ignoreArguments, if (failed(convertOperation(op, builder))) return failure(); + + // Set the branch weight metadata on the translated instruction. + if (auto iface = dyn_cast(op)) + setBranchWeightsMetadata(iface); } return success(); @@ -1183,6 +1187,19 @@ void ModuleTranslation::setTBAAMetadata(AliasAnalysisOpInterface op, inst->setMetadata(llvm::LLVMContext::MD_tbaa, node); } +void ModuleTranslation::setBranchWeightsMetadata(BranchWeightOpInterface op) { + DenseI32ArrayAttr weightsAttr = op.getBranchWeightsOrNull(); + if (!weightsAttr) + return; + + llvm::Instruction *inst = isa(op) ? lookupCall(op) : lookupBranch(op); + assert(inst && "expected the operation to have a mapping to an instruction"); + SmallVector weights(weightsAttr.asArrayRef()); + inst->setMetadata( + llvm::LLVMContext::MD_prof, + llvm::MDBuilder(getLLVMContext()).createBranchWeights(weights)); +} + LogicalResult ModuleTranslation::createTBAAMetadata() { llvm::LLVMContext &ctx = llvmModule->getContext(); llvm::IntegerType *offsetTy = llvm::IntegerType::get(ctx, 64); diff --git a/mlir/test/Conversion/SPIRVToLLVM/control-flow-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/control-flow-ops-to-llvm.mlir index 8c58d59e86d7..54ef71f75f52 100644 --- a/mlir/test/Conversion/SPIRVToLLVM/control-flow-ops-to-llvm.mlir +++ b/mlir/test/Conversion/SPIRVToLLVM/control-flow-ops-to-llvm.mlir @@ -68,7 +68,7 @@ spirv.module Logical GLSL450 { } spirv.func @cond_branch_with_weights(%cond: i1) -> () "None" { - // CHECK: llvm.cond_br %{{.*}} weights(dense<[1, 2]> : vector<2xi32>), ^bb1, ^bb2 + // CHECK: llvm.cond_br %{{.*}} weights([1, 2]), ^bb1, ^bb2 spirv.BranchConditional %cond [1, 2], ^true, ^false // CHECK: ^bb1: ^true: diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir index da4799d8a263..09bbc5a47396 100644 --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -874,7 +874,7 @@ func.func @switch_wrong_number_of_weights(%arg0 : i32) { // expected-error@+1 {{expects number of branch weights to match number of successors: 3 vs 2}} llvm.switch %arg0 : i32, ^bb1 [ 42: ^bb2(%arg0, %arg0 : i32, i32) - ] {branch_weights = dense<[13, 17, 19]> : vector<3xi32>} + ] {branch_weights = array} ^bb1: // pred: ^bb0 llvm.return diff --git a/mlir/test/Target/LLVMIR/Import/metadata-profiling.ll b/mlir/test/Target/LLVMIR/Import/metadata-profiling.ll index 688dd100f982..cc3b47a54dfe 100644 --- a/mlir/test/Target/LLVMIR/Import/metadata-profiling.ll +++ b/mlir/test/Target/LLVMIR/Import/metadata-profiling.ll @@ -4,7 +4,7 @@ define i64 @cond_br(i1 %arg1, i64 %arg2) { entry: ; CHECK: llvm.cond_br - ; CHECK-SAME: weights(dense<[0, 3]> : vector<2xi32>) + ; CHECK-SAME: weights([0, 3]) br i1 %arg1, label %bb1, label %bb2, !prof !0 bb1: ret i64 %arg2 @@ -19,7 +19,7 @@ bb2: ; CHECK-LABEL: @simple_switch( define i32 @simple_switch(i32 %arg1) { ; CHECK: llvm.switch - ; CHECK: {branch_weights = dense<[42, 3, 5]> : vector<3xi32>} + ; CHECK: {branch_weights = array} switch i32 %arg1, label %bbd [ i32 0, label %bb1 i32 9, label %bb2 @@ -41,7 +41,7 @@ declare void @fn() ; CHECK-LABEL: @call_branch_weights define void @call_branch_weights() { - ; CHECK: llvm.call @fn() {branch_weights = dense<42> : vector<1xi32>} + ; CHECK: llvm.call @fn() {branch_weights = array} call void @fn(), !prof !0 ret void } @@ -55,7 +55,7 @@ declare i32 @__gxx_personality_v0(...) ; CHECK-LABEL: @invoke_branch_weights define i32 @invoke_branch_weights() personality ptr @__gxx_personality_v0 { - ; CHECK: llvm.invoke @foo() to ^bb2 unwind ^bb1 {branch_weights = dense<[42, 99]> : vector<2xi32>} : () -> () + ; CHECK: llvm.invoke @foo() to ^bb2 unwind ^bb1 {branch_weights = array} : () -> () invoke void @foo() to label %bb2 unwind label %bb1, !prof !0 bb1: %1 = landingpad { ptr, i32 } cleanup diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir index 2500de25f498..3f97ebd9aa36 100644 --- a/mlir/test/Target/LLVMIR/llvmir.mlir +++ b/mlir/test/Target/LLVMIR/llvmir.mlir @@ -1802,7 +1802,7 @@ llvm.func @foo() { // Check that branch weight attributes are exported properly as metadata. llvm.func @cond_br_weights(%cond : i1, %arg0 : i32, %arg1 : i32) -> i32 { // CHECK: !prof ![[NODE:[0-9]+]] - llvm.cond_br %cond weights(dense<[5, 10]> : vector<2xi32>), ^bb1, ^bb2 + llvm.cond_br %cond weights([5, 10]), ^bb1, ^bb2 ^bb1: // pred: ^bb0 llvm.return %arg0 : i32 ^bb2: // pred: ^bb0 @@ -1818,7 +1818,7 @@ llvm.func @fn() // CHECK-LABEL: @call_branch_weights llvm.func @call_branch_weights() { // CHECK: !prof ![[NODE:[0-9]+]] - llvm.call @fn() {branch_weights = dense<42> : vector<1xi32>} : () -> () + llvm.call @fn() {branch_weights = array} : () -> () llvm.return } @@ -1833,7 +1833,7 @@ llvm.func @__gxx_personality_v0(...) -> i32 llvm.func @invoke_branch_weights() -> i32 attributes {personality = @__gxx_personality_v0} { %0 = llvm.mlir.constant(1 : i32) : i32 // CHECK: !prof ![[NODE:[0-9]+]] - llvm.invoke @foo() to ^bb2 unwind ^bb1 {branch_weights = dense<[42, 99]> : vector<2xi32>} : () -> () + llvm.invoke @foo() to ^bb2 unwind ^bb1 {branch_weights = array} : () -> () ^bb1: // pred: ^bb0 %1 = llvm.landingpad cleanup : !llvm.struct<(ptr, i32)> llvm.br ^bb2 @@ -2062,7 +2062,7 @@ llvm.func @switch_weights(%arg0: i32) -> i32 { llvm.switch %arg0 : i32, ^bb1(%0 : i32) [ 9: ^bb2(%1, %2 : i32, i32), 99: ^bb3 - ] {branch_weights = dense<[13, 17, 19]> : vector<3xi32>} + ] {branch_weights = array} ^bb1(%3: i32): // pred: ^bb0 llvm.return %3 : i32