[mlir] OpenMP-to-LLVM: properly set outer alloca insertion point

Previously, the OpenMP to LLVM IR conversion was setting the alloca insertion
point to the same position as the main compuation when converting OpenMP
`parallel` operations. This is problematic if, for example, the `parallel`
operation is placed inside a loop and would keep allocating on stack on each
iteration leading to stack overflow.

Reviewed By: kiranchandramohan

Differential Revision: https://reviews.llvm.org/D101307
This commit is contained in:
Alex Zinenko
2021-05-10 10:02:18 +02:00
parent 7f78e409d0
commit 72d013dd73
4 changed files with 201 additions and 16 deletions

View File

@@ -176,6 +176,82 @@ public:
/// it if it does not exist.
llvm::NamedMDNode *getOrInsertNamedModuleMetadata(StringRef name);
/// Common CRTP base class for ModuleTranslation stack frames.
class StackFrame {
public:
virtual ~StackFrame() {}
TypeID getTypeID() const { return typeID; }
protected:
explicit StackFrame(TypeID typeID) : typeID(typeID) {}
private:
const TypeID typeID;
virtual void anchor();
};
/// Concrete CRTP base class for ModuleTranslation stack frames. When
/// translating operations with regions, users of ModuleTranslation can store
/// state on ModuleTranslation stack before entering the region and inspect
/// it when converting operations nested within that region. Users are
/// expected to derive this class and put any relevant information into fields
/// of the derived class. The usual isa/dyn_cast functionality is available
/// for instances of derived classes.
template <typename Derived>
class StackFrameBase : public StackFrame {
public:
explicit StackFrameBase() : StackFrame(TypeID::get<Derived>()) {}
};
/// Creates a stack frame of type `T` on ModuleTranslation stack. `T` must
/// be derived from `StackFrameBase<T>` and constructible from the provided
/// arguments. Doing this before entering the region of the op being
/// translated makes the frame available when translating ops within that
/// region.
template <typename T, typename... Args>
void stackPush(Args &&... args) {
static_assert(
std::is_base_of<StackFrame, T>::value,
"can only push instances of StackFrame on ModuleTranslation stack");
stack.push_back(std::make_unique<T>(std::forward<Args>(args)...));
}
/// Pops the last element from the ModuleTranslation stack.
void stackPop() { stack.pop_back(); }
/// Calls `callback` for every ModuleTranslation stack frame of type `T`
/// starting from the top of the stack.
template <typename T>
WalkResult
stackWalk(llvm::function_ref<WalkResult(const T &)> callback) const {
static_assert(std::is_base_of<StackFrame, T>::value,
"expected T derived from StackFrame");
if (!callback)
return WalkResult::skip();
for (const std::unique_ptr<StackFrame> &frame : llvm::reverse(stack)) {
if (T *ptr = dyn_cast_or_null<T>(frame.get())) {
WalkResult result = callback(*ptr);
if (result.wasInterrupted())
return result;
}
}
return WalkResult::advance();
}
/// RAII object calling stackPush/stackPop on construction/destruction.
template <typename T>
struct SaveStack {
template <typename... Args>
explicit SaveStack(ModuleTranslation &m, Args &&...args)
: moduleTranslation(m) {
moduleTranslation.stackPush<T>(std::forward<Args>(args)...);
}
~SaveStack() { moduleTranslation.stackPop(); }
private:
ModuleTranslation &moduleTranslation;
};
private:
ModuleTranslation(Operation *module,
std::unique_ptr<llvm::Module> llvmModule);
@@ -233,6 +309,10 @@ private:
/// metadata. The metadata is attached to Latch block branches with this
/// attribute.
DenseMap<Attribute, llvm::MDNode *> loopOptionsMetadataMapping;
/// Stack of user-specified state elements, useful when translating operations
/// with regions.
SmallVector<std::unique_ptr<StackFrame>> stack;
};
namespace detail {
@@ -270,4 +350,14 @@ llvm::Value *createNvvmIntrinsicCall(llvm::IRBuilderBase &builder,
} // namespace LLVM
} // namespace mlir
namespace llvm {
template <typename T>
struct isa_impl<T, ::mlir::LLVM::ModuleTranslation::StackFrame> {
static inline bool
doit(const ::mlir::LLVM::ModuleTranslation::StackFrame &frame) {
return frame.getTypeID() == ::mlir::TypeID::get<T>();
}
};
} // namespace llvm
#endif // MLIR_TARGET_LLVMIR_MODULETRANSLATION_H

