From eb753f4aece37b47a3819467d6245ed6ccb1a2ba Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Mon, 28 Jan 2019 14:04:40 -0800 Subject: [PATCH] Add tblgen::Pattern to model Patterns defined in TableGen Similar to other tblgen:: abstractions, tblgen::Pattern hides the native TableGen API and provides a nicer API that is more coherent with the TableGen definitions. PiperOrigin-RevId: 231285143 --- mlir/include/mlir/TableGen/Operator.h | 7 +- mlir/include/mlir/TableGen/Pattern.h | 161 ++++++++++++++++++ mlir/lib/TableGen/Operator.cpp | 8 + mlir/lib/TableGen/Pattern.cpp | 133 +++++++++++++++ mlir/tools/mlir-tblgen/RewriterGen.cpp | 225 +++++++++++-------------- 5 files changed, 401 insertions(+), 133 deletions(-) create mode 100644 mlir/include/mlir/TableGen/Pattern.h create mode 100644 mlir/lib/TableGen/Pattern.cpp diff --git a/mlir/include/mlir/TableGen/Operator.h b/mlir/include/mlir/TableGen/Operator.h index f299e2f1e376..57613c2d4193 100644 --- a/mlir/include/mlir/TableGen/Operator.h +++ b/mlir/include/mlir/TableGen/Operator.h @@ -68,10 +68,10 @@ public: // Op attribute accessors. int getNumAttributes() const { return attributes.size(); } + // Returns the total number of native attributes. + int getNumNativeAttributes() const; NamedAttribute &getAttribute(int index) { return attributes[index]; } - const NamedAttribute &getAttribute(int index) const { - return attributes[index]; - } + const NamedAttribute &getAttribute(int index) const; // Op operand iterators. using operand_iterator = Operand *; @@ -87,6 +87,7 @@ public: // Op argument (attribute or operand) accessors. Argument getArg(int index); StringRef getArgName(int index) const; + // Returns the total number of arguments. int getNumArgs() const { return operands.size() + attributes.size(); } // Query functions for the documentation of the operator. diff --git a/mlir/include/mlir/TableGen/Pattern.h b/mlir/include/mlir/TableGen/Pattern.h new file mode 100644 index 000000000000..4f727c48fc53 --- /dev/null +++ b/mlir/include/mlir/TableGen/Pattern.h @@ -0,0 +1,161 @@ +//===- Pattern.h - Pattern wrapper class ------------------------*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// Pattern wrapper class to simplify using TableGen Record defining a MLIR +// Pattern. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TABLEGEN_PATTERN_H_ +#define MLIR_TABLEGEN_PATTERN_H_ + +#include "mlir/TableGen/Argument.h" +#include "mlir/TableGen/Operator.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/TableGen/Error.h" + +namespace llvm { +class Record; +class Init; +class DagInit; +class StringRef; +} // end namespace llvm + +namespace mlir { +namespace tblgen { + +// Mapping from TableGen Record to Operator wrapper object +using RecordOperatorMap = llvm::DenseMap; + +// Wrapper around DAG argument. +struct DagArg { + DagArg(Argument arg, llvm::Init *constraint) + : arg(arg), constraint(constraint) {} + + // Returns true if this DAG argument concerns an operation attribute. + bool isAttr() const; + + Argument arg; + llvm::Init *constraint; +}; + +class Pattern; + +// Wrapper class providing helper methods for accessing TableGen DAG constructs +// used inside Patterns. This class is lightweight and designed to be used like +// values. +// +// A TableGen DAG construct is of the syntax +// `(operator, arg0, arg1, ...)`. +// +// When used inside Patterns, `operator` corresponds to some dialect op, or +// a known list of verbs that defines special transformation actions. This +// `arg*` can be a nested DAG construct. This class provides getters to +// retrieve `operator` and `arg*` as tblgen:: wrapper objects for handy helper +// methods. +// +// A null DagNode contains a nullptr and converts to false implicitly. +class DagNode { +public: + explicit DagNode(const llvm::DagInit *node) : node(node) {} + + // Implicit bool converter that returns true if this DagNode is not a null + // DagNode. + operator bool() const { return node != nullptr; } + + // Returns the operator wrapper object corresponding to the dialect op matched + // by this DAG. The operator wrapper will be queried from the given `mapper` + // and created in it if not existing. + Operator &getDialectOp(RecordOperatorMap *mapper) const; + + // Returns the number of operations recursively involved in the DAG tree + // rooted from this node. + unsigned getNumOps() const; + + // Returns the number of immediate arguments to this DAG node. + unsigned getNumArgs() const; + + // Returns true if the `index`-th argument is a nested DAG construct. + bool isNestedDagArg(unsigned index) const; + + // Gets the `index`-th argument as a nested DAG construct if possible. Returns + // null DagNode otherwise. + DagNode getArgAsNestedDag(unsigned index) const; + // Gets the `index`-th argument as a TableGen DefInit* if possible. Returns + // nullptr otherwise. + // TODO: This method is exposing raw TableGen object and should be changed. + llvm::DefInit *getArgAsDefInit(unsigned index) const; + + // Returns the specified name of the `index`-th argument. + llvm::StringRef getArgName(unsigned index) const; + + // Collects all recursively bound arguments involved in the DAG tree rooted + // from this node. + void collectBoundArguments(Pattern *pattern) const; + + // Returns true if this DAG construct means to replace with an existing SSA + // value. + bool isReplaceWithValue() const; + +private: + const llvm::DagInit *node; // nullptr means null DagNode +}; + +// Wrapper class providing helper methods for accessing MLIR Pattern defined +// in TableGen. This class should closely reflect what is defined as class +// `Pattern` in TableGen. This class contains maps so it is not intended to be +// used as values. +class Pattern { +public: + explicit Pattern(const llvm::Record *def, RecordOperatorMap *mapper); + + // Returns the source pattern to match. + DagNode getSourcePattern() const; + + // Returns the number of results generated by applying this rewrite pattern. + unsigned getNumResults() const; + + // Returns the DAG tree root node of the `index`-th result pattern. + DagNode getResultPattern(unsigned index) const; + + // Checks whether an argument with the given `name` is bound in source + // pattern. Prints fatal error if not; does nothing otherwise. + void ensureArgBoundInSourcePattern(llvm::StringRef name) const; + + // Returns a reference to all the bound arguments in the source pattern. + llvm::StringMap &getSourcePatternBoundArgs(); + + // Returns the op that the root node of the source pattern matches. + const Operator &getSourceRootOp(); + + // Returns the operator wrapper object corresponding to the given `node`'s DAG + // operator. + Operator &getDialectOp(DagNode node); + +private: + // The TableGen definition of this pattern. + const llvm::Record &def; + + RecordOperatorMap *recordOpMap; // All operators + llvm::StringMap boundArguments; // All bound arguments +}; + +} // end namespace tblgen +} // end namespace mlir + +#endif // MLIR_TABLEGEN_PATTERN_H_ diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp index 49bf04673566..5bf332b72b2e 100644 --- a/mlir/lib/TableGen/Operator.cpp +++ b/mlir/lib/TableGen/Operator.cpp @@ -53,6 +53,14 @@ std::string tblgen::Operator::qualifiedCppClassName() const { return llvm::join(getSplitDefName(), "::"); } +int tblgen::Operator::getNumNativeAttributes() const { + return derivedAttrStart - nativeAttrStart; +} + +const tblgen::NamedAttribute &tblgen::Operator::getAttribute(int index) const { + return attributes[index]; +} + StringRef tblgen::Operator::getArgName(int index) const { DagInit *argumentValues = def.getValueAsDag("arguments"); return argumentValues->getArgName(index)->getValue(); diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp new file mode 100644 index 000000000000..b336f48ba104 --- /dev/null +++ b/mlir/lib/TableGen/Pattern.cpp @@ -0,0 +1,133 @@ +//===- Pattern.cpp - Pattern wrapper class ----------------------*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// Pattern wrapper class to simplify using TableGen Record defining a MLIR +// Pattern. +// +//===----------------------------------------------------------------------===// + +#include "mlir/TableGen/Pattern.h" +#include "llvm/ADT/Twine.h" +#include "llvm/TableGen/Record.h" + +using namespace mlir; + +using mlir::tblgen::Operator; + +bool tblgen::DagArg::isAttr() const { + return arg.is(); +} + +Operator &tblgen::DagNode::getDialectOp(RecordOperatorMap *mapper) const { + llvm::Record *opDef = cast(node->getOperator())->getDef(); + return mapper->try_emplace(opDef, opDef).first->second; +} + +unsigned tblgen::DagNode::getNumOps() const { + unsigned count = isReplaceWithValue() ? 0 : 1; + for (unsigned i = 0, e = getNumArgs(); i != e; ++i) { + if (auto child = getArgAsNestedDag(i)) + count += child.getNumOps(); + } + return count; +} + +unsigned tblgen::DagNode::getNumArgs() const { return node->getNumArgs(); } + +bool tblgen::DagNode::isNestedDagArg(unsigned index) const { + return isa(node->getArg(index)); +} + +tblgen::DagNode tblgen::DagNode::getArgAsNestedDag(unsigned index) const { + return DagNode(dyn_cast_or_null(node->getArg(index))); +} + +llvm::DefInit *tblgen::DagNode::getArgAsDefInit(unsigned index) const { + return dyn_cast(node->getArg(index)); +} + +StringRef tblgen::DagNode::getArgName(unsigned index) const { + return node->getArgNameStr(index); +} + +static void collectBoundArguments(const llvm::DagInit *tree, + tblgen::Pattern *pattern) { + auto &op = pattern->getDialectOp(tblgen::DagNode(tree)); + + // TODO(jpienaar): Expand to multiple matches. + for (unsigned i = 0, e = tree->getNumArgs(); i != e; ++i) { + auto *arg = tree->getArg(i); + + if (auto *argTree = dyn_cast(arg)) { + collectBoundArguments(argTree, pattern); + continue; + } + + StringRef name = tree->getArgNameStr(i); + if (name.empty()) + continue; + + pattern->getSourcePatternBoundArgs().try_emplace(name, op.getArg(i), arg); + } +} + +void tblgen::DagNode::collectBoundArguments(tblgen::Pattern *pattern) const { + ::collectBoundArguments(node, pattern); +} + +bool tblgen::DagNode::isReplaceWithValue() const { + auto *dagOpDef = cast(node->getOperator())->getDef(); + return dagOpDef->getName() == "replaceWithValue"; +} + +tblgen::Pattern::Pattern(const llvm::Record *def, RecordOperatorMap *mapper) + : def(*def), recordOpMap(mapper) { + getSourcePattern().collectBoundArguments(this); +} + +tblgen::DagNode tblgen::Pattern::getSourcePattern() const { + return tblgen::DagNode(def.getValueAsDag("PatternToMatch")); +} + +unsigned tblgen::Pattern::getNumResults() const { + auto *results = def.getValueAsListInit("ResultOps"); + return results->size(); +} + +tblgen::DagNode tblgen::Pattern::getResultPattern(unsigned index) const { + auto *results = def.getValueAsListInit("ResultOps"); + return tblgen::DagNode(cast(results->getElement(index))); +} + +void tblgen::Pattern::ensureArgBoundInSourcePattern( + llvm::StringRef name) const { + if (boundArguments.find(name) == boundArguments.end()) + PrintFatalError(def.getLoc(), + Twine("referencing unbound variable '") + name + "'"); +} + +llvm::StringMap &tblgen::Pattern::getSourcePatternBoundArgs() { + return boundArguments; +} + +const tblgen::Operator &tblgen::Pattern::getSourceRootOp() { + return getSourcePattern().getDialectOp(recordOpMap); +} + +tblgen::Operator &tblgen::Pattern::getDialectOp(DagNode node) { + return node.getDialectOp(recordOpMap); +} diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp index 7af3134bcf14..d954db861783 100644 --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -22,6 +22,7 @@ #include "mlir/TableGen/Attribute.h" #include "mlir/TableGen/GenInfo.h" #include "mlir/TableGen/Operator.h" +#include "mlir/TableGen/Pattern.h" #include "mlir/TableGen/Predicate.h" #include "mlir/TableGen/Type.h" #include "llvm/ADT/StringExtras.h" @@ -40,9 +41,12 @@ using namespace mlir; using mlir::tblgen::Argument; using mlir::tblgen::Attribute; +using mlir::tblgen::DagNode; using mlir::tblgen::NamedAttribute; using mlir::tblgen::Operand; using mlir::tblgen::Operator; +using mlir::tblgen::Pattern; +using mlir::tblgen::RecordOperatorMap; using mlir::tblgen::Type; namespace { @@ -62,102 +66,65 @@ struct DagArg { bool DagArg::isAttr() { return arg.is(); } namespace { -class Pattern { +class PatternEmitter { public: - static void emit(StringRef rewriteName, Record *p, raw_ostream &os); + static void emit(StringRef rewriteName, Record *p, RecordOperatorMap *mapper, + raw_ostream &os); private: - Pattern(Record *pattern, raw_ostream &os) : pattern(pattern), os(os) {} + PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os) + : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper), os(os) {} - // Emits the rewrite pattern named `rewriteName`. + // Emits the mlir::RewritePattern struct named `rewriteName`. void emit(StringRef rewriteName); - // Emits the matcher. - void emitMatcher(DagInit *tree); + // Emits the match() method. + void emitMatchMethod(DagNode tree); // Emits the rewrite() method. void emitRewriteMethod(); // Emits the C++ statement to replace the matched DAG with an existing value. - void emitReplaceWithExistingValue(DagInit *resultTree); + void emitReplaceWithExistingValue(DagNode resultTree); // Emits the C++ statement to replace the matched DAG with a new op. - void emitReplaceOpWithNewOp(DagInit *resultTree); + void emitReplaceOpWithNewOp(DagNode resultTree); // Emits the value of constant attribute to `os`. void emitAttributeValue(Record *constAttr); - // Collects bound arguments. - void collectBoundArguments(DagInit *tree); + // Emits C++ statements for matching the op constrained by the given DAG + // `tree`. + void emitOpMatch(DagNode tree, int depth); - // Checks whether an argument with the given `name` is bound in source - // pattern. Prints fatal error if not; does nothing otherwise. - void checkArgumentBound(StringRef name) const; - - // Helper function to match patterns. - void matchOp(DagInit *tree, int depth); - - // Returns the Operator stored for the given record. - Operator &getOperator(const llvm::Record *record); - - // Map from bound argument name to DagArg. - StringMap boundArguments; - - // Map from Record* to Operator. - DenseMap opMap; - - // Number of the operations in the input pattern. - int numberOfOpsMatched = 0; - - Record *pattern; +private: + // Pattern instantiation location followed by the location of multiclass + // prototypes used. This is intended to be used as a whole to + // PrintFatalError() on errors. + ArrayRef loc; + // Op's TableGen Record to wrapper object + RecordOperatorMap *opMap; + // Handy wrapper for pattern being emitted + Pattern pattern; raw_ostream &os; }; } // end namespace -// Returns the Operator stored for the given record. -auto Pattern::getOperator(const llvm::Record *record) -> Operator & { - return opMap.try_emplace(record, record).first->second; -} - -void Pattern::emitAttributeValue(Record *constAttr) { +void PatternEmitter::emitAttributeValue(Record *constAttr) { Attribute attr(constAttr->getValueAsDef("attr")); auto value = constAttr->getValue("value"); if (!attr.isConstBuildable()) - PrintFatalError(pattern->getLoc(), - "Attribute " + attr.getTableGenDefName() + - " does not have the 'constBuilderCall' field"); + PrintFatalError(loc, "Attribute " + attr.getTableGenDefName() + + " does not have the 'constBuilderCall' field"); // TODO(jpienaar): Verify the constants here os << formatv(attr.getConstBuilderTemplate().str().c_str(), "rewriter", value->getValue()->getAsUnquotedString()); } -void Pattern::collectBoundArguments(DagInit *tree) { - ++numberOfOpsMatched; - Operator &op = getOperator(cast(tree->getOperator())->getDef()); - // TODO(jpienaar): Expand to multiple matches. - for (int i = 0, e = tree->getNumArgs(); i != e; ++i) { - auto arg = tree->getArg(i); - if (auto argTree = dyn_cast(arg)) { - collectBoundArguments(argTree); - continue; - } - auto name = tree->getArgNameStr(i); - if (name.empty()) - continue; - boundArguments.try_emplace(name, op.getArg(i), arg); - } -} - -void Pattern::checkArgumentBound(StringRef name) const { - if (boundArguments.find(name) == boundArguments.end()) - PrintFatalError(pattern->getLoc(), - Twine("referencing unbound variable '") + name + "'"); -} - // Helper function to match patterns. -void Pattern::matchOp(DagInit *tree, int depth) { - Operator &op = getOperator(cast(tree->getOperator())->getDef()); +void PatternEmitter::emitOpMatch(DagNode tree, int depth) { + Operator &op = tree.getDialectOp(opMap); int indent = 4 + 2 * depth; // Skip the operand matching at depth 0 as the pattern rewriter already does. if (depth != 0) { @@ -167,27 +134,25 @@ void Pattern::matchOp(DagInit *tree, int depth) { "if (!op{0}->isa<{1}>()) return matchFailure();\n", depth, op.qualifiedCppClassName()); } - if (tree->getNumArgs() != op.getNumArgs()) - PrintFatalError(pattern->getLoc(), - Twine("mismatch in number of arguments to op '") + - op.getOperationName() + - "' in pattern and op's definition"); - for (int i = 0, e = tree->getNumArgs(); i != e; ++i) { - auto arg = tree->getArg(i); + if (tree.getNumArgs() != op.getNumArgs()) + PrintFatalError(loc, Twine("mismatch in number of arguments to op '") + + op.getOperationName() + + "' in pattern and op's definition"); + for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { auto opArg = op.getArg(i); - if (auto argTree = dyn_cast(arg)) { + if (DagNode argTree = tree.getArgAsNestedDag(i)) { os.indent(indent) << "{\n"; os.indent(indent + 2) << formatv( "auto op{0} = op{1}->getOperand({2})->getDefiningInst();\n", depth + 1, depth, i); - matchOp(argTree, depth + 1); + emitOpMatch(argTree, depth + 1); os.indent(indent) << "}\n"; continue; } // Verify arguments. - if (auto defInit = dyn_cast(arg)) { + if (auto defInit = tree.getArgAsDefInit(i)) { // Verify operands. if (auto *operand = opArg.dyn_cast()) { // Skip verification where not needed due to definition of op. @@ -195,8 +160,7 @@ void Pattern::matchOp(DagInit *tree, int depth) { goto StateCapture; if (!defInit->getDef()->isSubClassOf("Type")) - PrintFatalError(pattern->getLoc(), - "type argument required for operand"); + PrintFatalError(loc, "type argument required for operand"); auto constraint = tblgen::TypeConstraint(*defInit); os.indent(indent) @@ -219,7 +183,7 @@ void Pattern::matchOp(DagInit *tree, int depth) { } StateCapture: - auto name = tree->getArgNameStr(i); + auto name = tree.getArgName(i); if (name.empty()) continue; if (opArg.is()) @@ -234,7 +198,7 @@ void Pattern::matchOp(DagInit *tree, int depth) { } } -void Pattern::emitMatcher(DagInit *tree) { +void PatternEmitter::emitMatchMethod(DagNode tree) { // Emit the heading. os << R"( PatternMatchResult match(OperationInst *op0) const override { @@ -242,28 +206,30 @@ void Pattern::emitMatcher(DagInit *tree) { if (op0->getNumResults() != 1) return matchFailure(); auto state = std::make_unique();)" << "\n"; - matchOp(tree, 0); + emitOpMatch(tree, 0); os.indent(4) << "return matchSuccess(std::move(state));\n }\n"; } -void Pattern::emit(StringRef rewriteName) { - DagInit *tree = pattern->getValueAsDag("PatternToMatch"); - // Collect bound arguments and compute number of ops matched. +void PatternEmitter::emit(StringRef rewriteName) { + // Get the DAG tree for the source pattern + DagNode tree = pattern.getSourcePattern(); + // TODO(jpienaar): the benefit metric is simply number of ops matched at the // moment, revise. - collectBoundArguments(tree); + unsigned benefit = tree.getNumOps(); + + const Operator &rootOp = pattern.getSourceRootOp(); + auto rootName = rootOp.getOperationName(); // Emit RewritePattern for Pattern. - DefInit *root = cast(tree->getOperator()); - auto *rootName = cast(root->getDef()->getValueInit("opName")); os << formatv(R"(struct {0} : public RewritePattern { - {0}(MLIRContext *context) : RewritePattern({1}, {2}, context) {{})", - rewriteName, rootName->getAsString(), numberOfOpsMatched) + {0}(MLIRContext *context) : RewritePattern("{1}", {2}, context) {{})", + rewriteName, rootName, benefit) << "\n"; // Emit matched state. os << " struct MatchedState : public PatternState {\n"; - for (auto &arg : boundArguments) { + for (const auto &arg : pattern.getSourcePatternBoundArgs()) { if (auto namedAttr = arg.second.arg.dyn_cast()) { os.indent(4) << namedAttr->attr.getStorageType() << " " << arg.first() << ";\n"; @@ -273,23 +239,22 @@ void Pattern::emit(StringRef rewriteName) { } os << " };\n"; - emitMatcher(tree); + emitMatchMethod(tree); emitRewriteMethod(); os << "};\n"; } -void Pattern::emitRewriteMethod() { - ListInit *resultOps = pattern->getValueAsListInit("ResultOps"); - if (resultOps->size() != 1) +void PatternEmitter::emitRewriteMethod() { + if (pattern.getNumResults() != 1) PrintFatalError("only single result rules supported"); - DagInit *resultTree = cast(resultOps->getElement(0)); + + DagNode resultTree = pattern.getResultPattern(0); // TODO(jpienaar): Expand to multiple results. - for (auto result : resultTree->getArgs()) { - if (isa(result)) - PrintFatalError(pattern->getLoc(), "only single op result supported"); - } + for (unsigned i = 0, e = resultTree.getNumArgs(); i != e; ++i) + if (resultTree.getArgAsNestedDag(i)) + PrintFatalError(loc, "only single op result supported"); os << R"( void rewrite(OperationInst *op, std::unique_ptr state, @@ -297,8 +262,7 @@ void Pattern::emitRewriteMethod() { auto& s = *static_cast(state.get()); )"; - auto *dagOpDef = cast(resultTree->getOperator())->getDef(); - if (dagOpDef->getName() == "replaceWithValue") + if (resultTree.isReplaceWithValue()) emitReplaceWithExistingValue(resultTree); else emitReplaceOpWithNewOp(resultTree); @@ -306,31 +270,29 @@ void Pattern::emitRewriteMethod() { os << " }\n"; } -void Pattern::emitReplaceWithExistingValue(DagInit *resultTree) { - if (resultTree->getNumArgs() != 1) { - PrintFatalError(pattern->getLoc(), - "exactly one argument needed in the result pattern"); +void PatternEmitter::emitReplaceWithExistingValue(DagNode resultTree) { + if (resultTree.getNumArgs() != 1) { + PrintFatalError(loc, "exactly one argument needed in the result pattern"); } - auto name = resultTree->getArgNameStr(0); - checkArgumentBound(name); + auto name = resultTree.getArgName(0); + pattern.ensureArgBoundInSourcePattern(name); os.indent(4) << "rewriter.replaceOp(op, {s." << name << "});\n"; } -void Pattern::emitReplaceOpWithNewOp(DagInit *resultTree) { - DefInit *dagOperator = cast(resultTree->getOperator()); - Operator &resultOp = getOperator(dagOperator->getDef()); - auto resultOperands = dagOperator->getDef()->getValueAsDag("arguments"); +void PatternEmitter::emitReplaceOpWithNewOp(DagNode resultTree) { + Operator &resultOp = resultTree.getDialectOp(opMap); + auto numOpArgs = + resultOp.getNumOperands() + resultOp.getNumNativeAttributes(); os << formatv(R"( rewriter.replaceOpWithNewOp<{0}>(op, op->getResult(0)->getType())", resultOp.cppClassName()); - if (resultOperands->getNumArgs() != resultTree->getNumArgs()) { - PrintFatalError(pattern->getLoc(), - Twine("mismatch between arguments of resultant op (") + - Twine(resultOperands->getNumArgs()) + - ") and arguments provided for rewrite (" + - Twine(resultTree->getNumArgs()) + Twine(')')); + if (numOpArgs != resultTree.getNumArgs()) { + PrintFatalError(loc, Twine("mismatch between arguments of resultant op (") + + Twine(numOpArgs) + + ") and arguments provided for rewrite (" + + Twine(resultTree.getNumArgs()) + Twine(')')); } // Create the builder call for the result. @@ -340,8 +302,8 @@ void Pattern::emitReplaceOpWithNewOp(DagInit *resultTree) { // Start each operand on its own line. (os << ",\n").indent(6); - auto name = resultTree->getArgNameStr(i); - checkArgumentBound(name); + auto name = resultTree.getArgName(i); + pattern.ensureArgBoundInSourcePattern(name); if (operand.name) os << "/*" << operand.name->getAsUnquotedString() << "=*/"; os << "s." << name; @@ -350,18 +312,18 @@ void Pattern::emitReplaceOpWithNewOp(DagInit *resultTree) { } // Add attributes. - for (int e = resultTree->getNumArgs(); i != e; ++i) { + for (int e = resultTree.getNumArgs(); i != e; ++i) { // Start each attribute on its own line. (os << ",\n").indent(6); // The argument in the result DAG pattern. - auto name = resultTree->getArgNameStr(i); + auto argName = resultTree.getArgName(i); auto opName = resultOp.getArgName(i); - auto defInit = dyn_cast(resultTree->getArg(i)); + auto *defInit = resultTree.getArgAsDefInit(i); auto *value = defInit ? defInit->getDef()->getValue("value") : nullptr; if (!value) { - checkArgumentBound(name); - auto result = "s." + name; + pattern.ensureArgBoundInSourcePattern(argName); + auto result = "s." + argName; os << "/*" << opName << "=*/"; if (defInit) { auto transform = defInit->getDef(); @@ -380,31 +342,34 @@ void Pattern::emitReplaceOpWithNewOp(DagInit *resultTree) { // TODO(jpienaar): Refactor out into map to avoid recomputing these. auto argument = resultOp.getArg(i); if (!argument.is()) - PrintFatalError(pattern->getLoc(), - Twine("expected attribute ") + Twine(i)); + PrintFatalError(loc, Twine("expected attribute ") + Twine(i)); - if (!name.empty()) - os << "/*" << name << "=*/"; + if (!argName.empty()) + os << "/*" << argName << "=*/"; emitAttributeValue(defInit->getDef()); // TODO(jpienaar): verify types } os << "\n );\n"; } -void Pattern::emit(StringRef rewriteName, Record *p, raw_ostream &os) { - Pattern pattern(p, os); - pattern.emit(rewriteName); +void PatternEmitter::emit(StringRef rewriteName, Record *p, + RecordOperatorMap *mapper, raw_ostream &os) { + PatternEmitter(p, mapper, os).emit(rewriteName); } static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) { emitSourceFileHeader("Rewriters", os); const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern"); + // We put the map here because it can be shared among multiple patterns. + RecordOperatorMap recordOpMap; + // Ensure unique patterns simply by appending unique suffix. std::string baseRewriteName = "GeneratedConvert"; int rewritePatternCount = 0; for (Record *p : patterns) { - Pattern::emit(baseRewriteName + llvm::utostr(rewritePatternCount++), p, os); + PatternEmitter::emit(baseRewriteName + llvm::utostr(rewritePatternCount++), + p, &recordOpMap, os); } // Emit function to add the generated matchers to the pattern list.