[mlir-tblgen] Support binding multi-results of NativeCodeCall

We are able to bind NativeCodeCall result as binding operation. To make
table-gen have better understanding in the form of helper function,
we need to specify the number of return values in the NativeCodeCall
template. A VoidNativeCodeCall is added for void case.

Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D102160
This commit is contained in:
Chia-hung Duan
2021-07-21 11:23:06 +08:00
parent bec4a8157d
commit d7314b3c09
8 changed files with 226 additions and 45 deletions

View File

@@ -377,9 +377,6 @@ template. The string can be an arbitrary C++ expression that evaluates into some
C++ object expected at the `NativeCodeCall` site (here it would be expecting an
array attribute). Typically the string should be a function call.
Note that currently `NativeCodeCall` must return no more than one value or
attribute. This might change in the future.
##### `NativeCodeCall` placeholders
In `NativeCodeCall`, we can use placeholders like `$_builder`, `$N` and `$N...`.
@@ -428,6 +425,30 @@ parameters at the `NativeCodeCall` use site. For example, if we define
`SomeCall : NativeCodeCall<"someFn($1...)">` and use it like `(SomeCall $in0,
$in1, $in2)`, then this will be translated into C++ call `someFn($in1, $in2)`.
##### `NativeCodeCall` binding multi-results
To bind multi-results and access the N-th result with `$<name>__N`, specify the
number of return values in the template. Note that only `Value` type is
supported for multiple results binding. For example,
```tablegen
def PackAttrs : NativeCodeCall<"packAttrs($0, $1)", 2>;
def : Pattern<(TwoResultOp $attr1, $attr2),
[(OneResultOp (PackAttr:$res__0, $attr1, $attr2)),
(OneResultOp $res__1)]>;
```
Use `NativeCodeCallVoid` for case has no return value.
The correct number of returned value specified in NativeCodeCall is important.
It will be used to verify the consistency of the number of result values.
Additionally, `mlir-tblgen` will try to capture the return value of
NativeCodeCall in the generated code so that it will trigger a later compilation
error if a NativeCodeCall that doesn't return a result isn't labeled with 0
returns.
##### Customizing entire op building
`NativeCodeCall` is not only limited to transforming arguments for building an

View File

