From e7193a70f8217139eed53ebd5b6df7851f9b13b6 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Mon, 25 Feb 2019 09:22:04 -0800 Subject: [PATCH] EDSC: support conditional branch instructions Leverage the recently introduced support for multiple argument groups and multiple destination blocks in EDSC Expressions to implement conditional branches in EDSC. Conditional branches have two successors and three argument groups. The first group contains a single expression of i1 type that corresponds to the condition of the branch. The two following groups contain arguments of the two successors of the conditional branch instruction, in the same order as the successors. Expose this instruction to the C API and Python bindings. PiperOrigin-RevId: 235542768 --- mlir/bindings/python/pybind.cpp | 18 ++++++++ mlir/bindings/python/test/test_py2and3.py | 20 +++++++++ mlir/include/mlir-c/Core.h | 10 +++++ mlir/include/mlir/EDSC/Types.h | 5 +++ mlir/lib/EDSC/LowerEDSCTestPass.cpp | 34 +++++++++++++- mlir/lib/EDSC/MLIREmitter.cpp | 7 +-- mlir/lib/EDSC/Types.cpp | 55 ++++++++++++++++++++--- mlir/test/EDSC/for-loops.mlir | 18 ++++++++ 8 files changed, 157 insertions(+), 10 deletions(-) diff --git a/mlir/bindings/python/pybind.cpp b/mlir/bindings/python/pybind.cpp index 9de78d865644..2673c8a58b17 100644 --- a/mlir/bindings/python/pybind.cpp +++ b/mlir/bindings/python/pybind.cpp @@ -472,6 +472,24 @@ PYBIND11_MODULE(pybind, m) { return PythonStmt(::Branch(destination, makeCExprs(owning, operands))); }, py::arg("destination"), py::arg("operands") = py::list()); + m.def("CondBranch", + [](PythonExpr condition, PythonBlock trueDestination, + const py::list &trueOperands, PythonBlock falseDestination, + const py::list &falseOperands) { + SmallVector owningTrue; + SmallVector owningFalse; + return PythonStmt(::CondBranch( + condition, trueDestination, makeCExprs(owningTrue, trueOperands), + falseDestination, makeCExprs(owningFalse, falseOperands))); + }); + m.def("CondBranch", [](PythonExpr condition, PythonBlock trueDestination, + PythonBlock falseDestination) { + edsc_expr_list_t emptyList; + emptyList.exprs = nullptr; + emptyList.n = 0; + return PythonStmt(::CondBranch(condition, trueDestination, emptyList, + falseDestination, emptyList)); + }); m.def("For", [](const py::list &ivs, const py::list &lbs, const py::list &ubs, const py::list &steps, const py::list &stmts) { SmallVector owningIVs; diff --git a/mlir/bindings/python/test/test_py2and3.py b/mlir/bindings/python/test/test_py2and3.py index 314e1b55b2e3..4b78402f9300 100644 --- a/mlir/bindings/python/test/test_py2and3.py +++ b/mlir/bindings/python/test/test_py2and3.py @@ -192,6 +192,26 @@ class EdscTest(unittest.TestCase): self.assertIn("^bb1($1):\n" + "$6 = br ^bb2(($1 + 1))", str1) self.assertIn("^bb2($2):\n" + "$8 = br ^bb1($2)", str2) + def testCondBranch(self): + with E.ContextManager(): + cond = E.Expr(E.Bindable(self.boolType)) + b1 = E.Block([]) + b2 = E.Block([]) + b3 = E.Block([E.CondBranch(cond, b1, b2)]) + str = b3.__str__() + self.assertIn("cond_br($1, ^bb1, ^bb2)", str) + + def testCondBranchArgs(self): + with E.ContextManager(): + arg1, arg2, arg3 = (E.Expr(E.Bindable(self.i32Type)) for _ in range(3)) + expr1, expr2, expr3 = (E.Expr(E.Bindable(self.i32Type)) for _ in range(3)) + cond = E.Expr(E.Bindable(self.boolType)) + b1 = E.Block([arg1], []) + b2 = E.Block([arg2, arg3], []) + b3 = E.Block([E.CondBranch(cond, b1, [expr1], b2, [expr2, expr3])]) + str = b3.__str__() + self.assertIn("cond_br($7, ^bb1($4), ^bb2($5, $6))", str) + def testMLIRScalarTypes(self): module = E.MLIRModule() t = module.make_scalar_type("bf16") diff --git a/mlir/include/mlir-c/Core.h b/mlir/include/mlir-c/Core.h index 44ab5cb61beb..56d89087a4ed 100644 --- a/mlir/include/mlir-c/Core.h +++ b/mlir/include/mlir-c/Core.h @@ -240,6 +240,16 @@ edsc_block_t BlockSetBody(edsc_block_t, edsc_stmt_list_t stmts); /// `arguments` as block arguments. edsc_stmt_t Branch(edsc_block_t destination, edsc_expr_list_t arguments); +/// Returns an opaque statement that redirects the control flow to either +/// `trueDestination` or `falseDestination` depending on whether the +/// `condition` expression is true or false. The caller may pass expressions +/// as arguments to the destination blocks using `trueArguments` and +/// `falseArguments`, respectively. +edsc_stmt_t CondBranch(edsc_expr_t condition, edsc_block_t trueDestination, + edsc_expr_list_t trueArguments, + edsc_block_t falseDestination, + edsc_expr_list_t falseArguments); + /// Returns an opaque statement for an mlir::AffineForOp with `enclosedStmts` /// nested below it. edsc_stmt_t For(edsc_expr_t iv, edsc_expr_t lb, edsc_expr_t ub, diff --git a/mlir/include/mlir/EDSC/Types.h b/mlir/include/mlir/EDSC/Types.h index 9f29ab95bd0f..4a81239159d0 100644 --- a/mlir/include/mlir/EDSC/Types.h +++ b/mlir/include/mlir/EDSC/Types.h @@ -634,6 +634,11 @@ Expr call(Expr func, llvm::ArrayRef args); Stmt Return(ArrayRef values = {}); Stmt Branch(StmtBlock destination, ArrayRef args = {}); +Stmt CondBranch(Expr condition, StmtBlock trueDestination, + ArrayRef trueArgs, StmtBlock falseDestination, + ArrayRef falseArgs); +Stmt CondBranch(Expr condition, StmtBlock trueDestination, + StmtBlock falseDestination); Stmt For(Expr lb, Expr ub, Expr step, llvm::ArrayRef enclosedStmts); Stmt For(const Bindable &idx, Expr lb, Expr ub, Expr step, diff --git a/mlir/lib/EDSC/LowerEDSCTestPass.cpp b/mlir/lib/EDSC/LowerEDSCTestPass.cpp index 9ea9e1c392d1..67ca430efe86 100644 --- a/mlir/lib/EDSC/LowerEDSCTestPass.cpp +++ b/mlir/lib/EDSC/LowerEDSCTestPass.cpp @@ -25,7 +25,6 @@ #include "mlir/IR/Types.h" #include "mlir/Pass/Pass.h" #include "mlir/StandardOps/StandardOps.h" -#include "llvm/Support/raw_ostream.h" using namespace mlir; @@ -81,6 +80,39 @@ PassResult LowerEDSCTestPass::runOnFunction(Function *f) { edsc::MLIREmitter(&builder, f->getLoc()).emitStmt(instr); } + // Inject two EDSC-constructed blocks with arguments and a conditional branch + // instruction that transfers control to these blocks. + if (f->getName().strref() == "cond_branch") { + FuncBuilder builder(f); + edsc::ScopedEDSCContext context; + auto i1 = builder.getIntegerType(1); + auto i32 = builder.getIntegerType(32); + auto i64 = builder.getIntegerType(64); + edsc::Expr arg1(i32), arg2(i64), arg3(i32); + // Declare two blocks with different numbers of arguments. + edsc::StmtBlock b1 = edsc::block({arg1}, {edsc::Return()}), + b2 = edsc::block({arg2, arg3}, {edsc::Return()}); + edsc::Expr funcArg(i1); + + // Inject the conditional branch. + auto condBranch = edsc::CondBranch( + funcArg, b1, {edsc::constantInteger(i32, 32)}, b2, + {edsc::constantInteger(i64, 64), edsc::constantInteger(i32, 42)}); + + assert(f->getNumArguments() == 1 && "cond_branch must have 1 argument"); + assert(f->getArgument(0)->getType() == i1 && + "the argument of cond_branch must have i1 type"); + + // Remove the existing `return` instruction from the entry block of the + // function. It will be replaced by the conditional branch. + f->begin()->clear(); + builder.setInsertionPoint(&*f->begin(), f->begin()->begin()); + edsc::MLIREmitter(&builder, f->getLoc()) + .bind(edsc::Bindable(funcArg), f->getArgument(0)) + .emitStmt(condBranch); + return success(); + } + // Inject a EDSC-constructed `for` loop with bounds coming from function // arguments. if (f->getName().strref() == "dynamic_for_func_args") { diff --git a/mlir/lib/EDSC/MLIREmitter.cpp b/mlir/lib/EDSC/MLIREmitter.cpp index d52fdcd57c57..e6b6af134046 100644 --- a/mlir/lib/EDSC/MLIREmitter.cpp +++ b/mlir/lib/EDSC/MLIREmitter.cpp @@ -223,9 +223,10 @@ mlir::edsc::MLIREmitter &mlir::edsc::MLIREmitter::emitStmt(const Stmt &stmt) { assert((stmt.getRHS().is_op() || stmt.getRHS().is_op() || stmt.getRHS().is_op() || stmt.getRHS().is_op() || - stmt.getRHS().is_op()) && - "dealloc, store, return, br, or call_indirect expected as the only " - "0-result ops"); + stmt.getRHS().is_op() || + stmt.getRHS().is_op()) && + "dealloc, store, return, br, cond_br or call_indirect expected as " + "the only 0-result ops"); if (stmt.getRHS().is_op()) { assert( stmt.getRHS().cast().getTypes().empty() && diff --git a/mlir/lib/EDSC/Types.cpp b/mlir/lib/EDSC/Types.cpp index 9e1a9795823d..29abb212f10a 100644 --- a/mlir/lib/EDSC/Types.cpp +++ b/mlir/lib/EDSC/Types.cpp @@ -649,6 +649,24 @@ Stmt mlir::edsc::Branch(StmtBlock destination, ArrayRef args) { return VariadicExpr::make(arguments, {}, {}, {destination}); } +Stmt mlir::edsc::CondBranch(Expr condition, StmtBlock trueDestination, + ArrayRef trueArgs, StmtBlock falseDestination, + ArrayRef falseArgs) { + SmallVector arguments; + arguments.push_back(condition); + arguments.push_back(nullptr); + arguments.append(trueArgs.begin(), trueArgs.end()); + arguments.push_back(nullptr); + arguments.append(falseArgs.begin(), falseArgs.end()); + return VariadicExpr::make(arguments, {}, {}, + {trueDestination, falseDestination}); +} + +Stmt mlir::edsc::CondBranch(Expr condition, StmtBlock trueDestination, + StmtBlock falseDestination) { + return CondBranch(condition, trueDestination, {}, falseDestination, {}); +} + static raw_ostream &printBinaryExpr(raw_ostream &os, BinaryExpr e, StringRef infix) { os << '(' << e.getLHS() << ' ' << infix << ' ' << e.getRHS() << ')'; @@ -765,6 +783,27 @@ edsc_stmt_t Branch(edsc_block_t destination, edsc_expr_list_t arguments) { return mlir::edsc::Branch(StmtBlock(destination), args); } +edsc_stmt_t CondBranch(edsc_expr_t condition, edsc_block_t trueDestination, + edsc_expr_list_t trueArguments, + edsc_block_t falseDestination, + edsc_expr_list_t falseArguments) { + auto trueArgs = makeExprs(trueArguments); + auto falseArgs = makeExprs(falseArguments); + return mlir::edsc::CondBranch(Expr(condition), StmtBlock(trueDestination), + trueArgs, StmtBlock(falseDestination), + falseArgs); +} + +// If `blockArgs` is not empty, print it as a comma-separated parenthesized +// list, otherwise print nothing. +void printOptionalBlockArgs(ArrayRef blockArgs, llvm::raw_ostream &os) { + if (!blockArgs.empty()) + os << '('; + interleaveComma(blockArgs, os); + if (!blockArgs.empty()) + os << ")"; +} + void mlir::edsc::Expr::print(raw_ostream &os) const { if (auto unbound = this->dyn_cast()) { os << "$" << unbound.getId(); @@ -824,12 +863,16 @@ void mlir::edsc::Expr::print(raw_ostream &os) const { } if (narExpr.is_op()) { os << "br ^bb" << narExpr.getSuccessors().front().getId(); - auto blockArgs = getSuccessorArguments(0); - if (!blockArgs.empty()) - os << '('; - interleaveComma(blockArgs, os); - if (!blockArgs.empty()) - os << ")"; + printOptionalBlockArgs(getSuccessorArguments(0), os); + return; + } + if (narExpr.is_op()) { + os << "cond_br(" << getProperArguments()[0] << ", ^bb" + << getSuccessors().front().getId(); + printOptionalBlockArgs(getSuccessorArguments(0), os); + os << ", ^bb" << getSuccessors().back().getId(); + printOptionalBlockArgs(getSuccessorArguments(1), os); + os << ')'; return; } } diff --git a/mlir/test/EDSC/for-loops.mlir b/mlir/test/EDSC/for-loops.mlir index ea28ae1e4242..f116366416ac 100644 --- a/mlir/test/EDSC/for-loops.mlir +++ b/mlir/test/EDSC/for-loops.mlir @@ -25,6 +25,24 @@ func @blocks() { //CHECK-NEXT: } } +// This function will be detected by the test pass that will insert two +// EDSC-constructed blocks with arguments and a conditional branch that goes to +// both of them. +func @cond_branch(%arg0: i1) { + return +// CHECK-LABEL: @cond_branch +// CHECK-NEXT: %c0 = constant 0 : index +// CHECK-NEXT: %c1 = constant 1 : index +// CHECK-NEXT: %c32_i32 = constant 32 : i32 +// CHECK-NEXT: %c64_i64 = constant 64 : i64 +// CHECK-NEXT: %c42_i32 = constant 42 : i32 +// CHECK-NEXT: cond_br %arg0, ^bb1(%c32_i32 : i32), ^bb2(%c64_i64, %c42_i32 : i64, i32) +// CHECK-NEXT: ^bb1(%0: i32): // pred: ^bb0 +// CHECK-NEXT: return +// CHECK-NEXT: ^bb2(%1: i64, %2: i32): // pred: ^bb0 +// CHECK-NEXT: return +} + // This function will be detected by the test pass that will insert an // EDSC-constructed empty `for` loop that corresponds to // for %arg0 to %arg1 step 2