mirror of
https://github.com/intel/llvm.git
synced 2026-01-16 13:35:38 +08:00
Implement a super sketched out pattern match/rewrite framework and a sketched
out canonicalization pass to drive it, and a simple (x-x) === 0 pattern match as a test case. There is a tremendous number of improvements that need to land, and the matcher/rewriter and patterns will be split out of this file, but this is a starting point. PiperOrigin-RevId: 216788604
This commit is contained in:
@@ -293,6 +293,12 @@ public:
|
||||
setInsertionPoint(block, insertPoint);
|
||||
}
|
||||
|
||||
/// Create an ML function builder and set the insertion point to the start of
|
||||
/// the function.
|
||||
MLFuncBuilder(MLFunction *func) : Builder(func->getContext()) {
|
||||
setInsertionPoint(func, func->begin());
|
||||
}
|
||||
|
||||
/// Reset the insertion point to no location. Creating an operation without a
|
||||
/// set insertion point is an error, but this can still be useful when the
|
||||
/// current insertion point a builder refers to is being removed.
|
||||
|
||||
@@ -140,6 +140,11 @@ public:
|
||||
static void build(Builder *builder, OperationState *result, int64_t value,
|
||||
unsigned width);
|
||||
|
||||
/// Build a constant int op producing an integer with the specified type,
|
||||
/// which must be an integer type.
|
||||
static void build(Builder *builder, OperationState *result, int64_t value,
|
||||
Type *type);
|
||||
|
||||
int64_t getValue() const {
|
||||
return getAttrOfType<IntegerAttr>("value")->getValue();
|
||||
}
|
||||
|
||||
@@ -34,6 +34,9 @@ class ModulePass;
|
||||
/// Creates a constant folding pass.
|
||||
FunctionPass *createConstantFoldPass();
|
||||
|
||||
/// Creates an instance of the Canonicalizer pass.
|
||||
FunctionPass *createCanonicalizerPass();
|
||||
|
||||
/// Creates a loop unrolling pass. Default option or command-line options take
|
||||
/// effect if -1 is passed as parameter.
|
||||
MLFunctionPass *createLoopUnrollPass(int unrollFactor = -1,
|
||||
|
||||
@@ -259,6 +259,14 @@ void ConstantIntOp::build(Builder *builder, OperationState *result,
|
||||
builder->getIntegerType(width));
|
||||
}
|
||||
|
||||
/// Build a constant int op producing an integer with the specified type,
|
||||
/// which must be an integer type.
|
||||
void ConstantIntOp::build(Builder *builder, OperationState *result,
|
||||
int64_t value, Type *type) {
|
||||
assert(isa<IntegerType>(type) && "ConstantIntOp can only have integer type");
|
||||
ConstantOp::build(builder, result, builder->getIntegerAttr(value), type);
|
||||
}
|
||||
|
||||
/// ConstantIndexOp only matches values whose result type is Index.
|
||||
bool ConstantIndexOp::isClassFor(const Operation *op) {
|
||||
return ConstantOp::isClassFor(op) && op->getResult(0)->getType()->isIndex();
|
||||
|
||||
311
mlir/lib/Transforms/Canonicalizer.cpp
Normal file
311
mlir/lib/Transforms/Canonicalizer.cpp
Normal file
@@ -0,0 +1,311 @@
|
||||
//===- Canonicalizer.cpp - Canonicalize MLIR operations -------------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This transformation pass converts operations into their canonical forms by
|
||||
// folding constants, applying operation identity transformations etc.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/StandardOps/StandardOps.h"
|
||||
#include "mlir/Transforms/Pass.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
#include <memory>
|
||||
using namespace mlir;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Definition of Pattern and related types.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// TODO(clattner): Move this out of this file when it is ready.
|
||||
|
||||
// TODO(clattner): Define this as a tagged union with proper sentinels.
|
||||
typedef int PatternBenefit;
|
||||
|
||||
/// Pattern state is used by patterns that want to maintain state between their
|
||||
/// match and rewrite phases. Patterns can define a pattern-specific subclass
|
||||
/// of this.
|
||||
class PatternState {
|
||||
public:
|
||||
virtual ~PatternState() {}
|
||||
};
|
||||
|
||||
/// This is the type returned by a pattern match. The first field indicates the
|
||||
/// benefit of the match, the second is a state token that can optionally be
|
||||
/// produced by a pattern match to maintain state between the match and rewrite
|
||||
/// phases.
|
||||
typedef std::pair<PatternBenefit, std::unique_ptr<PatternState>>
|
||||
PatternMatchResult;
|
||||
|
||||
class Pattern {
|
||||
public:
|
||||
// Return the benefit (the inverse of “cost”) of matching this pattern,
|
||||
// if it is statically determinable. The result is an integer if known,
|
||||
// a sentinel if dynamically computed, and another sentinel if the
|
||||
// pattern can never be matched.
|
||||
PatternBenefit getStaticBenefit() const { return staticBenefit; }
|
||||
|
||||
// Return the root node that this pattern matches. Patterns that can
|
||||
// match multiple root types are instantiated once per root.
|
||||
OperationName getRootKind() const { return rootKind; }
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Implementation hooks for patterns to implement.
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
// Attempt to match against code rooted at the specified operation,
|
||||
// which is the same operation code as getRootKind(). On success it
|
||||
// returns the benefit of the match along with an (optional)
|
||||
// pattern-specific state which is passed back into its rewrite
|
||||
// function if this match is selected. On failure, this returns a
|
||||
// sentinel indicating that it didn’t match.
|
||||
virtual PatternMatchResult match(Operation *op) const = 0;
|
||||
|
||||
// Rewrite the IR rooted at the specified operation with the result of
|
||||
// this pattern, generating any new operations with the specified
|
||||
// builder. If an unexpected error is encountered (an internal
|
||||
// compiler error), it is emitted through the normal MLIR diagnostic
|
||||
// hooks and the IR is left in a valid state.
|
||||
virtual void rewrite(Operation *op, std::unique_ptr<PatternState> state,
|
||||
// TODO: Need a generic builder.
|
||||
MLFuncBuilder &builder) const {
|
||||
rewrite(op, builder);
|
||||
}
|
||||
|
||||
// Rewrite the IR rooted at the specified operation with the result of
|
||||
// this pattern, generating any new operations with the specified
|
||||
// builder. If an unexpected error is encountered (an internal
|
||||
// compiler error), it is emitted through the normal MLIR diagnostic
|
||||
// hooks and the IR is left in a valid state.
|
||||
virtual void rewrite(Operation *op,
|
||||
// TODO: Need a generic builder.
|
||||
MLFuncBuilder &builder) const {
|
||||
llvm_unreachable("need to implement one of the rewrite functions!");
|
||||
}
|
||||
|
||||
virtual ~Pattern();
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Helper methods to simplify pattern implementations
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
/// This method indicates that no match was found.
|
||||
static PatternMatchResult matchFailure() {
|
||||
// TODO: Use a proper sentinel / discriminated union instad of -1 magic
|
||||
// number.
|
||||
return {-1, std::unique_ptr<PatternState>()};
|
||||
}
|
||||
|
||||
static PatternMatchResult matchSuccess(
|
||||
PatternBenefit benefit,
|
||||
std::unique_ptr<PatternState> state = std::unique_ptr<PatternState>()) {
|
||||
return {benefit, std::move(state)};
|
||||
}
|
||||
|
||||
protected:
|
||||
Pattern(PatternBenefit staticBenefit, OperationName rootKind)
|
||||
: staticBenefit(staticBenefit), rootKind(rootKind) {}
|
||||
|
||||
private:
|
||||
const PatternBenefit staticBenefit;
|
||||
const OperationName rootKind;
|
||||
};
|
||||
|
||||
Pattern::~Pattern() {}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PatternMatcher class
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// This class manages optimization an execution of a group of patterns, and
|
||||
/// provides an API for finding the best match against a given node.
|
||||
///
|
||||
class PatternMatcher {
|
||||
public:
|
||||
/// Create a PatternMatch with the specified set of patterns. This takes
|
||||
/// ownership of the patterns in question.
|
||||
explicit PatternMatcher(ArrayRef<Pattern *> patterns)
|
||||
: patterns(patterns.begin(), patterns.end()) {}
|
||||
|
||||
typedef std::pair<Pattern *, std::unique_ptr<PatternState>> MatchResult;
|
||||
|
||||
/// Find the highest benefit pattern available in the pattern set for the DAG
|
||||
/// rooted at the specified node. This returns the pattern (and any state it
|
||||
/// needs) if found, or null if there are no matches.
|
||||
MatchResult findMatch(Operation *op);
|
||||
|
||||
~PatternMatcher() {
|
||||
for (auto *p : patterns)
|
||||
delete p;
|
||||
}
|
||||
|
||||
private:
|
||||
PatternMatcher(const PatternMatcher &) = delete;
|
||||
void operator=(const PatternMatcher &) = delete;
|
||||
|
||||
std::vector<Pattern *> patterns;
|
||||
};
|
||||
|
||||
/// Find the highest benefit pattern available in the pattern set for the DAG
|
||||
/// rooted at the specified node. This returns the pattern if found, or null
|
||||
/// if there are no matches.
|
||||
auto PatternMatcher::findMatch(Operation *op) -> MatchResult {
|
||||
// TODO: This is a completely trivial implementation, expand this in the
|
||||
// future.
|
||||
|
||||
// Keep track of the best match, the benefit of it, and any matcher specific
|
||||
// state it is maintaining.
|
||||
MatchResult bestMatch = {nullptr, nullptr};
|
||||
// TODO: eliminate magic numbers.
|
||||
PatternBenefit bestBenefit = -1;
|
||||
|
||||
for (auto *pattern : patterns) {
|
||||
// Ignore patterns that are for the wrong root.
|
||||
if (pattern->getRootKind() != op->getName())
|
||||
continue;
|
||||
|
||||
// If we know the static cost of the pattern is worse than what we've
|
||||
// already found then don't run it.
|
||||
auto staticBenefit = pattern->getStaticBenefit();
|
||||
if (staticBenefit < 0 || staticBenefit < bestBenefit)
|
||||
continue;
|
||||
|
||||
// Check to see if this pattern matches this node.
|
||||
auto result = pattern->match(op);
|
||||
// TODO: magic numbers.
|
||||
if (result.first < 0 || result.first < bestBenefit)
|
||||
continue;
|
||||
|
||||
// Okay we found a match that is better than our previous one, remember it.
|
||||
bestBenefit = result.first;
|
||||
bestMatch = {pattern, std::move(result.second)};
|
||||
}
|
||||
|
||||
// If we found any match, return it.
|
||||
return bestMatch;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Definition of a few patterns for canonicalizing operations.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
/// subi(x,x) -> 0
|
||||
///
|
||||
struct SimplifyXMinusX : public Pattern {
|
||||
SimplifyXMinusX(MLIRContext *context)
|
||||
// FIXME: rename getOperationName and add a proper one.
|
||||
: Pattern(1, OperationName(SubIOp::getOperationName(), context)) {}
|
||||
|
||||
std::pair<PatternBenefit, std::unique_ptr<PatternState>>
|
||||
match(Operation *op) const override {
|
||||
// TODO: Rename getAs -> dyn_cast, and add a cast<> method.
|
||||
auto subi = op->getAs<SubIOp>();
|
||||
assert(subi && "Matcher should have produced this");
|
||||
|
||||
if (subi->getOperand(0) == subi->getOperand(1))
|
||||
return matchSuccess(1);
|
||||
|
||||
return matchFailure();
|
||||
}
|
||||
|
||||
// Rewrite the IR rooted at the specified operation with the result of
|
||||
// this pattern, generating any new operations with the specified
|
||||
// builder. If an unexpected error is encountered (an internal
|
||||
// compiler error), it is emitted through the normal MLIR diagnostic
|
||||
// hooks and the IR is left in a valid state.
|
||||
virtual void rewrite(Operation *op, MLFuncBuilder &builder) const override {
|
||||
// TODO: Rename getAs -> dyn_cast, and add a cast<> method.
|
||||
auto subi = op->getAs<SubIOp>();
|
||||
assert(subi && "Matcher should have produced this");
|
||||
|
||||
// TODO: Better "replace and remove" API on Pattern.
|
||||
auto result =
|
||||
builder.create<ConstantIntOp>(op->getLoc(), 0, subi->getType());
|
||||
op->getResult(0)->replaceAllUsesWith(result->getResult());
|
||||
|
||||
cast<OperationStmt>(op)->eraseFromBlock();
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace.
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// The actual Canonicalizer Pass.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// TODO: Canonicalize and unique all constant operations into the entry of the
|
||||
// function.
|
||||
|
||||
namespace {
|
||||
/// Canonicalize operations in functions.
|
||||
struct Canonicalizer : public FunctionPass {
|
||||
PassResult runOnCFGFunction(CFGFunction *f) override;
|
||||
PassResult runOnMLFunction(MLFunction *f) override;
|
||||
|
||||
void simplifyFunction(std::vector<Operation *> &worklist,
|
||||
MLFuncBuilder &builder);
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
PassResult Canonicalizer::runOnCFGFunction(CFGFunction *f) {
|
||||
// TODO: Add this.
|
||||
return success();
|
||||
}
|
||||
|
||||
PassResult Canonicalizer::runOnMLFunction(MLFunction *f) {
|
||||
std::vector<Operation *> worklist;
|
||||
worklist.reserve(64);
|
||||
|
||||
f->walk([&](OperationStmt *stmt) { worklist.push_back(stmt); });
|
||||
|
||||
MLFuncBuilder builder(f);
|
||||
simplifyFunction(worklist, builder);
|
||||
return success();
|
||||
}
|
||||
|
||||
// TODO: This should work on both ML and CFG functions.
|
||||
void Canonicalizer::simplifyFunction(std::vector<Operation *> &worklist,
|
||||
MLFuncBuilder &builder) {
|
||||
// TODO: Instead of a hard coded list of patterns, ask the registered dialects
|
||||
// for their canonicalization patterns.
|
||||
|
||||
PatternMatcher matcher({new SimplifyXMinusX(builder.getContext())});
|
||||
|
||||
while (!worklist.empty()) {
|
||||
auto *op = worklist.back();
|
||||
worklist.pop_back();
|
||||
|
||||
// TODO: If no side effects, and operation has no users, then it is
|
||||
// trivially dead - remove it.
|
||||
|
||||
// TODO: Call the constant folding hook on this operation, and canonicalize
|
||||
// constants into the entry node.
|
||||
|
||||
// Check to see if we have any patterns that match this node.
|
||||
auto match = matcher.findMatch(op);
|
||||
if (!match.first)
|
||||
continue;
|
||||
|
||||
// TODO: Need to be a bit trickier to make sure new instructions get into
|
||||
// the worklist.
|
||||
match.first->rewrite(op, std::move(match.second), builder);
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a Canonicalizer pass.
|
||||
FunctionPass *mlir::createCanonicalizerPass() { return new Canonicalizer(); }
|
||||
10
mlir/test/Transforms/canonicalize.mlir
Normal file
10
mlir/test/Transforms/canonicalize.mlir
Normal file
@@ -0,0 +1,10 @@
|
||||
// RUN: mlir-opt %s -canonicalize | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: @test_subi_zero
|
||||
mlfunc @test_subi_zero(%x: i32) -> i32 {
|
||||
// CHECK: %c0_i32 = constant 0 : i32
|
||||
// CHECK-NEXT: return %c0
|
||||
%y = subi %x, %x : i32
|
||||
return %y: i32
|
||||
}
|
||||
|
||||
@@ -66,6 +66,7 @@ static cl::opt<bool>
|
||||
cl::init(false));
|
||||
|
||||
enum Passes {
|
||||
Canonicalize,
|
||||
ComposeAffineMaps,
|
||||
ConstantFold,
|
||||
ConvertToCFG,
|
||||
@@ -80,25 +81,26 @@ enum Passes {
|
||||
|
||||
static cl::list<Passes> passList(
|
||||
"", cl::desc("Compiler passes to run"),
|
||||
cl::values(clEnumValN(ComposeAffineMaps, "compose-affine-maps",
|
||||
"Compose affine maps"),
|
||||
clEnumValN(ConstantFold, "constant-fold",
|
||||
"Constant fold operations in functions"),
|
||||
clEnumValN(ConvertToCFG, "convert-to-cfg",
|
||||
"Convert all ML functions in the module to CFG ones"),
|
||||
clEnumValN(LoopUnroll, "loop-unroll", "Unroll loops"),
|
||||
clEnumValN(LoopUnrollAndJam, "loop-unroll-jam",
|
||||
"Unroll and jam loops"),
|
||||
clEnumValN(PipelineDataTransfer, "pipeline-data-transfer",
|
||||
"Pipeline non-blocking data transfers between"
|
||||
"explicitly managed levels of the memory hierarchy"),
|
||||
clEnumValN(PrintCFGGraph, "print-cfg-graph",
|
||||
"Print CFG graph per function"),
|
||||
clEnumValN(SimplifyAffineExpr, "simplify-affine-expr",
|
||||
"Simplify affine expressions"),
|
||||
clEnumValN(TFRaiseControlFlow, "tf-raise-control-flow",
|
||||
"Dynamic TensorFlow Switch/Match nodes to a CFG"),
|
||||
clEnumValN(XLALower, "xla-lower", "Lower to XLA dialect")));
|
||||
cl::values(
|
||||
clEnumValN(Canonicalize, "canonicalize", "Canonicalize operations"),
|
||||
clEnumValN(ComposeAffineMaps, "compose-affine-maps",
|
||||
"Compose affine maps"),
|
||||
clEnumValN(ConstantFold, "constant-fold",
|
||||
"Constant fold operations in functions"),
|
||||
clEnumValN(ConvertToCFG, "convert-to-cfg",
|
||||
"Convert all ML functions in the module to CFG ones"),
|
||||
clEnumValN(LoopUnroll, "loop-unroll", "Unroll loops"),
|
||||
clEnumValN(LoopUnrollAndJam, "loop-unroll-jam", "Unroll and jam loops"),
|
||||
clEnumValN(PipelineDataTransfer, "pipeline-data-transfer",
|
||||
"Pipeline non-blocking data transfers between"
|
||||
"explicitly managed levels of the memory hierarchy"),
|
||||
clEnumValN(PrintCFGGraph, "print-cfg-graph",
|
||||
"Print CFG graph per function"),
|
||||
clEnumValN(SimplifyAffineExpr, "simplify-affine-expr",
|
||||
"Simplify affine expressions"),
|
||||
clEnumValN(TFRaiseControlFlow, "tf-raise-control-flow",
|
||||
"Dynamic TensorFlow Switch/Match nodes to a CFG"),
|
||||
clEnumValN(XLALower, "xla-lower", "Lower to XLA dialect")));
|
||||
|
||||
enum OptResult { OptSuccess, OptFailure };
|
||||
|
||||
@@ -174,6 +176,9 @@ static OptResult performActions(SourceMgr &sourceMgr, MLIRContext *context) {
|
||||
auto passKind = passList[i];
|
||||
Pass *pass = nullptr;
|
||||
switch (passKind) {
|
||||
case Canonicalize:
|
||||
pass = createCanonicalizerPass();
|
||||
break;
|
||||
case ComposeAffineMaps:
|
||||
pass = createComposeAffineMapsPass();
|
||||
break;
|
||||
|
||||
Reference in New Issue
Block a user