diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index 3be6c97322b5..642282f8af18 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -24,7 +24,6 @@ def OpenMP_Dialect : Dialect { class OpenMP_Op traits = []> : Op; - //===----------------------------------------------------------------------===// // 2.6 parallel Construct //===----------------------------------------------------------------------===// @@ -81,8 +80,8 @@ def ParallelOp : OpenMP_Op<"parallel", [AttrSizedOperandSegments]> { of the parallel region. }]; - let arguments = (ins Optional:$if_expr_var, - Optional:$num_threads_var, + let arguments = (ins Optional:$if_expr_var, + Optional:$num_threads_var, OptionalAttr:$default_val, Variadic:$private_vars, Variadic:$firstprivate_vars, diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h index 3a701018beb5..e44ae976e0dd 100644 --- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h @@ -87,6 +87,8 @@ protected: llvm::IRBuilder<> &builder); virtual LogicalResult convertOmpOperation(Operation &op, llvm::IRBuilder<> &builder); + virtual LogicalResult convertOmpParallel(Operation &op, + llvm::IRBuilder<> &builder); static std::unique_ptr prepareLLVMModule(Operation *m); /// A helper to look up remapped operands in the value remapping table. @@ -100,7 +102,6 @@ private: LogicalResult convertFunctions(); LogicalResult convertGlobals(); LogicalResult convertOneFunction(LLVMFuncOp func); - void connectPHINodes(LLVMFuncOp func); LogicalResult convertBlock(Block &bb, bool ignoreArguments); llvm::Constant *getLLVMConstant(llvm::Type *llvmType, Attribute attr, diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index 657aa84afe1c..0defea6bbbb9 100644 --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -25,11 +25,13 @@ #include "llvm/ADT/SetVector.h" #include "llvm/Frontend/OpenMP/OMPIRBuilder.h" #include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CFG.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Cloning.h" using namespace mlir; @@ -304,7 +306,160 @@ ModuleTranslation::ModuleTranslation(Operation *module, assert(satisfiesLLVMModule(mlirModule) && "mlirModule should honor LLVM's module semantics."); } -ModuleTranslation::~ModuleTranslation() {} +ModuleTranslation::~ModuleTranslation() { + if (ompBuilder) + ompBuilder->finalize(); +} + +/// Get the SSA value passed to the current block from the terminator operation +/// of its predecessor. +static Value getPHISourceValue(Block *current, Block *pred, + unsigned numArguments, unsigned index) { + Operation &terminator = *pred->getTerminator(); + if (isa(terminator)) + return terminator.getOperand(index); + + // For conditional branches, we need to check if the current block is reached + // through the "true" or the "false" branch and take the relevant operands. + auto condBranchOp = dyn_cast(terminator); + assert(condBranchOp && + "only branch operations can be terminators of a block that " + "has successors"); + assert((condBranchOp.getSuccessor(0) != condBranchOp.getSuccessor(1)) && + "successors with arguments in LLVM conditional branches must be " + "different blocks"); + + return condBranchOp.getSuccessor(0) == current + ? condBranchOp.trueDestOperands()[index] + : condBranchOp.falseDestOperands()[index]; +} + +/// Connect the PHI nodes to the results of preceding blocks. +template +static void +connectPHINodes(T &func, const DenseMap &valueMapping, + const DenseMap &blockMapping) { + // Skip the first block, it cannot be branched to and its arguments correspond + // to the arguments of the LLVM function. + for (auto it = std::next(func.begin()), eit = func.end(); it != eit; ++it) { + Block *bb = &*it; + llvm::BasicBlock *llvmBB = blockMapping.lookup(bb); + auto phis = llvmBB->phis(); + auto numArguments = bb->getNumArguments(); + assert(numArguments == std::distance(phis.begin(), phis.end())); + for (auto &numberedPhiNode : llvm::enumerate(phis)) { + auto &phiNode = numberedPhiNode.value(); + unsigned index = numberedPhiNode.index(); + for (auto *pred : bb->getPredecessors()) { + phiNode.addIncoming(valueMapping.lookup(getPHISourceValue( + bb, pred, numArguments, index)), + blockMapping.lookup(pred)); + } + } + } +} + +// TODO: implement an iterative version +static void topologicalSortImpl(llvm::SetVector &blocks, Block *b) { + blocks.insert(b); + for (Block *bb : b->getSuccessors()) { + if (blocks.count(bb) == 0) + topologicalSortImpl(blocks, bb); + } +} + +/// Sort function blocks topologically. +template +static llvm::SetVector topologicalSort(T &f) { + // For each blocks that has not been visited yet (i.e. that has no + // predecessors), add it to the list and traverse its successors in DFS + // preorder. + llvm::SetVector blocks; + for (Block &b : f) { + if (blocks.count(&b) == 0) + topologicalSortImpl(blocks, &b); + } + assert(blocks.size() == f.getBlocks().size() && "some blocks are not sorted"); + + return blocks; +} + +/// Convert the OpenMP parallel Operation to LLVM IR. +LogicalResult +ModuleTranslation::convertOmpParallel(Operation &opInst, + llvm::IRBuilder<> &builder) { + using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy; + + auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP, + llvm::BasicBlock &continuationIP) { + llvm::LLVMContext &llvmContext = llvmModule->getContext(); + + llvm::BasicBlock *codeGenIPBB = codeGenIP.getBlock(); + llvm::Instruction *codeGenIPBBTI = codeGenIPBB->getTerminator(); + + builder.SetInsertPoint(codeGenIPBB); + + for (auto ®ion : opInst.getRegions()) { + for (auto &bb : region) { + auto *llvmBB = llvm::BasicBlock::Create( + llvmContext, "omp.par.region", codeGenIP.getBlock()->getParent()); + blockMapping[&bb] = llvmBB; + } + + // Then, convert blocks one by one in topological order to ensure + // defs are converted before uses. + llvm::SetVector blocks = topologicalSort(region); + for (auto indexedBB : llvm::enumerate(blocks)) { + Block *bb = indexedBB.value(); + llvm::BasicBlock *curLLVMBB = blockMapping[bb]; + if (bb->isEntryBlock()) + codeGenIPBBTI->setSuccessor(0, curLLVMBB); + + // TODO: Error not returned up the hierarchy + if (failed( + convertBlock(*bb, /*ignoreArguments=*/indexedBB.index() == 0))) + return; + + // If this block has the terminator then add a jump to + // continuation bb + for (auto &op : *bb) { + if (isa(op)) { + builder.SetInsertPoint(curLLVMBB); + builder.CreateBr(&continuationIP); + } + } + } + // Finally, after all blocks have been traversed and values mapped, + // connect the PHI nodes to the results of preceding blocks. + connectPHINodes(region, valueMapping, blockMapping); + } + }; + + // TODO: Perform appropriate actions according to the data-sharing + // attribute (shared, private, firstprivate, ...) of variables. + // Currently defaults to shared. + auto privCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP, + llvm::Value &vPtr, + llvm::Value *&replacementValue) -> InsertPointTy { + replacementValue = &vPtr; + + return codeGenIP; + }; + + // TODO: Perform finalization actions for variables. This has to be + // called for variables which have destructors/finalizers. + auto finiCB = [&](InsertPointTy codeGenIP) {}; + + // TODO: The various operands of parallel operation are not handled. + // Parallel operation is created with some default options for now. + llvm::Value *ifCond = nullptr; + llvm::Value *numThreads = nullptr; + bool isCancellable = false; + builder.restoreIP(ompBuilder->CreateParallel( + builder, bodyGenCB, privCB, finiCB, ifCond, numThreads, + llvm::omp::OMP_PROC_BIND_default, isCancellable)); + return success(); +} /// Given an OpenMP MLIR operation, create the corresponding LLVM IR /// (including OpenMP runtime calls). @@ -340,6 +495,9 @@ ModuleTranslation::convertOmpOperation(Operation &opInst, ompBuilder->CreateFlush(builder.saveIP()); return success(); }) + .Case([&](omp::TerminatorOp) { return success(); }) + .Case( + [&](omp::ParallelOp) { return convertOmpParallel(opInst, builder); }) .Default([&](Operation *inst) { return inst->emitError("unsupported OpenMP operation: ") << inst->getName(); @@ -556,75 +714,6 @@ LogicalResult ModuleTranslation::convertGlobals() { return success(); } -/// Get the SSA value passed to the current block from the terminator operation -/// of its predecessor. -static Value getPHISourceValue(Block *current, Block *pred, - unsigned numArguments, unsigned index) { - auto &terminator = *pred->getTerminator(); - if (isa(terminator)) { - return terminator.getOperand(index); - } - - // For conditional branches, we need to check if the current block is reached - // through the "true" or the "false" branch and take the relevant operands. - auto condBranchOp = dyn_cast(terminator); - assert(condBranchOp && - "only branch operations can be terminators of a block that " - "has successors"); - assert((condBranchOp.getSuccessor(0) != condBranchOp.getSuccessor(1)) && - "successors with arguments in LLVM conditional branches must be " - "different blocks"); - - return condBranchOp.getSuccessor(0) == current - ? condBranchOp.trueDestOperands()[index] - : condBranchOp.falseDestOperands()[index]; -} - -void ModuleTranslation::connectPHINodes(LLVMFuncOp func) { - // Skip the first block, it cannot be branched to and its arguments correspond - // to the arguments of the LLVM function. - for (auto it = std::next(func.begin()), eit = func.end(); it != eit; ++it) { - Block *bb = &*it; - llvm::BasicBlock *llvmBB = blockMapping.lookup(bb); - auto phis = llvmBB->phis(); - auto numArguments = bb->getNumArguments(); - assert(numArguments == std::distance(phis.begin(), phis.end())); - for (auto &numberedPhiNode : llvm::enumerate(phis)) { - auto &phiNode = numberedPhiNode.value(); - unsigned index = numberedPhiNode.index(); - for (auto *pred : bb->getPredecessors()) { - phiNode.addIncoming(valueMapping.lookup(getPHISourceValue( - bb, pred, numArguments, index)), - blockMapping.lookup(pred)); - } - } - } -} - -// TODO: implement an iterative version -static void topologicalSortImpl(llvm::SetVector &blocks, Block *b) { - blocks.insert(b); - for (Block *bb : b->getSuccessors()) { - if (blocks.count(bb) == 0) - topologicalSortImpl(blocks, bb); - } -} - -/// Sort function blocks topologically. -static llvm::SetVector topologicalSort(LLVMFuncOp f) { - // For each blocks that has not been visited yet (i.e. that has no - // predecessors), add it to the list and traverse its successors in DFS - // preorder. - llvm::SetVector blocks; - for (Block &b : f) { - if (blocks.count(&b) == 0) - topologicalSortImpl(blocks, &b); - } - assert(blocks.size() == f.getBlocks().size() && "some blocks are not sorted"); - - return blocks; -} - /// Attempts to add an attribute identified by `key`, optionally with the given /// `value` to LLVM function `llvmFunc`. Reports errors at `loc` if any. If the /// attribute has a kind known to LLVM IR, create the attribute of this kind, @@ -772,7 +861,7 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) { // Finally, after all blocks have been traversed and values mapped, connect // the PHI nodes to the results of preceding blocks. - connectPHINodes(func); + connectPHINodes(func, valueMapping, blockMapping); return success(); } diff --git a/mlir/test/Target/openmp-llvm.mlir b/mlir/test/Target/openmp-llvm.mlir index ddfc2a4cf786..c8acd8022b2b 100644 --- a/mlir/test/Target/openmp-llvm.mlir +++ b/mlir/test/Target/openmp-llvm.mlir @@ -32,3 +32,49 @@ llvm.func @test_flush_construct(%arg0: !llvm.i32) { // CHECK-NEXT: ret void llvm.return } + +// CHECK-LABEL: define void @test_omp_parallel_1() +llvm.func @test_omp_parallel_1() -> () { + // CHECK: call void{{.*}}@__kmpc_fork_call{{.*}}@[[OMP_OUTLINED_FN_1:.*]] to {{.*}} + omp.parallel { + omp.barrier + omp.terminator + } + + llvm.return +} + +// CHECK: define internal void @[[OMP_OUTLINED_FN_1]] + // CHECK: call void @__kmpc_barrier + +llvm.func @body(!llvm.i64) + +// CHECK-LABEL: define void @test_omp_parallel_2() +llvm.func @test_omp_parallel_2() -> () { + // CHECK: call void{{.*}}@__kmpc_fork_call{{.*}}@[[OMP_OUTLINED_FN_2:.*]] to {{.*}} + omp.parallel { + ^bb0: + %0 = llvm.mlir.constant(1 : index) : !llvm.i64 + %1 = llvm.mlir.constant(42 : index) : !llvm.i64 + llvm.call @body(%0) : (!llvm.i64) -> () + llvm.call @body(%1) : (!llvm.i64) -> () + llvm.br ^bb1 + + ^bb1: + %2 = llvm.add %0, %1 : !llvm.i64 + llvm.call @body(%2) : (!llvm.i64) -> () + omp.terminator + } + llvm.return +} + +// CHECK: define internal void @[[OMP_OUTLINED_FN_2]] + // CHECK-LABEL: omp.par.region: + // CHECK: br label %omp.par.region1 + // CHECK-LABEL: omp.par.region1: + // CHECK: call void @body(i64 1) + // CHECK: call void @body(i64 42) + // CHECK: br label %omp.par.region2 + // CHECK-LABEL: omp.par.region2: + // CHECK: call void @body(i64 43) + // CHECK: br label %omp.par.pre_finalize