mirror of
https://github.com/intel/llvm.git
synced 2026-01-24 00:20:25 +08:00
Add support for function result attributes.
This allows dialect-specific attributes to be attached to func results. (or more specifically, FunctionLike ops).
For example:
```
func @f() -> (i32 {my_dialect.some_attr = 3})
```
This attaches my_dialect.some_attr with value 3 to the first result of func @f.
Another more complex example:
```
func @g() -> (i32, f32 {my_dialect.some_attr = "foo", other_dialect.some_other_attr = [1,2,3]}, i1)
```
Here, the second result has two attributes attached.
PiperOrigin-RevId: 275564165
This commit is contained in:
committed by
A. Unique TensorFlower
parent
9e7e297da3
commit
9c9a7e9268
@@ -321,11 +321,19 @@ name via a string attribute like [SymbolRefAttr](#symbol-reference-attribute)):
|
||||
function ::= `func` function-signature function-attributes? function-body?
|
||||
|
||||
function-signature ::= symbol-ref-id `(` argument-list `)`
|
||||
(`->` function-result-type)?
|
||||
(`->` function-result-list)?
|
||||
|
||||
argument-list ::= (named-argument (`,` named-argument)*) | /*empty*/
|
||||
argument-list ::= (type attribute-dict? (`,` type attribute-dict?)*) | /*empty*/
|
||||
named-argument ::= ssa-id `:` type attribute-dict?
|
||||
|
||||
function-result-list ::= function-result-list-parens
|
||||
| non-function-type
|
||||
function-result-list-parens ::= `(` `)`
|
||||
| `(` function-result-list-no-parens `)`
|
||||
function-result-list-no-parens ::= function-result (`,` function-result)*
|
||||
function-result ::= type attribute-dict?
|
||||
|
||||
function-attributes ::= `attributes` attribute-dict
|
||||
function-body ::= region
|
||||
```
|
||||
|
||||
@@ -546,6 +546,10 @@ def LLVM_LLVMFuncOp : LLVM_ZeroResultOp<"func",
|
||||
// Depends on the type attribute being correct as checked by verifyType.
|
||||
unsigned getNumFuncArguments();
|
||||
|
||||
// Hook for OpTrait::FunctionLike, returns the number of function results.
|
||||
// Depends on the type attribute being correct as checked by verifyType.
|
||||
unsigned getNumFuncResults();
|
||||
|
||||
// Hook for OpTrait::FunctionLike, called after verifying that the 'type'
|
||||
// attribute is present. This can check for preconditions of the
|
||||
// getNumArguments hook not failing.
|
||||
|
||||
@@ -146,6 +146,15 @@ public:
|
||||
unsigned argIndex,
|
||||
NamedAttribute);
|
||||
|
||||
/// Verify an attribute from this dialect on the result at 'resultIndex' for
|
||||
/// the region at 'regionIndex' on the given operation. Returns failure if
|
||||
/// the verification failed, success otherwise. This hook may optionally be
|
||||
/// invoked from any operation containing a region.
|
||||
virtual LogicalResult verifyRegionResultAttribute(Operation *,
|
||||
unsigned regionIndex,
|
||||
unsigned resultIndex,
|
||||
NamedAttribute);
|
||||
|
||||
/// Verify an attribute from this dialect on the given operation. Returns
|
||||
/// failure if the verification failed, success otherwise.
|
||||
virtual LogicalResult verifyOperationAttribute(Operation *, NamedAttribute) {
|
||||
|
||||
@@ -136,16 +136,18 @@ public:
|
||||
}
|
||||
|
||||
private:
|
||||
// This trait needs access to `getNumFuncArguments` and `verifyType` hooks
|
||||
// defined below.
|
||||
// This trait needs access to the hooks defined below.
|
||||
friend class OpTrait::FunctionLike<FuncOp>;
|
||||
|
||||
/// Returns the number of arguments. This is a hook for OpTrait::FunctionLike.
|
||||
unsigned getNumFuncArguments() { return getType().getInputs().size(); }
|
||||
|
||||
/// Returns the number of results. This is a hook for OpTrait::FunctionLike.
|
||||
unsigned getNumFuncResults() { return getType().getResults().size(); }
|
||||
|
||||
/// Hook for OpTrait::FunctionLike, called after verifying that the 'type'
|
||||
/// attribute is present and checks if it holds a function type. Ensures
|
||||
/// getType and getNumFuncArguments can be called safely.
|
||||
/// getType, getNumFuncArguments, and getNumFuncResults can be called safely.
|
||||
LogicalResult verifyType() {
|
||||
auto type = getTypeAttr().getValue();
|
||||
if (!type.isa<FunctionType>())
|
||||
|
||||
@@ -39,6 +39,12 @@ inline StringRef getArgAttrName(unsigned arg, SmallVectorImpl<char> &out) {
|
||||
return ("arg" + Twine(arg)).toStringRef(out);
|
||||
}
|
||||
|
||||
/// Return the name of the attribute used for function results.
|
||||
inline StringRef getResultAttrName(unsigned arg, SmallVectorImpl<char> &out) {
|
||||
out.clear();
|
||||
return ("result" + Twine(arg)).toStringRef(out);
|
||||
}
|
||||
|
||||
/// Returns the dictionary attribute corresponding to the argument at 'index'.
|
||||
/// If there are no argument attributes at 'index', a null attribute is
|
||||
/// returned.
|
||||
@@ -47,12 +53,26 @@ inline DictionaryAttr getArgAttrDict(Operation *op, unsigned index) {
|
||||
return op->getAttrOfType<DictionaryAttr>(getArgAttrName(index, nameOut));
|
||||
}
|
||||
|
||||
/// Returns the dictionary attribute corresponding to the result at 'index'.
|
||||
/// If there are no result attributes at 'index', a null attribute is
|
||||
/// returned.
|
||||
inline DictionaryAttr getResultAttrDict(Operation *op, unsigned index) {
|
||||
SmallString<8> nameOut;
|
||||
return op->getAttrOfType<DictionaryAttr>(getResultAttrName(index, nameOut));
|
||||
}
|
||||
|
||||
/// Return all of the attributes for the argument at 'index'.
|
||||
inline ArrayRef<NamedAttribute> getArgAttrs(Operation *op, unsigned index) {
|
||||
auto argDict = getArgAttrDict(op, index);
|
||||
return argDict ? argDict.getValue() : llvm::None;
|
||||
}
|
||||
|
||||
/// Return all of the attributes for the result at 'index'.
|
||||
inline ArrayRef<NamedAttribute> getResultAttrs(Operation *op, unsigned index) {
|
||||
auto resultDict = getResultAttrDict(op, index);
|
||||
return resultDict ? resultDict.getValue() : llvm::None;
|
||||
}
|
||||
|
||||
/// A named class for passing around the variadic flag.
|
||||
class VariadicFlag {
|
||||
public:
|
||||
@@ -87,7 +107,7 @@ ParseResult parseFunctionLikeOp(OpAsmParser &parser, OperationState &result,
|
||||
/// argument and result types to use while printing.
|
||||
void printFunctionLikeOp(OpAsmPrinter &p, Operation *op,
|
||||
ArrayRef<Type> argTypes, bool isVariadic,
|
||||
ArrayRef<Type> results);
|
||||
ArrayRef<Type> resultTypes);
|
||||
|
||||
} // namespace impl
|
||||
|
||||
@@ -111,10 +131,13 @@ namespace OpTrait {
|
||||
/// - Concrete ops *must* define a member function `getNumFuncArguments()` that
|
||||
/// returns the number of function arguments based exclusively on type (so that
|
||||
/// it can be called on function declarations).
|
||||
/// - Concrete ops *must* define a member function `getNumFuncResults()` that
|
||||
/// returns the number of function results based exclusively on type (so that
|
||||
/// it can be called on function declarations).
|
||||
/// - To verify that the type respects op-specific invariants, concrete ops may
|
||||
/// redefine the `verifyType()` hook that will be called after verifying the
|
||||
/// presence of the `type` attribute and before any call to
|
||||
/// `getNumFuncArguments` from the verifier.
|
||||
/// `getNumFuncArguments`/`getNumFuncResults` from the verifier.
|
||||
template <typename ConcreteType>
|
||||
class FunctionLike : public OpTrait::TraitBase<ConcreteType, FunctionLike> {
|
||||
public:
|
||||
@@ -202,6 +225,10 @@ public:
|
||||
return static_cast<ConcreteType *>(this)->getNumFuncArguments();
|
||||
}
|
||||
|
||||
unsigned getNumResults() {
|
||||
return static_cast<ConcreteType *>(this)->getNumFuncResults();
|
||||
}
|
||||
|
||||
/// Gets argument.
|
||||
BlockArgument *getArgument(unsigned idx) {
|
||||
return getBlocks().front().getArgument(idx);
|
||||
@@ -278,11 +305,75 @@ public:
|
||||
NamedAttributeList::RemoveResult removeArgAttr(unsigned index,
|
||||
Identifier name);
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Result Attributes
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
/// FunctionLike operations allow for attaching attributes to each of the
|
||||
/// respective function results. These result attributes are stored as
|
||||
/// DictionaryAttrs in the main operation attribute dictionary. The name of
|
||||
/// these entries is `result` followed by the index of the result. These
|
||||
/// result attribute dictionaries are optional, and will generally only
|
||||
/// exist if they are non-empty.
|
||||
|
||||
/// Return all of the attributes for the result at 'index'.
|
||||
ArrayRef<NamedAttribute> getResultAttrs(unsigned index) {
|
||||
return ::mlir::impl::getResultAttrs(this->getOperation(), index);
|
||||
}
|
||||
|
||||
/// Return all result attributes of this function.
|
||||
void getAllResultAttrs(SmallVectorImpl<NamedAttributeList> &result) {
|
||||
for (unsigned i = 0, e = getNumResults(); i != e; ++i)
|
||||
result.emplace_back(getResultAttrDict(i));
|
||||
}
|
||||
|
||||
/// Return the specified attribute, if present, for the result at 'index',
|
||||
/// null otherwise.
|
||||
Attribute getResultAttr(unsigned index, Identifier name) {
|
||||
auto argDict = getResultAttrDict(index);
|
||||
return argDict ? argDict.get(name) : nullptr;
|
||||
}
|
||||
Attribute getResultAttr(unsigned index, StringRef name) {
|
||||
auto argDict = getResultAttrDict(index);
|
||||
return argDict ? argDict.get(name) : nullptr;
|
||||
}
|
||||
|
||||
template <typename AttrClass>
|
||||
AttrClass getResultAttrOfType(unsigned index, Identifier name) {
|
||||
return getResultAttr(index, name).template dyn_cast_or_null<AttrClass>();
|
||||
}
|
||||
template <typename AttrClass>
|
||||
AttrClass getResultAttrOfType(unsigned index, StringRef name) {
|
||||
return getResultAttr(index, name).template dyn_cast_or_null<AttrClass>();
|
||||
}
|
||||
|
||||
/// Set the attributes held by the result at 'index'.
|
||||
void setResultAttrs(unsigned index, ArrayRef<NamedAttribute> attributes);
|
||||
void setResultAttrs(unsigned index, NamedAttributeList attributes);
|
||||
void setAllResultAttrs(ArrayRef<NamedAttributeList> attributes) {
|
||||
assert(attributes.size() == getNumResults());
|
||||
for (unsigned i = 0, e = attributes.size(); i != e; ++i)
|
||||
setResultAttrs(i, attributes[i]);
|
||||
}
|
||||
|
||||
/// If the an attribute exists with the specified name, change it to the new
|
||||
/// value. Otherwise, add a new attribute with the specified name/value.
|
||||
void setResultAttr(unsigned index, Identifier name, Attribute value);
|
||||
void setResultAttr(unsigned index, StringRef name, Attribute value) {
|
||||
setResultAttr(index,
|
||||
Identifier::get(name, this->getOperation()->getContext()),
|
||||
value);
|
||||
}
|
||||
|
||||
/// Remove the attribute 'name' from the result at 'index'.
|
||||
NamedAttributeList::RemoveResult removeResultAttr(unsigned index,
|
||||
Identifier name);
|
||||
|
||||
protected:
|
||||
/// Returns the attribute entry name for the set of argument attributes at
|
||||
/// index 'arg'.
|
||||
static StringRef getArgAttrName(unsigned arg, SmallVectorImpl<char> &out) {
|
||||
return ::mlir::impl::getArgAttrName(arg, out);
|
||||
/// 'index'.
|
||||
static StringRef getArgAttrName(unsigned index, SmallVectorImpl<char> &out) {
|
||||
return ::mlir::impl::getArgAttrName(index, out);
|
||||
}
|
||||
|
||||
/// Returns the dictionary attribute corresponding to the argument at 'index'.
|
||||
@@ -293,6 +384,21 @@ protected:
|
||||
return ::mlir::impl::getArgAttrDict(this->getOperation(), index);
|
||||
}
|
||||
|
||||
/// Returns the attribute entry name for the set of result attributes at
|
||||
/// 'index'.
|
||||
static StringRef getResultAttrName(unsigned index,
|
||||
SmallVectorImpl<char> &out) {
|
||||
return ::mlir::impl::getResultAttrName(index, out);
|
||||
}
|
||||
|
||||
/// Returns the dictionary attribute corresponding to the result at 'index'.
|
||||
/// If there are no result attributes at 'index', a null attribute is
|
||||
/// returned.
|
||||
DictionaryAttr getResultAttrDict(unsigned index) {
|
||||
assert(index < getNumResults() && "invalid result number");
|
||||
return ::mlir::impl::getResultAttrDict(this->getOperation(), index);
|
||||
}
|
||||
|
||||
/// Hook for concrete classes to verify that the type attribute respects
|
||||
/// op-specific invariants. Default implementation always succeeds.
|
||||
LogicalResult verifyType() { return success(); }
|
||||
@@ -326,6 +432,23 @@ LogicalResult FunctionLike<ConcreteType>::verifyTrait(Operation *op) {
|
||||
}
|
||||
}
|
||||
|
||||
for (unsigned i = 0, e = funcOp.getNumResults(); i != e; ++i) {
|
||||
// Verify that all of the result attributes are dialect attributes, i.e.
|
||||
// that they contain a dialect prefix in their name. Call the dialect, if
|
||||
// registered, to verify the attributes themselves.
|
||||
for (auto attr : funcOp.getResultAttrs(i)) {
|
||||
if (!attr.first.strref().contains('.'))
|
||||
return funcOp.emitOpError("results may only have dialect attributes");
|
||||
auto dialectNamePair = attr.first.strref().split('.');
|
||||
if (auto *dialect = ctx->getRegisteredDialect(dialectNamePair.first)) {
|
||||
if (failed(dialect->verifyRegionResultAttribute(op, /*regionIndex=*/0,
|
||||
/*resultIndex=*/i,
|
||||
attr)))
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check that the op has exactly one region for the body.
|
||||
if (op->getNumRegions() != 1)
|
||||
return funcOp.emitOpError("expects one region");
|
||||
@@ -354,10 +477,10 @@ void FunctionLike<ConcreteType>::setArgAttrs(
|
||||
assert(index < getNumArguments() && "invalid argument number");
|
||||
SmallString<8> nameOut;
|
||||
getArgAttrName(index, nameOut);
|
||||
Operation *op = this->getOperation();
|
||||
|
||||
if (attributes.empty())
|
||||
return (void)static_cast<ConcreteType *>(this)->removeAttr(nameOut);
|
||||
Operation *op = this->getOperation();
|
||||
op->setAttr(nameOut, DictionaryAttr::get(attributes, op->getContext()));
|
||||
}
|
||||
|
||||
@@ -400,6 +523,64 @@ FunctionLike<ConcreteType>::removeArgAttr(unsigned index, Identifier name) {
|
||||
return result;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Function Result Attribute.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Set the attributes held by the result at 'index'.
|
||||
template <typename ConcreteType>
|
||||
void FunctionLike<ConcreteType>::setResultAttrs(
|
||||
unsigned index, ArrayRef<NamedAttribute> attributes) {
|
||||
assert(index < getNumResults() && "invalid result number");
|
||||
SmallString<8> nameOut;
|
||||
getResultAttrName(index, nameOut);
|
||||
|
||||
if (attributes.empty())
|
||||
return (void)static_cast<ConcreteType *>(this)->removeAttr(nameOut);
|
||||
Operation *op = this->getOperation();
|
||||
op->setAttr(nameOut, DictionaryAttr::get(attributes, op->getContext()));
|
||||
}
|
||||
|
||||
template <typename ConcreteType>
|
||||
void FunctionLike<ConcreteType>::setResultAttrs(unsigned index,
|
||||
NamedAttributeList attributes) {
|
||||
assert(index < getNumResults() && "invalid result number");
|
||||
SmallString<8> nameOut;
|
||||
if (auto newAttr = attributes.getDictionary())
|
||||
return this->getOperation()->setAttr(getResultAttrName(index, nameOut),
|
||||
newAttr);
|
||||
static_cast<ConcreteType *>(this)->removeAttr(
|
||||
getResultAttrName(index, nameOut));
|
||||
}
|
||||
|
||||
/// If the an attribute exists with the specified name, change it to the new
|
||||
/// value. Otherwise, add a new attribute with the specified name/value.
|
||||
template <typename ConcreteType>
|
||||
void FunctionLike<ConcreteType>::setResultAttr(unsigned index, Identifier name,
|
||||
Attribute value) {
|
||||
auto curAttr = getResultAttrDict(index);
|
||||
NamedAttributeList attrList(curAttr);
|
||||
attrList.set(name, value);
|
||||
|
||||
// If the attribute changed, then set the new arg attribute list.
|
||||
if (curAttr != attrList.getDictionary())
|
||||
setResultAttrs(index, attrList);
|
||||
}
|
||||
|
||||
/// Remove the attribute 'name' from the result at 'index'.
|
||||
template <typename ConcreteType>
|
||||
NamedAttributeList::RemoveResult
|
||||
FunctionLike<ConcreteType>::removeResultAttr(unsigned index, Identifier name) {
|
||||
// Build an attribute list and remove the attribute at 'name'.
|
||||
NamedAttributeList attrList(getResultAttrDict(index));
|
||||
auto result = attrList.remove(name);
|
||||
|
||||
// If the attribute was removed, then update the result dictionary.
|
||||
if (result == NamedAttributeList::RemoveResult::Removed)
|
||||
setResultAttrs(index, attrList);
|
||||
return result;
|
||||
}
|
||||
|
||||
} // end namespace OpTrait
|
||||
|
||||
} // end namespace mlir
|
||||
|
||||
@@ -1115,6 +1115,21 @@ unsigned LLVMFuncOp::getNumFuncArguments() {
|
||||
return getType().getUnderlyingType()->getFunctionNumParams();
|
||||
}
|
||||
|
||||
// Hook for OpTrait::FunctionLike, returns the number of function results.
|
||||
// Depends on the type attribute being correct as checked by verifyType
|
||||
unsigned LLVMFuncOp::getNumFuncResults() {
|
||||
llvm::FunctionType *funcType =
|
||||
cast<llvm::FunctionType>(getType().getUnderlyingType());
|
||||
// We model LLVM functions that return void as having zero results,
|
||||
// and all others as having one result.
|
||||
// If we modeled a void return as one result, then it would be possible to
|
||||
// attach an MLIR result attribute to it, and it isn't clear what semantics we
|
||||
// would assign to that.
|
||||
if (funcType->getReturnType()->isVoidTy())
|
||||
return 0;
|
||||
return 1;
|
||||
}
|
||||
|
||||
static LogicalResult verify(LLVMFuncOp op) {
|
||||
if (op.isExternal())
|
||||
return success();
|
||||
|
||||
@@ -89,6 +89,15 @@ LogicalResult Dialect::verifyRegionArgAttribute(Operation *, unsigned, unsigned,
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Verify an attribute from this dialect on the result at 'resultIndex' for
|
||||
/// the region at 'regionIndex' on the given operation. Returns failure if
|
||||
/// the verification failed, success otherwise. This hook may optionally be
|
||||
/// invoked from any operation containing a region.
|
||||
LogicalResult Dialect::verifyRegionResultAttribute(Operation *, unsigned,
|
||||
unsigned, NamedAttribute) {
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Parse an attribute registered to this dialect.
|
||||
Attribute Dialect::parseAttribute(StringRef attrData, Type type,
|
||||
Location loc) const {
|
||||
|
||||
@@ -88,6 +88,45 @@ parseArgumentList(OpAsmParser &parser, bool allowVariadic,
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Parse a function result list.
|
||||
///
|
||||
/// function-result-list ::= function-result-list-parens
|
||||
/// | non-function-type
|
||||
/// function-result-list-parens ::= `(` `)`
|
||||
/// | `(` function-result-list-no-parens `)`
|
||||
/// function-result-list-no-parens ::= function-result (`,` function-result)*
|
||||
/// function-result ::= type attribute-dict?
|
||||
///
|
||||
static ParseResult parseFunctionResultList(
|
||||
OpAsmParser &parser, SmallVectorImpl<Type> &resultTypes,
|
||||
SmallVectorImpl<SmallVector<NamedAttribute, 2>> &resultAttrs) {
|
||||
if (failed(parser.parseOptionalLParen())) {
|
||||
// We already know that there is no `(`, so parse a type.
|
||||
// Because there is no `(`, it cannot be a function type.
|
||||
Type ty;
|
||||
if (parser.parseType(ty))
|
||||
return failure();
|
||||
resultTypes.push_back(ty);
|
||||
resultAttrs.emplace_back();
|
||||
return success();
|
||||
}
|
||||
|
||||
// Special case for an empty set of parens.
|
||||
if (succeeded(parser.parseOptionalRParen()))
|
||||
return success();
|
||||
|
||||
// Parse individual function results.
|
||||
do {
|
||||
resultTypes.emplace_back();
|
||||
resultAttrs.emplace_back();
|
||||
if (parser.parseType(resultTypes.back()) ||
|
||||
parser.parseOptionalAttributeDict(resultAttrs.back())) {
|
||||
return failure();
|
||||
}
|
||||
} while (succeeded(parser.parseOptionalComma()));
|
||||
return parser.parseRParen();
|
||||
}
|
||||
|
||||
/// Parse a function signature, starting with a name and including the
|
||||
/// parameter list.
|
||||
static ParseResult parseFunctionSignature(
|
||||
@@ -95,12 +134,14 @@ static ParseResult parseFunctionSignature(
|
||||
SmallVectorImpl<OpAsmParser::OperandType> &argNames,
|
||||
SmallVectorImpl<Type> &argTypes,
|
||||
SmallVectorImpl<SmallVector<NamedAttribute, 2>> &argAttrs, bool &isVariadic,
|
||||
SmallVectorImpl<Type> &results) {
|
||||
SmallVectorImpl<Type> &resultTypes,
|
||||
SmallVectorImpl<SmallVector<NamedAttribute, 2>> &resultAttrs) {
|
||||
if (parseArgumentList(parser, allowVariadic, argTypes, argNames, argAttrs,
|
||||
isVariadic))
|
||||
return failure();
|
||||
// Parse the return types if present.
|
||||
return parser.parseOptionalArrowTypeList(results);
|
||||
if (succeeded(parser.parseOptionalArrow()))
|
||||
return parseFunctionResultList(parser, resultTypes, resultAttrs);
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Parser implementation for function-like operations. Uses `funcTypeBuilder`
|
||||
@@ -111,8 +152,9 @@ mlir::impl::parseFunctionLikeOp(OpAsmParser &parser, OperationState &result,
|
||||
mlir::impl::FuncTypeBuilder funcTypeBuilder) {
|
||||
SmallVector<OpAsmParser::OperandType, 4> entryArgs;
|
||||
SmallVector<SmallVector<NamedAttribute, 2>, 4> argAttrs;
|
||||
SmallVector<SmallVector<NamedAttribute, 2>, 4> resultAttrs;
|
||||
SmallVector<Type, 4> argTypes;
|
||||
SmallVector<Type, 4> results;
|
||||
SmallVector<Type, 4> resultTypes;
|
||||
auto &builder = parser.getBuilder();
|
||||
|
||||
// Parse the name as a symbol reference attribute.
|
||||
@@ -127,11 +169,11 @@ mlir::impl::parseFunctionLikeOp(OpAsmParser &parser, OperationState &result,
|
||||
auto signatureLocation = parser.getCurrentLocation();
|
||||
bool isVariadic = false;
|
||||
if (parseFunctionSignature(parser, allowVariadic, entryArgs, argTypes,
|
||||
argAttrs, isVariadic, results))
|
||||
argAttrs, isVariadic, resultTypes, resultAttrs))
|
||||
return failure();
|
||||
|
||||
std::string errorMessage;
|
||||
if (auto type = funcTypeBuilder(builder, argTypes, results,
|
||||
if (auto type = funcTypeBuilder(builder, argTypes, resultTypes,
|
||||
impl::VariadicFlag(isVariadic), errorMessage))
|
||||
result.addAttribute(getTypeAttrName(), TypeAttr::get(type));
|
||||
else
|
||||
@@ -145,12 +187,18 @@ mlir::impl::parseFunctionLikeOp(OpAsmParser &parser, OperationState &result,
|
||||
return failure();
|
||||
|
||||
// Add the attributes to the function arguments.
|
||||
SmallString<8> argAttrName;
|
||||
SmallString<8> attrNameBuf;
|
||||
for (unsigned i = 0, e = argTypes.size(); i != e; ++i)
|
||||
if (!argAttrs[i].empty())
|
||||
result.addAttribute(getArgAttrName(i, argAttrName),
|
||||
result.addAttribute(getArgAttrName(i, attrNameBuf),
|
||||
builder.getDictionaryAttr(argAttrs[i]));
|
||||
|
||||
// Add the attributes to the function results.
|
||||
for (unsigned i = 0, e = resultTypes.size(); i != e; ++i)
|
||||
if (!resultAttrs[i].empty())
|
||||
result.addAttribute(getResultAttrName(i, attrNameBuf),
|
||||
builder.getDictionaryAttr(resultAttrs[i]));
|
||||
|
||||
// Parse the optional function body.
|
||||
auto *body = result.addRegion();
|
||||
if (parser.parseOptionalRegion(*body, entryArgs,
|
||||
@@ -161,11 +209,29 @@ mlir::impl::parseFunctionLikeOp(OpAsmParser &parser, OperationState &result,
|
||||
return success();
|
||||
}
|
||||
|
||||
// Print a function result list.
|
||||
static void printFunctionResultList(OpAsmPrinter &p, ArrayRef<Type> types,
|
||||
ArrayRef<ArrayRef<NamedAttribute>> attrs) {
|
||||
assert(!types.empty() && "Should not be called for empty result list.");
|
||||
auto &os = p.getStream();
|
||||
bool needsParens =
|
||||
types.size() > 1 || types[0].isa<FunctionType>() || !attrs[0].empty();
|
||||
if (needsParens)
|
||||
os << '(';
|
||||
interleaveComma(llvm::zip(types, attrs), os,
|
||||
[&](const std::tuple<Type, ArrayRef<NamedAttribute>> &t) {
|
||||
p.printType(std::get<0>(t));
|
||||
p.printOptionalAttrDict(std::get<1>(t));
|
||||
});
|
||||
if (needsParens)
|
||||
os << ')';
|
||||
}
|
||||
|
||||
/// Print the signature of the function-like operation `op`. Assumes `op` has
|
||||
/// the FunctionLike trait and passed the verification.
|
||||
static void printSignature(OpAsmPrinter &p, Operation *op,
|
||||
ArrayRef<Type> argTypes, bool isVariadic,
|
||||
ArrayRef<Type> results) {
|
||||
ArrayRef<Type> resultTypes) {
|
||||
Region &body = op->getRegion(0);
|
||||
bool isExternal = body.empty();
|
||||
|
||||
@@ -190,14 +256,21 @@ static void printSignature(OpAsmPrinter &p, Operation *op,
|
||||
}
|
||||
|
||||
p << ')';
|
||||
p.printOptionalArrowTypeList(results);
|
||||
|
||||
if (!resultTypes.empty()) {
|
||||
p.getStream() << " -> ";
|
||||
SmallVector<ArrayRef<NamedAttribute>, 4> resultAttrs;
|
||||
for (int i = 0, e = resultTypes.size(); i < e; ++i)
|
||||
resultAttrs.push_back(::mlir::impl::getResultAttrs(op, i));
|
||||
printFunctionResultList(p, resultTypes, resultAttrs);
|
||||
}
|
||||
}
|
||||
|
||||
/// Printer implementation for function-like operations. Accepts lists of
|
||||
/// argument and result types to use while printing.
|
||||
void mlir::impl::printFunctionLikeOp(OpAsmPrinter &p, Operation *op,
|
||||
ArrayRef<Type> argTypes, bool isVariadic,
|
||||
ArrayRef<Type> results) {
|
||||
ArrayRef<Type> resultTypes) {
|
||||
// Print the operation and the function name.
|
||||
auto funcName =
|
||||
op->getAttrOfType<StringAttr>(::mlir::SymbolTable::getSymbolAttrName())
|
||||
@@ -206,20 +279,28 @@ void mlir::impl::printFunctionLikeOp(OpAsmPrinter &p, Operation *op,
|
||||
p.printSymbolName(funcName);
|
||||
|
||||
// Print the signature.
|
||||
printSignature(p, op, argTypes, isVariadic, results);
|
||||
printSignature(p, op, argTypes, isVariadic, resultTypes);
|
||||
|
||||
// Print out function attributes, if present.
|
||||
SmallVector<StringRef, 2> ignoredAttrs = {
|
||||
::mlir::SymbolTable::getSymbolAttrName(), getTypeAttrName()};
|
||||
|
||||
SmallString<8> attrNameBuf;
|
||||
|
||||
// Ignore any argument attributes.
|
||||
std::vector<SmallString<8>> argAttrStorage;
|
||||
SmallString<8> argAttrName;
|
||||
for (unsigned i = 0, e = argTypes.size(); i != e; ++i)
|
||||
if (op->getAttr(getArgAttrName(i, argAttrName)))
|
||||
argAttrStorage.emplace_back(argAttrName);
|
||||
if (op->getAttr(getArgAttrName(i, attrNameBuf)))
|
||||
argAttrStorage.emplace_back(attrNameBuf);
|
||||
ignoredAttrs.append(argAttrStorage.begin(), argAttrStorage.end());
|
||||
|
||||
// Ignore any result attributes.
|
||||
std::vector<SmallString<8>> resultAttrStorage;
|
||||
for (unsigned i = 0, e = resultTypes.size(); i != e; ++i)
|
||||
if (op->getAttr(getResultAttrName(i, attrNameBuf)))
|
||||
resultAttrStorage.emplace_back(attrNameBuf);
|
||||
ignoredAttrs.append(resultAttrStorage.begin(), resultAttrStorage.end());
|
||||
|
||||
auto attrs = op->getAttrs();
|
||||
if (attrs.size() > ignoredAttrs.size()) {
|
||||
p << "\n attributes ";
|
||||
|
||||
@@ -49,3 +49,27 @@ func @func_op() {
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error@+1 {{expected non-function type}}
|
||||
func @f() -> (foo
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error@+1 {{expected attribute name}}
|
||||
func @f() -> (i1 {)
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error@+1 {{invalid to use 'test.invalid_attr'}}
|
||||
func @f(%arg0: i64 {test.invalid_attr}) {
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error@+1 {{invalid to use 'test.invalid_attr'}}
|
||||
func @f(%arg0: i64) -> (i64 {test.invalid_attr}) {
|
||||
return %arg0 : i64
|
||||
}
|
||||
|
||||
@@ -924,6 +924,11 @@ func @invalid_func_arg_attr(i1 {non_dialect_attr = 10})
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error @+1 {{results may only have dialect attributes}}
|
||||
func @invalid_func_result_attr() -> (i1 {non_dialect_attr = 10})
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error @+1 {{expected '<' in tuple type}}
|
||||
func @invalid_tuple_missing_less(tuple i32>)
|
||||
|
||||
|
||||
@@ -847,6 +847,11 @@ func @func_arg_attrs(%arg0: i1 {dialect.attr = 10 : i64}) {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @func_result_attrs({{.*}}) -> (f32 {dialect.attr = 1 : i64})
|
||||
func @func_result_attrs(%arg0: f32) -> (f32 {dialect.attr = 1}) {
|
||||
return %arg0 : f32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @empty_tuple(tuple<>)
|
||||
func @empty_tuple(tuple<>)
|
||||
|
||||
|
||||
@@ -114,6 +114,24 @@ TestDialect::TestDialect(MLIRContext *context)
|
||||
allowUnknownOperations();
|
||||
}
|
||||
|
||||
LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op,
|
||||
unsigned regionIndex,
|
||||
unsigned argIndex,
|
||||
NamedAttribute namedAttr) {
|
||||
if (namedAttr.first == "test.invalid_attr")
|
||||
return op->emitError() << "invalid to use 'test.invalid_attr'";
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
|
||||
unsigned resultIndex,
|
||||
NamedAttribute namedAttr) {
|
||||
if (namedAttr.first == "test.invalid_attr")
|
||||
return op->emitError() << "invalid to use 'test.invalid_attr'";
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Test IsolatedRegionOp - parse passthrough region arguments.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -40,6 +40,14 @@ public:
|
||||
|
||||
/// Get the canonical string name of the dialect.
|
||||
static StringRef getDialectName() { return "test"; }
|
||||
|
||||
LogicalResult verifyRegionArgAttribute(Operation *, unsigned regionIndex,
|
||||
unsigned argIndex,
|
||||
NamedAttribute) override;
|
||||
|
||||
LogicalResult verifyRegionResultAttribute(Operation *, unsigned regionIndex,
|
||||
unsigned resultIndex,
|
||||
NamedAttribute) override;
|
||||
};
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
|
||||
Reference in New Issue
Block a user