Pass the pointer of the parent pipeline collection pass to PassInstrumentation::run*Pipeline.

For the cases where there are multiple levels of nested pass managers, the parent thread ID is not enough to distinguish the parent of a given pass pipeline. Passing in the parent pass gives an exact anchor point.

PiperOrigin-RevId: 272105461
This commit is contained in:
River Riddle
2019-09-30 17:44:31 -07:00
committed by A. Unique TensorFlower
parent fb41df9c4a
commit 1c649d5785
3 changed files with 76 additions and 35 deletions

View File

@@ -20,6 +20,7 @@
#include "mlir/Support/LLVM.h"
#include "mlir/Support/STLExtras.h"
#include "llvm/ADT/DenseMapInfo.h"
#include "llvm/ADT/StringRef.h"
namespace mlir {
@@ -37,21 +38,31 @@ struct PassInstrumentorImpl;
/// before running a pipeline.
class PassInstrumentation {
public:
/// This struct represents information related to the parent pass of pipeline.
/// It includes information that allows for effectively linking pipelines that
/// run on different threads.
struct PipelineParentInfo {
/// The thread of the parent pass that the current pipeline was spawned
/// from. Note: This is acquired from llvm::get_threadid().
uint64_t parentThreadID;
/// The pass that spawned this pipeline.
Pass *parentPass;
};
virtual ~PassInstrumentation() = 0;
/// A callback to run before a pass pipeline is executed. This function takes
/// the name of the operation type being operated on, and a thread id
/// corresponding to the parent thread this pipeline was spawned from.
/// Note: The parent thread id is collected via llvm::get_threadid().
/// the name of the operation type being operated on, and information related
/// to the parent that spawned this pipeline.
virtual void runBeforePipeline(const OperationName &name,
uint64_t parentThreadID) {}
const PipelineParentInfo &parentInfo) {}
/// A callback to run after a pass pipeline has executed. This function takes
/// the name of the operation type being operated on, and a thread id
/// corresponding to the parent thread this pipeline was spawned from.
/// Note: The parent thread id is collected via llvm::get_threadid().
/// the name of the operation type being operated on, and information related
/// to the parent that spawned this pipeline.
virtual void runAfterPipeline(const OperationName &name,
uint64_t parentThreadID) {}
const PipelineParentInfo &parentInfo) {}
/// A callback to run before a pass is executed. This function takes a pointer
/// to the pass to be executed, as well as the current operation being
@@ -92,10 +103,14 @@ public:
~PassInstrumentor();
/// See PassInstrumentation::runBeforePipeline for details.
void runBeforePipeline(const OperationName &name, uint64_t parentThreadID);
void
runBeforePipeline(const OperationName &name,
const PassInstrumentation::PipelineParentInfo &parentInfo);
/// See PassInstrumentation::runAfterPipeline for details.
void runAfterPipeline(const OperationName &name, uint64_t parentThreadID);
void
runAfterPipeline(const OperationName &name,
const PassInstrumentation::PipelineParentInfo &parentInfo);
/// See PassInstrumentation::runBeforePass for details.
void runBeforePass(Pass *pass, Operation *op);
@@ -121,4 +136,27 @@ private:
} // end namespace mlir
namespace llvm {
template <> struct DenseMapInfo<mlir::PassInstrumentation::PipelineParentInfo> {
using T = mlir::PassInstrumentation::PipelineParentInfo;
using PairInfo = DenseMapInfo<std::pair<uint64_t, void *>>;
static T getEmptyKey() {
auto pair = PairInfo::getEmptyKey();
return {pair.first, reinterpret_cast<mlir::Pass *>(pair.second)};
}
static T getTombstoneKey() {
auto pair = PairInfo::getTombstoneKey();
return {pair.first, reinterpret_cast<mlir::Pass *>(pair.second)};
}
static unsigned getHashValue(T val) {
return PairInfo::getHashValue({val.parentThreadID, val.parentPass});
}
static bool isEqual(T lhs, T rhs) {
return lhs.parentThreadID == rhs.parentThreadID &&
lhs.parentPass == rhs.parentPass;
}
};
} // end namespace llvm
#endif // MLIR_PASS_PASSINSTRUMENTATION_H_

View File

