mirror of
https://github.com/intel/llvm.git
synced 2026-01-17 23:45:25 +08:00
[mlir-tblgen] Support binding multi-results of NativeCodeCall
We are able to bind NativeCodeCall result as binding operation. To make table-gen have better understanding in the form of helper function, we need to specify the number of return values in the NativeCodeCall template. A VoidNativeCodeCall is added for void case. Reviewed By: jpienaar Differential Revision: https://reviews.llvm.org/D102160
This commit is contained in:
@@ -377,9 +377,6 @@ template. The string can be an arbitrary C++ expression that evaluates into some
|
||||
C++ object expected at the `NativeCodeCall` site (here it would be expecting an
|
||||
array attribute). Typically the string should be a function call.
|
||||
|
||||
Note that currently `NativeCodeCall` must return no more than one value or
|
||||
attribute. This might change in the future.
|
||||
|
||||
##### `NativeCodeCall` placeholders
|
||||
|
||||
In `NativeCodeCall`, we can use placeholders like `$_builder`, `$N` and `$N...`.
|
||||
@@ -428,6 +425,30 @@ parameters at the `NativeCodeCall` use site. For example, if we define
|
||||
`SomeCall : NativeCodeCall<"someFn($1...)">` and use it like `(SomeCall $in0,
|
||||
$in1, $in2)`, then this will be translated into C++ call `someFn($in1, $in2)`.
|
||||
|
||||
##### `NativeCodeCall` binding multi-results
|
||||
|
||||
To bind multi-results and access the N-th result with `$<name>__N`, specify the
|
||||
number of return values in the template. Note that only `Value` type is
|
||||
supported for multiple results binding. For example,
|
||||
|
||||
```tablegen
|
||||
|
||||
def PackAttrs : NativeCodeCall<"packAttrs($0, $1)", 2>;
|
||||
def : Pattern<(TwoResultOp $attr1, $attr2),
|
||||
[(OneResultOp (PackAttr:$res__0, $attr1, $attr2)),
|
||||
(OneResultOp $res__1)]>;
|
||||
|
||||
```
|
||||
|
||||
Use `NativeCodeCallVoid` for case has no return value.
|
||||
|
||||
The correct number of returned value specified in NativeCodeCall is important.
|
||||
It will be used to verify the consistency of the number of result values.
|
||||
Additionally, `mlir-tblgen` will try to capture the return value of
|
||||
NativeCodeCall in the generated code so that it will trigger a later compilation
|
||||
error if a NativeCodeCall that doesn't return a result isn't labeled with 0
|
||||
returns.
|
||||
|
||||
##### Customizing entire op building
|
||||
|
||||
`NativeCodeCall` is not only limited to transforming arguments for building an
|
||||
|
||||
@@ -2565,11 +2565,20 @@ class Pat<dag pattern, dag result, list<dag> preds = [],
|
||||
// If used as a DAG node, i.e., `(NativeCodeCall<"..."> <arg0>, ..., <argN>)`,
|
||||
// then positional placeholders are also supported; placeholder `$N` in the
|
||||
// wrapped C++ expression will be replaced by `<argN>`.
|
||||
//
|
||||
// ## Bind multiple results
|
||||
//
|
||||
// To bind multi-results and access the N-th result with `$<name>__N`, specify
|
||||
// the number of return values in the template. Note that only `Value` type is
|
||||
// supported for multiple results binding.
|
||||
|
||||
class NativeCodeCall<string expr> {
|
||||
class NativeCodeCall<string expr, int returns = 1> {
|
||||
string expression = expr;
|
||||
int numReturns = returns;
|
||||
}
|
||||
|
||||
class NativeCodeCallVoid<string expr> : NativeCodeCall<expr, 0>;
|
||||
|
||||
def ConstantLikeMatcher : NativeCodeCall<"success(matchPattern($_self->getResult(0), m_Constant(&$0)))">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -100,6 +100,11 @@ public:
|
||||
// Precondition: isNativeCodeCall()
|
||||
StringRef getNativeCodeTemplate() const;
|
||||
|
||||
// Returns the number of values will be returned by the native helper
|
||||
// function.
|
||||
// Precondition: isNativeCodeCall()
|
||||
int getNumReturnsOfNativeCode() const;
|
||||
|
||||
// Returns the string associated with the leaf.
|
||||
// Precondition: isStringAttr()
|
||||
std::string getStringAttr() const;
|
||||
@@ -181,6 +186,11 @@ public:
|
||||
// Precondition: isNativeCodeCall()
|
||||
StringRef getNativeCodeTemplate() const;
|
||||
|
||||
// Returns the number of values will be returned by the native helper
|
||||
// function.
|
||||
// Precondition: isNativeCodeCall()
|
||||
int getNumReturnsOfNativeCode() const;
|
||||
|
||||
void print(raw_ostream &os) const;
|
||||
|
||||
private:
|
||||
@@ -242,30 +252,32 @@ public:
|
||||
|
||||
// DagNode and DagLeaf are accessed by value which means it can't be used as
|
||||
// identifier here. Use an opaque pointer type instead.
|
||||
using DagAndIndex = std::pair<const void *, int>;
|
||||
using DagAndConstant = std::pair<const void *, int>;
|
||||
|
||||
// What kind of entity this symbol represents:
|
||||
// * Attr: op attribute
|
||||
// * Operand: op operand
|
||||
// * Result: op result
|
||||
// * Value: a value not attached to an op (e.g., from NativeCodeCall)
|
||||
enum class Kind : uint8_t { Attr, Operand, Result, Value };
|
||||
// * MultipleValues: a pack of values not attached to an op (e.g., from
|
||||
// NativeCodeCall). This kind supports indexing.
|
||||
enum class Kind : uint8_t { Attr, Operand, Result, Value, MultipleValues };
|
||||
|
||||
// Creates a SymbolInfo instance. `dagAndIndex` is only used for `Attr` and
|
||||
// `Operand` so should be llvm::None for `Result` and `Value` kind.
|
||||
// Creates a SymbolInfo instance. `dagAndConstant` is only used for `Attr`
|
||||
// and `Operand` so should be llvm::None for `Result` and `Value` kind.
|
||||
SymbolInfo(const Operator *op, Kind kind,
|
||||
Optional<DagAndIndex> dagAndIndex);
|
||||
Optional<DagAndConstant> dagAndConstant);
|
||||
|
||||
// Static methods for creating SymbolInfo.
|
||||
static SymbolInfo getAttr(const Operator *op, int index) {
|
||||
return SymbolInfo(op, Kind::Attr, DagAndIndex(nullptr, index));
|
||||
return SymbolInfo(op, Kind::Attr, DagAndConstant(nullptr, index));
|
||||
}
|
||||
static SymbolInfo getAttr() {
|
||||
return SymbolInfo(nullptr, Kind::Attr, llvm::None);
|
||||
}
|
||||
static SymbolInfo getOperand(DagNode node, const Operator *op, int index) {
|
||||
return SymbolInfo(op, Kind::Operand,
|
||||
DagAndIndex(node.getAsOpaquePointer(), index));
|
||||
DagAndConstant(node.getAsOpaquePointer(), index));
|
||||
}
|
||||
static SymbolInfo getResult(const Operator *op) {
|
||||
return SymbolInfo(op, Kind::Result, llvm::None);
|
||||
@@ -273,6 +285,10 @@ public:
|
||||
static SymbolInfo getValue() {
|
||||
return SymbolInfo(nullptr, Kind::Value, llvm::None);
|
||||
}
|
||||
static SymbolInfo getMultipleValues(int numValues) {
|
||||
return SymbolInfo(nullptr, Kind::MultipleValues,
|
||||
DagAndConstant(nullptr, numValues));
|
||||
}
|
||||
|
||||
// Returns the number of static values this symbol corresponds to.
|
||||
// A static value is an operand/result declared in ODS. Normally a symbol
|
||||
@@ -298,13 +314,21 @@ public:
|
||||
std::string getAllRangeUse(StringRef name, int index, const char *fmt,
|
||||
const char *separator) const;
|
||||
|
||||
// The argument index (for `Attr` and `Operand` only)
|
||||
int getArgIndex() const { return (*dagAndConstant).second; }
|
||||
|
||||
// The number of values in the MultipleValue
|
||||
int getSize() const { return (*dagAndConstant).second; }
|
||||
|
||||
const Operator *op; // The op where the bound entity belongs
|
||||
Kind kind; // The kind of the bound entity
|
||||
// The pair of DagNode pointer and argument index (for `Attr` and `Operand`
|
||||
// only). Note that operands may be bound to the same symbol, use the
|
||||
// DagNode and index to distinguish them. For `Attr`, the Dag part will be
|
||||
// nullptr.
|
||||
Optional<DagAndIndex> dagAndIndex;
|
||||
|
||||
// The pair of DagNode pointer and constant value (for `Attr`, `Operand` and
|
||||
// the size of MultipleValue symbol). Note that operands may be bound to the
|
||||
// same symbol, use the DagNode and index to distinguish them. For `Attr`
|
||||
// and MultipleValue, the Dag part will be nullptr.
|
||||
Optional<DagAndConstant> dagAndConstant;
|
||||
|
||||
// Alternative name for the symbol. It is used in case the name
|
||||
// is not unique. Applicable for `Operand` only.
|
||||
Optional<std::string> alternativeName;
|
||||
@@ -331,10 +355,17 @@ public:
|
||||
// `symbol` is already bound.
|
||||
bool bindOpResult(StringRef symbol, const Operator &op);
|
||||
|
||||
// Registers the given `symbol` as bound to a value. Returns false if `symbol`
|
||||
// is already bound.
|
||||
// A helper function for dispatching target value binding functions.
|
||||
bool bindValues(StringRef symbol, int numValues = 1);
|
||||
|
||||
// Registers the given `symbol` as bound to the Value(s). Returns false if
|
||||
// `symbol` is already bound.
|
||||
bool bindValue(StringRef symbol);
|
||||
|
||||
// Registers the given `symbol` as bound to a MultipleValue. Return false if
|
||||
// `symbol` is already bound.
|
||||
bool bindMultipleValues(StringRef symbol, int numValues);
|
||||
|
||||
// Registers the given `symbol` as bound to an attr. Returns false if `symbol`
|
||||
// is already bound.
|
||||
bool bindAttr(StringRef symbol);
|
||||
|
||||
@@ -83,6 +83,11 @@ llvm::StringRef DagLeaf::getNativeCodeTemplate() const {
|
||||
return cast<llvm::DefInit>(def)->getDef()->getValueAsString("expression");
|
||||
}
|
||||
|
||||
int DagLeaf::getNumReturnsOfNativeCode() const {
|
||||
assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
|
||||
return cast<llvm::DefInit>(def)->getDef()->getValueAsInt("numReturns");
|
||||
}
|
||||
|
||||
std::string DagLeaf::getStringAttr() const {
|
||||
assert(isStringAttr() && "the DAG leaf must be string attribute");
|
||||
return def->getAsUnquotedString();
|
||||
@@ -119,6 +124,13 @@ llvm::StringRef DagNode::getNativeCodeTemplate() const {
|
||||
->getValueAsString("expression");
|
||||
}
|
||||
|
||||
int DagNode::getNumReturnsOfNativeCode() const {
|
||||
assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
|
||||
return cast<llvm::DefInit>(node->getOperator())
|
||||
->getDef()
|
||||
->getValueAsInt("numReturns");
|
||||
}
|
||||
|
||||
llvm::StringRef DagNode::getSymbol() const { return node->getNameStr(); }
|
||||
|
||||
Operator &DagNode::getDialectOp(RecordOperatorMap *mapper) const {
|
||||
@@ -193,8 +205,8 @@ StringRef SymbolInfoMap::getValuePackName(StringRef symbol, int *index) {
|
||||
}
|
||||
|
||||
SymbolInfoMap::SymbolInfo::SymbolInfo(const Operator *op, SymbolInfo::Kind kind,
|
||||
Optional<DagAndIndex> dagAndIndex)
|
||||
: op(op), kind(kind), dagAndIndex(dagAndIndex) {}
|
||||
Optional<DagAndConstant> dagAndConstant)
|
||||
: op(op), kind(kind), dagAndConstant(dagAndConstant) {}
|
||||
|
||||
int SymbolInfoMap::SymbolInfo::getStaticValueCount() const {
|
||||
switch (kind) {
|
||||
@@ -204,6 +216,8 @@ int SymbolInfoMap::SymbolInfo::getStaticValueCount() const {
|
||||
return 1;
|
||||
case Kind::Result:
|
||||
return op->getNumResults();
|
||||
case Kind::MultipleValues:
|
||||
return getSize();
|
||||
}
|
||||
llvm_unreachable("unknown kind");
|
||||
}
|
||||
@@ -217,7 +231,7 @@ std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
|
||||
switch (kind) {
|
||||
case Kind::Attr: {
|
||||
if (op) {
|
||||
auto type = op->getArg((*dagAndIndex).second)
|
||||
auto type = op->getArg(getArgIndex())
|
||||
.get<NamedAttribute *>()
|
||||
->attr.getStorageType();
|
||||
return std::string(formatv("{0} {1};\n", type, name));
|
||||
@@ -235,6 +249,14 @@ std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
|
||||
case Kind::Value: {
|
||||
return std::string(formatv("::mlir::Value {0};\n", name));
|
||||
}
|
||||
case Kind::MultipleValues: {
|
||||
// This is for the variable used in the source pattern. Each named value in
|
||||
// source pattern will only be bound to a Value. The others in the result
|
||||
// pattern may be associated with multiple Values as we will use `auto` to
|
||||
// do the type inference.
|
||||
return std::string(formatv(
|
||||
"::mlir::Value {0}_raw; ::mlir::ValueRange {0}({0}_raw);\n", name));
|
||||
}
|
||||
case Kind::Result: {
|
||||
// Use the op itself for captured results.
|
||||
return std::string(formatv("{0} {1};\n", op->getQualCppClassName(), name));
|
||||
@@ -255,8 +277,7 @@ std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
|
||||
}
|
||||
case Kind::Operand: {
|
||||
assert(index < 0);
|
||||
auto *operand =
|
||||
op->getArg((*dagAndIndex).second).get<NamedTypeConstraint *>();
|
||||
auto *operand = op->getArg(getArgIndex()).get<NamedTypeConstraint *>();
|
||||
// If this operand is variadic, then return a range. Otherwise, return the
|
||||
// value itself.
|
||||
if (operand->isVariableLength()) {
|
||||
@@ -311,6 +332,21 @@ std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
|
||||
LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n");
|
||||
return std::string(repl);
|
||||
}
|
||||
case Kind::MultipleValues: {
|
||||
assert(op == nullptr);
|
||||
assert(index < getSize());
|
||||
if (index >= 0) {
|
||||
std::string repl =
|
||||
formatv(fmt, std::string(formatv("{0}[{1}]", name, index)));
|
||||
LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n");
|
||||
return repl;
|
||||
}
|
||||
// If it doesn't specify certain element, unpack them all.
|
||||
auto repl =
|
||||
formatv(fmt, std::string(formatv("{0}.begin(), {0}.end()", name)));
|
||||
LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n");
|
||||
return std::string(repl);
|
||||
}
|
||||
}
|
||||
llvm_unreachable("unknown kind");
|
||||
}
|
||||
@@ -353,6 +389,20 @@ std::string SymbolInfoMap::SymbolInfo::getAllRangeUse(
|
||||
LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n");
|
||||
return std::string(repl);
|
||||
}
|
||||
case Kind::MultipleValues: {
|
||||
assert(op == nullptr);
|
||||
assert(index < getSize());
|
||||
if (index >= 0) {
|
||||
std::string repl =
|
||||
formatv(fmt, std::string(formatv("{0}[{1}]", name, index)));
|
||||
LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n");
|
||||
return repl;
|
||||
}
|
||||
auto repl =
|
||||
formatv(fmt, std::string(formatv("{0}.begin(), {0}.end()", name)));
|
||||
LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n");
|
||||
return std::string(repl);
|
||||
}
|
||||
}
|
||||
llvm_unreachable("unknown kind");
|
||||
}
|
||||
@@ -395,11 +445,25 @@ bool SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) {
|
||||
return symbolInfoMap.count(inserted->first) == 1;
|
||||
}
|
||||
|
||||
bool SymbolInfoMap::bindValues(StringRef symbol, int numValues) {
|
||||
std::string name = getValuePackName(symbol).str();
|
||||
if (numValues > 1)
|
||||
return bindMultipleValues(name, numValues);
|
||||
return bindValue(name);
|
||||
}
|
||||
|
||||
bool SymbolInfoMap::bindValue(StringRef symbol) {
|
||||
auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getValue());
|
||||
return symbolInfoMap.count(inserted->first) == 1;
|
||||
}
|
||||
|
||||
bool SymbolInfoMap::bindMultipleValues(StringRef symbol, int numValues) {
|
||||
std::string name = getValuePackName(symbol).str();
|
||||
auto inserted =
|
||||
symbolInfoMap.emplace(name, SymbolInfo::getMultipleValues(numValues));
|
||||
return symbolInfoMap.count(inserted->first) == 1;
|
||||
}
|
||||
|
||||
bool SymbolInfoMap::bindAttr(StringRef symbol) {
|
||||
auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getAttr());
|
||||
return symbolInfoMap.count(inserted->first) == 1;
|
||||
@@ -423,11 +487,9 @@ SymbolInfoMap::findBoundSymbol(StringRef key, DagNode node, const Operator &op,
|
||||
|
||||
const auto symbolInfo = SymbolInfo::getOperand(node, &op, argIndex);
|
||||
|
||||
for (auto it = range.first; it != range.second; ++it) {
|
||||
if (it->second.dagAndIndex == symbolInfo.dagAndIndex) {
|
||||
for (auto it = range.first; it != range.second; ++it)
|
||||
if (it->second.dagAndConstant == symbolInfo.dagAndConstant)
|
||||
return it;
|
||||
}
|
||||
}
|
||||
|
||||
return symbolInfoMap.end();
|
||||
}
|
||||
@@ -633,7 +695,9 @@ void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
|
||||
if (!isSrcPattern) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "found symbol bound to NativeCodeCall: "
|
||||
<< treeName << '\n');
|
||||
verifyBind(infoMap.bindValue(treeName), treeName);
|
||||
verifyBind(
|
||||
infoMap.bindValues(treeName, tree.getNumReturnsOfNativeCode()),
|
||||
treeName);
|
||||
} else {
|
||||
PrintFatalError(&def,
|
||||
formatv("binding symbol '{0}' to NativecodeCall in "
|
||||
|
||||
@@ -857,7 +857,7 @@ def OpNativeCodeCall3 : TEST_Op<"native_code_call3"> {
|
||||
// Test that NativeCodeCall is not ignored if it is not used to directly
|
||||
// replace the matched root op.
|
||||
def : Pattern<(OpNativeCodeCall3 $input),
|
||||
[(NativeCodeCall<"createOpI($_builder, $_loc, $0)"> $input),
|
||||
[(NativeCodeCallVoid<"createOpI($_builder, $_loc, $0)"> $input),
|
||||
(OpK)]>;
|
||||
|
||||
def OpNativeCodeCall4 : TEST_Op<"native_code_call4"> {
|
||||
@@ -874,6 +874,19 @@ def BindNativeCodeCallResult : NativeCodeCall<"bindNativeCodeCallResult($0)">;
|
||||
def : Pat<(OpNativeCodeCall4 (GetFirstI32Result $ret)),
|
||||
(OpNativeCodeCall5 (BindNativeCodeCallResult:$native $ret), $native)>;
|
||||
|
||||
def OpNativeCodeCall6 : TEST_Op<"native_code_call6"> {
|
||||
let arguments = (ins I32:$input1, I32:$input2);
|
||||
let results = (outs I32:$output1, I32:$output2);
|
||||
}
|
||||
def OpNativeCodeCall7 : TEST_Op<"native_code_call7"> {
|
||||
let arguments = (ins I32:$input);
|
||||
let results = (outs I32);
|
||||
}
|
||||
def BindMultipleNativeCodeCallResult : NativeCodeCall<"bindMultipleNativeCodeCallResult($0, $1)", 2>;
|
||||
def : Pattern<(OpNativeCodeCall6 $arg1, $arg2),
|
||||
[(OpNativeCodeCall7 (BindMultipleNativeCodeCallResult:$native__0 $arg1, $arg2)),
|
||||
(OpNativeCodeCall7 $native__1)]>;
|
||||
|
||||
// Test AllAttrConstraintsOf.
|
||||
def OpAllAttrConstraint1 : TEST_Op<"all_attr_constraint_of1"> {
|
||||
let arguments = (ins I64ArrayAttr:$attr);
|
||||
@@ -1033,7 +1046,7 @@ def OpSymbolBindingNoResult : TEST_Op<"symbol_binding_no_result", []> {
|
||||
|
||||
// Test that we can bind to an op without results and reference it later.
|
||||
def : Pat<(OpSymbolBindingNoResult:$op $operand),
|
||||
(NativeCodeCall<"handleNoResultOp($_builder, $0)"> $op)>;
|
||||
(NativeCodeCallVoid<"handleNoResultOp($_builder, $0)"> $op)>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Test Patterns (Attributes)
|
||||
|
||||
@@ -44,6 +44,11 @@ static bool getFirstI32Result(Operation *op, Value &value) {
|
||||
|
||||
static Value bindNativeCodeCallResult(Value value) { return value; }
|
||||
|
||||
static SmallVector<Value, 2> bindMultipleNativeCodeCallResult(Value input1,
|
||||
Value input2) {
|
||||
return SmallVector<Value, 2>({input2, input1});
|
||||
}
|
||||
|
||||
// Test that natives calls are only called once during rewrites.
|
||||
// OpM_Test will return Pi, increased by 1 for each subsequent calls.
|
||||
// This let us check the number of times OpM_Test was called by inspecting
|
||||
|
||||
@@ -102,6 +102,16 @@ func @verifyNativeCodeCallBinding(%arg0 : i32) -> (i32) {
|
||||
return %1 : i32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: verifyMultipleNativeCodeCallBinding
|
||||
func@verifyMultipleNativeCodeCallBinding(%arg0 : i32) -> (i32) {
|
||||
%0 = "test.op_k"() : () -> (i32)
|
||||
%1 = "test.op_k"() : () -> (i32)
|
||||
// CHECK: %[[A:.*]] = "test.native_code_call7"(%1) : (i32) -> i32
|
||||
// CHECK: %[[A:.*]] = "test.native_code_call7"(%0) : (i32) -> i32
|
||||
%2, %3 = "test.native_code_call6"(%0, %1) : (i32, i32) -> (i32, i32)
|
||||
return %2 : i32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: verifyAllAttrConstraintOf
|
||||
func @verifyAllAttrConstraintOf() -> (i32, i32, i32) {
|
||||
// CHECK: "test.all_attr_constraint_of2"
|
||||
|
||||
@@ -754,7 +754,8 @@ void PatternEmitter::emitRewriteLogic() {
|
||||
// NativeCodeCall will only be materialized to `os` if it is used. Here
|
||||
// we are handling auxiliary patterns so we want the side effect even if
|
||||
// NativeCodeCall is not replacing matched root op's results.
|
||||
if (resultTree.isNativeCodeCall())
|
||||
if (resultTree.isNativeCodeCall() &&
|
||||
resultTree.getNumReturnsOfNativeCode() == 0)
|
||||
os << val << ";\n";
|
||||
}
|
||||
|
||||
@@ -804,11 +805,8 @@ std::string PatternEmitter::handleResultPattern(DagNode resultTree,
|
||||
"location directive can only be used with op creation");
|
||||
}
|
||||
|
||||
if (resultTree.isNativeCodeCall()) {
|
||||
auto symbol = handleReplaceWithNativeCodeCall(resultTree, depth);
|
||||
symbolInfoMap.bindValue(symbol);
|
||||
return symbol;
|
||||
}
|
||||
if (resultTree.isNativeCodeCall())
|
||||
return handleReplaceWithNativeCodeCall(resultTree, depth);
|
||||
|
||||
if (resultTree.isReplaceWithValue())
|
||||
return handleReplaceWithValue(resultTree).str();
|
||||
@@ -948,9 +946,39 @@ std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree,
|
||||
}
|
||||
|
||||
std::string symbol = tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse), attrs);
|
||||
if (!tree.getSymbol().empty()) {
|
||||
os << formatv("auto {0} = {1};\n", tree.getSymbol(), symbol);
|
||||
symbol = tree.getSymbol().str();
|
||||
|
||||
// In general, NativeCodeCall without naming binding don't need this. To
|
||||
// ensure void helper function has been correctly labeled, i.e., use
|
||||
// NativeCodeCallVoid, we cache the result to a local variable so that we will
|
||||
// get a compilation error in the auto-generated file.
|
||||
// Example.
|
||||
// // In the td file
|
||||
// Pat<(...), (NativeCodeCall<Foo> ...)>
|
||||
//
|
||||
// ---
|
||||
//
|
||||
// // In the auto-generated .cpp
|
||||
// ...
|
||||
// // Causes compilation error if Foo() returns void.
|
||||
// auto nativeVar = Foo();
|
||||
// ...
|
||||
if (tree.getNumReturnsOfNativeCode() != 0) {
|
||||
// Determine the local variable name for return value.
|
||||
std::string varName =
|
||||
SymbolInfoMap::getValuePackName(tree.getSymbol()).str();
|
||||
if (varName.empty()) {
|
||||
varName = formatv("nativeVar_{0}", nextValueId++);
|
||||
// Register the local variable for later uses.
|
||||
symbolInfoMap.bindValues(varName, tree.getNumReturnsOfNativeCode());
|
||||
}
|
||||
|
||||
// Catch the return value of helper function.
|
||||
os << formatv("auto {0} = {1}; (void){0};\n", varName, symbol);
|
||||
|
||||
if (!tree.getSymbol().empty())
|
||||
symbol = tree.getSymbol().str();
|
||||
else
|
||||
symbol = varName;
|
||||
}
|
||||
|
||||
return symbol;
|
||||
@@ -967,8 +995,10 @@ int PatternEmitter::getNodeValueCount(DagNode node) {
|
||||
// Otherwise this is an unbound op; we will use all its results.
|
||||
return pattern.getDialectOp(node).getNumResults();
|
||||
}
|
||||
// TODO: This considers all NativeCodeCall as returning one
|
||||
// value. Enhance if multi-value ones are needed.
|
||||
|
||||
if (node.isNativeCodeCall())
|
||||
return node.getNumReturnsOfNativeCode();
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
@@ -1191,8 +1221,7 @@ void PatternEmitter::supplyValuesForOpArgs(
|
||||
if (!subTree.isNativeCodeCall())
|
||||
PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
|
||||
"for creating attribute");
|
||||
os << formatv("/*{0}=*/{1}", opArgName,
|
||||
handleReplaceWithNativeCodeCall(subTree, depth));
|
||||
os << formatv("/*{0}=*/{1}", opArgName, childNodeNames.lookup(argIndex));
|
||||
} else {
|
||||
auto leaf = node.getArgAsLeaf(argIndex);
|
||||
// The argument in the result DAG pattern.
|
||||
@@ -1233,8 +1262,7 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
|
||||
if (!subTree.isNativeCodeCall())
|
||||
PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
|
||||
"for creating attribute");
|
||||
os << formatv(addAttrCmd, opArgName,
|
||||
handleReplaceWithNativeCodeCall(subTree, depth + 1));
|
||||
os << formatv(addAttrCmd, opArgName, childNodeNames.lookup(argIndex));
|
||||
} else {
|
||||
auto leaf = node.getArgAsLeaf(argIndex);
|
||||
// The argument in the result DAG pattern.
|
||||
|
||||
Reference in New Issue
Block a user