[OpenMP][IRBuilder] Add support for taskgroup

This patch adds support for generating taskgroup construct.

Reviewed By: Meinersbur

Differential Revision: https://reviews.llvm.org/D128203
This commit is contained in:
Shraiysh Vaishay
2022-07-18 15:25:54 +05:30
parent c8598fa22f
commit 35fc666877
3 changed files with 210 additions and 1 deletions

View File

@@ -630,6 +630,15 @@ public:
InsertPointTy AllocaIP, BodyGenCallbackTy BodyGenCB,
bool Tied = true, Value *Final = nullptr);
/// Generator for the taskgroup construct
///
/// \param Loc The location where the taskgroup construct was encountered.
/// \param AllocaIP The insertion point to be used for alloca instructions.
/// \param BodyGenCB Callback that will generate the region code.
InsertPointTy createTaskgroup(const LocationDescription &Loc,
InsertPointTy AllocaIP,
BodyGenCallbackTy BodyGenCB);
/// Functions used to generate reductions. Such functions take two Values
/// representing LHS and RHS of the reduction, respectively, and a reference
/// to the value that is updated to refer to the reduction result.

View File

@@ -1453,7 +1453,36 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
InsertPointTy(TaskAllocaBB, TaskAllocaBB->begin());
InsertPointTy TaskBodyIP = InsertPointTy(TaskBodyBB, TaskBodyBB->begin());
BodyGenCB(TaskAllocaIP, TaskBodyIP);
Builder.SetInsertPoint(TaskExitBB);
Builder.SetInsertPoint(TaskExitBB, TaskExitBB->begin());
return Builder.saveIP();
}
OpenMPIRBuilder::InsertPointTy
OpenMPIRBuilder::createTaskgroup(const LocationDescription &Loc,
InsertPointTy AllocaIP,
BodyGenCallbackTy BodyGenCB) {
if (!updateToLocation(Loc))
return InsertPointTy();
uint32_t SrcLocStrSize;
Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
Value *ThreadID = getOrCreateThreadID(Ident);
// Emit the @__kmpc_taskgroup runtime call to start the taskgroup
Function *TaskgroupFn =
getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_taskgroup);
Builder.CreateCall(TaskgroupFn, {Ident, ThreadID});
BasicBlock *TaskgroupExitBB = splitBB(Builder, true, "taskgroup.exit");
BodyGenCB(AllocaIP, Builder.saveIP());
Builder.SetInsertPoint(TaskgroupExitBB);
// Emit the @__kmpc_end_taskgroup runtime call to end the taskgroup
Function *EndTaskgroupFn =
getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_taskgroup);
Builder.CreateCall(EndTaskgroupFn, {Ident, ThreadID});
return Builder.saveIP();
}

View File