View File

@@ -23,6 +23,42 @@
using namespace mlir;
namespace {
/// ModuleTranslation stack frame for OpenMP operations. This keeps track of the
/// insertion points for allocas.
class OpenMPAllocaStackFrame
: public LLVM::ModuleTranslation::StackFrameBase<OpenMPAllocaStackFrame> {
public:
explicit OpenMPAllocaStackFrame(llvm::OpenMPIRBuilder::InsertPointTy allocaIP)
: allocaInsertPoint(allocaIP) {}
llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
};
} // namespace
/// Find the insertion point for allocas given the current insertion point for
/// normal operations in the builder.
static llvm::OpenMPIRBuilder::InsertPointTy
findAllocaInsertPoint(llvm::IRBuilderBase &builder,
const LLVM::ModuleTranslation &moduleTranslation) {
// If there is an alloca insertion point on stack, i.e. we are in a nested
// operation and a specific point was provided by some surrounding operation,
// use it.
llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
WalkResult walkResult = moduleTranslation.stackWalk<OpenMPAllocaStackFrame>(
[&](const OpenMPAllocaStackFrame &frame) {
allocaInsertPoint = frame.allocaInsertPoint;
return WalkResult::interrupt();
});
if (walkResult.wasInterrupted())
return allocaInsertPoint;
// Otherwise, insert to the entry block of the surrounding function.
llvm::BasicBlock &funcEntryBlock =
builder.GetInsertBlock()->getParent()->getEntryBlock();
return llvm::OpenMPIRBuilder::InsertPointTy(
&funcEntryBlock, funcEntryBlock.getFirstInsertionPt());
}
/// Converts the given region that appears within an OpenMP dialect operation to
/// LLVM IR, creating a branch from the `sourceBlock` to the entry block of the
/// region, and a branch from any block with an successor-less OpenMP terminator
@@ -91,6 +127,11 @@ convertOmpParallel(Operation &opInst, llvm::IRBuilderBase &builder,
auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
llvm::BasicBlock &continuationBlock) {
// Save the alloca insertion point on ModuleTranslation stack for use in
// nested regions.
LLVM::ModuleTranslation::SaveStack<OpenMPAllocaStackFrame> frame(
moduleTranslation, allocaIP);
// ParallelOp has only one region associated with it.
auto &region = cast<omp::ParallelOp>(opInst).getRegion();
convertOmpOpRegions(region, "omp.par.region", *codeGenIP.getBlock(),
@@ -124,18 +165,14 @@ convertOmpParallel(Operation &opInst, llvm::IRBuilderBase &builder,
pbKind = llvm::omp::getProcBindKind(bind.getValue());
// TODO: Is the Parallel construct cancellable?
bool isCancellable = false;
// TODO: Determine the actual alloca insertion point, e.g., the function
// entry or the alloca insertion point as provided by the body callback
// above.
llvm::OpenMPIRBuilder::InsertPointTy allocaIP(builder.saveIP());
if (failed(bodyGenStatus))
return failure();
llvm::OpenMPIRBuilder::LocationDescription ompLoc(
builder.saveIP(), builder.getCurrentDebugLocation());
builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createParallel(
ompLoc, allocaIP, bodyGenCB, privCB, finiCB, ifCond, numThreads, pbKind,
isCancellable));
return success();
ompLoc, findAllocaInsertPoint(builder, moduleTranslation), bodyGenCB,
privCB, finiCB, ifCond, numThreads, pbKind, isCancellable));
return bodyGenStatus;
}
/// Converts an OpenMP 'master' operation into LLVM IR using OpenMPIRBuilder.
@@ -233,7 +270,6 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
// TODO: this currently assumes WsLoop is semantically similar to SCF loop,
// i.e. it has a positive step, uses signed integer semantics. Reconsider
// this code when WsLoop clearly supports more cases.
llvm::BasicBlock *insertBlock = builder.GetInsertBlock();
llvm::CanonicalLoopInfo *loopInfo =
moduleTranslation.getOpenMPBuilder()->createCanonicalLoop(
ompLoc, bodyGen, lowerBound, upperBound, step, /*IsSigned=*/true,
@@ -241,12 +277,8 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
if (failed(bodyGenStatus))
return failure();
// TODO: get the alloca insertion point from the parallel operation builder.
// If we insert the at the top of the current function, they will be passed as
// extra arguments into the function the parallel operation builder outlines.
// Put them at the start of the current block for now.
llvm::OpenMPIRBuilder::InsertPointTy allocaIP(
insertBlock, insertBlock->getFirstInsertionPt());
llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
findAllocaInsertPoint(builder, moduleTranslation);
llvm::OpenMPIRBuilder::InsertPointTy afterIP;
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
if (isStatic) {

View File

@@ -755,6 +755,8 @@ ModuleTranslation::getOrInsertNamedModuleMetadata(StringRef name) {
return llvmModule->getOrInsertNamedMetadata(name);
}
void ModuleTranslation::StackFrame::anchor() {}
static std::unique_ptr<llvm::Module>
prepareLLVMModule(Operation *m, llvm::LLVMContext &llvmContext,
StringRef name) {

View File

@@ -151,6 +151,13 @@ llvm.func @test_omp_parallel_num_threads_3() -> () {
// CHECK: define void @test_omp_parallel_if_1(i32 %[[IF_VAR_1:.*]])
llvm.func @test_omp_parallel_if_1(%arg0: i32) -> () {
// Check that the allocas are emitted by the OpenMPIRBuilder at the top of the
// function, before the condition. Allocas are only emitted by the builder when
// the `if` clause is present. We match specific SSA value names since LLVM
// actually produces those names.
// CHECK: %tid.addr{{.*}} = alloca i32
// CHECK: %zero.addr{{.*}} = alloca i32
// CHECK: %[[IF_COND_VAR_1:.*]] = icmp slt i32 %[[IF_VAR_1]], 0
%0 = llvm.mlir.constant(0 : index) : i32
%1 = llvm.icmp "slt" %arg0, %0 : i32
@@ -184,6 +191,60 @@ llvm.func @test_omp_parallel_if_1(%arg0: i32) -> () {
// CHECK: define internal void @[[OMP_OUTLINED_FN_IF_1]]
// CHECK: call void @__kmpc_barrier
// -----
// CHECK-LABEL: @test_nested_alloca_ip
llvm.func @test_nested_alloca_ip(%arg0: i32) -> () {
// Check that the allocas are emitted by the OpenMPIRBuilder at the top of
// the function, before the condition. Allocas are only emitted by the
// builder when the `if` clause is present. We match specific SSA value names
// since LLVM actually produces those names and ensure they come before the
// "icmp" that is the first operation we emit.
// CHECK: %tid.addr{{.*}} = alloca i32
// CHECK: %zero.addr{{.*}} = alloca i32
// CHECK: icmp slt i32 %{{.*}}, 0
%0 = llvm.mlir.constant(0 : index) : i32
%1 = llvm.icmp "slt" %arg0, %0 : i32
omp.parallel if(%1 : i1) {
// The "parallel" operation will be outlined, check the the function is
// produced. Inside that function, further allocas should be placed before
// another "icmp".
// CHECK: define
// CHECK: %tid.addr{{.*}} = alloca i32
// CHECK: %zero.addr{{.*}} = alloca i32
// CHECK: icmp slt i32 %{{.*}}, 1
%2 = llvm.mlir.constant(1 : index) : i32
%3 = llvm.icmp "slt" %arg0, %2 : i32
omp.parallel if(%3 : i1) {
// One more nesting level.
// CHECK: define
// CHECK: %tid.addr{{.*}} = alloca i32
// CHECK: %zero.addr{{.*}} = alloca i32
// CHECK: icmp slt i32 %{{.*}}, 2
%4 = llvm.mlir.constant(2 : index) : i32
%5 = llvm.icmp "slt" %arg0, %4 : i32
omp.parallel if(%5 : i1) {
omp.barrier
omp.terminator
}
omp.barrier
omp.terminator
}
omp.barrier
omp.terminator
}
llvm.return
}
// -----
// CHECK-LABEL: define void @test_omp_parallel_3()
llvm.func @test_omp_parallel_3() -> () {
// CHECK: [[OMP_THREAD_3_1:%.*]] = call i32 @__kmpc_global_thread_num(%struct.ident_t* @{{[0-9]+}})