diff --git a/mlir/bindings/python/pybind.cpp b/mlir/bindings/python/pybind.cpp index ef2a4d92dcc8..9de78d865644 100644 --- a/mlir/bindings/python/pybind.cpp +++ b/mlir/bindings/python/pybind.cpp @@ -241,6 +241,8 @@ struct PythonBlock { return StmtBlock(blk).str(); } + PythonBlock set(const py::list &stmts); + edsc_block_t blk; }; @@ -317,6 +319,14 @@ static edsc_expr_list_t makeCExprs(llvm::SmallVectorImpl &owning, return edsc_expr_list_t{owning.data(), owning.size()}; } +static mlir_type_list_t makeCTypes(llvm::SmallVectorImpl &owning, + const py::list &types) { + for (auto &inp : types) { + owning.push_back(mlir_type_t{inp.cast()}); + } + return mlir_type_list_t{owning.data(), owning.size()}; +} + PythonExpr::PythonExpr(const PythonBindable &bindable) : expr{bindable.expr} {} PythonExpr MLIRFunctionEmitter::bindConstantBF16(double value) { @@ -410,6 +420,12 @@ void MLIRFunctionEmitter::emitBlockBody(PythonBlock block) { emitter.emitStmts(StmtBlock(block).getBody()); } +PythonBlock PythonBlock::set(const py::list &stmts) { + SmallVector owning; + ::BlockSetBody(blk, makeCStmts(owning, stmts)); + return *this; +} + PythonExpr dispatchCall(py::args args, py::kwargs kwargs) { assert(args.size() != 0); llvm::SmallVector exprs; @@ -438,10 +454,24 @@ PYBIND11_MODULE(pybind, m) { m.def("deleteContext", [](void *ctx) { delete reinterpret_cast(ctx); }); + m.def("Block", [](const py::list &args, const py::list &stmts) { + SmallVector owning; + SmallVector owningArgs; + return PythonBlock( + ::Block(makeCExprs(owningArgs, args), makeCStmts(owning, stmts))); + }); m.def("Block", [](const py::list &stmts) { SmallVector owning; - return PythonBlock(::Block(makeCStmts(owning, stmts))); + edsc_expr_list_t args{nullptr, 0}; + return PythonBlock(::Block(args, makeCStmts(owning, stmts))); }); + m.def( + "Branch", + [](PythonBlock destination, const py::list &operands) { + SmallVector owning; + return PythonStmt(::Branch(destination, makeCExprs(owning, operands))); + }, + py::arg("destination"), py::arg("operands") = py::list()); m.def("For", [](const py::list &ivs, const py::list &lbs, const py::list &ubs, const py::list &steps, const py::list &stmts) { SmallVector owningIVs; @@ -527,6 +557,7 @@ PYBIND11_MODULE(pybind, m) { py::class_(m, "StmtBlock", "Wrapping class for mlir::edsc::StmtBlock") .def(py::init()) + .def("set", &PythonBlock::set) .def("__str__", &PythonBlock::str); py::class_(m, "Type", "Wrapping class for mlir::Type.") diff --git a/mlir/bindings/python/test/test_py2and3.py b/mlir/bindings/python/test/test_py2and3.py index c6ab3ff25714..314e1b55b2e3 100644 --- a/mlir/bindings/python/test/test_py2and3.py +++ b/mlir/bindings/python/test/test_py2and3.py @@ -151,10 +151,47 @@ class EdscTest(unittest.TestCase): i, j = list(map(E.Expr, [E.Bindable(self.f32Type) for _ in range(2)])) stmt = E.Block([E.Stmt(i + j), E.Stmt(i - j)]) str = stmt.__str__() - self.assertIn("^bb:", str) + self.assertIn("^bb", str) self.assertIn(" = ($1 + $2)", str) self.assertIn(" = ($1 - $2)", str) + def testBlockArgs(self): + with E.ContextManager(): + module = E.MLIRModule() + t = module.make_scalar_type("i", 32) + i, j = list(map(E.Expr, [E.Bindable(t) for _ in range(2)])) + stmt = E.Block([i, j], [E.Stmt(i + j)]) + str = stmt.__str__() + self.assertIn("^bb", str) + self.assertIn("($1, $2):", str) + self.assertIn("($1 + $2)", str) + + def testBranch(self): + with E.ContextManager(): + i, j = list(map(E.Expr, [E.Bindable(self.i32Type) for _ in range(2)])) + b1 = E.Block([E.Stmt(i + j)]) + b2 = E.Block([E.Branch(b1)]) + str1 = b1.__str__() + str2 = b2.__str__() + self.assertIn("^bb1:\n" + "$4 = ($1 + $2)", str1) + self.assertIn("^bb2:\n" + "$6 = br ^bb1", str2) + + def testBranchArgs(self): + with E.ContextManager(): + b1arg, b2arg = (E.Expr(E.Bindable(self.i32Type)) for _ in range(2)) + # Declare empty blocks with arguments and bind those arguments. + b1 = E.Block([b1arg], []) + b2 = E.Block([b2arg], []) + one = E.ConstantInteger(self.i32Type, 1) + # Make blocks branch to each other in a sort of infinite loop. + # This checks that the EDSC implementation does not fall into such loop. + b1.set([E.Branch(b2, [b1arg + one])]) + b2.set([E.Branch(b1, [b2arg])]) + str1 = b1.__str__() + str2 = b2.__str__() + self.assertIn("^bb1($1):\n" + "$6 = br ^bb2(($1 + 1))", str1) + self.assertIn("^bb2($2):\n" + "$8 = br ^bb1($2)", str2) + def testMLIRScalarTypes(self): module = E.MLIRModule() t = module.make_scalar_type("bf16") diff --git a/mlir/include/mlir-c/Core.h b/mlir/include/mlir-c/Core.h index 09dc5c850e27..44ab5cb61beb 100644 --- a/mlir/include/mlir-c/Core.h +++ b/mlir/include/mlir-c/Core.h @@ -230,8 +230,15 @@ edsc_expr_t ConstantInteger(mlir_type_t type, int64_t value); edsc_stmt_t Return(edsc_expr_list_t values); /// Returns an opaque expression for an mlir::edsc::StmtBlock containing the -/// given list of statements. Block arguments are not currently supported. -edsc_block_t Block(edsc_stmt_list_t enclosedStmts); +/// given list of statements. +edsc_block_t Block(edsc_expr_list_t arguments, edsc_stmt_list_t enclosedStmts); + +/// Set the body of the block to the given statements and return the block. +edsc_block_t BlockSetBody(edsc_block_t, edsc_stmt_list_t stmts); + +/// Returns an opaque statement branching to `destination` and passing +/// `arguments` as block arguments. +edsc_stmt_t Branch(edsc_block_t destination, edsc_expr_list_t arguments); /// Returns an opaque statement for an mlir::AffineForOp with `enclosedStmts` /// nested below it. diff --git a/mlir/include/mlir/EDSC/Types.h b/mlir/include/mlir/EDSC/Types.h index bb333a93938c..9f29ab95bd0f 100644 --- a/mlir/include/mlir/EDSC/Types.h +++ b/mlir/include/mlir/EDSC/Types.h @@ -48,6 +48,8 @@ struct StmtBlockStorage; } // namespace detail +class StmtBlock; + /// EDSC Types closely mirror the core MLIR and uses an abstraction similar to /// AffineExpr: /// 1. a set of composable structs; @@ -164,7 +166,20 @@ public: ArrayRef getResultTypes() const; /// Returns the list of expressions used as arguments of this expression. - ArrayRef getChildExpressions() const; + ArrayRef getProperArguments() const; + + /// Returns the list of lists of expressions used as arguments of successors + /// of this expression (i.e., arguments passed to destination basic blocks in + /// terminator statements). + SmallVector, 4> getSuccessorArguments() const; + + /// Returns the list of expressions used as arguments of the `index`-th + /// successor of this expression. + ArrayRef getSuccessorArguments(int index) const; + + /// Returns the list of argument groups (includes the proper argument group, + /// followed by successor/block argument groups). + SmallVector, 4> getAllArgumentGroups() const; /// Returns the list of attributes of this expression. ArrayRef getAttributes() const; @@ -172,9 +187,13 @@ public: /// Returns the attribute with the given name, if any. Attribute getAttribute(StringRef name) const; + /// Returns the list of successors (StmtBlocks) of this expression. + ArrayRef getSuccessors() const; + /// Build the IR corresponding to this expression. SmallVector - build(FuncBuilder &b, const llvm::DenseMap &ssaBindings) const; + build(FuncBuilder &b, const llvm::DenseMap &ssaBindings, + const llvm::DenseMap &blockBindings) const; void print(raw_ostream &os) const; void dump() const; @@ -267,15 +286,18 @@ struct VariadicExpr : public Expr { friend class Expr; VariadicExpr(StringRef name, llvm::ArrayRef exprs, llvm::ArrayRef types = {}, - ArrayRef attrs = {}); + llvm::ArrayRef attrs = {}, + llvm::ArrayRef succ = {}); llvm::ArrayRef getExprs() const; llvm::ArrayRef getTypes() const; + llvm::ArrayRef getSuccessors() const; template static VariadicExpr make(llvm::ArrayRef exprs, llvm::ArrayRef types = {}, - llvm::ArrayRef attrs = {}) { - return VariadicExpr(T::getOperationName(), exprs, types, attrs); + llvm::ArrayRef attrs = {}, + llvm::ArrayRef succ = {}) { + return VariadicExpr(T::getOperationName(), exprs, types, attrs, succ); } protected: @@ -289,18 +311,6 @@ struct StmtBlockLikeExpr : public Expr { StmtBlockLikeExpr(ExprKind kind, llvm::ArrayRef exprs, llvm::ArrayRef types = {}); - /// Get the list of subexpressions. - /// StmtBlockLikeExprs can contain multiple groups of subexpressions separated - /// by null expressions and the result of this call will include them. - llvm::ArrayRef getExprs() const; - - /// Get the list of subexpression groups. - /// StmtBlockLikeExprs can contain multiple groups of subexpressions separated - /// by null expressions. This will identify those groups and return a list - /// of lists of subexpressions split around null expressions. Two null - /// expressions in a row identify an empty group. - SmallVector, 4> getExprGroups() const; - protected: StmtBlockLikeExpr(Expr::ImplType *ptr) : Expr(ptr) { assert(!ptr || isa() && "expected StmtBlockLikeExpr"); @@ -399,19 +409,23 @@ public: explicit StmtBlock(edsc_block_t st) : storage(reinterpret_cast(st)) {} StmtBlock(const StmtBlock &other) = default; - StmtBlock(llvm::ArrayRef stmts = {}); - StmtBlock(llvm::ArrayRef args, llvm::ArrayRef argTypes, - llvm::ArrayRef stmts = {}); + StmtBlock(llvm::ArrayRef stmts); + StmtBlock(llvm::ArrayRef args, llvm::ArrayRef stmts = {}); llvm::ArrayRef getArguments() const; llvm::ArrayRef getArgumentTypes() const; llvm::ArrayRef getBody() const; + uint64_t getId() const; void print(llvm::raw_ostream &os, Twine indent) const; std::string str() const; operator edsc_block_t() { return edsc_block_t{storage}; } + /// Reset the body of this block with the given list of statements. + StmtBlock &operator=(llvm::ArrayRef stmts); + void set(llvm::ArrayRef stmts) { *this = stmts; } + ImplType *getStoragePtr() const { return storage; } private: @@ -619,6 +633,8 @@ Expr call(Expr func, Type result, llvm::ArrayRef args); Expr call(Expr func, llvm::ArrayRef args); Stmt Return(ArrayRef values = {}); +Stmt Branch(StmtBlock destination, ArrayRef args = {}); + Stmt For(Expr lb, Expr ub, Expr step, llvm::ArrayRef enclosedStmts); Stmt For(const Bindable &idx, Expr lb, Expr ub, Expr step, llvm::ArrayRef enclosedStmts); @@ -633,11 +649,13 @@ Stmt For(llvm::ArrayRef indices, llvm::ArrayRef lbs, Stmt MaxMinFor(const Bindable &idx, ArrayRef lbs, ArrayRef ubs, Expr step, ArrayRef enclosedStmts); -StmtBlock block(llvm::ArrayRef args, llvm::ArrayRef argTypes, - llvm::ArrayRef stmts); -inline StmtBlock block(llvm::ArrayRef stmts) { - return block({}, {}, stmts); -} +/// Define an MLIR Block and bind its arguments to `args`. The types of block +/// arguments are those of `args`, each of which must have exactly one result +/// type. The body of the block may be empty and can be reset later. +StmtBlock block(llvm::ArrayRef args, llvm::ArrayRef stmts); +/// Define an MLIR Block without arguments. The body of the block can be empty +/// and can be reset later. +inline StmtBlock block(llvm::ArrayRef stmts) { return block({}, stmts); } /// This helper class exists purely for sugaring purposes and allows writing /// expressions such as: diff --git a/mlir/lib/EDSC/LowerEDSCTestPass.cpp b/mlir/lib/EDSC/LowerEDSCTestPass.cpp index 67d4fb380807..9ea9e1c392d1 100644 --- a/mlir/lib/EDSC/LowerEDSCTestPass.cpp +++ b/mlir/lib/EDSC/LowerEDSCTestPass.cpp @@ -42,21 +42,43 @@ struct LowerEDSCTestPass : public FunctionPass { #include "mlir/EDSC/reference-impl.inc" PassResult LowerEDSCTestPass::runOnFunction(Function *f) { - // Inject a EDSC-constructed list of blocks. + // Inject a EDSC-constructed infinite loop implemented by mutual branching + // between two blocks, following the pattern: + // + // br ^bb1 + // ^bb1: + // br ^bb2 + // ^bb2: + // br ^bb1 + // + // Use blocks with arguments. if (f->getName().strref() == "blocks") { using namespace edsc::op; FuncBuilder builder(f); edsc::ScopedEDSCContext context; + // Declare two blocks. Note that we must declare the blocks before creating + // branches to them. auto type = builder.getIntegerType(32); - edsc::Expr arg1(type), arg2(type), arg3(type), arg4(type); + edsc::Expr arg1(type), arg2(type), arg3(type), arg4(type), r(type); + edsc::StmtBlock b1 = edsc::block({arg1, arg2}, {}), + b2 = edsc::block({arg3, arg4}, {}); + auto c1 = edsc::constantInteger(type, 42); + auto c2 = edsc::constantInteger(type, 1234); - auto b1 = - edsc::block({arg1, arg2}, {type, type}, {arg1 + arg2, edsc::Return()}); - auto b2 = - edsc::block({arg3, arg4}, {type, type}, {arg3 - arg4, edsc::Return()}); + // Make an infinite loops by branching between the blocks. Note that copy- + // assigning a block won't work well with branches, update the body instead. + b1.set({r = arg1 + arg2, edsc::Branch(b2, {arg1, r})}); + b2.set({edsc::Branch(b1, {arg3, arg4})}); + auto instr = edsc::Branch(b2, {c1, c2}); - edsc::MLIREmitter(&builder, f->getLoc()).emitBlock(b1).emitBlock(b2); + // Remove the existing 'return' from the function, reset the builder after + // the instruction iterator invalidation and emit a branch to b2. This + // should also emit blocks b2 and b1 that appear as successors to the + // current block after the branch instruction is insterted. + f->begin()->clear(); + builder.setInsertionPoint(&*f->begin(), f->begin()->begin()); + edsc::MLIREmitter(&builder, f->getLoc()).emitStmt(instr); } // Inject a EDSC-constructed `for` loop with bounds coming from function diff --git a/mlir/lib/EDSC/MLIREmitter.cpp b/mlir/lib/EDSC/MLIREmitter.cpp index aa4ca47e6609..d52fdcd57c57 100644 --- a/mlir/lib/EDSC/MLIREmitter.cpp +++ b/mlir/lib/EDSC/MLIREmitter.cpp @@ -129,7 +129,16 @@ Value *mlir::edsc::MLIREmitter::emitExpr(Expr e) { bool expectedEmpty = false; if (e.isa() || e.isa() || e.isa() || e.isa()) { - auto results = e.build(*builder, ssaBindings); + // Emit any successors before the instruction with successors. At this + // point, all values defined by the current block must have been bound, the + // current instruction with successors cannot define new values, so the + // successor can use those values. + assert(e.getSuccessors().empty() || e.getResultTypes().empty() && + "an operation with successors must " + "not have results and vice versa"); + for (StmtBlock block : e.getSuccessors()) + emitBlock(block); + auto results = e.build(*builder, ssaBindings, blockBindings); assert(results.size() <= 1 && "2+-result exprs are not supported"); expectedEmpty = results.empty(); if (!results.empty()) @@ -138,7 +147,7 @@ Value *mlir::edsc::MLIREmitter::emitExpr(Expr e) { if (auto expr = e.dyn_cast()) { if (expr.getKind() == ExprKind::For) { - auto exprGroups = expr.getExprGroups(); + auto exprGroups = expr.getAllArgumentGroups(); assert(exprGroups.size() == 3 && "expected 3 expr groups in `for`"); assert(!exprGroups[0].empty() && "expected at least one lower bound"); assert(!exprGroups[1].empty() && "expected at least one upper bound"); @@ -213,8 +222,9 @@ mlir::edsc::MLIREmitter &mlir::edsc::MLIREmitter::emitStmt(const Stmt &stmt) { if (!val) { assert((stmt.getRHS().is_op() || stmt.getRHS().is_op() || stmt.getRHS().is_op() || - stmt.getRHS().is_op()) && - "dealloc, store, return or call_indirect expected as the only " + stmt.getRHS().is_op() || + stmt.getRHS().is_op()) && + "dealloc, store, return, br, or call_indirect expected as the only " "0-result ops"); if (stmt.getRHS().is_op()) { assert( diff --git a/mlir/lib/EDSC/Types.cpp b/mlir/lib/EDSC/Types.cpp index 39f09be8c5f6..9e1a9795823d 100644 --- a/mlir/lib/EDSC/Types.cpp +++ b/mlir/lib/EDSC/Types.cpp @@ -64,17 +64,23 @@ struct ExprStorage { unsigned id; StringRef opName; + + // Exprs can contain multiple groups of operands separated by null + // expressions. Two null expressions in a row identify an empty group. ArrayRef operands; + ArrayRef resultTypes; ArrayRef attributes; + ArrayRef successors; ExprStorage(ExprKind kind, StringRef name, ArrayRef results, ArrayRef children, ArrayRef attrs, - StringRef descr = "", unsigned exprId = Expr::newId()) + ArrayRef succ = {}, unsigned exprId = Expr::newId()) : kind(kind), id(exprId) { operands = copyIntoExprAllocator(children); resultTypes = copyIntoExprAllocator(results); attributes = copyIntoExprAllocator(attrs); + successors = copyIntoExprAllocator(succ); if (!name.empty()) { auto nameStorage = Expr::globalAllocator()->Allocate(name.size()); std::uninitialized_copy(name.begin(), name.end(), nameStorage); @@ -94,11 +100,24 @@ struct StmtStorage { struct StmtBlockStorage { StmtBlockStorage(ArrayRef args, ArrayRef argTypes, ArrayRef stmts) { + id = nextId(); arguments = copyIntoExprAllocator(args); argumentTypes = copyIntoExprAllocator(argTypes); statements = copyIntoExprAllocator(stmts); } + void replaceStmts(ArrayRef stmts) { + Expr::globalAllocator()->Deallocate(statements.data(), statements.size()); + statements = copyIntoExprAllocator(stmts); + } + + static uint64_t &nextId() { + static thread_local uint64_t next = 0; + return ++next; + } + static void resetIds() { nextId() = 0; } + + uint64_t id; ArrayRef arguments; ArrayRef argumentTypes; ArrayRef statements; @@ -111,6 +130,7 @@ struct StmtBlockStorage { mlir::edsc::ScopedEDSCContext::ScopedEDSCContext() { Expr::globalAllocator() = &allocator; Bindable::resetIds(); + StmtBlockStorage::resetIds(); } mlir::edsc::ScopedEDSCContext::~ScopedEDSCContext() { @@ -138,10 +158,6 @@ ArrayRef mlir::edsc::Expr::getResultTypes() const { return storage->resultTypes; } -ArrayRef mlir::edsc::Expr::getChildExpressions() const { - return storage->operands; -} - ArrayRef mlir::edsc::Expr::getAttributes() const { return storage->attributes; } @@ -153,26 +169,38 @@ Attribute mlir::edsc::Expr::getAttribute(StringRef name) const { return {}; } +ArrayRef mlir::edsc::Expr::getSuccessors() const { + return storage->successors; +} + StringRef mlir::edsc::Expr::getName() const { return static_cast(storage)->opName; } SmallVector -Expr::build(FuncBuilder &b, - const llvm::DenseMap &ssaBindings) const { +buildExprs(ArrayRef exprs, FuncBuilder &b, + const llvm::DenseMap &ssaBindings, + const llvm::DenseMap &blockBindings) { + SmallVector values; + values.reserve(exprs.size()); + for (auto child : exprs) { + auto subResults = child.build(b, ssaBindings, blockBindings); + assert(subResults.size() == 1 && + "expected single-result expression as operand"); + values.push_back(subResults.front()); + } + return values; +} + +SmallVector +Expr::build(FuncBuilder &b, const llvm::DenseMap &ssaBindings, + const llvm::DenseMap &blockBindings) const { auto it = ssaBindings.find(*this); if (it != ssaBindings.end()) return {it->second}; - auto *impl = static_cast(storage); - SmallVector operandValues; - operandValues.reserve(impl->operands.size()); - for (auto child : impl->operands) { - auto subResults = child.build(b, ssaBindings); - assert(subResults.size() == 1 && - "expected single-result expression as operand"); - operandValues.push_back(subResults.front()); - } + SmallVector operandValues = + buildExprs(getProperArguments(), b, ssaBindings, blockBindings); // Special case for emitting composed affine.applies. // FIXME: this should not be a special case, instead, define composed form as @@ -185,12 +213,24 @@ Expr::build(FuncBuilder &b, return {affInstr->getResult()}; } - auto state = OperationState(b.getContext(), b.getUnknownLoc(), impl->opName); + auto state = OperationState(b.getContext(), b.getUnknownLoc(), getName()); state.addOperands(operandValues); - state.addTypes(impl->resultTypes); - for (const auto &attr : impl->attributes) + state.addTypes(getResultTypes()); + for (const auto &attr : getAttributes()) state.addAttribute(attr.first, attr.second); + auto successors = getSuccessors(); + auto successorArgs = getSuccessorArguments(); + assert(successors.size() == successorArgs.size() && + "expected all successors to have a corresponding operand group"); + for (int i = 0, e = successors.size(); i < e; ++i) { + StmtBlock block = successors[i]; + assert(blockBindings.count(block) != 0 && "successor block does not exist"); + state.addSuccessor( + blockBindings.lookup(block), + buildExprs(successorArgs[i], b, ssaBindings, blockBindings)); + } + Instruction *inst = b.createOperation(state); return llvm::to_vector<4>(inst->getResults()); } @@ -499,17 +539,26 @@ edsc_stmt_t MaxMinFor(edsc_expr_t iv, edsc_max_expr_t lb, edsc_min_expr_t ub, Expr(step), stmts)); } -StmtBlock mlir::edsc::block(ArrayRef args, ArrayRef argTypes, - ArrayRef stmts) { - assert(args.size() == argTypes.size() && - "mismatching number of arguments and argument types"); - return StmtBlock(args, argTypes, stmts); +StmtBlock mlir::edsc::block(ArrayRef args, ArrayRef stmts) { + return StmtBlock(args, stmts); } -edsc_block_t Block(edsc_stmt_list_t enclosedStmts) { +edsc_block_t Block(edsc_expr_list_t arguments, edsc_stmt_list_t enclosedStmts) { llvm::SmallVector stmts; fillStmts(enclosedStmts, &stmts); - return StmtBlock(stmts); + + llvm::SmallVector args; + for (uint64_t i = 0; i < arguments.n; ++i) + args.emplace_back(Expr(arguments.exprs[i])); + + return StmtBlock(args, stmts); +} + +edsc_block_t BlockSetBody(edsc_block_t block, edsc_stmt_list_t stmts) { + llvm::SmallVector body; + fillStmts(stmts, &body); + StmtBlock(block).set(body); + return block; } Expr mlir::edsc::load(Expr m, ArrayRef indices) { @@ -593,6 +642,13 @@ edsc_stmt_t Return(edsc_expr_list_t values) { return Stmt(Return(makeExprs(values))); } +Stmt mlir::edsc::Branch(StmtBlock destination, ArrayRef args) { + SmallVector arguments; + arguments.push_back(nullptr); + arguments.insert(arguments.end(), args.begin(), args.end()); + return VariadicExpr::make(arguments, {}, {}, {destination}); +} + static raw_ostream &printBinaryExpr(raw_ostream &os, BinaryExpr e, StringRef infix) { os << '(' << e.getLHS() << ' ' << infix << ' ' << e.getRHS() << ')'; @@ -701,7 +757,12 @@ void printAffineApply(raw_ostream &os, mlir::edsc::Expr e) { assert(mapAttr && "expected a map in an affine apply expr"); printAffineMap(os, mapAttr.cast().getValue(), - e.getChildExpressions()); + e.getProperArguments()); +} + +edsc_stmt_t Branch(edsc_block_t destination, edsc_expr_list_t arguments) { + auto args = makeExprs(arguments); + return mlir::edsc::Branch(StmtBlock(destination), args); } void mlir::edsc::Expr::print(raw_ostream &os) const { @@ -737,15 +798,15 @@ void mlir::edsc::Expr::print(raw_ostream &os) const { // Handle known variadic ops with pretty forms. if (auto narExpr = this->dyn_cast()) { if (narExpr.is_op()) { - os << narExpr.getName() << '(' << getChildExpressions().front() << '['; - interleaveComma(getChildExpressions().drop_front(), os); + os << narExpr.getName() << '(' << getProperArguments().front() << '['; + interleaveComma(getProperArguments().drop_front(), os); os << "])"; return; } if (narExpr.is_op()) { - os << narExpr.getName() << '(' << getChildExpressions().front() << ", " - << getChildExpressions()[1] << '['; - interleaveComma(getChildExpressions().drop_front(2), os); + os << narExpr.getName() << '(' << getProperArguments().front() << ", " + << getProperArguments()[1] << '['; + interleaveComma(getProperArguments().drop_front(2), os); os << "])"; return; } @@ -756,11 +817,21 @@ void mlir::edsc::Expr::print(raw_ostream &os) const { return; } if (narExpr.is_op()) { - os << '@' << getChildExpressions().front() << '('; - interleaveComma(getChildExpressions().drop_front(), os); + os << '@' << getProperArguments().front() << '('; + interleaveComma(getProperArguments().drop_front(), os); os << ')'; return; } + if (narExpr.is_op()) { + os << "br ^bb" << narExpr.getSuccessors().front().getId(); + auto blockArgs = getSuccessorArguments(0); + if (!blockArgs.empty()) + os << '('; + interleaveComma(blockArgs, os); + if (!blockArgs.empty()) + os << ")"; + return; + } } // Special case for integer constants that are printed as is. Use @@ -778,7 +849,26 @@ void mlir::edsc::Expr::print(raw_ostream &os) const { if (this->isa() || this->isa() || this->isa() || this->isa()) { os << (getName().empty() ? "##unknown##" : getName()) << '('; - interleaveComma(getChildExpressions(), os); + interleaveComma(getProperArguments(), os); + auto successors = getSuccessors(); + if (!successors.empty()) { + os << '['; + interleave( + llvm::zip(successors, getSuccessorArguments()), + [&os](const std::tuple &> + &pair) { + const auto &block = std::get<0>(pair); + ArrayRef operands = std::get<1>(pair); + os << "^bb" << block.getId(); + if (!operands.empty()) { + os << '('; + interleaveComma(operands, os); + os << ')'; + } + }, + [&os]() { os << ", "; }); + os << ']'; + } auto attrs = getAttributes(); if (!attrs.empty()) { os << '{'; @@ -797,7 +887,7 @@ void mlir::edsc::Expr::print(raw_ostream &os) const { // We only print the lb, ub and step here, which are the StmtBlockLike // part of the `for` StmtBlockLikeExpr. case ExprKind::For: { - auto exprGroups = stmtLikeExpr.getExprGroups(); + auto exprGroups = stmtLikeExpr.getAllArgumentGroups(); assert(exprGroups.size() == 3 && "For StmtBlockLikeExpr expected 3 groups"); assert(exprGroups[2].size() == 1 && "expected 1 expr for loop step"); @@ -885,17 +975,21 @@ Expr mlir::edsc::TernaryExpr::getRHS() const { mlir::edsc::VariadicExpr::VariadicExpr(StringRef name, ArrayRef exprs, ArrayRef types, - ArrayRef attrs) + ArrayRef attrs, + ArrayRef succ) : Expr(Expr::globalAllocator()->Allocate()) { // Initialize with placement new. new (storage) - detail::ExprStorage(ExprKind::Variadic, name, types, exprs, attrs); + detail::ExprStorage(ExprKind::Variadic, name, types, exprs, attrs, succ); } ArrayRef mlir::edsc::VariadicExpr::getExprs() const { - return static_cast(storage)->operands; + return storage->operands; } ArrayRef mlir::edsc::VariadicExpr::getTypes() const { - return static_cast(storage)->resultTypes; + return storage->resultTypes; +} +ArrayRef mlir::edsc::VariadicExpr::getSuccessors() const { + return storage->successors; } mlir::edsc::StmtBlockLikeExpr::StmtBlockLikeExpr(ExprKind kind, @@ -905,24 +999,56 @@ mlir::edsc::StmtBlockLikeExpr::StmtBlockLikeExpr(ExprKind kind, // Initialize with placement new. new (storage) detail::ExprStorage(kind, "", types, exprs, {}); } -ArrayRef mlir::edsc::StmtBlockLikeExpr::getExprs() const { - return static_cast(storage)->operands; -} -SmallVector, 4> -mlir::edsc::StmtBlockLikeExpr::getExprGroups() const { - SmallVector, 4> groups; - ArrayRef exprs = getExprs(); - int start = 0; - for (int i = 0, e = exprs.size(); i < e; ++i) { - if (!exprs[i]) { - groups.push_back(exprs.slice(start, i - start)); - start = i + 1; - } + +static ArrayRef getOneArgumentGroupStartingFrom(int start, + ExprStorage *storage) { + for (int i = start, e = storage->operands.size(); i < e; ++i) { + if (!storage->operands[i]) + return storage->operands.slice(start, i - start); + } + return storage->operands.drop_front(start); +} + +static SmallVector, 4> +getAllArgumentGroupsStartingFrom(int start, ExprStorage *storage) { + SmallVector, 4> groups; + while (start < storage->operands.size()) { + auto group = getOneArgumentGroupStartingFrom(start, storage); + start += group.size() + 1; + groups.push_back(group); } - groups.push_back(exprs.slice(start, exprs.size() - start)); return groups; } +ArrayRef mlir::edsc::Expr::getProperArguments() const { + return getOneArgumentGroupStartingFrom(0, storage); +} + +SmallVector, 4> mlir::edsc::Expr::getSuccessorArguments() const { + // Skip the first group containing proper arguments. + // Note that +1 to size is necessary to step over the nullptrs in the list. + int start = getOneArgumentGroupStartingFrom(0, storage).size() + 1; + return getAllArgumentGroupsStartingFrom(start, storage); +} + +ArrayRef mlir::edsc::Expr::getSuccessorArguments(int index) const { + assert(index >= 0 && "argument group index is out of bounds"); + assert(!storage->operands.empty() && "argument list is empty"); + + // Skip over the first index + 1 groups (also includes proper arguments). + int start = 0; + for (int i = 0, e = index + 1; i < e; ++i) { + assert(start < storage->operands.size() && + "argument group index is out of bounds"); + start += getOneArgumentGroupStartingFrom(start, storage).size() + 1; + } + return getOneArgumentGroupStartingFrom(start, storage); +} + +SmallVector, 4> mlir::edsc::Expr::getAllArgumentGroups() const { + return getAllArgumentGroupsStartingFrom(0, storage); +} + mlir::edsc::Stmt::Stmt(const Bindable &lhs, const Expr &rhs, llvm::ArrayRef enclosedStmts) { storage = Expr::globalAllocator()->Allocate(); @@ -1012,15 +1138,29 @@ llvm::raw_ostream &mlir::edsc::operator<<(llvm::raw_ostream &os, } mlir::edsc::StmtBlock::StmtBlock(llvm::ArrayRef stmts) - : StmtBlock({}, {}, stmts) {} + : StmtBlock({}, stmts) {} mlir::edsc::StmtBlock::StmtBlock(llvm::ArrayRef args, - llvm::ArrayRef argTypes, llvm::ArrayRef stmts) { + // Extract block argument types from bindable types. + // Bindables must have a single type. + llvm::SmallVector argTypes; + argTypes.reserve(args.size()); + for (Bindable arg : args) { + auto argResults = arg.getResultTypes(); + assert(argResults.size() == 1 && + "only single-result expressions are supported"); + argTypes.push_back(argResults.front()); + } storage = Expr::globalAllocator()->Allocate(); new (storage) detail::StmtBlockStorage(args, argTypes, stmts); } +mlir::edsc::StmtBlock &mlir::edsc::StmtBlock::operator=(ArrayRef stmts) { + storage->replaceStmts(stmts); + return *this; +} + ArrayRef mlir::edsc::StmtBlock::getArguments() const { return storage->arguments; } @@ -1033,17 +1173,20 @@ ArrayRef mlir::edsc::StmtBlock::getBody() const { return storage->statements; } +uint64_t mlir::edsc::StmtBlock::getId() const { return storage->id; } + void mlir::edsc::StmtBlock::print(llvm::raw_ostream &os, Twine indent) const { - os << indent << "^bb"; + os << indent << "^bb" << storage->id; if (!getArgumentTypes().empty()) os << '('; interleaveComma(getArguments(), os); if (!getArgumentTypes().empty()) os << ')'; os << ":\n"; - - for (auto stmt : getBody()) + for (auto stmt : getBody()) { stmt.print(os, indent + " "); + os << '\n'; + } } std::string mlir::edsc::StmtBlock::str() const { diff --git a/mlir/test/EDSC/for-loops.mlir b/mlir/test/EDSC/for-loops.mlir index 7032b160316e..ea28ae1e4242 100644 --- a/mlir/test/EDSC/for-loops.mlir +++ b/mlir/test/EDSC/for-loops.mlir @@ -10,16 +10,19 @@ // CHECK-DAG: #[[id2dmap:.*]] = (d0, d1) -> (d0, d1) // This function will be detected by the test pass that will insert -// EDSC-constructed blocks with arguments. +// EDSC-constructed blocks with arguments forming an infinite loop. // CHECK-LABEL: @blocks func @blocks() { return -//CHECK: ^bb1(%0: i32, %1: i32): // no predecessors -//CHECK-NEXT: %2 = addi %0, %1 : i32 -//CHECK-NEXT: return -//CHECK: ^bb2(%3: i32, %4: i32): // no predecessors -//CHECK-NEXT: %5 = subi %3, %4 : i32 -//CHECK-NEXT: return +//CHECK: %c42_i32 = constant 42 : i32 +//CHECK-NEXT: %c1234_i32 = constant 1234 : i32 +//CHECK-NEXT: br ^bb1(%c42_i32, %c1234_i32 : i32, i32) +//CHECK-NEXT: ^bb1(%0: i32, %1: i32): // 2 preds: ^bb0, ^bb2 +//CHECK-NEXT: br ^bb2(%0, %1 : i32, i32) +//CHECK-NEXT: ^bb2(%2: i32, %3: i32): // pred: ^bb1 +//CHECK-NEXT: %4 = addi %2, %3 : i32 +//CHECK-NEXT: br ^bb1(%2, %4 : i32, i32) +//CHECK-NEXT: } } // This function will be detected by the test pass that will insert an