[MLIR] Split autogenerated pass declarations & C++ controllable pass options

The pass tablegen backend has been reworked to remove the monolithic nature of the autogenerated declarations.
The pass public header can be generated with the -gen-pass-decls option. It contains options structs and registrations: the inclusion of options structs can be controlled individually for each pass by defining the GEN_PASS_DECL_PASSNAME macro; the declaration of the registrations have been kept together and can still be included by defining the GEN_PASS_REGISTRATION macro.
The private code used for the pass implementation (i.e. the pass base class and the constructors definitions, if missing from tablegen) can be generated with the -gen-pass-defs option. Similarly to the declarations file, the definitions of each pass can be enabled by defining the GEN_PASS_DEF_PASNAME variable.
While doing so, the pass base class has been enriched to also accept a the aformentioned struct of options and copy them to the actual pass options, thus allowing each pass to also be configurable within C++ and not only through command line.

Reviewed By: rriddle, mehdi_amini, Mogball, jpienaar

Differential Revision: https://reviews.llvm.org/D131839
This commit is contained in:
Michele Scuttari
2022-08-24 09:59:50 +02:00
parent 3f20dcbf70
commit 32c5578bcd
8 changed files with 495 additions and 145 deletions

View File

@@ -828,16 +828,18 @@ def MyPass : Pass<"my-pass", "ModuleOp"> {
}
```
Using the `gen-pass-decls` generator, we can generate most of the boilerplate
above automatically. This generator takes as an input a `-name` parameter, that
provides a tag for the group of passes that are being generated. This generator
produces two chunks of output:
Using the `gen-pass-decls` and `gen-pass-defs` generators, we can generate most
of the boilerplate above automatically.
The first is a code block for registering the declarative passes with the global
registry. For each pass, the generator produces a `registerFooPass` where `Foo`
is the name of the definition specified in tablegen. It also generates a
`registerGroupPasses`, where `Group` is the tag provided via the `-name` input
parameter, that registers all of the passes present.
The `gen-pass-decls` generator takes as an input a `-name` parameter, that
provides a tag for the group of passes that are being generated. This generator
produces code with two purposes:
The first is to register the declared passes with the global registry. For
each pass, the generator produces a `registerPassName` where
`PassName` is the name of the definition specified in tablegen. It also
generates a `registerGroupPasses`, where `Group` is the tag provided via the
`-name` input parameter, that registers all of the passes present.
```c++
// gen-pass-decls -name="Example"
@@ -850,19 +852,61 @@ void registerMyPasses() {
registerExamplePasses();
// Register `MyPass` specifically.
registerMyPassPass();
registerMyPass();
}
```
The second is a base class for each of the passes, containing most of the boiler
The second is to provide a way to configure the pass options. These classes are
named in the form of `MyPassOptions`, where `MyPass` is the name of the pass
definition in tablegen. The configurable parameters reflect the options
declared in the tablegen file. Differently from the registration hooks, these
classes can be enabled on a per-pass basis by defining the
`GEN_PASS_DECL_PASSNAME` macro, where `PASSNAME` is the uppercase version of
the name specified in tablegen.
```c++
// .h.inc
#ifdef GEN_PASS_DECL_MYPASS
struct MyPassOptions {
bool option = true;
::llvm::ArrayRef<int64_t> listOption;
};
#undef GEN_PASS_DECL_MYPASS
#endif // GEN_PASS_DECL_MYPASS
```
If the `constructor` field has not been specified in the tablegen declaration,
then autogenerated file will also contain the declarations of the default
constructors.
```c++
// .h.inc
#ifdef GEN_PASS_DECL_MYPASS
...
std::unique_ptr<::mlir::Pass> createMyPass();
std::unique_ptr<::mlir::Pass> createMyPass(const MyPassOptions &options);
#undef GEN_PASS_DECL_MYPASS
#endif // GEN_PASS_DECL_MYPASS
```
The `gen-pass-defs` generator produces the definitions to be used for the pass
implementation.
It generates a base class for each of the passes, containing most of the boiler
plate related to pass definitions. These classes are named in the form of
`MyPassBase`, where `MyPass` is the name of the pass definition in tablegen. We
can update the original C++ pass definition as so:
```c++
/// Include the generated base pass class definitions.
#define GEN_PASS_CLASSES
#include "Passes.h.inc"
#define GEN_PASS_DEF_MYPASS
#include "Passes.cpp.inc"
/// Define the main class as deriving from the generated base class.
struct MyPass : MyPassBase<MyPass> {
@@ -874,13 +918,16 @@ struct MyPass : MyPassBase<MyPass> {
/// The definitions of the options and statistics are now generated within
/// the base class, but are accessible in the same way.
};
/// Expose this pass to the outside world.
std::unique_ptr<Pass> foo::createMyPass() {
return std::make_unique<MyPass>();
}
```
Similarly to the previous generator, the definitions can be enabled on a
per-pass basis by defining the appropriate preprocessor `GEN_PASS_DEF_PASSNAME`
macro, with `PASSNAME` equal to the uppercase version of the name of the pass
definition in tablegen.
If the `constructor` field has not been specified in tablegen, then the default
constructors are also defined and expect the name of the actual pass class to
be equal to the name defined in tablegen.
Using the `gen-pass-doc` generator, markdown documentation for each of the
passes can be generated. See [Passes.md](Passes.md) for example output of real
MLIR passes.

View File

@@ -76,7 +76,10 @@ class PassBase<string passArg, string base> {
string description = "";
// A C++ constructor call to create an instance of this pass.
code constructor = [{}];
// If empty, the default constructor declarations and definitions
// 'createPassName()' and 'createPassName(const PassNameOptions &options)'
// will be generated and the former will be used for the pass instantiation.
code constructor = "";
// A list of dialects this pass may produce entities in.
list<string> dependentDialects = [];

View File

@@ -90,6 +90,7 @@ StringRef Pass::getDescription() const {
StringRef Pass::getConstructor() const {
return def->getValueAsString("constructor");
}
ArrayRef<StringRef> Pass::getDependentDialects() const {
return dependentDialects;
}

View File

@@ -97,8 +97,15 @@ static bool emitCAPIImpl(const llvm::RecordKeeper &records, raw_ostream &os) {
for (const auto *def : records.getAllDerivedDefinitions("PassBase")) {
Pass pass(def);
StringRef defName = pass.getDef()->getName();
os << llvm::formatv(passCreateDef, groupName, defName,
pass.getConstructor());
std::string constructorCall;
if (StringRef constructor = pass.getConstructor(); !constructor.empty())
constructorCall = constructor.str();
else
constructorCall =
llvm::formatv("create{0}Pass()", pass.getDef()->getName()).str();
os << llvm::formatv(passCreateDef, groupName, defName, constructorCall);
}
return false;
}

View File

@@ -27,6 +27,161 @@ static llvm::cl::opt<std::string>
groupName("name", llvm::cl::desc("The name of this group of passes"),
llvm::cl::cat(passGenCat));
static void emitOldPassDecl(const Pass &pass, raw_ostream &os);
/// Extract the list of passes from the TableGen records.
static std::vector<Pass> getPasses(const llvm::RecordKeeper &recordKeeper) {
std::vector<Pass> passes;
for (const auto *def : recordKeeper.getAllDerivedDefinitions("PassBase"))
passes.emplace_back(def);
return passes;
}
const char *const passHeader = R"(
//===----------------------------------------------------------------------===//
// {0}
//===----------------------------------------------------------------------===//
)";
//===----------------------------------------------------------------------===//
// GEN: Pass registration generation
//===----------------------------------------------------------------------===//
/// The code snippet used to generate a pass registration.
///
/// {0}: The def name of the pass record.
/// {1}: The pass constructor call.
const char *const passRegistrationCode = R"(
//===----------------------------------------------------------------------===//
// {0} Registration
//===----------------------------------------------------------------------===//
inline void register{0}() {{
::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {{
return {1};
});
}
// Old registration code, kept for temporary backwards compatibility.
inline void register{0}Pass() {{
::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {{
return {1};
});
}
)";
/// The code snippet used to generate a function to register all passes in a
/// group.
///
/// {0}: The name of the pass group.
const char *const passGroupRegistrationCode = R"(
//===----------------------------------------------------------------------===//
// {0} Registration
//===----------------------------------------------------------------------===//
inline void register{0}Passes() {{
)";
/// Emits the definition of the struct to be used to control the pass options.
static void emitPassOptionsStruct(const Pass &pass, raw_ostream &os) {
StringRef passName = pass.getDef()->getName();
ArrayRef<PassOption> options = pass.getOptions();
// Emit the struct only if the pass has at least one option.
if (options.empty())
return;
os << llvm::formatv("struct {0}Options {{\n", passName);
for (const PassOption &opt : options) {
std::string type = opt.getType().str();
if (opt.isListOption())
type = "::llvm::ArrayRef<" + type + ">";
os.indent(2) << llvm::formatv("{0} {1}", type, opt.getCppVariableName());
if (Optional<StringRef> defaultVal = opt.getDefaultValue())
os << " = " << defaultVal;
os << ";\n";
}
os << "};\n";
}
/// Emit the code to be included in the public header of the pass.
static void emitPassDecls(const Pass &pass, raw_ostream &os) {
StringRef passName = pass.getDef()->getName();
std::string enableVarName = "GEN_PASS_DECL_" + passName.upper();
os << "#ifdef " << enableVarName << "\n";
os << llvm::formatv(passHeader, passName);
emitPassOptionsStruct(pass, os);
if (StringRef constructor = pass.getConstructor(); constructor.empty()) {
// Default constructor declaration.
os << "std::unique_ptr<::mlir::Pass> create" << passName << "();\n";
// Declaration of the constructor with options.
if (ArrayRef<PassOption> options = pass.getOptions(); !options.empty())
os << llvm::formatv("std::unique_ptr<::mlir::Pass> create{0}(const "
"{0}Options &options);\n",
passName);
}
os << "#undef " << enableVarName << "\n";
os << "#endif // " << enableVarName << "\n";
}
/// Emit the code for registering each of the given passes with the global
/// PassRegistry.
static void emitRegistrations(llvm::ArrayRef<Pass> passes, raw_ostream &os) {
os << "#ifdef GEN_PASS_REGISTRATION\n";
for (const Pass &pass : passes) {
std::string constructorCall;
if (StringRef constructor = pass.getConstructor(); !constructor.empty())
constructorCall = constructor.str();
else
constructorCall =
llvm::formatv("create{0}()", pass.getDef()->getName()).str();
os << llvm::formatv(passRegistrationCode, pass.getDef()->getName(),
constructorCall);
}
os << llvm::formatv(passGroupRegistrationCode, groupName);
for (const Pass &pass : passes)
os << " register" << pass.getDef()->getName() << "();\n";
os << "}\n";
os << "#undef GEN_PASS_REGISTRATION\n";
os << "#endif // GEN_PASS_REGISTRATION\n";
}
static void emitDecls(const llvm::RecordKeeper &recordKeeper, raw_ostream &os) {
std::vector<Pass> passes = getPasses(recordKeeper);
os << "/* Autogenerated by mlir-tblgen; don't manually edit */\n";
for (const Pass &pass : passes)
emitPassDecls(pass, os);
emitRegistrations(passes, os);
// TODO drop old pass declarations
// Emit the old code until all the passes have switched to the new design.
os << "#ifdef GEN_PASS_CLASSES\n";
for (const Pass &pass : passes)
emitOldPassDecl(pass, os);
os << "#undef GEN_PASS_CLASSES\n";
os << "#endif // GEN_PASS_CLASSES\n";
}
//===----------------------------------------------------------------------===//
// GEN: Pass base class generation
//===----------------------------------------------------------------------===//
@@ -38,10 +193,159 @@ static llvm::cl::opt<std::string>
/// {2): The command line argument for the pass.
/// {3}: The dependent dialects registration.
const char *const passDeclBegin = R"(
//===----------------------------------------------------------------------===//
// {0}
//===----------------------------------------------------------------------===//
template <typename DerivedT>
class {0}Base : public {1} {
public:
using Base = {0}Base;
{0}Base() : {1}(::mlir::TypeID::get<DerivedT>()) {{}
{0}Base(const {0}Base &other) : {1}(other) {{}
/// Returns the command-line argument attached to this pass.
static constexpr ::llvm::StringLiteral getArgumentName() {
return ::llvm::StringLiteral("{2}");
}
::llvm::StringRef getArgument() const override { return "{2}"; }
::llvm::StringRef getDescription() const override { return "{3}"; }
/// Returns the derived pass name.
static constexpr ::llvm::StringLiteral getPassName() {
return ::llvm::StringLiteral("{0}");
}
::llvm::StringRef getName() const override { return "{0}"; }
/// Support isa/dyn_cast functionality for the derived pass class.
static bool classof(const ::mlir::Pass *pass) {{
return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
}
/// A clone method to create a copy of this pass.
std::unique_ptr<::mlir::Pass> clonePass() const override {{
return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
}
/// Return the dialect that must be loaded in the context before this pass.
void getDependentDialects(::mlir::DialectRegistry &registry) const override {
{4}
}
/// Explicitly declare the TypeID for this class. We declare an explicit private
/// instantiation because Pass classes should only be visible by the current
/// library.
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID({0}Base<DerivedT>)
)";
/// Registration for a single dependent dialect, to be inserted for each
/// dependent dialect in the `getDependentDialects` above.
const char *const dialectRegistrationTemplate = R"(
registry.insert<{0}>();
)";
const char *const friendDefaultConstructorTemplate = R"(
friend std::unique_ptr<::mlir::Pass> create{0}() {{
return std::make_unique<DerivedT>();
}
)";
const char *const friendDefaultConstructorWithOptionsTemplate = R"(
friend std::unique_ptr<::mlir::Pass> create{0}(const {0}Options &options) {{
return std::make_unique<DerivedT>(options);
}
)";
/// Emit the declarations for each of the pass options.
static void emitPassOptionDecls(const Pass &pass, raw_ostream &os) {
for (const PassOption &opt : pass.getOptions()) {
os.indent(2) << "::mlir::Pass::"
<< (opt.isListOption() ? "ListOption" : "Option");
os << llvm::formatv(R"(<{0}> {1}{{*this, "{2}", ::llvm::cl::desc("{3}"))",
opt.getType(), opt.getCppVariableName(),
opt.getArgument(), opt.getDescription());
if (Optional<StringRef> defaultVal = opt.getDefaultValue())
os << ", ::llvm::cl::init(" << defaultVal << ")";
if (Optional<StringRef> additionalFlags = opt.getAdditionalFlags())
os << ", " << *additionalFlags;
os << "};\n";
}
}
/// Emit the declarations for each of the pass statistics.
static void emitPassStatisticDecls(const Pass &pass, raw_ostream &os) {
for (const PassStatistic &stat : pass.getStatistics()) {
os << llvm::formatv(
" ::mlir::Pass::Statistic {0}{{this, \"{1}\", \"{2}\"};\n",
stat.getCppVariableName(), stat.getName(), stat.getDescription());
}
}
/// Emit the code to be used in the implementation of the pass.
static void emitPassDefs(const Pass &pass, raw_ostream &os) {
StringRef passName = pass.getDef()->getName();
std::string enableVarName = "GEN_PASS_DEF_" + passName.upper();
os << "#ifdef " << enableVarName << "\n";
os << llvm::formatv(passHeader, passName);
std::string dependentDialectRegistrations;
{
llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations);
for (StringRef dependentDialect : pass.getDependentDialects())
dialectsOs << llvm::formatv(dialectRegistrationTemplate,
dependentDialect);
}
os << llvm::formatv(passDeclBegin, passName, pass.getBaseClass(),
pass.getArgument(), pass.getSummary(),
dependentDialectRegistrations);
if (ArrayRef<PassOption> options = pass.getOptions(); !options.empty()) {
os.indent(2) << llvm::formatv(
"{0}Base(const {0}Options &options) : {0}Base() {{\n", passName);
for (const PassOption &opt : pass.getOptions())
os.indent(4) << llvm::formatv("{0} = options.{0};\n",
opt.getCppVariableName());
os.indent(2) << "}\n";
}
// Protected content
os << "protected:\n";
emitPassOptionDecls(pass, os);
emitPassStatisticDecls(pass, os);
// Private content
os << "private:\n";
if (pass.getConstructor().empty()) {
os << llvm::formatv(friendDefaultConstructorTemplate, passName);
if (!pass.getOptions().empty())
os << llvm::formatv(friendDefaultConstructorWithOptionsTemplate,
passName);
}
os << "};\n";
os << "#undef " << enableVarName << "\n";
os << "#endif // " << enableVarName << "\n";
}
static void emitDefs(const llvm::RecordKeeper &recordKeeper, raw_ostream &os) {
std::vector<Pass> passes = getPasses(recordKeeper);
os << "/* Autogenerated by mlir-tblgen; don't manually edit */\n";
for (const Pass &pass : passes)
emitPassDefs(pass, os);
}
// TODO drop old pass declarations
// The old pass base class is being kept until all the passes have switched to
// the new decls/defs design.
const char *const oldPassDeclBegin = R"(
template <typename DerivedT>
class {0}Base : public {1} {
public:
@@ -87,39 +391,8 @@ public:
protected:
)";
/// Registration for a single dependent dialect, to be inserted for each
/// dependent dialect in the `getDependentDialects` above.
const char *const dialectRegistrationTemplate = R"(
registry.insert<{0}>();
)";
/// Emit the declarations for each of the pass options.
static void emitPassOptionDecls(const Pass &pass, raw_ostream &os) {
for (const PassOption &opt : pass.getOptions()) {
os.indent(2) << "::mlir::Pass::"
<< (opt.isListOption() ? "ListOption" : "Option");
os << llvm::formatv(R"(<{0}> {1}{{*this, "{2}", ::llvm::cl::desc("{3}"))",
opt.getType(), opt.getCppVariableName(),
opt.getArgument(), opt.getDescription());
if (Optional<StringRef> defaultVal = opt.getDefaultValue())
os << ", ::llvm::cl::init(" << defaultVal << ")";
if (Optional<StringRef> additionalFlags = opt.getAdditionalFlags())
os << ", " << *additionalFlags;
os << "};\n";
}
}
/// Emit the declarations for each of the pass statistics.
static void emitPassStatisticDecls(const Pass &pass, raw_ostream &os) {
for (const PassStatistic &stat : pass.getStatistics()) {
os << llvm::formatv(
" ::mlir::Pass::Statistic {0}{{this, \"{1}\", \"{2}\"};\n",
stat.getCppVariableName(), stat.getName(), stat.getDescription());
}
}
static void emitPassDecl(const Pass &pass, raw_ostream &os) {
/// Emit a backward-compatible declaration of the pass base class.
static void emitOldPassDecl(const Pass &pass, raw_ostream &os) {
StringRef defName = pass.getDef()->getName();
std::string dependentDialectRegistrations;
{
@@ -128,7 +401,7 @@ static void emitPassDecl(const Pass &pass, raw_ostream &os) {
dialectsOs << llvm::formatv(dialectRegistrationTemplate,
dependentDialect);
}
os << llvm::formatv(passDeclBegin, defName, pass.getBaseClass(),
os << llvm::formatv(oldPassDeclBegin, defName, pass.getBaseClass(),
pass.getArgument(), pass.getSummary(),
dependentDialectRegistrations);
emitPassOptionDecls(pass, os);
@@ -136,82 +409,16 @@ static void emitPassDecl(const Pass &pass, raw_ostream &os) {
os << "};\n";
}
/// Emit the code for registering each of the given passes with the global
/// PassRegistry.
static void emitPassDecls(ArrayRef<Pass> passes, raw_ostream &os) {
os << "#ifdef GEN_PASS_CLASSES\n";
for (const Pass &pass : passes)
emitPassDecl(pass, os);
os << "#undef GEN_PASS_CLASSES\n";
os << "#endif // GEN_PASS_CLASSES\n";
}
//===----------------------------------------------------------------------===//
// GEN: Pass registration generation
//===----------------------------------------------------------------------===//
/// The code snippet used to generate a pass registration.
///
/// {0}: The def name of the pass record.
/// {1}: The pass constructor call.
const char *const passRegistrationCode = R"(
//===----------------------------------------------------------------------===//
// {0} Registration
//===----------------------------------------------------------------------===//
inline void register{0}Pass() {{
::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {{
return {1};
});
}
)";
/// The code snippet used to generate a function to register all passes in a
/// group.
///
/// {0}: The name of the pass group.
const char *const passGroupRegistrationCode = R"(
//===----------------------------------------------------------------------===//
// {0} Registration
//===----------------------------------------------------------------------===//
inline void register{0}Passes() {{
)";
/// Emit the code for registering each of the given passes with the global
/// PassRegistry.
static void emitRegistration(ArrayRef<Pass> passes, raw_ostream &os) {
os << "#ifdef GEN_PASS_REGISTRATION\n";
for (const Pass &pass : passes) {
os << llvm::formatv(passRegistrationCode, pass.getDef()->getName(),
pass.getConstructor());
}
os << llvm::formatv(passGroupRegistrationCode, groupName);
for (const Pass &pass : passes)
os << " register" << pass.getDef()->getName() << "Pass();\n";
os << "}\n";
os << "#undef GEN_PASS_REGISTRATION\n";
os << "#endif // GEN_PASS_REGISTRATION\n";
}
//===----------------------------------------------------------------------===//
// GEN: Registration hooks
//===----------------------------------------------------------------------===//
static void emitDecls(const llvm::RecordKeeper &recordKeeper, raw_ostream &os) {
os << "/* Autogenerated by mlir-tblgen; don't manually edit */\n";
std::vector<Pass> passes;
for (const auto *def : recordKeeper.getAllDerivedDefinitions("PassBase"))
passes.emplace_back(def);
emitPassDecls(passes, os);
emitRegistration(passes, os);
}
static mlir::GenRegistration
genPassDecls("gen-pass-decls", "Generate pass declarations",
[](const llvm::RecordKeeper &records, raw_ostream &os) {
emitDecls(records, os);
return false;
});
static mlir::GenRegistration
genRegister("gen-pass-decls", "Generate pass declarations",
genPassDefs("gen-pass-defs", "Generate pass definitions",
[](const llvm::RecordKeeper &records, raw_ostream &os) {
emitDecls(records, os);
emitDefs(records, os);
return false;
});

View File

@@ -5,6 +5,7 @@ add_public_tablegen_target(MLIRTableGenEnumsIncGen)
set(LLVM_TARGET_DEFINITIONS passes.td)
mlir_tablegen(PassGenTest.h.inc -gen-pass-decls -name TableGenTest)
mlir_tablegen(PassGenTest.cpp.inc -gen-pass-defs -name TableGenTest)
add_public_tablegen_target(MLIRTableGenTestPassIncGen)
add_mlir_unittest(MLIRTableGenTests

View File

@@ -7,31 +7,36 @@
//===----------------------------------------------------------------------===//
#include "mlir/Pass/Pass.h"
#include "llvm/ADT/STLExtras.h"
#include "gmock/gmock.h"
std::unique_ptr<mlir::Pass> createTestPass(int v = 0);
std::unique_ptr<mlir::Pass> createTestPassWithCustomConstructor(int v = 0);
#define GEN_PASS_DECL_TESTPASS
#define GEN_PASS_DECL_TESTPASSWITHOPTIONS
#define GEN_PASS_DECL_TESTPASSWITHCUSTOMCONSTRUCTOR
#define GEN_PASS_REGISTRATION
#include "PassGenTest.h.inc"
#define GEN_PASS_CLASSES
#include "PassGenTest.h.inc"
#define GEN_PASS_DEF_TESTPASS
#define GEN_PASS_DEF_TESTPASSWITHOPTIONS
#define GEN_PASS_DEF_TESTPASSWITHCUSTOMCONSTRUCTOR
#include "PassGenTest.cpp.inc"
struct TestPass : public TestPassBase<TestPass> {
explicit TestPass(int v) : extraVal(v) {}
using TestPassBase::TestPassBase;
void runOnOperation() override {}
std::unique_ptr<mlir::Pass> clone() const {
return TestPassBase<TestPass>::clone();
}
int extraVal;
};
std::unique_ptr<mlir::Pass> createTestPass(int v) {
return std::make_unique<TestPass>(v);
TEST(PassGenTest, defaultGeneratedConstructor) {
std::unique_ptr<mlir::Pass> pass = createTestPass();
EXPECT_TRUE(pass.get() != nullptr);
}
TEST(PassGenTest, PassClone) {
@@ -41,7 +46,74 @@ TEST(PassGenTest, PassClone) {
return static_cast<const TestPass *>(pass.get());
};
const auto origPass = createTestPass(10);
const auto origPass = createTestPass();
const auto clonePass = unwrap(origPass)->clone();
EXPECT_TRUE(clonePass.get() != nullptr);
EXPECT_TRUE(origPass.get() != clonePass.get());
}
struct TestPassWithOptions
: public TestPassWithOptionsBase<TestPassWithOptions> {
using TestPassWithOptionsBase::TestPassWithOptionsBase;
void runOnOperation() override {}
std::unique_ptr<mlir::Pass> clone() const {
return TestPassWithOptionsBase<TestPassWithOptions>::clone();
}
unsigned getTestOption() const { return testOption; }
llvm::ArrayRef<int64_t> getTestListOption() const { return testListOption; }
};
TEST(PassGenTest, PassOptions) {
mlir::MLIRContext context;
TestPassWithOptionsOptions options;
options.testOption = 57;
llvm::SmallVector<int64_t, 2> testListOption = {1, 2};
options.testListOption = testListOption;
const auto unwrap = [](const std::unique_ptr<mlir::Pass> &pass) {
return static_cast<const TestPassWithOptions *>(pass.get());
};
const auto pass = createTestPassWithOptions(options);
EXPECT_EQ(unwrap(pass)->getTestOption(), 57);
EXPECT_EQ(unwrap(pass)->getTestListOption()[0], 1);
EXPECT_EQ(unwrap(pass)->getTestListOption()[1], 2);
}
struct TestPassWithCustomConstructor
: public TestPassWithCustomConstructorBase<TestPassWithCustomConstructor> {
explicit TestPassWithCustomConstructor(int v) : extraVal(v) {}
void runOnOperation() override {}
std::unique_ptr<mlir::Pass> clone() const {
return TestPassWithCustomConstructorBase<
TestPassWithCustomConstructor>::clone();
}
unsigned int extraVal = 23;
};
std::unique_ptr<mlir::Pass> createTestPassWithCustomConstructor(int v) {
return std::make_unique<TestPassWithCustomConstructor>(v);
}
TEST(PassGenTest, PassCloneWithCustomConstructor) {
mlir::MLIRContext context;
const auto unwrap = [](const std::unique_ptr<mlir::Pass> &pass) {
return static_cast<const TestPassWithCustomConstructor *>(pass.get());
};
const auto origPass = createTestPassWithCustomConstructor(10);
const auto clonePass = unwrap(origPass)->clone();
EXPECT_EQ(unwrap(origPass)->extraVal, unwrap(clonePass)->extraVal);

View File

@@ -12,8 +12,20 @@ include "mlir/Rewrite/PassUtil.td"
def TestPass : Pass<"test"> {
let summary = "Test pass";
let constructor = "::createTestPass()";
let options = RewritePassUtils.options;
}
def TestPassWithOptions : Pass<"test"> {
let summary = "Test pass with options";
let options = [
Option<"testOption", "testOption", "unsigned", "0", "Test option">,
ListOption<"testListOption", "test-list-option", "int64_t",
"Test list option">
];
}
def TestPassWithCustomConstructor : Pass<"test"> {
let summary = "Test pass with custom constructor";
let constructor = "::createTestPassWithCustomConstructor()";
}