[mlir] Pass Options ownership modifications (#110582)

This change makes two (related) changes: 

First, it updates the tablegen option for `ListOption` to emit a
`SmallVector` instead of an `ArrayRef`. This brings `ListOption` more
inline with the traditional `Option`, where values are typically
provided using types that have storage. After this change, all options
should be fully owned by a Pass' `Options` object after it has been
fully constructed, unless the underlying type of the `Option` explicitly
indicates otherwise.

Second, it updates the generated constructors for Passes to consume
options by value instead of reference, and prefers moving options into
the pass itself. This should be more efficient for non-trivial options
objects, where the previous interface forced a copy to be materialized.
Now, at worst case the API materializes a copy (no worse than before);
at best-case, all options objects are moved into place. Ideally, we
could update the Pass constructor to take an r-value reference to the
Options object instead, but this approach will require numerous changes
to existing passes and their factory functions.

---------

Authored-by: Nikhil Kalra <nkalra@apple.com>
This commit is contained in:
Nikhil Kalra
2024-10-01 09:48:51 -07:00
committed by GitHub
parent afc0557a04
commit fef3566a25
4 changed files with 16 additions and 14 deletions

View File

@@ -15,6 +15,7 @@
#define MLIR_TRANSFORMS_PASSES_H
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/LocationSnapshot.h"
#include "mlir/Transforms/ViewOpGraph.h"

View File

@@ -44,7 +44,8 @@ struct NarrowingPattern : OpRewritePattern<SourceOp> {
NarrowingPattern(MLIRContext *ctx, const ArithIntNarrowingOptions &options,
PatternBenefit benefit = 1)
: OpRewritePattern<SourceOp>(ctx, benefit),
supportedBitwidths(options.bitwidthsSupported) {
supportedBitwidths(options.bitwidthsSupported.begin(),
options.bitwidthsSupported.end()) {
assert(!supportedBitwidths.empty() && "Invalid options");
assert(!llvm::is_contained(supportedBitwidths, 0) && "Invalid bitwidth");
llvm::sort(supportedBitwidths);
@@ -757,7 +758,8 @@ struct ArithIntNarrowingPass final
MLIRContext *ctx = op->getContext();
RewritePatternSet patterns(ctx);
populateArithIntNarrowingPatterns(
patterns, ArithIntNarrowingOptions{bitwidthsSupported});
patterns, ArithIntNarrowingOptions{
llvm::to_vector_of<unsigned>(bitwidthsSupported)});
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
signalPassFailure();
}

View File

@@ -97,7 +97,7 @@ static void emitPassOptionsStruct(const Pass &pass, raw_ostream &os) {
std::string type = opt.getType().str();
if (opt.isListOption())
type = "::llvm::ArrayRef<" + type + ">";
type = "::llvm::SmallVector<" + type + ">";
os.indent(2) << llvm::formatv("{0} {1}", type, opt.getCppVariableName());
@@ -128,8 +128,8 @@ static void emitPassDecls(const Pass &pass, raw_ostream &os) {
// Declaration of the constructor with options.
if (ArrayRef<PassOption> options = pass.getOptions(); !options.empty())
os << llvm::formatv("std::unique_ptr<::mlir::Pass> create{0}(const "
"{0}Options &options);\n",
os << llvm::formatv("std::unique_ptr<::mlir::Pass> create{0}("
"{0}Options options);\n",
passName);
}
@@ -236,7 +236,7 @@ namespace impl {{
const char *const friendDefaultConstructorWithOptionsDeclTemplate = R"(
namespace impl {{
std::unique_ptr<::mlir::Pass> create{0}(const {0}Options &options);
std::unique_ptr<::mlir::Pass> create{0}({0}Options options);
} // namespace impl
)";
@@ -247,8 +247,8 @@ const char *const friendDefaultConstructorDefTemplate = R"(
)";
const char *const friendDefaultConstructorWithOptionsDefTemplate = R"(
friend std::unique_ptr<::mlir::Pass> create{0}(const {0}Options &options) {{
return std::make_unique<DerivedT>(options);
friend std::unique_ptr<::mlir::Pass> create{0}({0}Options options) {{
return std::make_unique<DerivedT>(std::move(options));
}
)";
@@ -259,8 +259,8 @@ std::unique_ptr<::mlir::Pass> create{0}() {{
)";
const char *const defaultConstructorWithOptionsDefTemplate = R"(
std::unique_ptr<::mlir::Pass> create{0}(const {0}Options &options) {{
return impl::create{0}(options);
std::unique_ptr<::mlir::Pass> create{0}({0}Options options) {{
return impl::create{0}(std::move(options));
}
)";
@@ -326,10 +326,10 @@ static void emitPassDefs(const Pass &pass, raw_ostream &os) {
if (ArrayRef<PassOption> options = pass.getOptions(); !options.empty()) {
os.indent(2) << llvm::formatv(
"{0}Base(const {0}Options &options) : {0}Base() {{\n", passName);
"{0}Base({0}Options options) : {0}Base() {{\n", passName);
for (const PassOption &opt : pass.getOptions())
os.indent(4) << llvm::formatv("{0} = options.{0};\n",
os.indent(4) << llvm::formatv("{0} = std::move(options.{0});\n",
opt.getCppVariableName());
os.indent(2) << "}\n";

View File

@@ -72,8 +72,7 @@ TEST(PassGenTest, PassOptions) {
TestPassWithOptionsOptions options;
options.testOption = 57;
llvm::SmallVector<int64_t, 2> testListOption = {1, 2};
options.testListOption = testListOption;
options.testListOption = {1, 2};
const auto unwrap = [](const std::unique_ptr<mlir::Pass> &pass) {
return static_cast<const TestPassWithOptions *>(pass.get());