diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index c41a886b06bb..1da231289d4e 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -184,11 +184,10 @@ public: return op; } - ForStmt *createFor() { - auto stmt = new ForStmt(); - block->getStatements().push_back(stmt); - return stmt; - } + // Creates for statement. When step is not specified, it is set to 1. + ForStmt *createFor(AffineConstantExpr *lowerBound, + AffineConstantExpr *upperBound, + AffineConstantExpr *step = nullptr); IfStmt *createIf() { auto stmt = new IfStmt(); diff --git a/mlir/include/mlir/IR/Statements.h b/mlir/include/mlir/IR/Statements.h index 902bf352534c..72bcd1173fed 100644 --- a/mlir/include/mlir/IR/Statements.h +++ b/mlir/include/mlir/IR/Statements.h @@ -22,10 +22,11 @@ #ifndef MLIR_IR_STATEMENTS_H #define MLIR_IR_STATEMENTS_H -#include "mlir/Support/LLVM.h" +#include "mlir/IR/AffineExpr.h" #include "mlir/IR/Operation.h" #include "mlir/IR/Statement.h" #include "mlir/IR/StmtBlock.h" +#include "mlir/Support/LLVM.h" namespace mlir { @@ -50,11 +51,20 @@ public: /// For statement represents an affine loop nest. class ForStmt : public Statement, public StmtBlock { public: - explicit ForStmt() : Statement(Kind::For), StmtBlock(StmtBlockKind::For) {} + // TODO: lower and upper bounds should be affine maps with + // dimension and symbol use lists. + explicit ForStmt(AffineConstantExpr *lowerBound, + AffineConstantExpr *upperBound, AffineConstantExpr *step) + : Statement(Kind::For), StmtBlock(StmtBlockKind::For), + lowerBound(lowerBound), upperBound(upperBound), step(step) {} + //TODO: delete nested statements or assert that they are gone. ~ForStmt() {} - // TODO: represent loop variable, bounds and step + // TODO: represent induction variable + AffineConstantExpr *getLowerBound() const { return lowerBound; } + AffineConstantExpr *getUpperBound() const { return upperBound; } + AffineConstantExpr *getStep() const { return step; } /// Methods for support type inquiry through isa, cast, and dyn_cast. static bool classof(const Statement *stmt) { @@ -64,9 +74,14 @@ public: static bool classof(const StmtBlock *block) { return block->getStmtBlockKind() == StmtBlockKind::For; } + +private: + AffineConstantExpr *lowerBound; + AffineConstantExpr *upperBound; + AffineConstantExpr *step; }; -/// If clause represents statements contained within then or else clause +/// An if clause represents statements contained within a then or an else clause /// of an if statement. class IfClause : public StmtBlock { public: diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 424487c25167..570ae4923e51 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -594,7 +594,12 @@ void MLFunctionState::print(const Statement *stmt) { void MLFunctionState::print(const OperationStmt *stmt) { printOperation(stmt); } void MLFunctionState::print(const ForStmt *stmt) { - os.indent(numSpaces) << "for {\n"; + os.indent(numSpaces) << "for x = " << *stmt->getLowerBound(); + os << " to " << *stmt->getUpperBound(); + if (stmt->getStep()->getValue() != 1) + os << " step " << *stmt->getStep(); + + os << " {\n"; print(static_cast(stmt)); os.indent(numSpaces) << "}"; } diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index 8d27991b8fd0..3d7e023c0ac6 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -143,3 +143,17 @@ AffineExpr *Builder::getFloorDivExpr(AffineExpr *lhs, AffineExpr *rhs) { AffineExpr *Builder::getCeilDivExpr(AffineExpr *lhs, AffineExpr *rhs) { return AffineCeilDivExpr::get(lhs, rhs, context); } + +//===----------------------------------------------------------------------===// +// Statements +//===----------------------------------------------------------------------===// + +ForStmt *MLFuncBuilder::createFor(AffineConstantExpr *lowerBound, + AffineConstantExpr *upperBound, + AffineConstantExpr *step) { + if (!step) + step = getConstantExpr(1); + auto stmt = new ForStmt(lowerBound, upperBound, step); + block->getStatements().push_back(stmt); + return stmt; +} diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index 7fbbf9b5b8a1..71b4904572ab 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -1581,6 +1581,7 @@ private: MLFuncBuilder builder; ParseResult parseForStmt(); + AffineConstantExpr *parseIntConstant(); ParseResult parseIfStmt(); ParseResult parseElseClause(IfClause *elseClause); ParseResult parseStatements(StmtBlock *block); @@ -1598,10 +1599,11 @@ ParseResult MLFunctionParser::parseFunctionBody() { if (!consumeIf(Token::kw_return)) emitError("ML function must end with return statement"); - // TODO: parse return statement operands - if (!consumeIf(Token::r_brace)) - emitError("expected '}' in ML function"); + // TODO: store return operands in the IR. + SmallVector dummyUseInfo; + if (parseOptionalSSAUseList(Token::r_brace, dummyUseInfo)) + return ParseFailure; getModule()->functionList.push_back(function); @@ -1616,17 +1618,66 @@ ParseResult MLFunctionParser::parseFunctionBody() { ParseResult MLFunctionParser::parseForStmt() { consumeToken(Token::kw_for); - //TODO: parse loop header - ForStmt *stmt = builder.createFor(); + // Parse induction variable + if (getToken().isNot(Token::percent_identifier)) + return emitError("expected SSA identifier for the loop variable"); - // If parsing of the for statement body fails - // MLIR contains for statement with successfully parsed nested statements + // TODO: create SSA value definition from name + StringRef name = getTokenSpelling().drop_front(); + (void)name; + + consumeToken(Token::percent_identifier); + + if (!consumeIf(Token::equal)) + return emitError("expected ="); + + // Parse loop bounds + AffineConstantExpr *lowerBound = parseIntConstant(); + if (!lowerBound) + return ParseFailure; + + if (!consumeIf(Token::kw_to)) + return emitError("expected 'to' between bounds"); + + AffineConstantExpr *upperBound = parseIntConstant(); + if (!upperBound) + return ParseFailure; + + // Parse step + AffineConstantExpr *step = nullptr; + if (consumeIf(Token::kw_step)) { + step = parseIntConstant(); + if (!step) + return ParseFailure; + } + + // Create for statement. + ForStmt *stmt = builder.createFor(lowerBound, upperBound, step); + + // If parsing of the for statement body fails, + // MLIR contains for statement with those nested statements that have been + // successfully parsed. if (parseStmtBlock(static_cast(stmt))) return ParseFailure; return ParseSuccess; } +// This method is temporary workaround to parse simple loop bounds and +// step. +// TODO: remove this method once it's no longer used. +AffineConstantExpr *MLFunctionParser::parseIntConstant() { + if (getToken().isNot(Token::integer)) + return (emitError("expected non-negative integer for now"), nullptr); + + auto val = getToken().getUInt64IntegerValue(); + if (!val.hasValue() || (int64_t)val.getValue() < 0) { + return (emitError("constant too large for affineint"), nullptr); + } + consumeToken(Token::integer); + return builder.getConstantExpr((int64_t)val.getValue()); +} + /// If statement. /// /// ml-if-head ::= `if` ml-if-cond `{` ml-stmt* `}` @@ -1642,13 +1693,14 @@ ParseResult MLFunctionParser::parseIfStmt() { //TODO: parse condition if (!consumeIf(Token::r_paren)) - return emitError("expected )"); + return emitError("expected ')'"); IfStmt *ifStmt = builder.createIf(); IfClause *thenClause = ifStmt->getThenClause(); - // If parsing of the then or optional else clause fails MLIR contains - // if statement with successfully parsed nested statements. + // When parsing of an if statement body fails, the IR contains + // the if statement with the portion of the body that has been + // successfully parsed. if (parseStmtBlock(thenClause)) return ParseFailure; @@ -1735,7 +1787,10 @@ private: ParseResult parseAffineMapDef(); // Functions. - ParseResult parseFunctionSignature(StringRef &name, FunctionType *&type); + ParseResult parseMLArgumentList(SmallVectorImpl &argTypes, + SmallVectorImpl &argNames); + ParseResult parseFunctionSignature(StringRef &name, FunctionType *&type, + SmallVectorImpl *argNames); ParseResult parseExtFunc(); ParseResult parseCFGFunc(); ParseResult parseMLFunc(); @@ -1769,14 +1824,50 @@ ParseResult ModuleParser::parseAffineMapDef() { return ParseSuccess; } +/// Parse a (possibly empty) list of MLFunction arguments with types. +/// +/// ml-argument ::= ssa-id `:` type +/// ml-argument-list ::= ml-argument (`,` ml-argument)* | /*empty*/ +/// +ParseResult +ModuleParser::parseMLArgumentList(SmallVectorImpl &argTypes, + SmallVectorImpl &argNames) { + auto parseElt = [&]() -> ParseResult { + // Parse argument name + if (getToken().isNot(Token::percent_identifier)) + return emitError("expected SSA identifier"); + + StringRef name = getTokenSpelling().drop_front(); + consumeToken(Token::percent_identifier); + argNames.push_back(name); + + if (!consumeIf(Token::colon)) + return emitError("expected ':'"); + + // Parse argument type + auto elt = parseType(); + if (!elt) + return ParseFailure; + argTypes.push_back(elt); + + return ParseSuccess; + }; + + if (!consumeIf(Token::l_paren)) + llvm_unreachable("expected '('"); + + return parseCommaSeparatedList(Token::r_paren, parseElt); +} + /// Parse a function signature, starting with a name and including the parameter /// list. /// -/// argument-list ::= type (`,` type)* | /*empty*/ +/// argument-list ::= type (`,` type)* | /*empty*/ | ml-argument-list /// function-signature ::= function-id `(` argument-list `)` (`->` type-list)? /// -ParseResult ModuleParser::parseFunctionSignature(StringRef &name, - FunctionType *&type) { +ParseResult +ModuleParser::parseFunctionSignature(StringRef &name, FunctionType *&type, + SmallVectorImpl *argNames) { if (getToken().isNot(Token::at_identifier)) return emitError("expected a function identifier like '@foo'"); @@ -1786,8 +1877,15 @@ ParseResult ModuleParser::parseFunctionSignature(StringRef &name, if (getToken().isNot(Token::l_paren)) return emitError("expected '(' in function signature"); - SmallVector arguments; - if (parseTypeList(arguments)) + SmallVector argTypes; + ParseResult parseResult; + + if (argNames) + parseResult = parseMLArgumentList(argTypes, *argNames); + else + parseResult = parseTypeList(argTypes); + + if (parseResult) return ParseFailure; // Parse the return type if present. @@ -1796,7 +1894,7 @@ ParseResult ModuleParser::parseFunctionSignature(StringRef &name, if (parseTypeList(results)) return ParseFailure; } - type = builder.getFunctionType(arguments, results); + type = builder.getFunctionType(argTypes, results); return ParseSuccess; } @@ -1809,7 +1907,7 @@ ParseResult ModuleParser::parseExtFunc() { StringRef name; FunctionType *type = nullptr; - if (parseFunctionSignature(name, type)) + if (parseFunctionSignature(name, type, /*arguments*/ nullptr)) return ParseFailure; // Okay, the external function definition was parsed correctly. @@ -1826,7 +1924,7 @@ ParseResult ModuleParser::parseCFGFunc() { StringRef name; FunctionType *type = nullptr; - if (parseFunctionSignature(name, type)) + if (parseFunctionSignature(name, type, /*arguments*/ nullptr)) return ParseFailure; // Okay, the CFG function signature was parsed correctly, create the function. @@ -1844,10 +1942,11 @@ ParseResult ModuleParser::parseMLFunc() { StringRef name; FunctionType *type = nullptr; - + SmallVector argNames; // FIXME: Parse ML function signature (args + types) // by passing pointer to SmallVector into parseFunctionSignature - if (parseFunctionSignature(name, type)) + + if (parseFunctionSignature(name, type, &argNames)) return ParseFailure; // Okay, the ML function signature was parsed correctly, create the function. diff --git a/mlir/lib/Parser/Token.h b/mlir/lib/Parser/Token.h index 73baaace04da..f847bcf4efdf 100644 --- a/mlir/lib/Parser/Token.h +++ b/mlir/lib/Parser/Token.h @@ -76,7 +76,7 @@ public: /// return None. Optional getUnsignedIntegerValue() const; - /// For an integer token, return its value as an int64_t. If it doesn't fit, + /// For an integer token, return its value as an uint64_t. If it doesn't fit, /// return None. Optional getUInt64IntegerValue() const; diff --git a/mlir/lib/Parser/TokenKinds.def b/mlir/lib/Parser/TokenKinds.def index dda5caee215d..de6758cddaf3 100644 --- a/mlir/lib/Parser/TokenKinds.def +++ b/mlir/lib/Parser/TokenKinds.def @@ -104,7 +104,9 @@ TOK_KEYWORD(mlfunc) TOK_KEYWORD(mod) TOK_KEYWORD(return) TOK_KEYWORD(size) +TOK_KEYWORD(step) TOK_KEYWORD(tensor) +TOK_KEYWORD(to) TOK_KEYWORD(true) TOK_KEYWORD(vector) diff --git a/mlir/test/IR/parser-errors.mlir b/mlir/test/IR/parser-errors.mlir index 4f67c8e9b79d..49fd2da893f3 100644 --- a/mlir/test/IR/parser-errors.mlir +++ b/mlir/test/IR/parser-errors.mlir @@ -130,12 +130,24 @@ extfunc @illegaltype(i0) // expected-error {{invalid integer width}} // ----- +mlfunc @malformed_for() { + for %i = 1 too 10 { // expected-error {{expected 'to' between bounds}} + } +} + +// ----- + mlfunc @incomplete_for() { - for + for %i = 1 to 10 step 2 } // expected-error {{expected '{' before statement list}} // ----- +mlfunc @nonconstant_step(%1 : i32) { + for %2 = 1 to 5 step %1 { // expected-error {{expected non-negative integer for now}} + +// ----- + mlfunc @non_statement() { asd // expected-error {{expected operation name in quotes}} } @@ -160,7 +172,6 @@ bb40: return } - // ----- cfgfunc @redef() { @@ -168,4 +179,16 @@ bb42: %x = "dim"(){index: 0} : ()->i32 %x = "dim"(){index: 0} : ()->i32 // expected-error {{redefinition of SSA value %x}} return -} \ No newline at end of file +} + +mlfunc @missing_rbrace() { + return %a +mlfunc @d {return} // expected-error {{expected ',' or '}'}} + +// ----- + +mlfunc @malformed_type(%a : intt) { // expected-error {{expected type}} +} + +// ----- + diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir index 3a7986c7aab7..9b384b4e3bdc 100644 --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -96,22 +96,32 @@ mlfunc @emptyMLF() { return // CHECK: return } // CHECK: } +// CHECK-LABEL: mlfunc @mlfunc_with_args(f16) { +mlfunc @mlfunc_with_args(%a : f16) { + return %a // CHECK: return +} + // CHECK-LABEL: cfgfunc @cfgfunc_with_ops() { cfgfunc @cfgfunc_with_ops() { bb0: %t = "getTensor"() : () -> tensor<4x4x?xf32> + // CHECK: dim xxx, 2 : sometype %a = "dim"(%t){index: 2} : (tensor<4x4x?xf32>) -> affineint // CHECK: addf xx, yy : sometype "addf"() : () -> () + + // CHECK: return return } // CHECK-LABEL: mlfunc @loops() { mlfunc @loops() { - for { // CHECK: for { - for { // CHECK: for { + // CHECK: for x = 1 to 100 step 2 { + for %i = 1 to 100 step 2 { + // CHECK: for x = 1 to 200 { + for %j = 1 to 200 { } // CHECK: } } // CHECK: } return // CHECK: return @@ -119,14 +129,14 @@ mlfunc @loops() { // CHECK-LABEL: mlfunc @ifstmt() { mlfunc @ifstmt() { - for { // CHECK for { - if () { // CHECK if () { - } else if () { // CHECK } else if () { - } else { // CHECK } else { - } // CHECK } - } // CHECK } - return // CHECK return -} // CHECK } + for %i = 1 to 10 { // CHECK for x = 1 to 10 { + if () { // CHECK if () { + } else if () { // CHECK } else if () { + } else { // CHECK } else { + } // CHECK } + } // CHECK } + return // CHECK return +} // CHECK } // CHECK-LABEL: cfgfunc @attributes() { cfgfunc @attributes() {