EDSC: introduce min/max only usable inside for upper/lower bounds of a loop

Introduce a type-safe way of building a 'for' loop with max/min bounds in EDSC.
Define new types MaxExpr and MinExpr in C++ EDSC API and expose them to Python
bindings.  Use values of these type to construct 'for' loops with max/min in
newly introduced overloads of the `edsc::For` factory function.  Note that in C
APIs, we still must expose MaxMinFor as a different function because C has no
overloads.  Also note that MaxExpr and MinExpr do _not_ derive from Expr
because they are not allowed to be used in a regular Expr context (which may
produce `affine.apply` instructions not expecting `min` or `max`).

Factory functions `Min` and `Max` in Python can be further overloaded to
produce chains of comparisons and selects on non-index types.  This is not
trivial in C++ since overloaded functions cannot differ by the return type only
(`MaxExpr` or `Expr`) and making `MaxExpr` derive from `Expr` defies the
purpose of type-safe construction.

PiperOrigin-RevId: 234786131
This commit is contained in:
Alex Zinenko
2019-02-20 06:54:53 -08:00
committed by jpienaar
parent d055a4e100
commit 21bd4540f3
5 changed files with 128 additions and 20 deletions

View File

@@ -222,6 +222,22 @@ struct PythonIndexed : public edsc_indexed_t {
operator PythonExpr() { return PythonExpr(base); }
};
struct PythonMaxExpr {
PythonMaxExpr() : expr(nullptr) {}
PythonMaxExpr(const edsc_max_expr_t &e) : expr(e) {}
operator edsc_max_expr_t() { return expr; }
edsc_max_expr_t expr;
};
struct PythonMinExpr {
PythonMinExpr() : expr(nullptr) {}
PythonMinExpr(const edsc_min_expr_t &e) : expr(e) {}
operator edsc_min_expr_t() { return expr; }
edsc_min_expr_t expr;
};
struct MLIRFunctionEmitter {
MLIRFunctionEmitter(PythonFunction f)
: currentFunction(reinterpret_cast<mlir::Function *>(f.function)),
@@ -386,19 +402,23 @@ PYBIND11_MODULE(pybind, m) {
makeCExprs(owningUBs, ubs), makeCExprs(owningSteps, steps),
makeCStmts(owningStmts, stmts)));
});
m.def("Max", [](const py::list &args) {
SmallVector<edsc_expr_t, 8> owning;
return PythonMaxExpr(::Max(makeCExprs(owning, args)));
});
m.def("Min", [](const py::list &args) {
SmallVector<edsc_expr_t, 8> owning;
return PythonMinExpr(::Min(makeCExprs(owning, args)));
});
m.def("For", [](PythonExpr iv, PythonExpr lb, PythonExpr ub, PythonExpr step,
const py::list &stmts) {
SmallVector<edsc_stmt_t, 8> owning;
return PythonStmt(::For(iv, lb, ub, step, makeCStmts(owning, stmts)));
});
m.def("MaxMinFor", [](PythonExpr iv, const py::list &lbs, const py::list &ubs,
PythonExpr step, const py::list &stmts) {
SmallVector<edsc_expr_t, 8> owningLBs;
SmallVector<edsc_expr_t, 8> owningUBs;
SmallVector<edsc_stmt_t, 8> owningStmts;
return PythonStmt(::MaxMinFor(iv, makeCExprs(owningLBs, lbs),
makeCExprs(owningUBs, ubs), step,
makeCStmts(owningStmts, stmts)));
m.def("For", [](PythonExpr iv, PythonMaxExpr lb, PythonMinExpr ub,
PythonExpr step, const py::list &stmts) {
SmallVector<edsc_stmt_t, 8> owning;
return PythonStmt(::MaxMinFor(iv, lb, ub, step, makeCStmts(owning, stmts)));
});
m.def("Select", [](PythonExpr cond, PythonExpr e1, PythonExpr e2) {
return PythonExpr(::Select(cond, e1, e2));
@@ -623,6 +643,11 @@ PYBIND11_MODULE(pybind, m) {
Store(value, instance, makeCExprs(owning, indices)));
},
R"DOC(Returns the Stmt that stores into an Indexed)DOC");
py::class_<PythonMaxExpr>(m, "MaxExpr",
"Wrapping class for mlir::edsc::MaxExpr");
py::class_<PythonMinExpr>(m, "MinExpr",
"Wrapping class for mlir::edsc::MinExpr");
}
} // namespace python

View File