@@ -12,10 +12,12 @@
#include "llvm/IR/DIBuilder.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/InstIterator.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Verifier.h"
#include "llvm/Passes/PassBuilder.h"
#include "llvm/Support/Casting.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "gtest/gtest.h"
@@ -4918,4 +4920,173 @@ TEST_F(OpenMPIRBuilderTest, CreateTaskFinal) {
EXPECT_FALSE(verifyModule(*M, &errs()));
}
TEST_F(OpenMPIRBuilderTest, CreateTaskgroup) {
using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
OpenMPIRBuilder OMPBuilder(*M);
OMPBuilder.initialize();
F->setName("func");
IRBuilder<> Builder(BB);
AllocaInst *ValPtr32 = Builder.CreateAlloca(Builder.getInt32Ty());
AllocaInst *ValPtr128 = Builder.CreateAlloca(Builder.getInt128Ty());
Value *Val128 =
Builder.CreateLoad(Builder.getInt128Ty(), ValPtr128, "bodygen.load");
Instruction *ThenTerm, *ElseTerm;
Value *InternalStoreInst, *InternalLoad32, *InternalLoad128, *InternalIfCmp;
auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
Builder.restoreIP(AllocaIP);
AllocaInst *Local128 = Builder.CreateAlloca(Builder.getInt128Ty(), nullptr,
"bodygen.alloca128");
Builder.restoreIP(CodeGenIP);
// Loading and storing captured pointer and values
InternalStoreInst = Builder.CreateStore(Val128, Local128);
InternalLoad32 = Builder.CreateLoad(ValPtr32->getAllocatedType(), ValPtr32,
"bodygen.load32");
InternalLoad128 = Builder.CreateLoad(Local128->getAllocatedType(), Local128,
"bodygen.local.load128");
InternalIfCmp = Builder.CreateICmpNE(
InternalLoad32,
Builder.CreateTrunc(InternalLoad128, InternalLoad32->getType()));
SplitBlockAndInsertIfThenElse(InternalIfCmp,
CodeGenIP.getBlock()->getTerminator(),
&ThenTerm, &ElseTerm);
};
BasicBlock *AllocaBB = Builder.GetInsertBlock();
BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "alloca.split");
OpenMPIRBuilder::LocationDescription Loc(
InsertPointTy(BodyBB, BodyBB->getFirstInsertionPt()), DL);
Builder.restoreIP(OMPBuilder.createTaskgroup(
Loc, InsertPointTy(AllocaBB, AllocaBB->getFirstInsertionPt()),
BodyGenCB));
OMPBuilder.finalize();
Builder.CreateRetVoid();
EXPECT_FALSE(verifyModule(*M, &errs()));
CallInst *TaskgroupCall = dyn_cast<CallInst>(
OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_taskgroup)
->user_back());
ASSERT_NE(TaskgroupCall, nullptr);
CallInst *EndTaskgroupCall = dyn_cast<CallInst>(
OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_taskgroup)
->user_back());
ASSERT_NE(EndTaskgroupCall, nullptr);
// Verify the Ident argument
GlobalVariable *Ident = cast<GlobalVariable>(TaskgroupCall->getArgOperand(0));
ASSERT_NE(Ident, nullptr);
EXPECT_TRUE(Ident->hasInitializer());
Constant *Initializer = Ident->getInitializer();
GlobalVariable *SrcStrGlob =
cast<GlobalVariable>(Initializer->getOperand(4)->stripPointerCasts());
ASSERT_NE(SrcStrGlob, nullptr);
ConstantDataArray *SrcSrc =
dyn_cast<ConstantDataArray>(SrcStrGlob->getInitializer());
ASSERT_NE(SrcSrc, nullptr);
// Verify the num_threads argument.
CallInst *GTID = dyn_cast<CallInst>(TaskgroupCall->getArgOperand(1));
ASSERT_NE(GTID, nullptr);
EXPECT_EQ(GTID->arg_size(), 1U);
EXPECT_EQ(GTID->getCalledFunction(), OMPBuilder.getOrCreateRuntimeFunctionPtr(
OMPRTL___kmpc_global_thread_num));
// Checking the general structure of the IR generated is same as expected.
Instruction *GeneratedStoreInst = TaskgroupCall->getNextNonDebugInstruction();
EXPECT_EQ(GeneratedStoreInst, InternalStoreInst);
Instruction *GeneratedLoad32 =
GeneratedStoreInst->getNextNonDebugInstruction();
EXPECT_EQ(GeneratedLoad32, InternalLoad32);
Instruction *GeneratedLoad128 = GeneratedLoad32->getNextNonDebugInstruction();
EXPECT_EQ(GeneratedLoad128, InternalLoad128);
// Checking the ordering because of the if statements and that
// `__kmp_end_taskgroup` call is after the if branching.
BasicBlock *RefOrder[] = {TaskgroupCall->getParent(), ThenTerm->getParent(),
ThenTerm->getSuccessor(0),
EndTaskgroupCall->getParent(),
ElseTerm->getParent()};
verifyDFSOrder(F, RefOrder);
}
TEST_F(OpenMPIRBuilderTest, CreateTaskgroupWithTasks) {
using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
OpenMPIRBuilder OMPBuilder(*M);
OMPBuilder.initialize();
F->setName("func");
IRBuilder<> Builder(BB);
auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
Builder.restoreIP(AllocaIP);
AllocaInst *Alloca32 =
Builder.CreateAlloca(Builder.getInt32Ty(), nullptr, "bodygen.alloca32");
AllocaInst *Alloca64 =
Builder.CreateAlloca(Builder.getInt64Ty(), nullptr, "bodygen.alloca64");
Builder.restoreIP(CodeGenIP);
auto TaskBodyGenCB1 = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
Builder.restoreIP(CodeGenIP);
LoadInst *LoadValue =
Builder.CreateLoad(Alloca64->getAllocatedType(), Alloca64);
Value *AddInst = Builder.CreateAdd(LoadValue, Builder.getInt64(64));
Builder.CreateStore(AddInst, Alloca64);
};
OpenMPIRBuilder::LocationDescription Loc(Builder.saveIP(), DL);
Builder.restoreIP(OMPBuilder.createTask(Loc, AllocaIP, TaskBodyGenCB1));
auto TaskBodyGenCB2 = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
Builder.restoreIP(CodeGenIP);
LoadInst *LoadValue =
Builder.CreateLoad(Alloca32->getAllocatedType(), Alloca32);
Value *AddInst = Builder.CreateAdd(LoadValue, Builder.getInt32(32));
Builder.CreateStore(AddInst, Alloca32);
};
OpenMPIRBuilder::LocationDescription Loc2(Builder.saveIP(), DL);
Builder.restoreIP(OMPBuilder.createTask(Loc2, AllocaIP, TaskBodyGenCB2));
};
BasicBlock *AllocaBB = Builder.GetInsertBlock();
BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "alloca.split");
OpenMPIRBuilder::LocationDescription Loc(
InsertPointTy(BodyBB, BodyBB->getFirstInsertionPt()), DL);
Builder.restoreIP(OMPBuilder.createTaskgroup(
Loc, InsertPointTy(AllocaBB, AllocaBB->getFirstInsertionPt()),
BodyGenCB));
OMPBuilder.finalize();
Builder.CreateRetVoid();
EXPECT_FALSE(verifyModule(*M, &errs()));
CallInst *TaskgroupCall = dyn_cast<CallInst>(
OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_taskgroup)
->user_back());
ASSERT_NE(TaskgroupCall, nullptr);
CallInst *EndTaskgroupCall = dyn_cast<CallInst>(
OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_taskgroup)
->user_back());
ASSERT_NE(EndTaskgroupCall, nullptr);
Function *TaskAllocFn =
OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_alloc);
ASSERT_EQ(TaskAllocFn->getNumUses(), 2);
CallInst *FirstTaskAllocCall =
dyn_cast_or_null<CallInst>(*TaskAllocFn->users().begin());
CallInst *SecondTaskAllocCall =
dyn_cast_or_null<CallInst>(*TaskAllocFn->users().begin()++);
ASSERT_NE(FirstTaskAllocCall, nullptr);
ASSERT_NE(SecondTaskAllocCall, nullptr);
// Verify that the tasks have been generated in order and inside taskgroup
// construct.
BasicBlock *RefOrder[] = {
TaskgroupCall->getParent(), FirstTaskAllocCall->getParent(),
SecondTaskAllocCall->getParent(), EndTaskgroupCall->getParent()};
verifyDFSOrder(F, RefOrder);
}
} // namespace