Add basic infrastructure for instrumenting pass execution and analysis computation. A virtual class, PassInstrumentation, is provided to allow for different parts of the pass manager infrastructure. The currently available hooks allow for instrumenting:

* before/after pass execution
* after a pass fails
* before/after an analysis is computed

After getting this infrastructure in place, we can start providing common developer utilities like pass timing, IR printing after pass execution, etc.

PiperOrigin-RevId: 237709692
This commit is contained in:
River Riddle
2019-03-10 14:45:47 -07:00
committed by jpienaar
parent 861eb87471
commit 2d2b40bce5
6 changed files with 243 additions and 30 deletions

View File

@@ -19,9 +19,11 @@
#define MLIR_PASS_ANALYSISMANAGER_H
#include "mlir/IR/Module.h"
#include "mlir/Pass/PassInstrumentation.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/TypeName.h"
namespace mlir {
/// A special type used by analyses to provide an address that identifies a
@@ -100,20 +102,36 @@ template <typename IRUnitT> class AnalysisMap {
using ConceptMap =
DenseMap<const AnalysisID *, std::unique_ptr<AnalysisConcept>>;
/// Utility to return the name of the given analysis class.
template <typename AnalysisT> static llvm::StringRef getAnalysisName() {
StringRef name = llvm::getTypeName<AnalysisT>();
if (!name.consume_front("mlir::"))
name.consume_front("(anonymous namespace)::");
return name;
}
public:
explicit AnalysisMap(IRUnitT *ir) : ir(ir) {}
/// Get an analysis for the current IR unit, computing it if necessary.
template <typename AnalysisT> AnalysisT &getAnalysis() {
template <typename AnalysisT> AnalysisT &getAnalysis(PassInstrumentor *pi) {
auto *id = AnalysisID::getID<AnalysisT>();
typename ConceptMap::iterator it;
bool wasInserted;
std::tie(it, wasInserted) =
analyses.try_emplace(AnalysisID::getID<AnalysisT>());
std::tie(it, wasInserted) = analyses.try_emplace(id);
// If we don't have a cached analysis for this function, compute it directly
// and add it to the cache.
if (wasInserted)
if (wasInserted) {
if (pi)
pi->runBeforeAnalysis(getAnalysisName<AnalysisT>(), id, ir);
it->second = llvm::make_unique<AnalysisModel<AnalysisT>>(ir);
if (pi)
pi->runAfterAnalysis(getAnalysisName<AnalysisT>(), id, ir);
}
return static_cast<AnalysisModel<AnalysisT> &>(*it->second).analysis;
}
@@ -154,6 +172,7 @@ private:
//===----------------------------------------------------------------------===//
// Analysis Management
//===----------------------------------------------------------------------===//
class ModuleAnalysisManager;
/// An analysis manager for a specific function instance. This class can only be
/// constructed from a ModuleAnalysisManager instance.
@@ -163,13 +182,11 @@ public:
// exist and if it does it may be stale.
template <typename AnalysisT>
llvm::Optional<std::reference_wrapper<AnalysisT>>
getCachedModuleAnalysis() const {
return parentImpl->getCachedAnalysis<AnalysisT>();
}
getCachedModuleAnalysis() const;
// Query for the given analysis for the current function.
template <typename AnalysisT> AnalysisT &getAnalysis() {
return impl->getAnalysis<AnalysisT>();
return impl->getAnalysis<AnalysisT>(getPassInstrumentor());
}
// Query for a cached entry of the given analysis on the current function.
@@ -189,14 +206,17 @@ public:
/// Clear any held analyses.
void clear() { impl->clear(); }
private:
FunctionAnalysisManager(const detail::AnalysisMap<Module> *parentImpl,
detail::AnalysisMap<Function> *impl)
: parentImpl(parentImpl), impl(impl) {}
/// Returns a pass instrumentation object for the current function. This value
/// may be null.
PassInstrumentor *getPassInstrumentor() const;
/// A reference to the analysis map of the parent module within the owning
/// analysis manager.
const detail::AnalysisMap<Module> *parentImpl;
private:
FunctionAnalysisManager(const ModuleAnalysisManager *parent,
detail::AnalysisMap<Function> *impl)
: parent(parent), impl(impl) {}
/// A reference to the parent analysis manager.
const ModuleAnalysisManager *parent;
/// A reference to the impl analysis map within the owning analysis manager.
detail::AnalysisMap<Function> *impl;
@@ -208,7 +228,8 @@ private:
/// An analysis manager for a specific module instance.
class ModuleAnalysisManager {
public:
ModuleAnalysisManager(Module *module) : moduleAnalyses(module) {}
ModuleAnalysisManager(Module *module, PassInstrumentor *passInstrumentor)
: moduleAnalyses(module), passInstrumentor(passInstrumentor) {}
ModuleAnalysisManager(const ModuleAnalysisManager &) = delete;
ModuleAnalysisManager &operator=(const ModuleAnalysisManager &) = delete;
@@ -232,7 +253,7 @@ public:
/// Query for the analysis for the module. The analysis is computed if it does
/// not exist.
template <typename AnalysisT> AnalysisT &getAnalysis() {
return moduleAnalyses.getAnalysis<AnalysisT>();
return moduleAnalyses.getAnalysis<AnalysisT>(getPassInstrumentor());
}
/// Query for a cached analysis for the module, or return null.
@@ -247,14 +268,29 @@ public:
/// Invalidate any non preserved analyses.
void invalidate(const detail::PreservedAnalyses &pa);
/// Returns a pass instrumentation object for the current module. This value
/// may be null.
PassInstrumentor *getPassInstrumentor() const { return passInstrumentor; }
private:
/// The cached analyses for functions within the current module.
DenseMap<Function *, detail::AnalysisMap<Function>> functionAnalyses;
/// The analyses for the owning module.
detail::AnalysisMap<Module> moduleAnalyses;
/// An optional instrumentation object.
PassInstrumentor *passInstrumentor;
};
// Query for a cached analysis on the parent Module. The analysis may not exist
// and if it does it may be stale.
template <typename AnalysisT>
llvm::Optional<std::reference_wrapper<AnalysisT>>
FunctionAnalysisManager::getCachedModuleAnalysis() const {
return parent->getCachedAnalysis<AnalysisT>();
}
} // end namespace mlir
#endif // MLIR_PASS_ANALYSISMANAGER_H

View File

@@ -47,6 +47,9 @@ public:
/// Return the kind of this pass.
Kind getKind() const { return passIDAndKind.getInt(); }
/// Returns the derived pass name.
virtual StringRef getName() = 0;
protected:
Pass(const PassID *passID, Kind kind) : passIDAndKind(passID, kind) {}
@@ -181,8 +184,7 @@ class PassModel : public BasePassT {
protected:
PassModel() : BasePassT(PassID::getID<PassT>()) {}
/// TODO(riverriddle) Provide additional utilities for cloning, getting the
/// derived class name, etc.
/// TODO(riverriddle) Provide additional utilities for cloning, etc.
/// Signal that some invariant was broken when running. The IR is allowed to
/// be in an invalid state.
@@ -214,6 +216,14 @@ protected:
void markAnalysesPreserved(const AnalysisID *id) {
this->getPassState().preservedAnalyses.preserve(id);
}
/// Returns the derived pass name.
StringRef getName() override {
StringRef name = llvm::getTypeName<PassT>();
if (!name.consume_front("mlir::"))
name.consume_front("(anonymous namespace)::");
return name;
}
};
} // end namespace detail

View File

@@ -0,0 +1,118 @@
//===- PassInstrumentation.h ------------------------------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef MLIR_PASS_PASSINSTRUMENTATION_H_
#define MLIR_PASS_PASSINSTRUMENTATION_H_
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/Any.h"
#include "llvm/ADT/StringRef.h"
namespace mlir {
struct AnalysisID;
class Pass;
/// PassInstrumentation provdes several entry points into the pass manager
/// infrastructure. Instrumentations should be added directly to a PassManager
/// before running a pipeline.
class PassInstrumentation {
public:
virtual ~PassInstrumentation() = 0;
/// A callback to run before a pass is executed. This function takes a pointer
/// to the pass to be executed, as well as an llvm::Any holding a pointer to
/// the IR unit being transformed on.
virtual void runBeforePass(Pass *pass, const llvm::Any &ir) {}
/// A callback to run after a pass is successfully executed. This function
/// takes a pointer to the pass to be executed, as well as an llvm::Any
/// holding a pointer to the IR unit being transformed on.
virtual void runAfterPass(Pass *pass, const llvm::Any &ir) {}
/// A callback to run when a pass execution fails. This function takes a
/// pointer to the pass that was being executed, as well as an llvm::Any
/// holding a pointer to the IR unit that was being transformed. Note
/// that the ir unit may be in an invalid state.
virtual void runAfterPassFailed(Pass *pass, const llvm::Any &ir) {}
/// A callback to run before an analysis is computed. This function takes the
/// name of the analysis to be computed, its AnalysisID, as well as an
/// llvm::Any holding a pointer to the IR unit being analyzed on.
virtual void runBeforeAnalysis(llvm::StringRef name, AnalysisID *id,
const llvm::Any &ir) {}
/// A callback to run before an analysis is computed. This function takes the
/// name of the analysis that was computed, its AnalysisID, as well as an
/// llvm::Any holding a pointer to the IR unit that was analyzed.
virtual void runAfterAnalysis(llvm::StringRef name, AnalysisID *id,
const llvm::Any &ir) {}
};
/// This class holds a collection of PassInstrumentation objects, and invokes
/// their respective call backs.
class PassInstrumentor {
public:
/// See PassInstrumentation::runBeforePass for details.
template <typename IRUnitT> void runBeforePass(Pass *pass, IRUnitT *ir) {
llvm::Any irAny(ir);
for (auto &instr : instrumentations)
instr->runBeforePass(pass, irAny);
}
/// See PassInstrumentation::runAfterPass for details.
template <typename IRUnitT> void runAfterPass(Pass *pass, IRUnitT *ir) {
llvm::Any irAny(ir);
for (auto &instr : instrumentations)
instr->runAfterPass(pass, irAny);
}
/// See PassInstrumentation::runAfterPassFailed for details.
template <typename IRUnitT> void runAfterPassFailed(Pass *pass, IRUnitT *ir) {
llvm::Any irAny(ir);
for (auto &instr : instrumentations)
instr->runAfterPassFailed(pass, irAny);
}
/// See PassInstrumentation::runBeforeAnalysis for details.
template <typename IRUnitT>
void runBeforeAnalysis(llvm::StringRef name, AnalysisID *id, IRUnitT *ir) {
llvm::Any irAny(ir);
for (auto &instr : instrumentations)
instr->runBeforeAnalysis(name, id, irAny);
}
/// See PassInstrumentation::runAfterAnalysis for details.
template <typename IRUnitT>
void runAfterAnalysis(llvm::StringRef name, AnalysisID *id, IRUnitT *ir) {
llvm::Any irAny(ir);
for (auto &instr : instrumentations)
instr->runAfterAnalysis(name, id, irAny);
}
/// Add the given instrumentation to the collection. This takes ownership over
/// the given pointer.
void addInstrumentation(PassInstrumentation *pi) {
instrumentations.emplace_back(pi);
}
private:
std::vector<std::unique_ptr<PassInstrumentation>> instrumentations;
};
} // end namespace mlir
#endif // MLIR_PASS_PASSINSTRUMENTATION_H_

View File

@@ -26,6 +26,8 @@ class FunctionPassBase;
class Module;
class ModulePassBase;
class Pass;
class PassInstrumentation;
class PassInstrumentor;
namespace detail {
class PassExecutor;
@@ -56,6 +58,10 @@ public:
LLVM_NODISCARD
LogicalResult run(Module *module);
/// Add the provided instrumentation to the pass manager. This takes ownership
/// over the given pointer.
void addInstrumentation(PassInstrumentation *pi);
private:
/// A stack of nested pass executors on sub-module IR units, e.g. function.
llvm::SmallVector<detail::PassExecutor *, 1> nestedExecutorStack;
@@ -65,6 +71,9 @@ private:
/// Flag that specifies if the IR should be verified after each pass has run.
bool verifyPasses;
/// A manager for pass instrumentations.
std::unique_ptr<PassInstrumentor> instrumentor;
};
} // end namespace mlir