@@ -75,7 +75,7 @@ class EdscTest(unittest.TestCase):
step = E.Expr(E.Bindable(self.indexType))
lbs = list(map(E.Expr, [E.Bindable(self.indexType) for _ in range(4)]))
ubs = list(map(E.Expr, [E.Bindable(self.indexType) for _ in range(3)]))
loop = E.MaxMinFor(i, lbs, ubs, step, [])
loop = E.For(i, E.Max(lbs), E.Min(ubs), step, [])
s = str(loop)
self.assertIn("for($1 = max($3, $4, $5, $6) to min($7, $8, $9) step $2)",
s)

View File

@@ -40,6 +40,10 @@ typedef void *mlir_func_t;
typedef void *edsc_mlir_emitter_t;
/// Opaque C type for mlir::edsc::Expr.
typedef void *edsc_expr_t;
/// Opaque C type for mlir::edsc::MaxExpr.
typedef void *edsc_max_expr_t;
/// Opaque C type for mlir::edsc::MinExpr.
typedef void *edsc_min_expr_t;
/// Opaque C type for mlir::edsc::Stmt.
typedef void *edsc_stmt_t;
/// Opaque C type for mlir::edsc::Block.
@@ -238,12 +242,17 @@ edsc_stmt_t ForNest(edsc_expr_list_t iv, edsc_expr_list_t lb,
edsc_expr_list_t ub, edsc_expr_list_t step,
edsc_stmt_list_t enclosedStmts);
/// Returns an opaque 'max' expression that can be used only inside a for loop.
edsc_max_expr_t Max(edsc_expr_list_t args);
/// Returns an opaque 'min' expression that can be used only inside a for loop.
edsc_min_expr_t Min(edsc_expr_list_t args);
/// Returns an opaque statement for an mlir::AffineForOp with the lower bound
/// `max(lbs)` and the upper bound `min(ubs)`, and with `enclosedStmts` nested
/// below it.
edsc_stmt_t MaxMinFor(edsc_expr_t iv, edsc_expr_list_t lbs,
edsc_expr_list_t ubs, edsc_expr_t step,
edsc_stmt_list_t enclosedStmts);
edsc_stmt_t MaxMinFor(edsc_expr_t iv, edsc_max_expr_t lb, edsc_min_expr_t ub,
edsc_expr_t step, edsc_stmt_list_t enclosedStmts);
/// Returns an opaque expression for the corresponding Binary operation.
edsc_expr_t Add(edsc_expr_t e1, edsc_expr_t e2);

View File

@@ -677,6 +677,41 @@ private:
llvm::SmallVector<Expr, 8> indices;
};
struct MaxExpr {
public:
explicit MaxExpr(llvm::ArrayRef<Expr> arguments);
explicit MaxExpr(edsc_max_expr_t st)
: storage(reinterpret_cast<detail::ExprStorage *>(st)) {}
llvm::ArrayRef<Expr> getArguments() const;
operator edsc_max_expr_t() { return storage; }
private:
detail::ExprStorage *storage;
};
struct MinExpr {
public:
explicit MinExpr(llvm::ArrayRef<Expr> arguments);
explicit MinExpr(edsc_min_expr_t st)
: storage(reinterpret_cast<detail::ExprStorage *>(st)) {}
llvm::ArrayRef<Expr> getArguments() const;
operator edsc_min_expr_t() { return storage; }
private:
detail::ExprStorage *storage;
};
Stmt For(const Bindable &idx, MaxExpr lb, MinExpr ub, Expr step,
llvm::ArrayRef<Stmt> enclosedStmts);
Stmt For(llvm::ArrayRef<Expr> idxs, llvm::ArrayRef<MaxExpr> lbs,
llvm::ArrayRef<MinExpr> ubs, llvm::ArrayRef<Expr> steps,
llvm::ArrayRef<Stmt> enclosedStmts);
inline MaxExpr Max(llvm::ArrayRef<Expr> args) { return MaxExpr(args); }
inline MinExpr Min(llvm::ArrayRef<Expr> args) { return MinExpr(args); }
} // namespace edsc
} // namespace mlir

View File

