mirror of
https://github.com/intel/llvm.git
synced 2026-01-20 10:58:11 +08:00
[MLIR] Split autogenerated pass declarations & C++ controllable pass options
The pass tablegen backend has been reworked to remove the monolithic nature of the autogenerated declarations. The pass public header can be generated with the -gen-pass-decls option. It contains options structs and registrations: the inclusion of options structs can be controlled individually for each pass by defining the GEN_PASS_DECL_PASSNAME macro; the declaration of the registrations have been kept together and can still be included by defining the GEN_PASS_REGISTRATION macro. The private code used for the pass implementation (i.e. the pass base class and the constructors definitions, if missing from tablegen) can be generated with the -gen-pass-defs option. Similarly to the declarations file, the definitions of each pass can be enabled by defining the GEN_PASS_DEF_PASNAME variable. While doing so, the pass base class has been enriched to also accept a the aformentioned struct of options and copy them to the actual pass options, thus allowing each pass to also be configurable within C++ and not only through command line. Reviewed By: rriddle, mehdi_amini, Mogball, jpienaar Differential Revision: https://reviews.llvm.org/D131839
This commit is contained in:
@@ -828,16 +828,18 @@ def MyPass : Pass<"my-pass", "ModuleOp"> {
|
||||
}
|
||||
```
|
||||
|
||||
Using the `gen-pass-decls` generator, we can generate most of the boilerplate
|
||||
above automatically. This generator takes as an input a `-name` parameter, that
|
||||
provides a tag for the group of passes that are being generated. This generator
|
||||
produces two chunks of output:
|
||||
Using the `gen-pass-decls` and `gen-pass-defs` generators, we can generate most
|
||||
of the boilerplate above automatically.
|
||||
|
||||
The first is a code block for registering the declarative passes with the global
|
||||
registry. For each pass, the generator produces a `registerFooPass` where `Foo`
|
||||
is the name of the definition specified in tablegen. It also generates a
|
||||
`registerGroupPasses`, where `Group` is the tag provided via the `-name` input
|
||||
parameter, that registers all of the passes present.
|
||||
The `gen-pass-decls` generator takes as an input a `-name` parameter, that
|
||||
provides a tag for the group of passes that are being generated. This generator
|
||||
produces code with two purposes:
|
||||
|
||||
The first is to register the declared passes with the global registry. For
|
||||
each pass, the generator produces a `registerPassName` where
|
||||
`PassName` is the name of the definition specified in tablegen. It also
|
||||
generates a `registerGroupPasses`, where `Group` is the tag provided via the
|
||||
`-name` input parameter, that registers all of the passes present.
|
||||
|
||||
```c++
|
||||
// gen-pass-decls -name="Example"
|
||||
@@ -850,19 +852,61 @@ void registerMyPasses() {
|
||||
registerExamplePasses();
|
||||
|
||||
// Register `MyPass` specifically.
|
||||
registerMyPassPass();
|
||||
registerMyPass();
|
||||
}
|
||||
```
|
||||
|
||||
The second is a base class for each of the passes, containing most of the boiler
|
||||
The second is to provide a way to configure the pass options. These classes are
|
||||
named in the form of `MyPassOptions`, where `MyPass` is the name of the pass
|
||||
definition in tablegen. The configurable parameters reflect the options
|
||||
declared in the tablegen file. Differently from the registration hooks, these
|
||||
classes can be enabled on a per-pass basis by defining the
|
||||
`GEN_PASS_DECL_PASSNAME` macro, where `PASSNAME` is the uppercase version of
|
||||
the name specified in tablegen.
|
||||
|
||||
```c++
|
||||
// .h.inc
|
||||
|
||||
#ifdef GEN_PASS_DECL_MYPASS
|
||||
|
||||
struct MyPassOptions {
|
||||
bool option = true;
|
||||
::llvm::ArrayRef<int64_t> listOption;
|
||||
};
|
||||
|
||||
#undef GEN_PASS_DECL_MYPASS
|
||||
#endif // GEN_PASS_DECL_MYPASS
|
||||
```
|
||||
|
||||
If the `constructor` field has not been specified in the tablegen declaration,
|
||||
then autogenerated file will also contain the declarations of the default
|
||||
constructors.
|
||||
|
||||
```c++
|
||||
// .h.inc
|
||||
|
||||
#ifdef GEN_PASS_DECL_MYPASS
|
||||
...
|
||||
|
||||
std::unique_ptr<::mlir::Pass> createMyPass();
|
||||
std::unique_ptr<::mlir::Pass> createMyPass(const MyPassOptions &options);
|
||||
|
||||
#undef GEN_PASS_DECL_MYPASS
|
||||
#endif // GEN_PASS_DECL_MYPASS
|
||||
```
|
||||
|
||||
The `gen-pass-defs` generator produces the definitions to be used for the pass
|
||||
implementation.
|
||||
|
||||
It generates a base class for each of the passes, containing most of the boiler
|
||||
plate related to pass definitions. These classes are named in the form of
|
||||
`MyPassBase`, where `MyPass` is the name of the pass definition in tablegen. We
|
||||
can update the original C++ pass definition as so:
|
||||
|
||||
```c++
|
||||
/// Include the generated base pass class definitions.
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "Passes.h.inc"
|
||||
#define GEN_PASS_DEF_MYPASS
|
||||
#include "Passes.cpp.inc"
|
||||
|
||||
/// Define the main class as deriving from the generated base class.
|
||||
struct MyPass : MyPassBase<MyPass> {
|
||||
@@ -874,13 +918,16 @@ struct MyPass : MyPassBase<MyPass> {
|
||||
/// The definitions of the options and statistics are now generated within
|
||||
/// the base class, but are accessible in the same way.
|
||||
};
|
||||
|
||||
/// Expose this pass to the outside world.
|
||||
std::unique_ptr<Pass> foo::createMyPass() {
|
||||
return std::make_unique<MyPass>();
|
||||
}
|
||||
```
|
||||
|
||||
Similarly to the previous generator, the definitions can be enabled on a
|
||||
per-pass basis by defining the appropriate preprocessor `GEN_PASS_DEF_PASSNAME`
|
||||
macro, with `PASSNAME` equal to the uppercase version of the name of the pass
|
||||
definition in tablegen.
|
||||
If the `constructor` field has not been specified in tablegen, then the default
|
||||
constructors are also defined and expect the name of the actual pass class to
|
||||
be equal to the name defined in tablegen.
|
||||
|
||||
Using the `gen-pass-doc` generator, markdown documentation for each of the
|
||||
passes can be generated. See [Passes.md](Passes.md) for example output of real
|
||||
MLIR passes.
|
||||
|
||||
@@ -76,7 +76,10 @@ class PassBase<string passArg, string base> {
|
||||
string description = "";
|
||||
|
||||
// A C++ constructor call to create an instance of this pass.
|
||||
code constructor = [{}];
|
||||
// If empty, the default constructor declarations and definitions
|
||||
// 'createPassName()' and 'createPassName(const PassNameOptions &options)'
|
||||
// will be generated and the former will be used for the pass instantiation.
|
||||
code constructor = "";
|
||||
|
||||
// A list of dialects this pass may produce entities in.
|
||||
list<string> dependentDialects = [];
|
||||
|
||||
@@ -90,6 +90,7 @@ StringRef Pass::getDescription() const {
|
||||
StringRef Pass::getConstructor() const {
|
||||
return def->getValueAsString("constructor");
|
||||
}
|
||||
|
||||
ArrayRef<StringRef> Pass::getDependentDialects() const {
|
||||
return dependentDialects;
|
||||
}
|
||||
|
||||
@@ -97,8 +97,15 @@ static bool emitCAPIImpl(const llvm::RecordKeeper &records, raw_ostream &os) {
|
||||
for (const auto *def : records.getAllDerivedDefinitions("PassBase")) {
|
||||
Pass pass(def);
|
||||
StringRef defName = pass.getDef()->getName();
|
||||
os << llvm::formatv(passCreateDef, groupName, defName,
|
||||
pass.getConstructor());
|
||||
|
||||
std::string constructorCall;
|
||||
if (StringRef constructor = pass.getConstructor(); !constructor.empty())
|
||||
constructorCall = constructor.str();
|
||||
else
|
||||
constructorCall =
|
||||
llvm::formatv("create{0}Pass()", pass.getDef()->getName()).str();
|
||||
|
||||
os << llvm::formatv(passCreateDef, groupName, defName, constructorCall);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -27,6 +27,161 @@ static llvm::cl::opt<std::string>
|
||||
groupName("name", llvm::cl::desc("The name of this group of passes"),
|
||||
llvm::cl::cat(passGenCat));
|
||||
|
||||
static void emitOldPassDecl(const Pass &pass, raw_ostream &os);
|
||||
|
||||
/// Extract the list of passes from the TableGen records.
|
||||
static std::vector<Pass> getPasses(const llvm::RecordKeeper &recordKeeper) {
|
||||
std::vector<Pass> passes;
|
||||
|
||||
for (const auto *def : recordKeeper.getAllDerivedDefinitions("PassBase"))
|
||||
passes.emplace_back(def);
|
||||
|
||||
return passes;
|
||||
}
|
||||
|
||||
const char *const passHeader = R"(
|
||||
//===----------------------------------------------------------------------===//
|
||||
// {0}
|
||||
//===----------------------------------------------------------------------===//
|
||||
)";
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GEN: Pass registration generation
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// The code snippet used to generate a pass registration.
|
||||
///
|
||||
/// {0}: The def name of the pass record.
|
||||
/// {1}: The pass constructor call.
|
||||
const char *const passRegistrationCode = R"(
|
||||
//===----------------------------------------------------------------------===//
|
||||
// {0} Registration
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
inline void register{0}() {{
|
||||
::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {{
|
||||
return {1};
|
||||
});
|
||||
}
|
||||
|
||||
// Old registration code, kept for temporary backwards compatibility.
|
||||
inline void register{0}Pass() {{
|
||||
::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {{
|
||||
return {1};
|
||||
});
|
||||
}
|
||||
)";
|
||||
|
||||
/// The code snippet used to generate a function to register all passes in a
|
||||
/// group.
|
||||
///
|
||||
/// {0}: The name of the pass group.
|
||||
const char *const passGroupRegistrationCode = R"(
|
||||
//===----------------------------------------------------------------------===//
|
||||
// {0} Registration
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
inline void register{0}Passes() {{
|
||||
)";
|
||||
|
||||
/// Emits the definition of the struct to be used to control the pass options.
|
||||
static void emitPassOptionsStruct(const Pass &pass, raw_ostream &os) {
|
||||
StringRef passName = pass.getDef()->getName();
|
||||
ArrayRef<PassOption> options = pass.getOptions();
|
||||
|
||||
// Emit the struct only if the pass has at least one option.
|
||||
if (options.empty())
|
||||
return;
|
||||
|
||||
os << llvm::formatv("struct {0}Options {{\n", passName);
|
||||
|
||||
for (const PassOption &opt : options) {
|
||||
std::string type = opt.getType().str();
|
||||
|
||||
if (opt.isListOption())
|
||||
type = "::llvm::ArrayRef<" + type + ">";
|
||||
|
||||
os.indent(2) << llvm::formatv("{0} {1}", type, opt.getCppVariableName());
|
||||
|
||||
if (Optional<StringRef> defaultVal = opt.getDefaultValue())
|
||||
os << " = " << defaultVal;
|
||||
|
||||
os << ";\n";
|
||||
}
|
||||
|
||||
os << "};\n";
|
||||
}
|
||||
|
||||
/// Emit the code to be included in the public header of the pass.
|
||||
static void emitPassDecls(const Pass &pass, raw_ostream &os) {
|
||||
StringRef passName = pass.getDef()->getName();
|
||||
std::string enableVarName = "GEN_PASS_DECL_" + passName.upper();
|
||||
|
||||
os << "#ifdef " << enableVarName << "\n";
|
||||
os << llvm::formatv(passHeader, passName);
|
||||
|
||||
emitPassOptionsStruct(pass, os);
|
||||
|
||||
if (StringRef constructor = pass.getConstructor(); constructor.empty()) {
|
||||
// Default constructor declaration.
|
||||
os << "std::unique_ptr<::mlir::Pass> create" << passName << "();\n";
|
||||
|
||||
// Declaration of the constructor with options.
|
||||
if (ArrayRef<PassOption> options = pass.getOptions(); !options.empty())
|
||||
os << llvm::formatv("std::unique_ptr<::mlir::Pass> create{0}(const "
|
||||
"{0}Options &options);\n",
|
||||
passName);
|
||||
}
|
||||
|
||||
os << "#undef " << enableVarName << "\n";
|
||||
os << "#endif // " << enableVarName << "\n";
|
||||
}
|
||||
|
||||
/// Emit the code for registering each of the given passes with the global
|
||||
/// PassRegistry.
|
||||
static void emitRegistrations(llvm::ArrayRef<Pass> passes, raw_ostream &os) {
|
||||
os << "#ifdef GEN_PASS_REGISTRATION\n";
|
||||
|
||||
for (const Pass &pass : passes) {
|
||||
std::string constructorCall;
|
||||
if (StringRef constructor = pass.getConstructor(); !constructor.empty())
|
||||
constructorCall = constructor.str();
|
||||
else
|
||||
constructorCall =
|
||||
llvm::formatv("create{0}()", pass.getDef()->getName()).str();
|
||||
|
||||
os << llvm::formatv(passRegistrationCode, pass.getDef()->getName(),
|
||||
constructorCall);
|
||||
}
|
||||
|
||||
os << llvm::formatv(passGroupRegistrationCode, groupName);
|
||||
|
||||
for (const Pass &pass : passes)
|
||||
os << " register" << pass.getDef()->getName() << "();\n";
|
||||
|
||||
os << "}\n";
|
||||
os << "#undef GEN_PASS_REGISTRATION\n";
|
||||
os << "#endif // GEN_PASS_REGISTRATION\n";
|
||||
}
|
||||
|
||||
static void emitDecls(const llvm::RecordKeeper &recordKeeper, raw_ostream &os) {
|
||||
std::vector<Pass> passes = getPasses(recordKeeper);
|
||||
os << "/* Autogenerated by mlir-tblgen; don't manually edit */\n";
|
||||
|
||||
for (const Pass &pass : passes)
|
||||
emitPassDecls(pass, os);
|
||||
|
||||
emitRegistrations(passes, os);
|
||||
|
||||
// TODO drop old pass declarations
|
||||
// Emit the old code until all the passes have switched to the new design.
|
||||
os << "#ifdef GEN_PASS_CLASSES\n";
|
||||
for (const Pass &pass : passes)
|
||||
emitOldPassDecl(pass, os);
|
||||
os << "#undef GEN_PASS_CLASSES\n";
|
||||
os << "#endif // GEN_PASS_CLASSES\n";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GEN: Pass base class generation
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -38,10 +193,159 @@ static llvm::cl::opt<std::string>
|
||||
/// {2): The command line argument for the pass.
|
||||
/// {3}: The dependent dialects registration.
|
||||
const char *const passDeclBegin = R"(
|
||||
//===----------------------------------------------------------------------===//
|
||||
// {0}
|
||||
//===----------------------------------------------------------------------===//
|
||||
template <typename DerivedT>
|
||||
class {0}Base : public {1} {
|
||||
public:
|
||||
using Base = {0}Base;
|
||||
|
||||
{0}Base() : {1}(::mlir::TypeID::get<DerivedT>()) {{}
|
||||
{0}Base(const {0}Base &other) : {1}(other) {{}
|
||||
|
||||
/// Returns the command-line argument attached to this pass.
|
||||
static constexpr ::llvm::StringLiteral getArgumentName() {
|
||||
return ::llvm::StringLiteral("{2}");
|
||||
}
|
||||
::llvm::StringRef getArgument() const override { return "{2}"; }
|
||||
|
||||
::llvm::StringRef getDescription() const override { return "{3}"; }
|
||||
|
||||
/// Returns the derived pass name.
|
||||
static constexpr ::llvm::StringLiteral getPassName() {
|
||||
return ::llvm::StringLiteral("{0}");
|
||||
}
|
||||
::llvm::StringRef getName() const override { return "{0}"; }
|
||||
|
||||
/// Support isa/dyn_cast functionality for the derived pass class.
|
||||
static bool classof(const ::mlir::Pass *pass) {{
|
||||
return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
|
||||
}
|
||||
|
||||
/// A clone method to create a copy of this pass.
|
||||
std::unique_ptr<::mlir::Pass> clonePass() const override {{
|
||||
return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
|
||||
}
|
||||
|
||||
/// Return the dialect that must be loaded in the context before this pass.
|
||||
void getDependentDialects(::mlir::DialectRegistry ®istry) const override {
|
||||
{4}
|
||||
}
|
||||
|
||||
/// Explicitly declare the TypeID for this class. We declare an explicit private
|
||||
/// instantiation because Pass classes should only be visible by the current
|
||||
/// library.
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID({0}Base<DerivedT>)
|
||||
|
||||
)";
|
||||
|
||||
/// Registration for a single dependent dialect, to be inserted for each
|
||||
/// dependent dialect in the `getDependentDialects` above.
|
||||
const char *const dialectRegistrationTemplate = R"(
|
||||
registry.insert<{0}>();
|
||||
)";
|
||||
|
||||
const char *const friendDefaultConstructorTemplate = R"(
|
||||
friend std::unique_ptr<::mlir::Pass> create{0}() {{
|
||||
return std::make_unique<DerivedT>();
|
||||
}
|
||||
)";
|
||||
|
||||
const char *const friendDefaultConstructorWithOptionsTemplate = R"(
|
||||
friend std::unique_ptr<::mlir::Pass> create{0}(const {0}Options &options) {{
|
||||
return std::make_unique<DerivedT>(options);
|
||||
}
|
||||
)";
|
||||
|
||||
/// Emit the declarations for each of the pass options.
|
||||
static void emitPassOptionDecls(const Pass &pass, raw_ostream &os) {
|
||||
for (const PassOption &opt : pass.getOptions()) {
|
||||
os.indent(2) << "::mlir::Pass::"
|
||||
<< (opt.isListOption() ? "ListOption" : "Option");
|
||||
|
||||
os << llvm::formatv(R"(<{0}> {1}{{*this, "{2}", ::llvm::cl::desc("{3}"))",
|
||||
opt.getType(), opt.getCppVariableName(),
|
||||
opt.getArgument(), opt.getDescription());
|
||||
if (Optional<StringRef> defaultVal = opt.getDefaultValue())
|
||||
os << ", ::llvm::cl::init(" << defaultVal << ")";
|
||||
if (Optional<StringRef> additionalFlags = opt.getAdditionalFlags())
|
||||
os << ", " << *additionalFlags;
|
||||
os << "};\n";
|
||||
}
|
||||
}
|
||||
|
||||
/// Emit the declarations for each of the pass statistics.
|
||||
static void emitPassStatisticDecls(const Pass &pass, raw_ostream &os) {
|
||||
for (const PassStatistic &stat : pass.getStatistics()) {
|
||||
os << llvm::formatv(
|
||||
" ::mlir::Pass::Statistic {0}{{this, \"{1}\", \"{2}\"};\n",
|
||||
stat.getCppVariableName(), stat.getName(), stat.getDescription());
|
||||
}
|
||||
}
|
||||
|
||||
/// Emit the code to be used in the implementation of the pass.
|
||||
static void emitPassDefs(const Pass &pass, raw_ostream &os) {
|
||||
StringRef passName = pass.getDef()->getName();
|
||||
std::string enableVarName = "GEN_PASS_DEF_" + passName.upper();
|
||||
|
||||
os << "#ifdef " << enableVarName << "\n";
|
||||
os << llvm::formatv(passHeader, passName);
|
||||
|
||||
std::string dependentDialectRegistrations;
|
||||
{
|
||||
llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations);
|
||||
for (StringRef dependentDialect : pass.getDependentDialects())
|
||||
dialectsOs << llvm::formatv(dialectRegistrationTemplate,
|
||||
dependentDialect);
|
||||
}
|
||||
|
||||
os << llvm::formatv(passDeclBegin, passName, pass.getBaseClass(),
|
||||
pass.getArgument(), pass.getSummary(),
|
||||
dependentDialectRegistrations);
|
||||
|
||||
if (ArrayRef<PassOption> options = pass.getOptions(); !options.empty()) {
|
||||
os.indent(2) << llvm::formatv(
|
||||
"{0}Base(const {0}Options &options) : {0}Base() {{\n", passName);
|
||||
|
||||
for (const PassOption &opt : pass.getOptions())
|
||||
os.indent(4) << llvm::formatv("{0} = options.{0};\n",
|
||||
opt.getCppVariableName());
|
||||
|
||||
os.indent(2) << "}\n";
|
||||
}
|
||||
|
||||
// Protected content
|
||||
os << "protected:\n";
|
||||
emitPassOptionDecls(pass, os);
|
||||
emitPassStatisticDecls(pass, os);
|
||||
|
||||
// Private content
|
||||
os << "private:\n";
|
||||
|
||||
if (pass.getConstructor().empty()) {
|
||||
os << llvm::formatv(friendDefaultConstructorTemplate, passName);
|
||||
|
||||
if (!pass.getOptions().empty())
|
||||
os << llvm::formatv(friendDefaultConstructorWithOptionsTemplate,
|
||||
passName);
|
||||
}
|
||||
|
||||
os << "};\n";
|
||||
|
||||
os << "#undef " << enableVarName << "\n";
|
||||
os << "#endif // " << enableVarName << "\n";
|
||||
}
|
||||
|
||||
static void emitDefs(const llvm::RecordKeeper &recordKeeper, raw_ostream &os) {
|
||||
std::vector<Pass> passes = getPasses(recordKeeper);
|
||||
os << "/* Autogenerated by mlir-tblgen; don't manually edit */\n";
|
||||
|
||||
for (const Pass &pass : passes)
|
||||
emitPassDefs(pass, os);
|
||||
}
|
||||
|
||||
// TODO drop old pass declarations
|
||||
// The old pass base class is being kept until all the passes have switched to
|
||||
// the new decls/defs design.
|
||||
const char *const oldPassDeclBegin = R"(
|
||||
template <typename DerivedT>
|
||||
class {0}Base : public {1} {
|
||||
public:
|
||||
@@ -87,39 +391,8 @@ public:
|
||||
protected:
|
||||
)";
|
||||
|
||||
/// Registration for a single dependent dialect, to be inserted for each
|
||||
/// dependent dialect in the `getDependentDialects` above.
|
||||
const char *const dialectRegistrationTemplate = R"(
|
||||
registry.insert<{0}>();
|
||||
)";
|
||||
|
||||
/// Emit the declarations for each of the pass options.
|
||||
static void emitPassOptionDecls(const Pass &pass, raw_ostream &os) {
|
||||
for (const PassOption &opt : pass.getOptions()) {
|
||||
os.indent(2) << "::mlir::Pass::"
|
||||
<< (opt.isListOption() ? "ListOption" : "Option");
|
||||
|
||||
os << llvm::formatv(R"(<{0}> {1}{{*this, "{2}", ::llvm::cl::desc("{3}"))",
|
||||
opt.getType(), opt.getCppVariableName(),
|
||||
opt.getArgument(), opt.getDescription());
|
||||
if (Optional<StringRef> defaultVal = opt.getDefaultValue())
|
||||
os << ", ::llvm::cl::init(" << defaultVal << ")";
|
||||
if (Optional<StringRef> additionalFlags = opt.getAdditionalFlags())
|
||||
os << ", " << *additionalFlags;
|
||||
os << "};\n";
|
||||
}
|
||||
}
|
||||
|
||||
/// Emit the declarations for each of the pass statistics.
|
||||
static void emitPassStatisticDecls(const Pass &pass, raw_ostream &os) {
|
||||
for (const PassStatistic &stat : pass.getStatistics()) {
|
||||
os << llvm::formatv(
|
||||
" ::mlir::Pass::Statistic {0}{{this, \"{1}\", \"{2}\"};\n",
|
||||
stat.getCppVariableName(), stat.getName(), stat.getDescription());
|
||||
}
|
||||
}
|
||||
|
||||
static void emitPassDecl(const Pass &pass, raw_ostream &os) {
|
||||
/// Emit a backward-compatible declaration of the pass base class.
|
||||
static void emitOldPassDecl(const Pass &pass, raw_ostream &os) {
|
||||
StringRef defName = pass.getDef()->getName();
|
||||
std::string dependentDialectRegistrations;
|
||||
{
|
||||
@@ -128,7 +401,7 @@ static void emitPassDecl(const Pass &pass, raw_ostream &os) {
|
||||
dialectsOs << llvm::formatv(dialectRegistrationTemplate,
|
||||
dependentDialect);
|
||||
}
|
||||
os << llvm::formatv(passDeclBegin, defName, pass.getBaseClass(),
|
||||
os << llvm::formatv(oldPassDeclBegin, defName, pass.getBaseClass(),
|
||||
pass.getArgument(), pass.getSummary(),
|
||||
dependentDialectRegistrations);
|
||||
emitPassOptionDecls(pass, os);
|
||||
@@ -136,82 +409,16 @@ static void emitPassDecl(const Pass &pass, raw_ostream &os) {
|
||||
os << "};\n";
|
||||
}
|
||||
|
||||
/// Emit the code for registering each of the given passes with the global
|
||||
/// PassRegistry.
|
||||
static void emitPassDecls(ArrayRef<Pass> passes, raw_ostream &os) {
|
||||
os << "#ifdef GEN_PASS_CLASSES\n";
|
||||
for (const Pass &pass : passes)
|
||||
emitPassDecl(pass, os);
|
||||
os << "#undef GEN_PASS_CLASSES\n";
|
||||
os << "#endif // GEN_PASS_CLASSES\n";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GEN: Pass registration generation
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// The code snippet used to generate a pass registration.
|
||||
///
|
||||
/// {0}: The def name of the pass record.
|
||||
/// {1}: The pass constructor call.
|
||||
const char *const passRegistrationCode = R"(
|
||||
//===----------------------------------------------------------------------===//
|
||||
// {0} Registration
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
inline void register{0}Pass() {{
|
||||
::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {{
|
||||
return {1};
|
||||
});
|
||||
}
|
||||
)";
|
||||
|
||||
/// The code snippet used to generate a function to register all passes in a
|
||||
/// group.
|
||||
///
|
||||
/// {0}: The name of the pass group.
|
||||
const char *const passGroupRegistrationCode = R"(
|
||||
//===----------------------------------------------------------------------===//
|
||||
// {0} Registration
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
inline void register{0}Passes() {{
|
||||
)";
|
||||
|
||||
/// Emit the code for registering each of the given passes with the global
|
||||
/// PassRegistry.
|
||||
static void emitRegistration(ArrayRef<Pass> passes, raw_ostream &os) {
|
||||
os << "#ifdef GEN_PASS_REGISTRATION\n";
|
||||
for (const Pass &pass : passes) {
|
||||
os << llvm::formatv(passRegistrationCode, pass.getDef()->getName(),
|
||||
pass.getConstructor());
|
||||
}
|
||||
|
||||
os << llvm::formatv(passGroupRegistrationCode, groupName);
|
||||
for (const Pass &pass : passes)
|
||||
os << " register" << pass.getDef()->getName() << "Pass();\n";
|
||||
os << "}\n";
|
||||
os << "#undef GEN_PASS_REGISTRATION\n";
|
||||
os << "#endif // GEN_PASS_REGISTRATION\n";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GEN: Registration hooks
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void emitDecls(const llvm::RecordKeeper &recordKeeper, raw_ostream &os) {
|
||||
os << "/* Autogenerated by mlir-tblgen; don't manually edit */\n";
|
||||
std::vector<Pass> passes;
|
||||
for (const auto *def : recordKeeper.getAllDerivedDefinitions("PassBase"))
|
||||
passes.emplace_back(def);
|
||||
|
||||
emitPassDecls(passes, os);
|
||||
emitRegistration(passes, os);
|
||||
}
|
||||
static mlir::GenRegistration
|
||||
genPassDecls("gen-pass-decls", "Generate pass declarations",
|
||||
[](const llvm::RecordKeeper &records, raw_ostream &os) {
|
||||
emitDecls(records, os);
|
||||
return false;
|
||||
});
|
||||
|
||||
static mlir::GenRegistration
|
||||
genRegister("gen-pass-decls", "Generate pass declarations",
|
||||
genPassDefs("gen-pass-defs", "Generate pass definitions",
|
||||
[](const llvm::RecordKeeper &records, raw_ostream &os) {
|
||||
emitDecls(records, os);
|
||||
emitDefs(records, os);
|
||||
return false;
|
||||
});
|
||||
|
||||
@@ -5,6 +5,7 @@ add_public_tablegen_target(MLIRTableGenEnumsIncGen)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS passes.td)
|
||||
mlir_tablegen(PassGenTest.h.inc -gen-pass-decls -name TableGenTest)
|
||||
mlir_tablegen(PassGenTest.cpp.inc -gen-pass-defs -name TableGenTest)
|
||||
add_public_tablegen_target(MLIRTableGenTestPassIncGen)
|
||||
|
||||
add_mlir_unittest(MLIRTableGenTests
|
||||
|
||||
@@ -7,31 +7,36 @@
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
|
||||
#include "gmock/gmock.h"
|
||||
|
||||
std::unique_ptr<mlir::Pass> createTestPass(int v = 0);
|
||||
std::unique_ptr<mlir::Pass> createTestPassWithCustomConstructor(int v = 0);
|
||||
|
||||
#define GEN_PASS_DECL_TESTPASS
|
||||
#define GEN_PASS_DECL_TESTPASSWITHOPTIONS
|
||||
#define GEN_PASS_DECL_TESTPASSWITHCUSTOMCONSTRUCTOR
|
||||
#define GEN_PASS_REGISTRATION
|
||||
#include "PassGenTest.h.inc"
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "PassGenTest.h.inc"
|
||||
#define GEN_PASS_DEF_TESTPASS
|
||||
#define GEN_PASS_DEF_TESTPASSWITHOPTIONS
|
||||
#define GEN_PASS_DEF_TESTPASSWITHCUSTOMCONSTRUCTOR
|
||||
#include "PassGenTest.cpp.inc"
|
||||
|
||||
struct TestPass : public TestPassBase<TestPass> {
|
||||
explicit TestPass(int v) : extraVal(v) {}
|
||||
using TestPassBase::TestPassBase;
|
||||
|
||||
void runOnOperation() override {}
|
||||
|
||||
std::unique_ptr<mlir::Pass> clone() const {
|
||||
return TestPassBase<TestPass>::clone();
|
||||
}
|
||||
|
||||
int extraVal;
|
||||
};
|
||||
|
||||
std::unique_ptr<mlir::Pass> createTestPass(int v) {
|
||||
return std::make_unique<TestPass>(v);
|
||||
TEST(PassGenTest, defaultGeneratedConstructor) {
|
||||
std::unique_ptr<mlir::Pass> pass = createTestPass();
|
||||
EXPECT_TRUE(pass.get() != nullptr);
|
||||
}
|
||||
|
||||
TEST(PassGenTest, PassClone) {
|
||||
@@ -41,7 +46,74 @@ TEST(PassGenTest, PassClone) {
|
||||
return static_cast<const TestPass *>(pass.get());
|
||||
};
|
||||
|
||||
const auto origPass = createTestPass(10);
|
||||
const auto origPass = createTestPass();
|
||||
const auto clonePass = unwrap(origPass)->clone();
|
||||
|
||||
EXPECT_TRUE(clonePass.get() != nullptr);
|
||||
EXPECT_TRUE(origPass.get() != clonePass.get());
|
||||
}
|
||||
|
||||
struct TestPassWithOptions
|
||||
: public TestPassWithOptionsBase<TestPassWithOptions> {
|
||||
using TestPassWithOptionsBase::TestPassWithOptionsBase;
|
||||
|
||||
void runOnOperation() override {}
|
||||
|
||||
std::unique_ptr<mlir::Pass> clone() const {
|
||||
return TestPassWithOptionsBase<TestPassWithOptions>::clone();
|
||||
}
|
||||
|
||||
unsigned getTestOption() const { return testOption; }
|
||||
|
||||
llvm::ArrayRef<int64_t> getTestListOption() const { return testListOption; }
|
||||
};
|
||||
|
||||
TEST(PassGenTest, PassOptions) {
|
||||
mlir::MLIRContext context;
|
||||
|
||||
TestPassWithOptionsOptions options;
|
||||
options.testOption = 57;
|
||||
|
||||
llvm::SmallVector<int64_t, 2> testListOption = {1, 2};
|
||||
options.testListOption = testListOption;
|
||||
|
||||
const auto unwrap = [](const std::unique_ptr<mlir::Pass> &pass) {
|
||||
return static_cast<const TestPassWithOptions *>(pass.get());
|
||||
};
|
||||
|
||||
const auto pass = createTestPassWithOptions(options);
|
||||
|
||||
EXPECT_EQ(unwrap(pass)->getTestOption(), 57);
|
||||
EXPECT_EQ(unwrap(pass)->getTestListOption()[0], 1);
|
||||
EXPECT_EQ(unwrap(pass)->getTestListOption()[1], 2);
|
||||
}
|
||||
|
||||
struct TestPassWithCustomConstructor
|
||||
: public TestPassWithCustomConstructorBase<TestPassWithCustomConstructor> {
|
||||
explicit TestPassWithCustomConstructor(int v) : extraVal(v) {}
|
||||
|
||||
void runOnOperation() override {}
|
||||
|
||||
std::unique_ptr<mlir::Pass> clone() const {
|
||||
return TestPassWithCustomConstructorBase<
|
||||
TestPassWithCustomConstructor>::clone();
|
||||
}
|
||||
|
||||
unsigned int extraVal = 23;
|
||||
};
|
||||
|
||||
std::unique_ptr<mlir::Pass> createTestPassWithCustomConstructor(int v) {
|
||||
return std::make_unique<TestPassWithCustomConstructor>(v);
|
||||
}
|
||||
|
||||
TEST(PassGenTest, PassCloneWithCustomConstructor) {
|
||||
mlir::MLIRContext context;
|
||||
|
||||
const auto unwrap = [](const std::unique_ptr<mlir::Pass> &pass) {
|
||||
return static_cast<const TestPassWithCustomConstructor *>(pass.get());
|
||||
};
|
||||
|
||||
const auto origPass = createTestPassWithCustomConstructor(10);
|
||||
const auto clonePass = unwrap(origPass)->clone();
|
||||
|
||||
EXPECT_EQ(unwrap(origPass)->extraVal, unwrap(clonePass)->extraVal);
|
||||
|
||||
@@ -12,8 +12,20 @@ include "mlir/Rewrite/PassUtil.td"
|
||||
|
||||
def TestPass : Pass<"test"> {
|
||||
let summary = "Test pass";
|
||||
|
||||
let constructor = "::createTestPass()";
|
||||
|
||||
let options = RewritePassUtils.options;
|
||||
}
|
||||
|
||||
def TestPassWithOptions : Pass<"test"> {
|
||||
let summary = "Test pass with options";
|
||||
|
||||
let options = [
|
||||
Option<"testOption", "testOption", "unsigned", "0", "Test option">,
|
||||
ListOption<"testListOption", "test-list-option", "int64_t",
|
||||
"Test list option">
|
||||
];
|
||||
}
|
||||
|
||||
def TestPassWithCustomConstructor : Pass<"test"> {
|
||||
let summary = "Test pass with custom constructor";
|
||||
|
||||
let constructor = "::createTestPassWithCustomConstructor()";
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user