@@ -2565,11 +2565,20 @@ class Pat<dag pattern, dag result, list<dag> preds = [],
// If used as a DAG node, i.e., `(NativeCodeCall<"..."> <arg0>, ..., <argN>)`,
// then positional placeholders are also supported; placeholder `$N` in the
// wrapped C++ expression will be replaced by `<argN>`.
//
// ## Bind multiple results
//
// To bind multi-results and access the N-th result with `$<name>__N`, specify
// the number of return values in the template. Note that only `Value` type is
// supported for multiple results binding.
class NativeCodeCall<string expr> {
class NativeCodeCall<string expr, int returns = 1> {
string expression = expr;
int numReturns = returns;
}
class NativeCodeCallVoid<string expr> : NativeCodeCall<expr, 0>;
def ConstantLikeMatcher : NativeCodeCall<"success(matchPattern($_self->getResult(0), m_Constant(&$0)))">;
//===----------------------------------------------------------------------===//

View File

@@ -100,6 +100,11 @@ public:
// Precondition: isNativeCodeCall()
StringRef getNativeCodeTemplate() const;
// Returns the number of values will be returned by the native helper
// function.
// Precondition: isNativeCodeCall()
int getNumReturnsOfNativeCode() const;
// Returns the string associated with the leaf.
// Precondition: isStringAttr()
std::string getStringAttr() const;
@@ -181,6 +186,11 @@ public:
// Precondition: isNativeCodeCall()
StringRef getNativeCodeTemplate() const;
// Returns the number of values will be returned by the native helper
// function.
// Precondition: isNativeCodeCall()
int getNumReturnsOfNativeCode() const;
void print(raw_ostream &os) const;
private:
@@ -242,30 +252,32 @@ public:
// DagNode and DagLeaf are accessed by value which means it can't be used as
// identifier here. Use an opaque pointer type instead.
using DagAndIndex = std::pair<const void *, int>;
using DagAndConstant = std::pair<const void *, int>;
// What kind of entity this symbol represents:
// * Attr: op attribute
// * Operand: op operand
// * Result: op result
// * Value: a value not attached to an op (e.g., from NativeCodeCall)
enum class Kind : uint8_t { Attr, Operand, Result, Value };
// * MultipleValues: a pack of values not attached to an op (e.g., from
// NativeCodeCall). This kind supports indexing.
enum class Kind : uint8_t { Attr, Operand, Result, Value, MultipleValues };
// Creates a SymbolInfo instance. `dagAndIndex` is only used for `Attr` and
// `Operand` so should be llvm::None for `Result` and `Value` kind.
// Creates a SymbolInfo instance. `dagAndConstant` is only used for `Attr`
// and `Operand` so should be llvm::None for `Result` and `Value` kind.
SymbolInfo(const Operator *op, Kind kind,
Optional<DagAndIndex> dagAndIndex);
Optional<DagAndConstant> dagAndConstant);
// Static methods for creating SymbolInfo.
static SymbolInfo getAttr(const Operator *op, int index) {
return SymbolInfo(op, Kind::Attr, DagAndIndex(nullptr, index));
return SymbolInfo(op, Kind::Attr, DagAndConstant(nullptr, index));
}
static SymbolInfo getAttr() {
return SymbolInfo(nullptr, Kind::Attr, llvm::None);
}
static SymbolInfo getOperand(DagNode node, const Operator *op, int index) {
return SymbolInfo(op, Kind::Operand,
DagAndIndex(node.getAsOpaquePointer(), index));
DagAndConstant(node.getAsOpaquePointer(), index));
}
static SymbolInfo getResult(const Operator *op) {
return SymbolInfo(op, Kind::Result, llvm::None);
@@ -273,6 +285,10 @@ public:
static SymbolInfo getValue() {
return SymbolInfo(nullptr, Kind::Value, llvm::None);
}
static SymbolInfo getMultipleValues(int numValues) {
return SymbolInfo(nullptr, Kind::MultipleValues,
DagAndConstant(nullptr, numValues));
}
// Returns the number of static values this symbol corresponds to.
// A static value is an operand/result declared in ODS. Normally a symbol
@@ -298,13 +314,21 @@ public:
std::string getAllRangeUse(StringRef name, int index, const char *fmt,
const char *separator) const;
// The argument index (for `Attr` and `Operand` only)
int getArgIndex() const { return (*dagAndConstant).second; }
// The number of values in the MultipleValue
int getSize() const { return (*dagAndConstant).second; }
const Operator *op; // The op where the bound entity belongs
Kind kind; // The kind of the bound entity
// The pair of DagNode pointer and argument index (for `Attr` and `Operand`
// only). Note that operands may be bound to the same symbol, use the
// DagNode and index to distinguish them. For `Attr`, the Dag part will be
// nullptr.
Optional<DagAndIndex> dagAndIndex;
// The pair of DagNode pointer and constant value (for `Attr`, `Operand` and
// the size of MultipleValue symbol). Note that operands may be bound to the
// same symbol, use the DagNode and index to distinguish them. For `Attr`
// and MultipleValue, the Dag part will be nullptr.
Optional<DagAndConstant> dagAndConstant;
// Alternative name for the symbol. It is used in case the name
// is not unique. Applicable for `Operand` only.
Optional<std::string> alternativeName;
@@ -331,10 +355,17 @@ public:
// `symbol` is already bound.
bool bindOpResult(StringRef symbol, const Operator &op);
// Registers the given `symbol` as bound to a value. Returns false if `symbol`
// is already bound.
// A helper function for dispatching target value binding functions.
bool bindValues(StringRef symbol, int numValues = 1);
// Registers the given `symbol` as bound to the Value(s). Returns false if
// `symbol` is already bound.
bool bindValue(StringRef symbol);
// Registers the given `symbol` as bound to a MultipleValue. Return false if
// `symbol` is already bound.
bool bindMultipleValues(StringRef symbol, int numValues);
// Registers the given `symbol` as bound to an attr. Returns false if `symbol`
// is already bound.
bool bindAttr(StringRef symbol);

View File

@@ -83,6 +83,11 @@ llvm::StringRef DagLeaf::getNativeCodeTemplate() const {
return cast<llvm::DefInit>(def)->getDef()->getValueAsString("expression");
}
int DagLeaf::getNumReturnsOfNativeCode() const {
assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
return cast<llvm::DefInit>(def)->getDef()->getValueAsInt("numReturns");
}
std::string DagLeaf::getStringAttr() const {
assert(isStringAttr() && "the DAG leaf must be string attribute");
return def->getAsUnquotedString();
@@ -119,6 +124,13 @@ llvm::StringRef DagNode::getNativeCodeTemplate() const {
->getValueAsString("expression");
}
int DagNode::getNumReturnsOfNativeCode() const {
assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
return cast<llvm::DefInit>(node->getOperator())
->getDef()
->getValueAsInt("numReturns");
}
llvm::StringRef DagNode::getSymbol() const { return node->getNameStr(); }
Operator &DagNode::getDialectOp(RecordOperatorMap *mapper) const {
@@ -193,8 +205,8 @@ StringRef SymbolInfoMap::getValuePackName(StringRef symbol, int *index) {
}
SymbolInfoMap::SymbolInfo::SymbolInfo(const Operator *op, SymbolInfo::Kind kind,
Optional<DagAndIndex> dagAndIndex)
: op(op), kind(kind), dagAndIndex(dagAndIndex) {}
Optional<DagAndConstant> dagAndConstant)
: op(op), kind(kind), dagAndConstant(dagAndConstant) {}
int SymbolInfoMap::SymbolInfo::getStaticValueCount() const {
switch (kind) {
@@ -204,6 +216,8 @@ int SymbolInfoMap::SymbolInfo::getStaticValueCount() const {
return 1;
case Kind::Result:
return op->getNumResults();
case Kind::MultipleValues:
return getSize();
}
llvm_unreachable("unknown kind");
}
@@ -217,7 +231,7 @@ std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
switch (kind) {
case Kind::Attr: {
if (op) {
auto type = op->getArg((*dagAndIndex).second)
auto type = op->getArg(getArgIndex())
.get<NamedAttribute *>()
->attr.getStorageType();
return std::string(formatv("{0} {1};\n", type, name));
@@ -235,6 +249,14 @@ std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
case Kind::Value: {
return std::string(formatv("::mlir::Value {0};\n", name));
}
case Kind::MultipleValues: {
// This is for the variable used in the source pattern. Each named value in
// source pattern will only be bound to a Value. The others in the result
// pattern may be associated with multiple Values as we will use `auto` to
// do the type inference.
return std::string(formatv(
"::mlir::Value {0}_raw; ::mlir::ValueRange {0}({0}_raw);\n", name));
}
case Kind::Result: {
// Use the op itself for captured results.
return std::string(formatv("{0} {1};\n", op->getQualCppClassName(), name));
@@ -255,8 +277,7 @@ std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
}
case Kind::Operand: {
assert(index < 0);
auto *operand =
op->getArg((*dagAndIndex).second).get<NamedTypeConstraint *>();
auto *operand = op->getArg(getArgIndex()).get<NamedTypeConstraint *>();
// If this operand is variadic, then return a range. Otherwise, return the
// value itself.
if (operand->isVariableLength()) {
@@ -311,6 +332,21 @@ std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n");
return std::string(repl);
}
case Kind::MultipleValues: {
assert(op == nullptr);
assert(index < getSize());
if (index >= 0) {
std::string repl =
formatv(fmt, std::string(formatv("{0}[{1}]", name, index)));
LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n");
return repl;
}
// If it doesn't specify certain element, unpack them all.
auto repl =
formatv(fmt, std::string(formatv("{0}.begin(), {0}.end()", name)));
LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n");
return std::string(repl);
}
}
llvm_unreachable("unknown kind");
}
@@ -353,6 +389,20 @@ std::string SymbolInfoMap::SymbolInfo::getAllRangeUse(
LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n");
return std::string(repl);
}
case Kind::MultipleValues: {
assert(op == nullptr);
assert(index < getSize());
if (index >= 0) {
std::string repl =
formatv(fmt, std::string(formatv("{0}[{1}]", name, index)));
LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n");
return repl;
}
auto repl =
formatv(fmt, std::string(formatv("{0}.begin(), {0}.end()", name)));
LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n");
return std::string(repl);
}
}
llvm_unreachable("unknown kind");
}
@@ -395,11 +445,25 @@ bool SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) {
return symbolInfoMap.count(inserted->first) == 1;
}
bool SymbolInfoMap::bindValues(StringRef symbol, int numValues) {
std::string name = getValuePackName(symbol).str();
if (numValues > 1)
return bindMultipleValues(name, numValues);
return bindValue(name);
}
bool SymbolInfoMap::bindValue(StringRef symbol) {
auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getValue());
return symbolInfoMap.count(inserted->first) == 1;
}
bool SymbolInfoMap::bindMultipleValues(StringRef symbol, int numValues) {
std::string name = getValuePackName(symbol).str();
auto inserted =
symbolInfoMap.emplace(name, SymbolInfo::getMultipleValues(numValues));
return symbolInfoMap.count(inserted->first) == 1;
}
bool SymbolInfoMap::bindAttr(StringRef symbol) {
auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getAttr());
return symbolInfoMap.count(inserted->first) == 1;
@@ -423,11 +487,9 @@ SymbolInfoMap::findBoundSymbol(StringRef key, DagNode node, const Operator &op,
const auto symbolInfo = SymbolInfo::getOperand(node, &op, argIndex);
for (auto it = range.first; it != range.second; ++it) {
if (it->second.dagAndIndex == symbolInfo.dagAndIndex) {
for (auto it = range.first; it != range.second; ++it)
if (it->second.dagAndConstant == symbolInfo.dagAndConstant)
return it;
}
}
return symbolInfoMap.end();
}
@@ -633,7 +695,9 @@ void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
if (!isSrcPattern) {
LLVM_DEBUG(llvm::dbgs() << "found symbol bound to NativeCodeCall: "
<< treeName << '\n');
verifyBind(infoMap.bindValue(treeName), treeName);
verifyBind(
infoMap.bindValues(treeName, tree.getNumReturnsOfNativeCode()),
treeName);
} else {
PrintFatalError(&def,
formatv("binding symbol '{0}' to NativecodeCall in "

View File

@@ -857,7 +857,7 @@ def OpNativeCodeCall3 : TEST_Op<"native_code_call3"> {
// Test that NativeCodeCall is not ignored if it is not used to directly
// replace the matched root op.
def : Pattern<(OpNativeCodeCall3 $input),
[(NativeCodeCall<"createOpI($_builder, $_loc, $0)"> $input),
[(NativeCodeCallVoid<"createOpI($_builder, $_loc, $0)"> $input),
(OpK)]>;
def OpNativeCodeCall4 : TEST_Op<"native_code_call4"> {
@@ -874,6 +874,19 @@ def BindNativeCodeCallResult : NativeCodeCall<"bindNativeCodeCallResult($0)">;
def : Pat<(OpNativeCodeCall4 (GetFirstI32Result $ret)),
(OpNativeCodeCall5 (BindNativeCodeCallResult:$native $ret), $native)>;
def OpNativeCodeCall6 : TEST_Op<"native_code_call6"> {
let arguments = (ins I32:$input1, I32:$input2);
let results = (outs I32:$output1, I32:$output2);
}
def OpNativeCodeCall7 : TEST_Op<"native_code_call7"> {
let arguments = (ins I32:$input);
let results = (outs I32);
}
def BindMultipleNativeCodeCallResult : NativeCodeCall<"bindMultipleNativeCodeCallResult($0, $1)", 2>;
def : Pattern<(OpNativeCodeCall6 $arg1, $arg2),
[(OpNativeCodeCall7 (BindMultipleNativeCodeCallResult:$native__0 $arg1, $arg2)),
(OpNativeCodeCall7 $native__1)]>;
// Test AllAttrConstraintsOf.
def OpAllAttrConstraint1 : TEST_Op<"all_attr_constraint_of1"> {
let arguments = (ins I64ArrayAttr:$attr);
@@ -1033,7 +1046,7 @@ def OpSymbolBindingNoResult : TEST_Op<"symbol_binding_no_result", []> {
// Test that we can bind to an op without results and reference it later.
def : Pat<(OpSymbolBindingNoResult:$op $operand),
(NativeCodeCall<"handleNoResultOp($_builder, $0)"> $op)>;
(NativeCodeCallVoid<"handleNoResultOp($_builder, $0)"> $op)>;
//===----------------------------------------------------------------------===//
// Test Patterns (Attributes)

View File

@@ -44,6 +44,11 @@ static bool getFirstI32Result(Operation *op, Value &value) {
static Value bindNativeCodeCallResult(Value value) { return value; }
static SmallVector<Value, 2> bindMultipleNativeCodeCallResult(Value input1,
Value input2) {
return SmallVector<Value, 2>({input2, input1});
}
// Test that natives calls are only called once during rewrites.
// OpM_Test will return Pi, increased by 1 for each subsequent calls.
// This let us check the number of times OpM_Test was called by inspecting

View File

@@ -102,6 +102,16 @@ func @verifyNativeCodeCallBinding(%arg0 : i32) -> (i32) {
return %1 : i32
}
// CHECK-LABEL: verifyMultipleNativeCodeCallBinding
func@verifyMultipleNativeCodeCallBinding(%arg0 : i32) -> (i32) {
%0 = "test.op_k"() : () -> (i32)
%1 = "test.op_k"() : () -> (i32)
// CHECK: %[[A:.*]] = "test.native_code_call7"(%1) : (i32) -> i32
// CHECK: %[[A:.*]] = "test.native_code_call7"(%0) : (i32) -> i32
%2, %3 = "test.native_code_call6"(%0, %1) : (i32, i32) -> (i32, i32)
return %2 : i32
}
// CHECK-LABEL: verifyAllAttrConstraintOf
func @verifyAllAttrConstraintOf() -> (i32, i32, i32) {
// CHECK: "test.all_attr_constraint_of2"

View File

@@ -754,7 +754,8 @@ void PatternEmitter::emitRewriteLogic() {
// NativeCodeCall will only be materialized to `os` if it is used. Here
// we are handling auxiliary patterns so we want the side effect even if
// NativeCodeCall is not replacing matched root op's results.
if (resultTree.isNativeCodeCall())
if (resultTree.isNativeCodeCall() &&
resultTree.getNumReturnsOfNativeCode() == 0)
os << val << ";\n";
}
@@ -804,11 +805,8 @@ std::string PatternEmitter::handleResultPattern(DagNode resultTree,
"location directive can only be used with op creation");
}
if (resultTree.isNativeCodeCall()) {
auto symbol = handleReplaceWithNativeCodeCall(resultTree, depth);
symbolInfoMap.bindValue(symbol);
return symbol;
}
if (resultTree.isNativeCodeCall())
return handleReplaceWithNativeCodeCall(resultTree, depth);
if (resultTree.isReplaceWithValue())
return handleReplaceWithValue(resultTree).str();
@@ -948,9 +946,39 @@ std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree,
}
std::string symbol = tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse), attrs);
if (!tree.getSymbol().empty()) {
os << formatv("auto {0} = {1};\n", tree.getSymbol(), symbol);
symbol = tree.getSymbol().str();
// In general, NativeCodeCall without naming binding don't need this. To
// ensure void helper function has been correctly labeled, i.e., use
// NativeCodeCallVoid, we cache the result to a local variable so that we will
// get a compilation error in the auto-generated file.
// Example.
// // In the td file
// Pat<(...), (NativeCodeCall<Foo> ...)>
//
// ---
//
// // In the auto-generated .cpp
// ...
// // Causes compilation error if Foo() returns void.
// auto nativeVar = Foo();
// ...
if (tree.getNumReturnsOfNativeCode() != 0) {
// Determine the local variable name for return value.
std::string varName =
SymbolInfoMap::getValuePackName(tree.getSymbol()).str();
if (varName.empty()) {
varName = formatv("nativeVar_{0}", nextValueId++);
// Register the local variable for later uses.
symbolInfoMap.bindValues(varName, tree.getNumReturnsOfNativeCode());
}
// Catch the return value of helper function.
os << formatv("auto {0} = {1}; (void){0};\n", varName, symbol);
if (!tree.getSymbol().empty())
symbol = tree.getSymbol().str();
else
symbol = varName;
}
return symbol;
@@ -967,8 +995,10 @@ int PatternEmitter::getNodeValueCount(DagNode node) {
// Otherwise this is an unbound op; we will use all its results.
return pattern.getDialectOp(node).getNumResults();
}
// TODO: This considers all NativeCodeCall as returning one
// value. Enhance if multi-value ones are needed.
if (node.isNativeCodeCall())
return node.getNumReturnsOfNativeCode();
return 1;
}
@@ -1191,8 +1221,7 @@ void PatternEmitter::supplyValuesForOpArgs(
if (!subTree.isNativeCodeCall())
PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
"for creating attribute");
os << formatv("/*{0}=*/{1}", opArgName,
handleReplaceWithNativeCodeCall(subTree, depth));
os << formatv("/*{0}=*/{1}", opArgName, childNodeNames.lookup(argIndex));
} else {
auto leaf = node.getArgAsLeaf(argIndex);
// The argument in the result DAG pattern.
@@ -1233,8 +1262,7 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
if (!subTree.isNativeCodeCall())
PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
"for creating attribute");
os << formatv(addAttrCmd, opArgName,
handleReplaceWithNativeCodeCall(subTree, depth + 1));
os << formatv(addAttrCmd, opArgName, childNodeNames.lookup(argIndex));
} else {
auto leaf = node.getArgAsLeaf(argIndex);
// The argument in the result DAG pattern.