mirror of
https://github.com/intel/llvm.git
synced 2026-01-27 06:06:34 +08:00
EDSC: introduce min/max only usable inside for upper/lower bounds of a loop
Introduce a type-safe way of building a 'for' loop with max/min bounds in EDSC. Define new types MaxExpr and MinExpr in C++ EDSC API and expose them to Python bindings. Use values of these type to construct 'for' loops with max/min in newly introduced overloads of the `edsc::For` factory function. Note that in C APIs, we still must expose MaxMinFor as a different function because C has no overloads. Also note that MaxExpr and MinExpr do _not_ derive from Expr because they are not allowed to be used in a regular Expr context (which may produce `affine.apply` instructions not expecting `min` or `max`). Factory functions `Min` and `Max` in Python can be further overloaded to produce chains of comparisons and selects on non-index types. This is not trivial in C++ since overloaded functions cannot differ by the return type only (`MaxExpr` or `Expr`) and making `MaxExpr` derive from `Expr` defies the purpose of type-safe construction. PiperOrigin-RevId: 234786131
This commit is contained in:
@@ -222,6 +222,22 @@ struct PythonIndexed : public edsc_indexed_t {
|
||||
operator PythonExpr() { return PythonExpr(base); }
|
||||
};
|
||||
|
||||
struct PythonMaxExpr {
|
||||
PythonMaxExpr() : expr(nullptr) {}
|
||||
PythonMaxExpr(const edsc_max_expr_t &e) : expr(e) {}
|
||||
operator edsc_max_expr_t() { return expr; }
|
||||
|
||||
edsc_max_expr_t expr;
|
||||
};
|
||||
|
||||
struct PythonMinExpr {
|
||||
PythonMinExpr() : expr(nullptr) {}
|
||||
PythonMinExpr(const edsc_min_expr_t &e) : expr(e) {}
|
||||
operator edsc_min_expr_t() { return expr; }
|
||||
|
||||
edsc_min_expr_t expr;
|
||||
};
|
||||
|
||||
struct MLIRFunctionEmitter {
|
||||
MLIRFunctionEmitter(PythonFunction f)
|
||||
: currentFunction(reinterpret_cast<mlir::Function *>(f.function)),
|
||||
@@ -386,19 +402,23 @@ PYBIND11_MODULE(pybind, m) {
|
||||
makeCExprs(owningUBs, ubs), makeCExprs(owningSteps, steps),
|
||||
makeCStmts(owningStmts, stmts)));
|
||||
});
|
||||
m.def("Max", [](const py::list &args) {
|
||||
SmallVector<edsc_expr_t, 8> owning;
|
||||
return PythonMaxExpr(::Max(makeCExprs(owning, args)));
|
||||
});
|
||||
m.def("Min", [](const py::list &args) {
|
||||
SmallVector<edsc_expr_t, 8> owning;
|
||||
return PythonMinExpr(::Min(makeCExprs(owning, args)));
|
||||
});
|
||||
m.def("For", [](PythonExpr iv, PythonExpr lb, PythonExpr ub, PythonExpr step,
|
||||
const py::list &stmts) {
|
||||
SmallVector<edsc_stmt_t, 8> owning;
|
||||
return PythonStmt(::For(iv, lb, ub, step, makeCStmts(owning, stmts)));
|
||||
});
|
||||
m.def("MaxMinFor", [](PythonExpr iv, const py::list &lbs, const py::list &ubs,
|
||||
PythonExpr step, const py::list &stmts) {
|
||||
SmallVector<edsc_expr_t, 8> owningLBs;
|
||||
SmallVector<edsc_expr_t, 8> owningUBs;
|
||||
SmallVector<edsc_stmt_t, 8> owningStmts;
|
||||
return PythonStmt(::MaxMinFor(iv, makeCExprs(owningLBs, lbs),
|
||||
makeCExprs(owningUBs, ubs), step,
|
||||
makeCStmts(owningStmts, stmts)));
|
||||
m.def("For", [](PythonExpr iv, PythonMaxExpr lb, PythonMinExpr ub,
|
||||
PythonExpr step, const py::list &stmts) {
|
||||
SmallVector<edsc_stmt_t, 8> owning;
|
||||
return PythonStmt(::MaxMinFor(iv, lb, ub, step, makeCStmts(owning, stmts)));
|
||||
});
|
||||
m.def("Select", [](PythonExpr cond, PythonExpr e1, PythonExpr e2) {
|
||||
return PythonExpr(::Select(cond, e1, e2));
|
||||
@@ -623,6 +643,11 @@ PYBIND11_MODULE(pybind, m) {
|
||||
Store(value, instance, makeCExprs(owning, indices)));
|
||||
},
|
||||
R"DOC(Returns the Stmt that stores into an Indexed)DOC");
|
||||
|
||||
py::class_<PythonMaxExpr>(m, "MaxExpr",
|
||||
"Wrapping class for mlir::edsc::MaxExpr");
|
||||
py::class_<PythonMinExpr>(m, "MinExpr",
|
||||
"Wrapping class for mlir::edsc::MinExpr");
|
||||
}
|
||||
|
||||
} // namespace python
|
||||
|
||||
@@ -75,7 +75,7 @@ class EdscTest(unittest.TestCase):
|
||||
step = E.Expr(E.Bindable(self.indexType))
|
||||
lbs = list(map(E.Expr, [E.Bindable(self.indexType) for _ in range(4)]))
|
||||
ubs = list(map(E.Expr, [E.Bindable(self.indexType) for _ in range(3)]))
|
||||
loop = E.MaxMinFor(i, lbs, ubs, step, [])
|
||||
loop = E.For(i, E.Max(lbs), E.Min(ubs), step, [])
|
||||
s = str(loop)
|
||||
self.assertIn("for($1 = max($3, $4, $5, $6) to min($7, $8, $9) step $2)",
|
||||
s)
|
||||
|
||||
@@ -40,6 +40,10 @@ typedef void *mlir_func_t;
|
||||
typedef void *edsc_mlir_emitter_t;
|
||||
/// Opaque C type for mlir::edsc::Expr.
|
||||
typedef void *edsc_expr_t;
|
||||
/// Opaque C type for mlir::edsc::MaxExpr.
|
||||
typedef void *edsc_max_expr_t;
|
||||
/// Opaque C type for mlir::edsc::MinExpr.
|
||||
typedef void *edsc_min_expr_t;
|
||||
/// Opaque C type for mlir::edsc::Stmt.
|
||||
typedef void *edsc_stmt_t;
|
||||
/// Opaque C type for mlir::edsc::Block.
|
||||
@@ -238,12 +242,17 @@ edsc_stmt_t ForNest(edsc_expr_list_t iv, edsc_expr_list_t lb,
|
||||
edsc_expr_list_t ub, edsc_expr_list_t step,
|
||||
edsc_stmt_list_t enclosedStmts);
|
||||
|
||||
/// Returns an opaque 'max' expression that can be used only inside a for loop.
|
||||
edsc_max_expr_t Max(edsc_expr_list_t args);
|
||||
|
||||
/// Returns an opaque 'min' expression that can be used only inside a for loop.
|
||||
edsc_min_expr_t Min(edsc_expr_list_t args);
|
||||
|
||||
/// Returns an opaque statement for an mlir::AffineForOp with the lower bound
|
||||
/// `max(lbs)` and the upper bound `min(ubs)`, and with `enclosedStmts` nested
|
||||
/// below it.
|
||||
edsc_stmt_t MaxMinFor(edsc_expr_t iv, edsc_expr_list_t lbs,
|
||||
edsc_expr_list_t ubs, edsc_expr_t step,
|
||||
edsc_stmt_list_t enclosedStmts);
|
||||
edsc_stmt_t MaxMinFor(edsc_expr_t iv, edsc_max_expr_t lb, edsc_min_expr_t ub,
|
||||
edsc_expr_t step, edsc_stmt_list_t enclosedStmts);
|
||||
|
||||
/// Returns an opaque expression for the corresponding Binary operation.
|
||||
edsc_expr_t Add(edsc_expr_t e1, edsc_expr_t e2);
|
||||
|
||||
@@ -677,6 +677,41 @@ private:
|
||||
llvm::SmallVector<Expr, 8> indices;
|
||||
};
|
||||
|
||||
struct MaxExpr {
|
||||
public:
|
||||
explicit MaxExpr(llvm::ArrayRef<Expr> arguments);
|
||||
explicit MaxExpr(edsc_max_expr_t st)
|
||||
: storage(reinterpret_cast<detail::ExprStorage *>(st)) {}
|
||||
llvm::ArrayRef<Expr> getArguments() const;
|
||||
|
||||
operator edsc_max_expr_t() { return storage; }
|
||||
|
||||
private:
|
||||
detail::ExprStorage *storage;
|
||||
};
|
||||
|
||||
struct MinExpr {
|
||||
public:
|
||||
explicit MinExpr(llvm::ArrayRef<Expr> arguments);
|
||||
explicit MinExpr(edsc_min_expr_t st)
|
||||
: storage(reinterpret_cast<detail::ExprStorage *>(st)) {}
|
||||
llvm::ArrayRef<Expr> getArguments() const;
|
||||
|
||||
operator edsc_min_expr_t() { return storage; }
|
||||
|
||||
private:
|
||||
detail::ExprStorage *storage;
|
||||
};
|
||||
|
||||
Stmt For(const Bindable &idx, MaxExpr lb, MinExpr ub, Expr step,
|
||||
llvm::ArrayRef<Stmt> enclosedStmts);
|
||||
Stmt For(llvm::ArrayRef<Expr> idxs, llvm::ArrayRef<MaxExpr> lbs,
|
||||
llvm::ArrayRef<MinExpr> ubs, llvm::ArrayRef<Expr> steps,
|
||||
llvm::ArrayRef<Stmt> enclosedStmts);
|
||||
|
||||
inline MaxExpr Max(llvm::ArrayRef<Expr> args) { return MaxExpr(args); }
|
||||
inline MinExpr Min(llvm::ArrayRef<Expr> args) { return MinExpr(args); }
|
||||
|
||||
} // namespace edsc
|
||||
} // namespace mlir
|
||||
|
||||
|
||||
@@ -369,9 +369,9 @@ Stmt mlir::edsc::For(const Bindable &idx, Expr lb, Expr ub, Expr step,
|
||||
stmts);
|
||||
}
|
||||
|
||||
Stmt mlir::edsc::For(ArrayRef<Expr> indices, ArrayRef<Expr> lbs,
|
||||
ArrayRef<Expr> ubs, ArrayRef<Expr> steps,
|
||||
ArrayRef<Stmt> enclosedStmts) {
|
||||
template <typename LB, typename UB>
|
||||
Stmt forNestImpl(ArrayRef<Expr> indices, ArrayRef<LB> lbs, ArrayRef<UB> ubs,
|
||||
ArrayRef<Expr> steps, ArrayRef<Stmt> enclosedStmts) {
|
||||
assert(!indices.empty());
|
||||
assert(indices.size() == lbs.size());
|
||||
assert(indices.size() == ubs.size());
|
||||
@@ -387,6 +387,24 @@ Stmt mlir::edsc::For(ArrayRef<Expr> indices, ArrayRef<Expr> lbs,
|
||||
return curStmt;
|
||||
}
|
||||
|
||||
Stmt mlir::edsc::For(ArrayRef<Expr> indices, ArrayRef<Expr> lbs,
|
||||
ArrayRef<Expr> ubs, ArrayRef<Expr> steps,
|
||||
ArrayRef<Stmt> enclosedStmts) {
|
||||
return forNestImpl(indices, lbs, ubs, steps, enclosedStmts);
|
||||
}
|
||||
|
||||
Stmt mlir::edsc::For(const Bindable &idx, MaxExpr lb, MinExpr ub, Expr step,
|
||||
llvm::ArrayRef<Stmt> enclosedStmts) {
|
||||
return MaxMinFor(idx, lb.getArguments(), ub.getArguments(), step,
|
||||
enclosedStmts);
|
||||
}
|
||||
|
||||
Stmt mlir::edsc::For(llvm::ArrayRef<Expr> idxs, llvm::ArrayRef<MaxExpr> lbs,
|
||||
llvm::ArrayRef<MinExpr> ubs, llvm::ArrayRef<Expr> steps,
|
||||
llvm::ArrayRef<Stmt> enclosedStmts) {
|
||||
return forNestImpl(idxs, lbs, ubs, steps, enclosedStmts);
|
||||
}
|
||||
|
||||
Stmt mlir::edsc::MaxMinFor(const Bindable &idx, ArrayRef<Expr> lbs,
|
||||
ArrayRef<Expr> ubs, Expr step,
|
||||
ArrayRef<Stmt> enclosedStmts) {
|
||||
@@ -405,6 +423,14 @@ Stmt mlir::edsc::MaxMinFor(const Bindable &idx, ArrayRef<Expr> lbs,
|
||||
return Stmt(idx, StmtBlockLikeExpr(ExprKind::For, exprs), enclosedStmts);
|
||||
}
|
||||
|
||||
edsc_max_expr_t Max(edsc_expr_list_t args) {
|
||||
return mlir::edsc::Max(makeExprs(args));
|
||||
}
|
||||
|
||||
edsc_min_expr_t Min(edsc_expr_list_t args) {
|
||||
return mlir::edsc::Min(makeExprs(args));
|
||||
}
|
||||
|
||||
edsc_stmt_t For(edsc_expr_t iv, edsc_expr_t lb, edsc_expr_t ub,
|
||||
edsc_expr_t step, edsc_stmt_list_t enclosedStmts) {
|
||||
llvm::SmallVector<Stmt, 8> stmts;
|
||||
@@ -422,13 +448,12 @@ edsc_stmt_t ForNest(edsc_expr_list_t ivs, edsc_expr_list_t lbs,
|
||||
makeExprs(steps), stmts));
|
||||
}
|
||||
|
||||
edsc_stmt_t MaxMinFor(edsc_expr_t iv, edsc_expr_list_t lbs,
|
||||
edsc_expr_list_t ubs, edsc_expr_t step,
|
||||
edsc_stmt_list_t enclosedStmts) {
|
||||
edsc_stmt_t MaxMinFor(edsc_expr_t iv, edsc_max_expr_t lb, edsc_min_expr_t ub,
|
||||
edsc_expr_t step, edsc_stmt_list_t enclosedStmts) {
|
||||
llvm::SmallVector<Stmt, 8> stmts;
|
||||
fillStmts(enclosedStmts, &stmts);
|
||||
return Stmt(MaxMinFor(Expr(iv).cast<Bindable>(), makeExprs(lbs),
|
||||
makeExprs(ubs), Expr(step), stmts));
|
||||
return Stmt(For(Expr(iv).cast<Bindable>(), MaxExpr(lb), MinExpr(ub),
|
||||
Expr(step), stmts));
|
||||
}
|
||||
|
||||
StmtBlock mlir::edsc::block(ArrayRef<Bindable> args, ArrayRef<Type> argTypes,
|
||||
@@ -982,6 +1007,20 @@ edsc_indexed_t index(edsc_indexed_t indexed, edsc_expr_list_t indices) {
|
||||
return edsc_indexed_t{indexed.base, indices};
|
||||
}
|
||||
|
||||
MaxExpr::MaxExpr(ArrayRef<Expr> arguments) {
|
||||
storage = Expr::globalAllocator()->Allocate<detail::ExprStorage>();
|
||||
new (storage) detail::ExprStorage(ExprKind::Variadic, "", {}, arguments, {});
|
||||
}
|
||||
|
||||
ArrayRef<Expr> MaxExpr::getArguments() const { return storage->operands; }
|
||||
|
||||
MinExpr::MinExpr(ArrayRef<Expr> arguments) {
|
||||
storage = Expr::globalAllocator()->Allocate<detail::ExprStorage>();
|
||||
new (storage) detail::ExprStorage(ExprKind::Variadic, "", {}, arguments, {});
|
||||
}
|
||||
|
||||
ArrayRef<Expr> MinExpr::getArguments() const { return storage->operands; }
|
||||
|
||||
mlir_type_t makeScalarType(mlir_context_t context, const char *name,
|
||||
unsigned bitwidth) {
|
||||
mlir::MLIRContext *c = reinterpret_cast<mlir::MLIRContext *>(context);
|
||||
|
||||
Reference in New Issue
Block a user