diff --git a/mlir/g3doc/LangRef.md b/mlir/g3doc/LangRef.md index 493ade1a5908..391f77325d47 100644 --- a/mlir/g3doc/LangRef.md +++ b/mlir/g3doc/LangRef.md @@ -321,11 +321,19 @@ name via a string attribute like [SymbolRefAttr](#symbol-reference-attribute)): function ::= `func` function-signature function-attributes? function-body? function-signature ::= symbol-ref-id `(` argument-list `)` - (`->` function-result-type)? + (`->` function-result-list)? + argument-list ::= (named-argument (`,` named-argument)*) | /*empty*/ argument-list ::= (type attribute-dict? (`,` type attribute-dict?)*) | /*empty*/ named-argument ::= ssa-id `:` type attribute-dict? +function-result-list ::= function-result-list-parens + | non-function-type +function-result-list-parens ::= `(` `)` + | `(` function-result-list-no-parens `)` +function-result-list-no-parens ::= function-result (`,` function-result)* +function-result ::= type attribute-dict? + function-attributes ::= `attributes` attribute-dict function-body ::= region ``` diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index a8004a47d5a2..35fe5155557b 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -546,6 +546,10 @@ def LLVM_LLVMFuncOp : LLVM_ZeroResultOp<"func", // Depends on the type attribute being correct as checked by verifyType. unsigned getNumFuncArguments(); + // Hook for OpTrait::FunctionLike, returns the number of function results. + // Depends on the type attribute being correct as checked by verifyType. + unsigned getNumFuncResults(); + // Hook for OpTrait::FunctionLike, called after verifying that the 'type' // attribute is present. This can check for preconditions of the // getNumArguments hook not failing. diff --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h index 810d11c2ef21..bf7db9157187 100644 --- a/mlir/include/mlir/IR/Dialect.h +++ b/mlir/include/mlir/IR/Dialect.h @@ -146,6 +146,15 @@ public: unsigned argIndex, NamedAttribute); + /// Verify an attribute from this dialect on the result at 'resultIndex' for + /// the region at 'regionIndex' on the given operation. Returns failure if + /// the verification failed, success otherwise. This hook may optionally be + /// invoked from any operation containing a region. + virtual LogicalResult verifyRegionResultAttribute(Operation *, + unsigned regionIndex, + unsigned resultIndex, + NamedAttribute); + /// Verify an attribute from this dialect on the given operation. Returns /// failure if the verification failed, success otherwise. virtual LogicalResult verifyOperationAttribute(Operation *, NamedAttribute) { diff --git a/mlir/include/mlir/IR/Function.h b/mlir/include/mlir/IR/Function.h index 95920b38c14b..a7777c639e47 100644 --- a/mlir/include/mlir/IR/Function.h +++ b/mlir/include/mlir/IR/Function.h @@ -136,16 +136,18 @@ public: } private: - // This trait needs access to `getNumFuncArguments` and `verifyType` hooks - // defined below. + // This trait needs access to the hooks defined below. friend class OpTrait::FunctionLike; /// Returns the number of arguments. This is a hook for OpTrait::FunctionLike. unsigned getNumFuncArguments() { return getType().getInputs().size(); } + /// Returns the number of results. This is a hook for OpTrait::FunctionLike. + unsigned getNumFuncResults() { return getType().getResults().size(); } + /// Hook for OpTrait::FunctionLike, called after verifying that the 'type' /// attribute is present and checks if it holds a function type. Ensures - /// getType and getNumFuncArguments can be called safely. + /// getType, getNumFuncArguments, and getNumFuncResults can be called safely. LogicalResult verifyType() { auto type = getTypeAttr().getValue(); if (!type.isa()) diff --git a/mlir/include/mlir/IR/FunctionSupport.h b/mlir/include/mlir/IR/FunctionSupport.h index 7fa27ff268b9..ccac4d3862ba 100644 --- a/mlir/include/mlir/IR/FunctionSupport.h +++ b/mlir/include/mlir/IR/FunctionSupport.h @@ -39,6 +39,12 @@ inline StringRef getArgAttrName(unsigned arg, SmallVectorImpl &out) { return ("arg" + Twine(arg)).toStringRef(out); } +/// Return the name of the attribute used for function results. +inline StringRef getResultAttrName(unsigned arg, SmallVectorImpl &out) { + out.clear(); + return ("result" + Twine(arg)).toStringRef(out); +} + /// Returns the dictionary attribute corresponding to the argument at 'index'. /// If there are no argument attributes at 'index', a null attribute is /// returned. @@ -47,12 +53,26 @@ inline DictionaryAttr getArgAttrDict(Operation *op, unsigned index) { return op->getAttrOfType(getArgAttrName(index, nameOut)); } +/// Returns the dictionary attribute corresponding to the result at 'index'. +/// If there are no result attributes at 'index', a null attribute is +/// returned. +inline DictionaryAttr getResultAttrDict(Operation *op, unsigned index) { + SmallString<8> nameOut; + return op->getAttrOfType(getResultAttrName(index, nameOut)); +} + /// Return all of the attributes for the argument at 'index'. inline ArrayRef getArgAttrs(Operation *op, unsigned index) { auto argDict = getArgAttrDict(op, index); return argDict ? argDict.getValue() : llvm::None; } +/// Return all of the attributes for the result at 'index'. +inline ArrayRef getResultAttrs(Operation *op, unsigned index) { + auto resultDict = getResultAttrDict(op, index); + return resultDict ? resultDict.getValue() : llvm::None; +} + /// A named class for passing around the variadic flag. class VariadicFlag { public: @@ -87,7 +107,7 @@ ParseResult parseFunctionLikeOp(OpAsmParser &parser, OperationState &result, /// argument and result types to use while printing. void printFunctionLikeOp(OpAsmPrinter &p, Operation *op, ArrayRef argTypes, bool isVariadic, - ArrayRef results); + ArrayRef resultTypes); } // namespace impl @@ -111,10 +131,13 @@ namespace OpTrait { /// - Concrete ops *must* define a member function `getNumFuncArguments()` that /// returns the number of function arguments based exclusively on type (so that /// it can be called on function declarations). +/// - Concrete ops *must* define a member function `getNumFuncResults()` that +/// returns the number of function results based exclusively on type (so that +/// it can be called on function declarations). /// - To verify that the type respects op-specific invariants, concrete ops may /// redefine the `verifyType()` hook that will be called after verifying the /// presence of the `type` attribute and before any call to -/// `getNumFuncArguments` from the verifier. +/// `getNumFuncArguments`/`getNumFuncResults` from the verifier. template class FunctionLike : public OpTrait::TraitBase { public: @@ -202,6 +225,10 @@ public: return static_cast(this)->getNumFuncArguments(); } + unsigned getNumResults() { + return static_cast(this)->getNumFuncResults(); + } + /// Gets argument. BlockArgument *getArgument(unsigned idx) { return getBlocks().front().getArgument(idx); @@ -278,11 +305,75 @@ public: NamedAttributeList::RemoveResult removeArgAttr(unsigned index, Identifier name); + //===--------------------------------------------------------------------===// + // Result Attributes + //===--------------------------------------------------------------------===// + + /// FunctionLike operations allow for attaching attributes to each of the + /// respective function results. These result attributes are stored as + /// DictionaryAttrs in the main operation attribute dictionary. The name of + /// these entries is `result` followed by the index of the result. These + /// result attribute dictionaries are optional, and will generally only + /// exist if they are non-empty. + + /// Return all of the attributes for the result at 'index'. + ArrayRef getResultAttrs(unsigned index) { + return ::mlir::impl::getResultAttrs(this->getOperation(), index); + } + + /// Return all result attributes of this function. + void getAllResultAttrs(SmallVectorImpl &result) { + for (unsigned i = 0, e = getNumResults(); i != e; ++i) + result.emplace_back(getResultAttrDict(i)); + } + + /// Return the specified attribute, if present, for the result at 'index', + /// null otherwise. + Attribute getResultAttr(unsigned index, Identifier name) { + auto argDict = getResultAttrDict(index); + return argDict ? argDict.get(name) : nullptr; + } + Attribute getResultAttr(unsigned index, StringRef name) { + auto argDict = getResultAttrDict(index); + return argDict ? argDict.get(name) : nullptr; + } + + template + AttrClass getResultAttrOfType(unsigned index, Identifier name) { + return getResultAttr(index, name).template dyn_cast_or_null(); + } + template + AttrClass getResultAttrOfType(unsigned index, StringRef name) { + return getResultAttr(index, name).template dyn_cast_or_null(); + } + + /// Set the attributes held by the result at 'index'. + void setResultAttrs(unsigned index, ArrayRef attributes); + void setResultAttrs(unsigned index, NamedAttributeList attributes); + void setAllResultAttrs(ArrayRef attributes) { + assert(attributes.size() == getNumResults()); + for (unsigned i = 0, e = attributes.size(); i != e; ++i) + setResultAttrs(i, attributes[i]); + } + + /// If the an attribute exists with the specified name, change it to the new + /// value. Otherwise, add a new attribute with the specified name/value. + void setResultAttr(unsigned index, Identifier name, Attribute value); + void setResultAttr(unsigned index, StringRef name, Attribute value) { + setResultAttr(index, + Identifier::get(name, this->getOperation()->getContext()), + value); + } + + /// Remove the attribute 'name' from the result at 'index'. + NamedAttributeList::RemoveResult removeResultAttr(unsigned index, + Identifier name); + protected: /// Returns the attribute entry name for the set of argument attributes at - /// index 'arg'. - static StringRef getArgAttrName(unsigned arg, SmallVectorImpl &out) { - return ::mlir::impl::getArgAttrName(arg, out); + /// 'index'. + static StringRef getArgAttrName(unsigned index, SmallVectorImpl &out) { + return ::mlir::impl::getArgAttrName(index, out); } /// Returns the dictionary attribute corresponding to the argument at 'index'. @@ -293,6 +384,21 @@ protected: return ::mlir::impl::getArgAttrDict(this->getOperation(), index); } + /// Returns the attribute entry name for the set of result attributes at + /// 'index'. + static StringRef getResultAttrName(unsigned index, + SmallVectorImpl &out) { + return ::mlir::impl::getResultAttrName(index, out); + } + + /// Returns the dictionary attribute corresponding to the result at 'index'. + /// If there are no result attributes at 'index', a null attribute is + /// returned. + DictionaryAttr getResultAttrDict(unsigned index) { + assert(index < getNumResults() && "invalid result number"); + return ::mlir::impl::getResultAttrDict(this->getOperation(), index); + } + /// Hook for concrete classes to verify that the type attribute respects /// op-specific invariants. Default implementation always succeeds. LogicalResult verifyType() { return success(); } @@ -326,6 +432,23 @@ LogicalResult FunctionLike::verifyTrait(Operation *op) { } } + for (unsigned i = 0, e = funcOp.getNumResults(); i != e; ++i) { + // Verify that all of the result attributes are dialect attributes, i.e. + // that they contain a dialect prefix in their name. Call the dialect, if + // registered, to verify the attributes themselves. + for (auto attr : funcOp.getResultAttrs(i)) { + if (!attr.first.strref().contains('.')) + return funcOp.emitOpError("results may only have dialect attributes"); + auto dialectNamePair = attr.first.strref().split('.'); + if (auto *dialect = ctx->getRegisteredDialect(dialectNamePair.first)) { + if (failed(dialect->verifyRegionResultAttribute(op, /*regionIndex=*/0, + /*resultIndex=*/i, + attr))) + return failure(); + } + } + } + // Check that the op has exactly one region for the body. if (op->getNumRegions() != 1) return funcOp.emitOpError("expects one region"); @@ -354,10 +477,10 @@ void FunctionLike::setArgAttrs( assert(index < getNumArguments() && "invalid argument number"); SmallString<8> nameOut; getArgAttrName(index, nameOut); - Operation *op = this->getOperation(); if (attributes.empty()) return (void)static_cast(this)->removeAttr(nameOut); + Operation *op = this->getOperation(); op->setAttr(nameOut, DictionaryAttr::get(attributes, op->getContext())); } @@ -400,6 +523,64 @@ FunctionLike::removeArgAttr(unsigned index, Identifier name) { return result; } +//===----------------------------------------------------------------------===// +// Function Result Attribute. +//===----------------------------------------------------------------------===// + +/// Set the attributes held by the result at 'index'. +template +void FunctionLike::setResultAttrs( + unsigned index, ArrayRef attributes) { + assert(index < getNumResults() && "invalid result number"); + SmallString<8> nameOut; + getResultAttrName(index, nameOut); + + if (attributes.empty()) + return (void)static_cast(this)->removeAttr(nameOut); + Operation *op = this->getOperation(); + op->setAttr(nameOut, DictionaryAttr::get(attributes, op->getContext())); +} + +template +void FunctionLike::setResultAttrs(unsigned index, + NamedAttributeList attributes) { + assert(index < getNumResults() && "invalid result number"); + SmallString<8> nameOut; + if (auto newAttr = attributes.getDictionary()) + return this->getOperation()->setAttr(getResultAttrName(index, nameOut), + newAttr); + static_cast(this)->removeAttr( + getResultAttrName(index, nameOut)); +} + +/// If the an attribute exists with the specified name, change it to the new +/// value. Otherwise, add a new attribute with the specified name/value. +template +void FunctionLike::setResultAttr(unsigned index, Identifier name, + Attribute value) { + auto curAttr = getResultAttrDict(index); + NamedAttributeList attrList(curAttr); + attrList.set(name, value); + + // If the attribute changed, then set the new arg attribute list. + if (curAttr != attrList.getDictionary()) + setResultAttrs(index, attrList); +} + +/// Remove the attribute 'name' from the result at 'index'. +template +NamedAttributeList::RemoveResult +FunctionLike::removeResultAttr(unsigned index, Identifier name) { + // Build an attribute list and remove the attribute at 'name'. + NamedAttributeList attrList(getResultAttrDict(index)); + auto result = attrList.remove(name); + + // If the attribute was removed, then update the result dictionary. + if (result == NamedAttributeList::RemoveResult::Removed) + setResultAttrs(index, attrList); + return result; +} + } // end namespace OpTrait } // end namespace mlir diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 23e3889c0490..618ee231f9e6 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -1115,6 +1115,21 @@ unsigned LLVMFuncOp::getNumFuncArguments() { return getType().getUnderlyingType()->getFunctionNumParams(); } +// Hook for OpTrait::FunctionLike, returns the number of function results. +// Depends on the type attribute being correct as checked by verifyType +unsigned LLVMFuncOp::getNumFuncResults() { + llvm::FunctionType *funcType = + cast(getType().getUnderlyingType()); + // We model LLVM functions that return void as having zero results, + // and all others as having one result. + // If we modeled a void return as one result, then it would be possible to + // attach an MLIR result attribute to it, and it isn't clear what semantics we + // would assign to that. + if (funcType->getReturnType()->isVoidTy()) + return 0; + return 1; +} + static LogicalResult verify(LLVMFuncOp op) { if (op.isExternal()) return success(); diff --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp index 6a7dcaed0583..f8539c01d977 100644 --- a/mlir/lib/IR/Dialect.cpp +++ b/mlir/lib/IR/Dialect.cpp @@ -89,6 +89,15 @@ LogicalResult Dialect::verifyRegionArgAttribute(Operation *, unsigned, unsigned, return success(); } +/// Verify an attribute from this dialect on the result at 'resultIndex' for +/// the region at 'regionIndex' on the given operation. Returns failure if +/// the verification failed, success otherwise. This hook may optionally be +/// invoked from any operation containing a region. +LogicalResult Dialect::verifyRegionResultAttribute(Operation *, unsigned, + unsigned, NamedAttribute) { + return success(); +} + /// Parse an attribute registered to this dialect. Attribute Dialect::parseAttribute(StringRef attrData, Type type, Location loc) const { diff --git a/mlir/lib/IR/FunctionSupport.cpp b/mlir/lib/IR/FunctionSupport.cpp index 468301e94311..22f207ecf150 100644 --- a/mlir/lib/IR/FunctionSupport.cpp +++ b/mlir/lib/IR/FunctionSupport.cpp @@ -88,6 +88,45 @@ parseArgumentList(OpAsmParser &parser, bool allowVariadic, return success(); } +/// Parse a function result list. +/// +/// function-result-list ::= function-result-list-parens +/// | non-function-type +/// function-result-list-parens ::= `(` `)` +/// | `(` function-result-list-no-parens `)` +/// function-result-list-no-parens ::= function-result (`,` function-result)* +/// function-result ::= type attribute-dict? +/// +static ParseResult parseFunctionResultList( + OpAsmParser &parser, SmallVectorImpl &resultTypes, + SmallVectorImpl> &resultAttrs) { + if (failed(parser.parseOptionalLParen())) { + // We already know that there is no `(`, so parse a type. + // Because there is no `(`, it cannot be a function type. + Type ty; + if (parser.parseType(ty)) + return failure(); + resultTypes.push_back(ty); + resultAttrs.emplace_back(); + return success(); + } + + // Special case for an empty set of parens. + if (succeeded(parser.parseOptionalRParen())) + return success(); + + // Parse individual function results. + do { + resultTypes.emplace_back(); + resultAttrs.emplace_back(); + if (parser.parseType(resultTypes.back()) || + parser.parseOptionalAttributeDict(resultAttrs.back())) { + return failure(); + } + } while (succeeded(parser.parseOptionalComma())); + return parser.parseRParen(); +} + /// Parse a function signature, starting with a name and including the /// parameter list. static ParseResult parseFunctionSignature( @@ -95,12 +134,14 @@ static ParseResult parseFunctionSignature( SmallVectorImpl &argNames, SmallVectorImpl &argTypes, SmallVectorImpl> &argAttrs, bool &isVariadic, - SmallVectorImpl &results) { + SmallVectorImpl &resultTypes, + SmallVectorImpl> &resultAttrs) { if (parseArgumentList(parser, allowVariadic, argTypes, argNames, argAttrs, isVariadic)) return failure(); - // Parse the return types if present. - return parser.parseOptionalArrowTypeList(results); + if (succeeded(parser.parseOptionalArrow())) + return parseFunctionResultList(parser, resultTypes, resultAttrs); + return success(); } /// Parser implementation for function-like operations. Uses `funcTypeBuilder` @@ -111,8 +152,9 @@ mlir::impl::parseFunctionLikeOp(OpAsmParser &parser, OperationState &result, mlir::impl::FuncTypeBuilder funcTypeBuilder) { SmallVector entryArgs; SmallVector, 4> argAttrs; + SmallVector, 4> resultAttrs; SmallVector argTypes; - SmallVector results; + SmallVector resultTypes; auto &builder = parser.getBuilder(); // Parse the name as a symbol reference attribute. @@ -127,11 +169,11 @@ mlir::impl::parseFunctionLikeOp(OpAsmParser &parser, OperationState &result, auto signatureLocation = parser.getCurrentLocation(); bool isVariadic = false; if (parseFunctionSignature(parser, allowVariadic, entryArgs, argTypes, - argAttrs, isVariadic, results)) + argAttrs, isVariadic, resultTypes, resultAttrs)) return failure(); std::string errorMessage; - if (auto type = funcTypeBuilder(builder, argTypes, results, + if (auto type = funcTypeBuilder(builder, argTypes, resultTypes, impl::VariadicFlag(isVariadic), errorMessage)) result.addAttribute(getTypeAttrName(), TypeAttr::get(type)); else @@ -145,12 +187,18 @@ mlir::impl::parseFunctionLikeOp(OpAsmParser &parser, OperationState &result, return failure(); // Add the attributes to the function arguments. - SmallString<8> argAttrName; + SmallString<8> attrNameBuf; for (unsigned i = 0, e = argTypes.size(); i != e; ++i) if (!argAttrs[i].empty()) - result.addAttribute(getArgAttrName(i, argAttrName), + result.addAttribute(getArgAttrName(i, attrNameBuf), builder.getDictionaryAttr(argAttrs[i])); + // Add the attributes to the function results. + for (unsigned i = 0, e = resultTypes.size(); i != e; ++i) + if (!resultAttrs[i].empty()) + result.addAttribute(getResultAttrName(i, attrNameBuf), + builder.getDictionaryAttr(resultAttrs[i])); + // Parse the optional function body. auto *body = result.addRegion(); if (parser.parseOptionalRegion(*body, entryArgs, @@ -161,11 +209,29 @@ mlir::impl::parseFunctionLikeOp(OpAsmParser &parser, OperationState &result, return success(); } +// Print a function result list. +static void printFunctionResultList(OpAsmPrinter &p, ArrayRef types, + ArrayRef> attrs) { + assert(!types.empty() && "Should not be called for empty result list."); + auto &os = p.getStream(); + bool needsParens = + types.size() > 1 || types[0].isa() || !attrs[0].empty(); + if (needsParens) + os << '('; + interleaveComma(llvm::zip(types, attrs), os, + [&](const std::tuple> &t) { + p.printType(std::get<0>(t)); + p.printOptionalAttrDict(std::get<1>(t)); + }); + if (needsParens) + os << ')'; +} + /// Print the signature of the function-like operation `op`. Assumes `op` has /// the FunctionLike trait and passed the verification. static void printSignature(OpAsmPrinter &p, Operation *op, ArrayRef argTypes, bool isVariadic, - ArrayRef results) { + ArrayRef resultTypes) { Region &body = op->getRegion(0); bool isExternal = body.empty(); @@ -190,14 +256,21 @@ static void printSignature(OpAsmPrinter &p, Operation *op, } p << ')'; - p.printOptionalArrowTypeList(results); + + if (!resultTypes.empty()) { + p.getStream() << " -> "; + SmallVector, 4> resultAttrs; + for (int i = 0, e = resultTypes.size(); i < e; ++i) + resultAttrs.push_back(::mlir::impl::getResultAttrs(op, i)); + printFunctionResultList(p, resultTypes, resultAttrs); + } } /// Printer implementation for function-like operations. Accepts lists of /// argument and result types to use while printing. void mlir::impl::printFunctionLikeOp(OpAsmPrinter &p, Operation *op, ArrayRef argTypes, bool isVariadic, - ArrayRef results) { + ArrayRef resultTypes) { // Print the operation and the function name. auto funcName = op->getAttrOfType(::mlir::SymbolTable::getSymbolAttrName()) @@ -206,20 +279,28 @@ void mlir::impl::printFunctionLikeOp(OpAsmPrinter &p, Operation *op, p.printSymbolName(funcName); // Print the signature. - printSignature(p, op, argTypes, isVariadic, results); + printSignature(p, op, argTypes, isVariadic, resultTypes); // Print out function attributes, if present. SmallVector ignoredAttrs = { ::mlir::SymbolTable::getSymbolAttrName(), getTypeAttrName()}; + SmallString<8> attrNameBuf; + // Ignore any argument attributes. std::vector> argAttrStorage; - SmallString<8> argAttrName; for (unsigned i = 0, e = argTypes.size(); i != e; ++i) - if (op->getAttr(getArgAttrName(i, argAttrName))) - argAttrStorage.emplace_back(argAttrName); + if (op->getAttr(getArgAttrName(i, attrNameBuf))) + argAttrStorage.emplace_back(attrNameBuf); ignoredAttrs.append(argAttrStorage.begin(), argAttrStorage.end()); + // Ignore any result attributes. + std::vector> resultAttrStorage; + for (unsigned i = 0, e = resultTypes.size(); i != e; ++i) + if (op->getAttr(getResultAttrName(i, attrNameBuf))) + resultAttrStorage.emplace_back(attrNameBuf); + ignoredAttrs.append(resultAttrStorage.begin(), resultAttrStorage.end()); + auto attrs = op->getAttrs(); if (attrs.size() > ignoredAttrs.size()) { p << "\n attributes "; diff --git a/mlir/test/IR/invalid-func-op.mlir b/mlir/test/IR/invalid-func-op.mlir index 16d734d66622..15547369fc7b 100644 --- a/mlir/test/IR/invalid-func-op.mlir +++ b/mlir/test/IR/invalid-func-op.mlir @@ -49,3 +49,27 @@ func @func_op() { } return } + +// ----- + +// expected-error@+1 {{expected non-function type}} +func @f() -> (foo + +// ----- + +// expected-error@+1 {{expected attribute name}} +func @f() -> (i1 {) + +// ----- + +// expected-error@+1 {{invalid to use 'test.invalid_attr'}} +func @f(%arg0: i64 {test.invalid_attr}) { + return +} + +// ----- + +// expected-error@+1 {{invalid to use 'test.invalid_attr'}} +func @f(%arg0: i64) -> (i64 {test.invalid_attr}) { + return %arg0 : i64 +} diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir index ea368d7a5510..2f8dcc96b6d1 100644 --- a/mlir/test/IR/invalid.mlir +++ b/mlir/test/IR/invalid.mlir @@ -924,6 +924,11 @@ func @invalid_func_arg_attr(i1 {non_dialect_attr = 10}) // ----- +// expected-error @+1 {{results may only have dialect attributes}} +func @invalid_func_result_attr() -> (i1 {non_dialect_attr = 10}) + +// ----- + // expected-error @+1 {{expected '<' in tuple type}} func @invalid_tuple_missing_less(tuple i32>) diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir index dc300c7392d6..13b48566ad1f 100644 --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -847,6 +847,11 @@ func @func_arg_attrs(%arg0: i1 {dialect.attr = 10 : i64}) { return } +// CHECK-LABEL: func @func_result_attrs({{.*}}) -> (f32 {dialect.attr = 1 : i64}) +func @func_result_attrs(%arg0: f32) -> (f32 {dialect.attr = 1}) { + return %arg0 : f32 +} + // CHECK-LABEL: func @empty_tuple(tuple<>) func @empty_tuple(tuple<>) diff --git a/mlir/test/lib/TestDialect/TestDialect.cpp b/mlir/test/lib/TestDialect/TestDialect.cpp index ee8325fd13ec..2e3d97b473f7 100644 --- a/mlir/test/lib/TestDialect/TestDialect.cpp +++ b/mlir/test/lib/TestDialect/TestDialect.cpp @@ -114,6 +114,24 @@ TestDialect::TestDialect(MLIRContext *context) allowUnknownOperations(); } +LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op, + unsigned regionIndex, + unsigned argIndex, + NamedAttribute namedAttr) { + if (namedAttr.first == "test.invalid_attr") + return op->emitError() << "invalid to use 'test.invalid_attr'"; + return success(); +} + +LogicalResult +TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex, + unsigned resultIndex, + NamedAttribute namedAttr) { + if (namedAttr.first == "test.invalid_attr") + return op->emitError() << "invalid to use 'test.invalid_attr'"; + return success(); +} + //===----------------------------------------------------------------------===// // Test IsolatedRegionOp - parse passthrough region arguments. //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/TestDialect/TestDialect.h b/mlir/test/lib/TestDialect/TestDialect.h index ffe2a1c50ec7..ade0eb81c40a 100644 --- a/mlir/test/lib/TestDialect/TestDialect.h +++ b/mlir/test/lib/TestDialect/TestDialect.h @@ -40,6 +40,14 @@ public: /// Get the canonical string name of the dialect. static StringRef getDialectName() { return "test"; } + + LogicalResult verifyRegionArgAttribute(Operation *, unsigned regionIndex, + unsigned argIndex, + NamedAttribute) override; + + LogicalResult verifyRegionResultAttribute(Operation *, unsigned regionIndex, + unsigned resultIndex, + NamedAttribute) override; }; #define GET_OP_CLASSES