@@ -314,7 +314,8 @@ OpToOpPassAdaptor::OpToOpPassAdaptor(OpPassManager &&mgr)
/// Run the held pipeline over all nested operations.
void OpToOpPassAdaptor::runOnOperation() {
auto am = getAnalysisManager();
auto currentThreadID = llvm::get_threadid();
PassInstrumentation::PipelineParentInfo parentInfo = {llvm::get_threadid(),
this};
auto *instrumentor = am.getPassInstrumentor();
for (auto &region : getOperation()->getRegions()) {
for (auto &block : region) {
@@ -325,10 +326,10 @@ void OpToOpPassAdaptor::runOnOperation() {
// Run the held pipeline over the current operation.
if (instrumentor)
instrumentor->runBeforePipeline(mgr->getOpName(), currentThreadID);
instrumentor->runBeforePipeline(mgr->getOpName(), parentInfo);
auto result = runPipeline(*mgr, &op, am.slice(&op));
if (instrumentor)
instrumentor->runAfterPipeline(mgr->getOpName(), currentThreadID);
instrumentor->runAfterPipeline(mgr->getOpName(), parentInfo);
if (failed(result))
return signalPassFailure();
@@ -380,7 +381,8 @@ void OpToOpPassAdaptorParallel::runOnOperation() {
std::atomic<unsigned> opIt(0);
// Get the current thread for this adaptor.
auto parentThreadID = llvm::get_threadid();
PassInstrumentation::PipelineParentInfo parentInfo = {llvm::get_threadid(),
this};
auto *instrumentor = am.getPassInstrumentor();
// An atomic failure variable for the async executors.
@@ -405,10 +407,10 @@ void OpToOpPassAdaptorParallel::runOnOperation() {
assert(pm && "expected valid pass manager for operation");
if (instrumentor)
instrumentor->runBeforePipeline(pm->getOpName(), parentThreadID);
instrumentor->runBeforePipeline(pm->getOpName(), parentInfo);
auto pipelineResult = runPipeline(*pm, it.first, it.second);
if (instrumentor)
instrumentor->runAfterPipeline(pm->getOpName(), parentThreadID);
instrumentor->runAfterPipeline(pm->getOpName(), parentInfo);
// Drop this thread from being tracked by the diagnostic handler.
// After this task has finished, the thread may be used outside of
@@ -553,19 +555,21 @@ PassInstrumentor::PassInstrumentor() : impl(new PassInstrumentorImpl()) {}
PassInstrumentor::~PassInstrumentor() {}
/// See PassInstrumentation::runBeforePipeline for details.
void PassInstrumentor::runBeforePipeline(const OperationName &name,
uint64_t parentThreadID) {
void PassInstrumentor::runBeforePipeline(
const OperationName &name,
const PassInstrumentation::PipelineParentInfo &parentInfo) {
llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
for (auto &instr : impl->instrumentations)
instr->runBeforePipeline(name, parentThreadID);
instr->runBeforePipeline(name, parentInfo);
}
/// See PassInstrumentation::runAfterPipeline for details.
void PassInstrumentor::runAfterPipeline(const OperationName &name,
uint64_t parentThreadID) {
void PassInstrumentor::runAfterPipeline(
const OperationName &name,
const PassInstrumentation::PipelineParentInfo &parentInfo) {
llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
for (auto &instr : impl->instrumentations)
instr->runAfterPipeline(name, parentThreadID);
for (auto &instr : llvm::reverse(impl->instrumentations))
instr->runAfterPipeline(name, parentInfo);
}
/// See PassInstrumentation::runBeforePass for details.

View File

@@ -93,8 +93,8 @@ struct Timer {
/// Returns the total time for this timer in seconds.
TimeRecord getTotalTime() {
// If we have a valid wall time, then we directly compute the seconds.
if (wallTime.count()) {
// If this is a pass or analysis timer, use the recorded time directly.
if (kind == TimerKind::PassOrAnalysis) {
return TimeRecord(
std::chrono::duration_cast<std::chrono::duration<double>>(wallTime)
.count(),
@@ -174,9 +174,9 @@ struct PassTiming : public PassInstrumentation {
/// Setup the instrumentation hooks.
void runBeforePipeline(const OperationName &name,
uint64_t parentThreadID) override;
const PipelineParentInfo &parentInfo) override;
void runAfterPipeline(const OperationName &name,
uint64_t parentThreadID) override;
const PipelineParentInfo &parentInfo) override;
void runBeforePass(Pass *pass, Operation *) override { startPassTimer(pass); }
void runAfterPass(Pass *pass, Operation *) override;
void runAfterPassFailed(Pass *pass, Operation *op) override {
@@ -245,15 +245,14 @@ struct PassTiming : public PassInstrumentation {
PassTimingDisplayMode displayMode;
/// A mapping of pipeline timers that need to be merged into the parent
/// collection. The timers are mapped to the thread id of the parent thread to
/// merge into.
DenseMap<uint64_t, SmallVector<Timer::ChildrenMap::value_type, 4>>
/// collection. The timers are mapped to the parent info to merge into.
DenseMap<PipelineParentInfo, SmallVector<Timer::ChildrenMap::value_type, 4>>
pipelinesToMerge;
};
} // end anonymous namespace
void PassTiming::runBeforePipeline(const OperationName &name,
uint64_t parentThreadID) {
const PipelineParentInfo &parentInfo) {
// We don't actually want to time the piplelines, they gather their total
// from their held passes.
getTimer(name.getAsOpaquePointer(), TimerKind::Pipeline,
@@ -261,7 +260,7 @@ void PassTiming::runBeforePipeline(const OperationName &name,
}
void PassTiming::runAfterPipeline(const OperationName &name,
uint64_t parentThreadID) {
const PipelineParentInfo &parentInfo) {
// Pop the timer for the pipeline.
auto tid = llvm::get_threadid();
auto &activeTimers = activeThreadTimers[tid];
@@ -270,7 +269,7 @@ void PassTiming::runAfterPipeline(const OperationName &name,
// If the current thread is the same as the parent, there is nothing left to
// do.
if (tid == parentThreadID)
if (tid == parentInfo.parentThreadID)
return;
// Otherwise, mark the pipeline timer for merging into the correct parent
@@ -280,7 +279,7 @@ void PassTiming::runAfterPipeline(const OperationName &name,
assert(parentTimer->children.size() == 1 &&
parentTimer->children.count(name.getAsOpaquePointer()) &&
"expected a single pipeline timer");
pipelinesToMerge[parentThreadID].push_back(
pipelinesToMerge[parentInfo].push_back(
std::move(*parentTimer->children.begin()));
rootTimers.erase(tid);
}
@@ -322,7 +321,7 @@ void PassTiming::runAfterPass(Pass *pass, Operation *) {
// If this is an OpToOpPassAdaptorParallel, then we need to merge in the
// timing data for the pipelines running on other threads.
if (isa<OpToOpPassAdaptorParallel>(pass)) {
auto toMerge = pipelinesToMerge.find(llvm::get_threadid());
auto toMerge = pipelinesToMerge.find({llvm::get_threadid(), pass});
if (toMerge != pipelinesToMerge.end()) {
for (auto &it : toMerge->second)
timer->mergeChild(std::move(it));