@@ -369,9 +369,9 @@ Stmt mlir::edsc::For(const Bindable &idx, Expr lb, Expr ub, Expr step,
stmts);
}
Stmt mlir::edsc::For(ArrayRef<Expr> indices, ArrayRef<Expr> lbs,
ArrayRef<Expr> ubs, ArrayRef<Expr> steps,
ArrayRef<Stmt> enclosedStmts) {
template <typename LB, typename UB>
Stmt forNestImpl(ArrayRef<Expr> indices, ArrayRef<LB> lbs, ArrayRef<UB> ubs,
ArrayRef<Expr> steps, ArrayRef<Stmt> enclosedStmts) {
assert(!indices.empty());
assert(indices.size() == lbs.size());
assert(indices.size() == ubs.size());
@@ -387,6 +387,24 @@ Stmt mlir::edsc::For(ArrayRef<Expr> indices, ArrayRef<Expr> lbs,
return curStmt;
}
Stmt mlir::edsc::For(ArrayRef<Expr> indices, ArrayRef<Expr> lbs,
ArrayRef<Expr> ubs, ArrayRef<Expr> steps,
ArrayRef<Stmt> enclosedStmts) {
return forNestImpl(indices, lbs, ubs, steps, enclosedStmts);
}
Stmt mlir::edsc::For(const Bindable &idx, MaxExpr lb, MinExpr ub, Expr step,
llvm::ArrayRef<Stmt> enclosedStmts) {
return MaxMinFor(idx, lb.getArguments(), ub.getArguments(), step,
enclosedStmts);
}
Stmt mlir::edsc::For(llvm::ArrayRef<Expr> idxs, llvm::ArrayRef<MaxExpr> lbs,
llvm::ArrayRef<MinExpr> ubs, llvm::ArrayRef<Expr> steps,
llvm::ArrayRef<Stmt> enclosedStmts) {
return forNestImpl(idxs, lbs, ubs, steps, enclosedStmts);
}
Stmt mlir::edsc::MaxMinFor(const Bindable &idx, ArrayRef<Expr> lbs,
ArrayRef<Expr> ubs, Expr step,
ArrayRef<Stmt> enclosedStmts) {
@@ -405,6 +423,14 @@ Stmt mlir::edsc::MaxMinFor(const Bindable &idx, ArrayRef<Expr> lbs,
return Stmt(idx, StmtBlockLikeExpr(ExprKind::For, exprs), enclosedStmts);
}
edsc_max_expr_t Max(edsc_expr_list_t args) {
return mlir::edsc::Max(makeExprs(args));
}
edsc_min_expr_t Min(edsc_expr_list_t args) {
return mlir::edsc::Min(makeExprs(args));
}
edsc_stmt_t For(edsc_expr_t iv, edsc_expr_t lb, edsc_expr_t ub,
edsc_expr_t step, edsc_stmt_list_t enclosedStmts) {
llvm::SmallVector<Stmt, 8> stmts;
@@ -422,13 +448,12 @@ edsc_stmt_t ForNest(edsc_expr_list_t ivs, edsc_expr_list_t lbs,
makeExprs(steps), stmts));
}
edsc_stmt_t MaxMinFor(edsc_expr_t iv, edsc_expr_list_t lbs,
edsc_expr_list_t ubs, edsc_expr_t step,
edsc_stmt_list_t enclosedStmts) {
edsc_stmt_t MaxMinFor(edsc_expr_t iv, edsc_max_expr_t lb, edsc_min_expr_t ub,
edsc_expr_t step, edsc_stmt_list_t enclosedStmts) {
llvm::SmallVector<Stmt, 8> stmts;
fillStmts(enclosedStmts, &stmts);
return Stmt(MaxMinFor(Expr(iv).cast<Bindable>(), makeExprs(lbs),
makeExprs(ubs), Expr(step), stmts));
return Stmt(For(Expr(iv).cast<Bindable>(), MaxExpr(lb), MinExpr(ub),
Expr(step), stmts));
}
StmtBlock mlir::edsc::block(ArrayRef<Bindable> args, ArrayRef<Type> argTypes,
@@ -982,6 +1007,20 @@ edsc_indexed_t index(edsc_indexed_t indexed, edsc_expr_list_t indices) {
return edsc_indexed_t{indexed.base, indices};
}
MaxExpr::MaxExpr(ArrayRef<Expr> arguments) {
storage = Expr::globalAllocator()->Allocate<detail::ExprStorage>();
new (storage) detail::ExprStorage(ExprKind::Variadic, "", {}, arguments, {});
}
ArrayRef<Expr> MaxExpr::getArguments() const { return storage->operands; }
MinExpr::MinExpr(ArrayRef<Expr> arguments) {
storage = Expr::globalAllocator()->Allocate<detail::ExprStorage>();
new (storage) detail::ExprStorage(ExprKind::Variadic, "", {}, arguments, {});
}
ArrayRef<Expr> MinExpr::getArguments() const { return storage->operands; }
mlir_type_t makeScalarType(mlir_context_t context, const char *name,
unsigned bitwidth) {
mlir::MLIRContext *c = reinterpret_cast<mlir::MLIRContext *>(context);