mirror of
https://github.com/intel/llvm.git
synced 2026-01-24 17:01:00 +08:00
[mlir] OpenMP-to-LLVM: properly set outer alloca insertion point
Previously, the OpenMP to LLVM IR conversion was setting the alloca insertion point to the same position as the main compuation when converting OpenMP `parallel` operations. This is problematic if, for example, the `parallel` operation is placed inside a loop and would keep allocating on stack on each iteration leading to stack overflow. Reviewed By: kiranchandramohan Differential Revision: https://reviews.llvm.org/D101307
This commit is contained in:
@@ -176,6 +176,82 @@ public:
|
||||
/// it if it does not exist.
|
||||
llvm::NamedMDNode *getOrInsertNamedModuleMetadata(StringRef name);
|
||||
|
||||
/// Common CRTP base class for ModuleTranslation stack frames.
|
||||
class StackFrame {
|
||||
public:
|
||||
virtual ~StackFrame() {}
|
||||
TypeID getTypeID() const { return typeID; }
|
||||
|
||||
protected:
|
||||
explicit StackFrame(TypeID typeID) : typeID(typeID) {}
|
||||
|
||||
private:
|
||||
const TypeID typeID;
|
||||
virtual void anchor();
|
||||
};
|
||||
|
||||
/// Concrete CRTP base class for ModuleTranslation stack frames. When
|
||||
/// translating operations with regions, users of ModuleTranslation can store
|
||||
/// state on ModuleTranslation stack before entering the region and inspect
|
||||
/// it when converting operations nested within that region. Users are
|
||||
/// expected to derive this class and put any relevant information into fields
|
||||
/// of the derived class. The usual isa/dyn_cast functionality is available
|
||||
/// for instances of derived classes.
|
||||
template <typename Derived>
|
||||
class StackFrameBase : public StackFrame {
|
||||
public:
|
||||
explicit StackFrameBase() : StackFrame(TypeID::get<Derived>()) {}
|
||||
};
|
||||
|
||||
/// Creates a stack frame of type `T` on ModuleTranslation stack. `T` must
|
||||
/// be derived from `StackFrameBase<T>` and constructible from the provided
|
||||
/// arguments. Doing this before entering the region of the op being
|
||||
/// translated makes the frame available when translating ops within that
|
||||
/// region.
|
||||
template <typename T, typename... Args>
|
||||
void stackPush(Args &&... args) {
|
||||
static_assert(
|
||||
std::is_base_of<StackFrame, T>::value,
|
||||
"can only push instances of StackFrame on ModuleTranslation stack");
|
||||
stack.push_back(std::make_unique<T>(std::forward<Args>(args)...));
|
||||
}
|
||||
|
||||
/// Pops the last element from the ModuleTranslation stack.
|
||||
void stackPop() { stack.pop_back(); }
|
||||
|
||||
/// Calls `callback` for every ModuleTranslation stack frame of type `T`
|
||||
/// starting from the top of the stack.
|
||||
template <typename T>
|
||||
WalkResult
|
||||
stackWalk(llvm::function_ref<WalkResult(const T &)> callback) const {
|
||||
static_assert(std::is_base_of<StackFrame, T>::value,
|
||||
"expected T derived from StackFrame");
|
||||
if (!callback)
|
||||
return WalkResult::skip();
|
||||
for (const std::unique_ptr<StackFrame> &frame : llvm::reverse(stack)) {
|
||||
if (T *ptr = dyn_cast_or_null<T>(frame.get())) {
|
||||
WalkResult result = callback(*ptr);
|
||||
if (result.wasInterrupted())
|
||||
return result;
|
||||
}
|
||||
}
|
||||
return WalkResult::advance();
|
||||
}
|
||||
|
||||
/// RAII object calling stackPush/stackPop on construction/destruction.
|
||||
template <typename T>
|
||||
struct SaveStack {
|
||||
template <typename... Args>
|
||||
explicit SaveStack(ModuleTranslation &m, Args &&...args)
|
||||
: moduleTranslation(m) {
|
||||
moduleTranslation.stackPush<T>(std::forward<Args>(args)...);
|
||||
}
|
||||
~SaveStack() { moduleTranslation.stackPop(); }
|
||||
|
||||
private:
|
||||
ModuleTranslation &moduleTranslation;
|
||||
};
|
||||
|
||||
private:
|
||||
ModuleTranslation(Operation *module,
|
||||
std::unique_ptr<llvm::Module> llvmModule);
|
||||
@@ -233,6 +309,10 @@ private:
|
||||
/// metadata. The metadata is attached to Latch block branches with this
|
||||
/// attribute.
|
||||
DenseMap<Attribute, llvm::MDNode *> loopOptionsMetadataMapping;
|
||||
|
||||
/// Stack of user-specified state elements, useful when translating operations
|
||||
/// with regions.
|
||||
SmallVector<std::unique_ptr<StackFrame>> stack;
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
@@ -270,4 +350,14 @@ llvm::Value *createNvvmIntrinsicCall(llvm::IRBuilderBase &builder,
|
||||
} // namespace LLVM
|
||||
} // namespace mlir
|
||||
|
||||
namespace llvm {
|
||||
template <typename T>
|
||||
struct isa_impl<T, ::mlir::LLVM::ModuleTranslation::StackFrame> {
|
||||
static inline bool
|
||||
doit(const ::mlir::LLVM::ModuleTranslation::StackFrame &frame) {
|
||||
return frame.getTypeID() == ::mlir::TypeID::get<T>();
|
||||
}
|
||||
};
|
||||
} // namespace llvm
|
||||
|
||||
#endif // MLIR_TARGET_LLVMIR_MODULETRANSLATION_H
|
||||
|
||||
@@ -23,6 +23,42 @@
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
/// ModuleTranslation stack frame for OpenMP operations. This keeps track of the
|
||||
/// insertion points for allocas.
|
||||
class OpenMPAllocaStackFrame
|
||||
: public LLVM::ModuleTranslation::StackFrameBase<OpenMPAllocaStackFrame> {
|
||||
public:
|
||||
explicit OpenMPAllocaStackFrame(llvm::OpenMPIRBuilder::InsertPointTy allocaIP)
|
||||
: allocaInsertPoint(allocaIP) {}
|
||||
llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
/// Find the insertion point for allocas given the current insertion point for
|
||||
/// normal operations in the builder.
|
||||
static llvm::OpenMPIRBuilder::InsertPointTy
|
||||
findAllocaInsertPoint(llvm::IRBuilderBase &builder,
|
||||
const LLVM::ModuleTranslation &moduleTranslation) {
|
||||
// If there is an alloca insertion point on stack, i.e. we are in a nested
|
||||
// operation and a specific point was provided by some surrounding operation,
|
||||
// use it.
|
||||
llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
|
||||
WalkResult walkResult = moduleTranslation.stackWalk<OpenMPAllocaStackFrame>(
|
||||
[&](const OpenMPAllocaStackFrame &frame) {
|
||||
allocaInsertPoint = frame.allocaInsertPoint;
|
||||
return WalkResult::interrupt();
|
||||
});
|
||||
if (walkResult.wasInterrupted())
|
||||
return allocaInsertPoint;
|
||||
|
||||
// Otherwise, insert to the entry block of the surrounding function.
|
||||
llvm::BasicBlock &funcEntryBlock =
|
||||
builder.GetInsertBlock()->getParent()->getEntryBlock();
|
||||
return llvm::OpenMPIRBuilder::InsertPointTy(
|
||||
&funcEntryBlock, funcEntryBlock.getFirstInsertionPt());
|
||||
}
|
||||
|
||||
/// Converts the given region that appears within an OpenMP dialect operation to
|
||||
/// LLVM IR, creating a branch from the `sourceBlock` to the entry block of the
|
||||
/// region, and a branch from any block with an successor-less OpenMP terminator
|
||||
@@ -91,6 +127,11 @@ convertOmpParallel(Operation &opInst, llvm::IRBuilderBase &builder,
|
||||
|
||||
auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
|
||||
llvm::BasicBlock &continuationBlock) {
|
||||
// Save the alloca insertion point on ModuleTranslation stack for use in
|
||||
// nested regions.
|
||||
LLVM::ModuleTranslation::SaveStack<OpenMPAllocaStackFrame> frame(
|
||||
moduleTranslation, allocaIP);
|
||||
|
||||
// ParallelOp has only one region associated with it.
|
||||
auto ®ion = cast<omp::ParallelOp>(opInst).getRegion();
|
||||
convertOmpOpRegions(region, "omp.par.region", *codeGenIP.getBlock(),
|
||||
@@ -124,18 +165,14 @@ convertOmpParallel(Operation &opInst, llvm::IRBuilderBase &builder,
|
||||
pbKind = llvm::omp::getProcBindKind(bind.getValue());
|
||||
// TODO: Is the Parallel construct cancellable?
|
||||
bool isCancellable = false;
|
||||
// TODO: Determine the actual alloca insertion point, e.g., the function
|
||||
// entry or the alloca insertion point as provided by the body callback
|
||||
// above.
|
||||
llvm::OpenMPIRBuilder::InsertPointTy allocaIP(builder.saveIP());
|
||||
if (failed(bodyGenStatus))
|
||||
return failure();
|
||||
|
||||
llvm::OpenMPIRBuilder::LocationDescription ompLoc(
|
||||
builder.saveIP(), builder.getCurrentDebugLocation());
|
||||
builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createParallel(
|
||||
ompLoc, allocaIP, bodyGenCB, privCB, finiCB, ifCond, numThreads, pbKind,
|
||||
isCancellable));
|
||||
return success();
|
||||
ompLoc, findAllocaInsertPoint(builder, moduleTranslation), bodyGenCB,
|
||||
privCB, finiCB, ifCond, numThreads, pbKind, isCancellable));
|
||||
|
||||
return bodyGenStatus;
|
||||
}
|
||||
|
||||
/// Converts an OpenMP 'master' operation into LLVM IR using OpenMPIRBuilder.
|
||||
@@ -233,7 +270,6 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
|
||||
// TODO: this currently assumes WsLoop is semantically similar to SCF loop,
|
||||
// i.e. it has a positive step, uses signed integer semantics. Reconsider
|
||||
// this code when WsLoop clearly supports more cases.
|
||||
llvm::BasicBlock *insertBlock = builder.GetInsertBlock();
|
||||
llvm::CanonicalLoopInfo *loopInfo =
|
||||
moduleTranslation.getOpenMPBuilder()->createCanonicalLoop(
|
||||
ompLoc, bodyGen, lowerBound, upperBound, step, /*IsSigned=*/true,
|
||||
@@ -241,12 +277,8 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
|
||||
if (failed(bodyGenStatus))
|
||||
return failure();
|
||||
|
||||
// TODO: get the alloca insertion point from the parallel operation builder.
|
||||
// If we insert the at the top of the current function, they will be passed as
|
||||
// extra arguments into the function the parallel operation builder outlines.
|
||||
// Put them at the start of the current block for now.
|
||||
llvm::OpenMPIRBuilder::InsertPointTy allocaIP(
|
||||
insertBlock, insertBlock->getFirstInsertionPt());
|
||||
llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
|
||||
findAllocaInsertPoint(builder, moduleTranslation);
|
||||
llvm::OpenMPIRBuilder::InsertPointTy afterIP;
|
||||
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
|
||||
if (isStatic) {
|
||||
|
||||
@@ -755,6 +755,8 @@ ModuleTranslation::getOrInsertNamedModuleMetadata(StringRef name) {
|
||||
return llvmModule->getOrInsertNamedMetadata(name);
|
||||
}
|
||||
|
||||
void ModuleTranslation::StackFrame::anchor() {}
|
||||
|
||||
static std::unique_ptr<llvm::Module>
|
||||
prepareLLVMModule(Operation *m, llvm::LLVMContext &llvmContext,
|
||||
StringRef name) {
|
||||
|
||||
@@ -151,6 +151,13 @@ llvm.func @test_omp_parallel_num_threads_3() -> () {
|
||||
// CHECK: define void @test_omp_parallel_if_1(i32 %[[IF_VAR_1:.*]])
|
||||
llvm.func @test_omp_parallel_if_1(%arg0: i32) -> () {
|
||||
|
||||
// Check that the allocas are emitted by the OpenMPIRBuilder at the top of the
|
||||
// function, before the condition. Allocas are only emitted by the builder when
|
||||
// the `if` clause is present. We match specific SSA value names since LLVM
|
||||
// actually produces those names.
|
||||
// CHECK: %tid.addr{{.*}} = alloca i32
|
||||
// CHECK: %zero.addr{{.*}} = alloca i32
|
||||
|
||||
// CHECK: %[[IF_COND_VAR_1:.*]] = icmp slt i32 %[[IF_VAR_1]], 0
|
||||
%0 = llvm.mlir.constant(0 : index) : i32
|
||||
%1 = llvm.icmp "slt" %arg0, %0 : i32
|
||||
@@ -184,6 +191,60 @@ llvm.func @test_omp_parallel_if_1(%arg0: i32) -> () {
|
||||
// CHECK: define internal void @[[OMP_OUTLINED_FN_IF_1]]
|
||||
// CHECK: call void @__kmpc_barrier
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @test_nested_alloca_ip
|
||||
llvm.func @test_nested_alloca_ip(%arg0: i32) -> () {
|
||||
|
||||
// Check that the allocas are emitted by the OpenMPIRBuilder at the top of
|
||||
// the function, before the condition. Allocas are only emitted by the
|
||||
// builder when the `if` clause is present. We match specific SSA value names
|
||||
// since LLVM actually produces those names and ensure they come before the
|
||||
// "icmp" that is the first operation we emit.
|
||||
// CHECK: %tid.addr{{.*}} = alloca i32
|
||||
// CHECK: %zero.addr{{.*}} = alloca i32
|
||||
// CHECK: icmp slt i32 %{{.*}}, 0
|
||||
%0 = llvm.mlir.constant(0 : index) : i32
|
||||
%1 = llvm.icmp "slt" %arg0, %0 : i32
|
||||
|
||||
omp.parallel if(%1 : i1) {
|
||||
// The "parallel" operation will be outlined, check the the function is
|
||||
// produced. Inside that function, further allocas should be placed before
|
||||
// another "icmp".
|
||||
// CHECK: define
|
||||
// CHECK: %tid.addr{{.*}} = alloca i32
|
||||
// CHECK: %zero.addr{{.*}} = alloca i32
|
||||
// CHECK: icmp slt i32 %{{.*}}, 1
|
||||
%2 = llvm.mlir.constant(1 : index) : i32
|
||||
%3 = llvm.icmp "slt" %arg0, %2 : i32
|
||||
|
||||
omp.parallel if(%3 : i1) {
|
||||
// One more nesting level.
|
||||
// CHECK: define
|
||||
// CHECK: %tid.addr{{.*}} = alloca i32
|
||||
// CHECK: %zero.addr{{.*}} = alloca i32
|
||||
// CHECK: icmp slt i32 %{{.*}}, 2
|
||||
|
||||
%4 = llvm.mlir.constant(2 : index) : i32
|
||||
%5 = llvm.icmp "slt" %arg0, %4 : i32
|
||||
|
||||
omp.parallel if(%5 : i1) {
|
||||
omp.barrier
|
||||
omp.terminator
|
||||
}
|
||||
|
||||
omp.barrier
|
||||
omp.terminator
|
||||
}
|
||||
omp.barrier
|
||||
omp.terminator
|
||||
}
|
||||
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: define void @test_omp_parallel_3()
|
||||
llvm.func @test_omp_parallel_3() -> () {
|
||||
// CHECK: [[OMP_THREAD_3_1:%.*]] = call i32 @__kmpc_global_thread_num(%struct.ident_t* @{{[0-9]+}})
|
||||
|
||||
Reference in New Issue
Block a user