diff --git a/mlir/bindings/python/pybind.cpp b/mlir/bindings/python/pybind.cpp index d3ffebbb50f8..0356840d0b35 100644 --- a/mlir/bindings/python/pybind.cpp +++ b/mlir/bindings/python/pybind.cpp @@ -222,6 +222,22 @@ struct PythonIndexed : public edsc_indexed_t { operator PythonExpr() { return PythonExpr(base); } }; +struct PythonMaxExpr { + PythonMaxExpr() : expr(nullptr) {} + PythonMaxExpr(const edsc_max_expr_t &e) : expr(e) {} + operator edsc_max_expr_t() { return expr; } + + edsc_max_expr_t expr; +}; + +struct PythonMinExpr { + PythonMinExpr() : expr(nullptr) {} + PythonMinExpr(const edsc_min_expr_t &e) : expr(e) {} + operator edsc_min_expr_t() { return expr; } + + edsc_min_expr_t expr; +}; + struct MLIRFunctionEmitter { MLIRFunctionEmitter(PythonFunction f) : currentFunction(reinterpret_cast(f.function)), @@ -386,19 +402,23 @@ PYBIND11_MODULE(pybind, m) { makeCExprs(owningUBs, ubs), makeCExprs(owningSteps, steps), makeCStmts(owningStmts, stmts))); }); + m.def("Max", [](const py::list &args) { + SmallVector owning; + return PythonMaxExpr(::Max(makeCExprs(owning, args))); + }); + m.def("Min", [](const py::list &args) { + SmallVector owning; + return PythonMinExpr(::Min(makeCExprs(owning, args))); + }); m.def("For", [](PythonExpr iv, PythonExpr lb, PythonExpr ub, PythonExpr step, const py::list &stmts) { SmallVector owning; return PythonStmt(::For(iv, lb, ub, step, makeCStmts(owning, stmts))); }); - m.def("MaxMinFor", [](PythonExpr iv, const py::list &lbs, const py::list &ubs, - PythonExpr step, const py::list &stmts) { - SmallVector owningLBs; - SmallVector owningUBs; - SmallVector owningStmts; - return PythonStmt(::MaxMinFor(iv, makeCExprs(owningLBs, lbs), - makeCExprs(owningUBs, ubs), step, - makeCStmts(owningStmts, stmts))); + m.def("For", [](PythonExpr iv, PythonMaxExpr lb, PythonMinExpr ub, + PythonExpr step, const py::list &stmts) { + SmallVector owning; + return PythonStmt(::MaxMinFor(iv, lb, ub, step, makeCStmts(owning, stmts))); }); m.def("Select", [](PythonExpr cond, PythonExpr e1, PythonExpr e2) { return PythonExpr(::Select(cond, e1, e2)); @@ -623,6 +643,11 @@ PYBIND11_MODULE(pybind, m) { Store(value, instance, makeCExprs(owning, indices))); }, R"DOC(Returns the Stmt that stores into an Indexed)DOC"); + + py::class_(m, "MaxExpr", + "Wrapping class for mlir::edsc::MaxExpr"); + py::class_(m, "MinExpr", + "Wrapping class for mlir::edsc::MinExpr"); } } // namespace python diff --git a/mlir/bindings/python/test/test_py2and3.py b/mlir/bindings/python/test/test_py2and3.py index 564fb716c591..936aa5227336 100644 --- a/mlir/bindings/python/test/test_py2and3.py +++ b/mlir/bindings/python/test/test_py2and3.py @@ -75,7 +75,7 @@ class EdscTest(unittest.TestCase): step = E.Expr(E.Bindable(self.indexType)) lbs = list(map(E.Expr, [E.Bindable(self.indexType) for _ in range(4)])) ubs = list(map(E.Expr, [E.Bindable(self.indexType) for _ in range(3)])) - loop = E.MaxMinFor(i, lbs, ubs, step, []) + loop = E.For(i, E.Max(lbs), E.Min(ubs), step, []) s = str(loop) self.assertIn("for($1 = max($3, $4, $5, $6) to min($7, $8, $9) step $2)", s) diff --git a/mlir/include/mlir-c/Core.h b/mlir/include/mlir-c/Core.h index 8505cdc0ae6b..36caeae47338 100644 --- a/mlir/include/mlir-c/Core.h +++ b/mlir/include/mlir-c/Core.h @@ -40,6 +40,10 @@ typedef void *mlir_func_t; typedef void *edsc_mlir_emitter_t; /// Opaque C type for mlir::edsc::Expr. typedef void *edsc_expr_t; +/// Opaque C type for mlir::edsc::MaxExpr. +typedef void *edsc_max_expr_t; +/// Opaque C type for mlir::edsc::MinExpr. +typedef void *edsc_min_expr_t; /// Opaque C type for mlir::edsc::Stmt. typedef void *edsc_stmt_t; /// Opaque C type for mlir::edsc::Block. @@ -238,12 +242,17 @@ edsc_stmt_t ForNest(edsc_expr_list_t iv, edsc_expr_list_t lb, edsc_expr_list_t ub, edsc_expr_list_t step, edsc_stmt_list_t enclosedStmts); +/// Returns an opaque 'max' expression that can be used only inside a for loop. +edsc_max_expr_t Max(edsc_expr_list_t args); + +/// Returns an opaque 'min' expression that can be used only inside a for loop. +edsc_min_expr_t Min(edsc_expr_list_t args); + /// Returns an opaque statement for an mlir::AffineForOp with the lower bound /// `max(lbs)` and the upper bound `min(ubs)`, and with `enclosedStmts` nested /// below it. -edsc_stmt_t MaxMinFor(edsc_expr_t iv, edsc_expr_list_t lbs, - edsc_expr_list_t ubs, edsc_expr_t step, - edsc_stmt_list_t enclosedStmts); +edsc_stmt_t MaxMinFor(edsc_expr_t iv, edsc_max_expr_t lb, edsc_min_expr_t ub, + edsc_expr_t step, edsc_stmt_list_t enclosedStmts); /// Returns an opaque expression for the corresponding Binary operation. edsc_expr_t Add(edsc_expr_t e1, edsc_expr_t e2); diff --git a/mlir/include/mlir/EDSC/Types.h b/mlir/include/mlir/EDSC/Types.h index 599d6e3b1bdb..300fed0d51d9 100644 --- a/mlir/include/mlir/EDSC/Types.h +++ b/mlir/include/mlir/EDSC/Types.h @@ -677,6 +677,41 @@ private: llvm::SmallVector indices; }; +struct MaxExpr { +public: + explicit MaxExpr(llvm::ArrayRef arguments); + explicit MaxExpr(edsc_max_expr_t st) + : storage(reinterpret_cast(st)) {} + llvm::ArrayRef getArguments() const; + + operator edsc_max_expr_t() { return storage; } + +private: + detail::ExprStorage *storage; +}; + +struct MinExpr { +public: + explicit MinExpr(llvm::ArrayRef arguments); + explicit MinExpr(edsc_min_expr_t st) + : storage(reinterpret_cast(st)) {} + llvm::ArrayRef getArguments() const; + + operator edsc_min_expr_t() { return storage; } + +private: + detail::ExprStorage *storage; +}; + +Stmt For(const Bindable &idx, MaxExpr lb, MinExpr ub, Expr step, + llvm::ArrayRef enclosedStmts); +Stmt For(llvm::ArrayRef idxs, llvm::ArrayRef lbs, + llvm::ArrayRef ubs, llvm::ArrayRef steps, + llvm::ArrayRef enclosedStmts); + +inline MaxExpr Max(llvm::ArrayRef args) { return MaxExpr(args); } +inline MinExpr Min(llvm::ArrayRef args) { return MinExpr(args); } + } // namespace edsc } // namespace mlir diff --git a/mlir/lib/EDSC/Types.cpp b/mlir/lib/EDSC/Types.cpp index 451b0917fb22..a07d7649af0a 100644 --- a/mlir/lib/EDSC/Types.cpp +++ b/mlir/lib/EDSC/Types.cpp @@ -369,9 +369,9 @@ Stmt mlir::edsc::For(const Bindable &idx, Expr lb, Expr ub, Expr step, stmts); } -Stmt mlir::edsc::For(ArrayRef indices, ArrayRef lbs, - ArrayRef ubs, ArrayRef steps, - ArrayRef enclosedStmts) { +template +Stmt forNestImpl(ArrayRef indices, ArrayRef lbs, ArrayRef ubs, + ArrayRef steps, ArrayRef enclosedStmts) { assert(!indices.empty()); assert(indices.size() == lbs.size()); assert(indices.size() == ubs.size()); @@ -387,6 +387,24 @@ Stmt mlir::edsc::For(ArrayRef indices, ArrayRef lbs, return curStmt; } +Stmt mlir::edsc::For(ArrayRef indices, ArrayRef lbs, + ArrayRef ubs, ArrayRef steps, + ArrayRef enclosedStmts) { + return forNestImpl(indices, lbs, ubs, steps, enclosedStmts); +} + +Stmt mlir::edsc::For(const Bindable &idx, MaxExpr lb, MinExpr ub, Expr step, + llvm::ArrayRef enclosedStmts) { + return MaxMinFor(idx, lb.getArguments(), ub.getArguments(), step, + enclosedStmts); +} + +Stmt mlir::edsc::For(llvm::ArrayRef idxs, llvm::ArrayRef lbs, + llvm::ArrayRef ubs, llvm::ArrayRef steps, + llvm::ArrayRef enclosedStmts) { + return forNestImpl(idxs, lbs, ubs, steps, enclosedStmts); +} + Stmt mlir::edsc::MaxMinFor(const Bindable &idx, ArrayRef lbs, ArrayRef ubs, Expr step, ArrayRef enclosedStmts) { @@ -405,6 +423,14 @@ Stmt mlir::edsc::MaxMinFor(const Bindable &idx, ArrayRef lbs, return Stmt(idx, StmtBlockLikeExpr(ExprKind::For, exprs), enclosedStmts); } +edsc_max_expr_t Max(edsc_expr_list_t args) { + return mlir::edsc::Max(makeExprs(args)); +} + +edsc_min_expr_t Min(edsc_expr_list_t args) { + return mlir::edsc::Min(makeExprs(args)); +} + edsc_stmt_t For(edsc_expr_t iv, edsc_expr_t lb, edsc_expr_t ub, edsc_expr_t step, edsc_stmt_list_t enclosedStmts) { llvm::SmallVector stmts; @@ -422,13 +448,12 @@ edsc_stmt_t ForNest(edsc_expr_list_t ivs, edsc_expr_list_t lbs, makeExprs(steps), stmts)); } -edsc_stmt_t MaxMinFor(edsc_expr_t iv, edsc_expr_list_t lbs, - edsc_expr_list_t ubs, edsc_expr_t step, - edsc_stmt_list_t enclosedStmts) { +edsc_stmt_t MaxMinFor(edsc_expr_t iv, edsc_max_expr_t lb, edsc_min_expr_t ub, + edsc_expr_t step, edsc_stmt_list_t enclosedStmts) { llvm::SmallVector stmts; fillStmts(enclosedStmts, &stmts); - return Stmt(MaxMinFor(Expr(iv).cast(), makeExprs(lbs), - makeExprs(ubs), Expr(step), stmts)); + return Stmt(For(Expr(iv).cast(), MaxExpr(lb), MinExpr(ub), + Expr(step), stmts)); } StmtBlock mlir::edsc::block(ArrayRef args, ArrayRef argTypes, @@ -982,6 +1007,20 @@ edsc_indexed_t index(edsc_indexed_t indexed, edsc_expr_list_t indices) { return edsc_indexed_t{indexed.base, indices}; } +MaxExpr::MaxExpr(ArrayRef arguments) { + storage = Expr::globalAllocator()->Allocate(); + new (storage) detail::ExprStorage(ExprKind::Variadic, "", {}, arguments, {}); +} + +ArrayRef MaxExpr::getArguments() const { return storage->operands; } + +MinExpr::MinExpr(ArrayRef arguments) { + storage = Expr::globalAllocator()->Allocate(); + new (storage) detail::ExprStorage(ExprKind::Variadic, "", {}, arguments, {}); +} + +ArrayRef MinExpr::getArguments() const { return storage->operands; } + mlir_type_t makeScalarType(mlir_context_t context, const char *name, unsigned bitwidth) { mlir::MLIRContext *c = reinterpret_cast(context);