View File

@@ -40,15 +40,28 @@ LogicalResult FunctionPassBase::run(Function *fn,
// Initialize the pass state.
passState.emplace(fn, fam);
// Instrument before the pass has run.
auto pi = fam.getPassInstrumentor();
if (pi)
pi->runBeforePass(this, fn);
// Invoke the virtual runOnFunction function.
runOnFunction();
// Invalidate any non preserved analyses.
fam.invalidate(passState->preservedAnalyses);
// Return false if the pass signaled a failure.
return passState->irAndPassFailed.getInt() ? LogicalResult::failure()
: LogicalResult::success();
// Instrument after the pass has run.
bool passFailed = passState->irAndPassFailed.getInt();
if (pi) {
if (passFailed)
pi->runAfterPassFailed(this, fn);
else
pi->runAfterPass(this, fn);
}
// Return if the pass signaled a failure.
return passFailed ? LogicalResult::failure() : LogicalResult::success();
}
/// Forwarding function to execute this pass.
@@ -56,15 +69,28 @@ LogicalResult ModulePassBase::run(Module *module, ModuleAnalysisManager &mam) {
// Initialize the pass state.
passState.emplace(module, mam);
// Instrument before the pass has run.
auto pi = mam.getPassInstrumentor();
if (pi)
pi->runBeforePass(this, module);
// Invoke the virtual runOnModule function.
runOnModule();
// Invalidate any non preserved analyses.
mam.invalidate(passState->preservedAnalyses);
// Return false if the pass signaled a failure.
return passState->irAndPassFailed.getInt() ? LogicalResult::failure()
: LogicalResult::success();
// Instrument after the pass has run.
bool passFailed = passState->irAndPassFailed.getInt();
if (pi) {
if (passFailed)
pi->runAfterPassFailed(this, module);
else
pi->runAfterPass(this, module);
}
// Return if the pass signaled a failure.
return passFailed ? LogicalResult::failure() : LogicalResult::success();
}
//===----------------------------------------------------------------------===//
@@ -290,20 +316,34 @@ void PassManager::addPass(FunctionPassBase *pass) {
/// Run the passes within this manager on the provided module.
LogicalResult PassManager::run(Module *module) {
ModuleAnalysisManager mam(module);
ModuleAnalysisManager mam(module, instrumentor.get());
return mpe->run(module, mam);
}
/// Add the provided instrumentation to the pass manager. This takes ownership
/// over the given pointer.
void PassManager::addInstrumentation(PassInstrumentation *pi) {
if (!instrumentor)
instrumentor.reset(new PassInstrumentor());
instrumentor->addInstrumentation(pi);
}
//===----------------------------------------------------------------------===//
// AnalysisManager
//===----------------------------------------------------------------------===//
/// Returns a pass instrumentation object for the current function.
PassInstrumentor *FunctionAnalysisManager::getPassInstrumentor() const {
return parent->getPassInstrumentor();
}
/// Create an analysis slice for the given child function.
FunctionAnalysisManager ModuleAnalysisManager::slice(Function *function) {
assert(function->getModule() == moduleAnalyses.getIRUnit() &&
"function has a different parent module");
auto it = functionAnalyses.try_emplace(function, function);
return {&moduleAnalyses, &it.first->second};
return {this, &it.first->second};
}
/// Invalidate any non preserved analyses.

View File

@@ -38,7 +38,7 @@ TEST(AnalysisManagerTest, FineGrainModuleAnalysisPreservation) {
// Test fine grain invalidation of the module analysis manager.
std::unique_ptr<Module> module(new Module(&context));
ModuleAnalysisManager mam(&*module);
ModuleAnalysisManager mam(&*module, /*passInstrumentor=*/nullptr);
// Query two different analyses, but only preserve one before invalidating.
mam.getAnalysis<MyAnalysis>();
@@ -65,7 +65,7 @@ TEST(AnalysisManagerTest, FineGrainFunctionAnalysisPreservation) {
module->getFunctions().push_back(func1);
// Test fine grain invalidation of the function analysis manager.
ModuleAnalysisManager mam(&*module);
ModuleAnalysisManager mam(&*module, /*passInstrumentor=*/nullptr);
FunctionAnalysisManager fam = mam.slice(func1);
// Query two different analyses, but only preserve one before invalidating.
@@ -94,7 +94,7 @@ TEST(AnalysisManagerTest, FineGrainChildFunctionAnalysisPreservation) {
// Test fine grain invalidation of a function analysis from within a module
// analysis manager.
ModuleAnalysisManager mam(&*module);
ModuleAnalysisManager mam(&*module, /*passInstrumentor=*/nullptr);
// Query two different analyses, but only preserve one before invalidating.
mam.getFunctionAnalysis<MyAnalysis>(func1);