Add a C API for EDSCs in other languages + python

This CL adds support for calling EDSCs from other languages than C++.
Following the LLVM convention this CL:
1. declares simple opaque types and a C API in mlir-c/Core.h;
2. defines the implementation directly in lib/EDSC/Types.cpp and
lib/EDSC/MLIREmitter.cpp.

Unlike LLVM however the nomenclature for these types and API functions is not
well-defined, naming suggestions are most welcome.

To avoid the need for conversion functions, Types.h and MLIREmitter.h include
mlir-c/Core.h and provide constructors and conversion operators between the
mlir::edsc type and the corresponding C type.

In this first commit, mlir-c/Core.h only contains the types for the C API
to allow EDSCs to work from Python. This includes both a minimal set of core
MLIR
types (mlir_context_t, mlir_type_t, mlir_func_t) as well as the EDSC types
(edsc_mlir_emitter_t, edsc_expr_t, edsc_stmt_t, edsc_indexed_t). This can be
restructured in the future as concrete needs arise.

For now, the API only supports:
1. scalar types;
2. memrefs of scalar types with static or symbolic shapes;
3. functions with input and output of these types.

The C API is not complete wrt ownership semantics. This is in large part due
to the fact that python bindings are written with Pybind11 which allows very
idiomatic C++ bindings. An effort is made to write a large chunk of these
bindings using the C API but some C++isms are used where the design benefits
from this simplication. A fully isolated C API will make more sense once we
also integrate with another language like Swift and have enough use cases to
drive the design.

Lastly, this CL also fixes a bug in mlir::ExecutionEngine were the order of
declaration of llvmContext and the JIT result in an improper order of
destructors (which used to crash before the fix).

PiperOrigin-RevId: 231290250
This commit is contained in:
Nicolas Vasilache
2019-01-28 14:32:00 -08:00
committed by jpienaar
parent eb753f4aec
commit cacf05892e
7 changed files with 686 additions and 102 deletions

248
mlir/include/mlir-c/Core.h Normal file
View File

@@ -0,0 +1,248 @@
/*===-- mlir-c/Core.h - Core Library C Interface ------------------*- C -*-===*\
|* *|
|* Copyright 2019 The MLIR Authors. *|
|* *|
|* Licensed under the Apache License, Version 2.0 (the "License"); *|
|* you may not use this file except in compliance with the License. *|
|* You may obtain a copy of the License at *|
|* *|
|* http://www.apache.org/licenses/LICENSE-2.0 *|
|* *|
|* Unless required by applicable law or agreed to in writing, software *|
|* distributed under the License is distributed on an "AS IS" BASIS, *|
|* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. *|
|* See the License for the specific language governing permissions and *|
|* limitations under the License. *|
|* *|
|*===----------------------------------------------------------------------===*|
|* *|
|* This header declares the C interface to MLIR. *|
|* *|
\*===----------------------------------------------------------------------===*/
#ifndef MLIR_C_CORE_H
#define MLIR_C_CORE_H
#ifdef __cplusplus
#include <cstdint>
extern "C" {
#else
#include <stdint.h>
#endif
/// Opaque MLIR types.
/// Opaque C type for mlir::MLIRContext*.
typedef void *mlir_context_t;
/// Opaque C type for mlir::Type.
typedef const void *mlir_type_t;
/// Opaque C type for mlir::Function*.
typedef void *mlir_func_t;
/// Opaque C type for mlir::edsc::MLIREmiter.
typedef void *edsc_mlir_emitter_t;
/// Opaque C type for mlir::edsc::Expr.
typedef void *edsc_expr_t;
/// Opaque C type for mlir::edsc::Stmt.
typedef void *edsc_stmt_t;
/// Simple C lists for non-owning mlir Opaque C types.
/// Recommended usage is construction from the `data()` and `size()` of a scoped
/// owning SmallVectorImpl<...> and passing to one of the C functions declared
/// later in this file.
/// Once the function returns and the proper EDSC has been constructed,
/// resources are freed by exiting the scope.
typedef struct {
int64_t *values;
uint64_t n;
} int64_list_t;
typedef struct {
mlir_type_t *types;
uint64_t n;
} mlir_type_list_t;
typedef struct {
edsc_expr_t *exprs;
uint64_t n;
} edsc_expr_list_t;
typedef struct {
edsc_stmt_t *stmts;
uint64_t n;
} edsc_stmt_list_t;
typedef struct {
edsc_expr_t base;
edsc_expr_list_t indices;
} edsc_indexed_t;
typedef struct {
edsc_indexed_t *list;
uint64_t n;
} edsc_indexed_list_t;
/// Minimal C API for exposing EDSCs to Swift, Python and other languages.
/// Returns a simple scalar mlir::Type using the following convention:
/// - makeScalarType(c, "bf16") return an `mlir::Type::getBF16`
/// - makeScalarType(c, "f16") return an `mlir::Type::getF16`
/// - makeScalarType(c, "f32") return an `mlir::Type::getF32`
/// - makeScalarType(c, "f64") return an `mlir::Type::getF64`
/// - makeScalarType(c, "index") return an `mlir::Type::getIndex`
/// - makeScalarType(c, "i", bitwidth) return an
/// `mlir::Type::getInteger(bitwidth)`
///
/// No other combinations are currently supported.
mlir_type_t makeScalarType(mlir_context_t context, const char *name,
unsigned bitwidth);
/// Returns an `mlir::MemRefType` of the element type `elemType` and shape
/// `sizes`.
mlir_type_t makeMemRefType(mlir_context_t context, mlir_type_t elemType,
int64_list_t sizes);
/// Returns an `mlir::FunctionType` of the element type `elemType` and shape
/// `sizes`.
mlir_type_t makeFunctionType(mlir_context_t context, mlir_type_list_t inputs,
mlir_type_list_t outputs);
/// Returns the arity of `function`.
unsigned getFunctionArity(mlir_func_t function);
/// Returns a new opaque mlir::edsc::Expr that is bound into `emitter` with a
/// constant of the specified type.
edsc_expr_t bindConstantBF16(edsc_mlir_emitter_t emitter, double value);
edsc_expr_t bindConstantF16(edsc_mlir_emitter_t emitter, float value);
edsc_expr_t bindConstantF32(edsc_mlir_emitter_t emitter, float value);
edsc_expr_t bindConstantF64(edsc_mlir_emitter_t emitter, double value);
edsc_expr_t bindConstantInt(edsc_mlir_emitter_t emitter, int64_t value,
unsigned bitwidth);
edsc_expr_t bindConstantIndex(edsc_mlir_emitter_t emitter, int64_t value);
/// Returns the rank of the `function` argument at position `pos`.
/// If the argument is of MemRefType, this returns the rank of the MemRef.
/// Otherwise returns `0`.
/// TODO(ntv): support more than MemRefType and scalar Type.
unsigned getRankOfFunctionArgument(mlir_func_t function, unsigned pos);
/// Returns an opaque mlir::Type of the `function` argument at position `pos`.
mlir_type_t getTypeOfFunctionArgument(mlir_func_t function, unsigned pos);
/// Returns an opaque mlir::edsc::Expr that has been bound to the `pos` argument
/// of `function`.
edsc_expr_t bindFunctionArgument(edsc_mlir_emitter_t emitter,
mlir_func_t function, unsigned pos);
/// Fills the preallocated list `result` with opaque mlir::edsc::Expr that have
/// been bound to each argument of `function`.
///
/// Prerequisites:
/// - `result` must have been preallocated with space for exactly the number
/// of arguments of `function`.
void bindFunctionArguments(edsc_mlir_emitter_t emitter, mlir_func_t function,
edsc_expr_list_t *result);
/// Returns the rank of `boundMemRef`. This API function is provided to more
/// easily compose with `bindFunctionArgument`. A similar function could be
/// provided for an mlir_type_t of type MemRefType but it is expected that users
/// of this API either:
/// 1. construct the MemRefType explicitly, in which case they already have
/// access to the rank and shape of the MemRefType;
/// 2. access MemRefs via mlir_function_t *values* in which case they would
/// pass edsc_expr_t bound to an edsc_emitter_t.
///
/// Prerequisites:
/// - `boundMemRef` must be an opaque edsc_expr_t that has alreay been bound
/// in `emitter`.
unsigned getBoundMemRefRank(edsc_mlir_emitter_t emitter,
edsc_expr_t boundMemRef);
/// Fills the preallocated list `result` with opaque mlir::edsc::Expr that have
/// been bound to each dimension of `boundMemRef`.
///
/// Prerequisites:
/// - `result` must have been preallocated with space for exactly the rank of
/// `boundMemRef`;
/// - `boundMemRef` must be an opaque edsc_expr_t that has alreay been bound
/// in `emitter`. This is because symbolic MemRef shapes require an SSAValue
/// that can only be recovered from `emitter`.
void bindMemRefShape(edsc_mlir_emitter_t emitter, edsc_expr_t boundMemRef,
edsc_expr_list_t *result);
/// Fills the preallocated lists `resultLbs`, `resultUbs` and `resultSteps` with
/// opaque mlir::edsc::Expr that have been bound to proper values to traverse
/// each dimension of `memRefType`.
/// At the moment:
/// - `resultsLbs` are always bound to the constant index `0`;
/// - `resultsUbs` are always bound to the shape of `memRefType`;
/// - `resultsSteps` are always bound to the constant index `1`.
/// In the future, this will allow non-contiguous MemRef views.
///
/// Prerequisites:
/// - `resultLbs`, `resultUbs` and `resultSteps` must have each been
/// preallocated with space for exactly the rank of `boundMemRef`;
/// - `boundMemRef` must be an opaque edsc_expr_t that has alreay been bound
/// in `emitter`. This is because symbolic MemRef shapes require an SSAValue
/// that can only be recovered from `emitter`.
void bindMemRefView(edsc_mlir_emitter_t emitter, edsc_expr_t boundMemRef,
edsc_expr_list_t *resultLbs, edsc_expr_list_t *resultUbs,
edsc_expr_list_t *resultSteps);
/// Returns an opaque expression for an mlir::edsc::Expr.
edsc_expr_t makeBindable();
/// Returns an opaque expression for an mlir::edsc::Stmt.
edsc_stmt_t makeStmt(edsc_expr_t e);
/// Returns an opaque expression for an mlir::edsc::Indexed.
edsc_indexed_t makeIndexed(edsc_expr_t expr);
/// Returns an indexed opaque expression with indices bound in the structure
/// given an `indexed` and `indices`.
/// Prerequisite:
/// - `indexed` must not have been indexed previously.
edsc_indexed_t index(edsc_indexed_t indexed, edsc_expr_list_t indices);
/// Returns an opaque expression that will emit an mlir::LoadOp.
edsc_expr_t Load(edsc_indexed_t indexed, edsc_expr_list_t indices);
/// Returns an opaque statement for an mlir::StoreOp.
edsc_stmt_t Store(edsc_expr_t value, edsc_indexed_t indexed,
edsc_expr_list_t indices);
/// Returns an opaque statement for an mlir::SelectOp.
edsc_expr_t Select(edsc_expr_t cond, edsc_expr_t lhs, edsc_expr_t rhs);
/// Returns an opaque statement for an mlir::ReturnOp.
edsc_stmt_t Return(edsc_expr_list_t values);
/// Returns a single opaque statement that acts as an mlir block. At the moment
/// this is pure syntactic sugar to allow lists of mlir::edsc::Stmt to be
/// specified and emitted. In particular, block arguments are not currently
/// supported.
edsc_stmt_t Block(edsc_stmt_list_t enclosedStmts);
/// Returns an opaque statement for an mlir::ForInst with `enclosedStmts` nested
/// below it.
edsc_stmt_t For(edsc_expr_t iv, edsc_expr_t lb, edsc_expr_t ub,
edsc_expr_t step, edsc_stmt_list_t enclosedStmts);
/// Returns an opaque statement for a perfectly nested set of mlir::ForInst with
/// `enclosedStmts` nested below it.
edsc_stmt_t ForNest(edsc_expr_list_t iv, edsc_expr_list_t lb,
edsc_expr_list_t ub, edsc_expr_list_t step,
edsc_stmt_list_t enclosedStmts);
/// Returns an opaque expression for the corresponding Binary operation.
edsc_expr_t Add(edsc_expr_t e1, edsc_expr_t e2);
edsc_expr_t Sub(edsc_expr_t e1, edsc_expr_t e2);
edsc_expr_t Mul(edsc_expr_t e1, edsc_expr_t e2);
// edsc_expr_t Div(edsc_expr_t e1, edsc_expr_t e2);
edsc_expr_t LT(edsc_expr_t e1, edsc_expr_t e2);
edsc_expr_t LE(edsc_expr_t e1, edsc_expr_t e2);
edsc_expr_t GT(edsc_expr_t e1, edsc_expr_t e2);
edsc_expr_t GE(edsc_expr_t e1, edsc_expr_t e2);
#ifdef __cplusplus
} // end extern "C"
#endif
#endif // MLIR_C_CORE_H

View File

@@ -121,7 +121,7 @@ struct MLIREmitter {
///
/// Prerequisite:
/// `memRef` is a Value of type MemRefType.
SmallVector<edsc::Bindable, 8> makeBoundSizes(Value *memRef);
SmallVector<edsc::Expr, 8> makeBoundSizes(Value *memRef);
private:
FuncBuilder *builder;

View File

@@ -24,6 +24,7 @@
#ifndef MLIR_LIB_EDSC_TYPES_H_
#define MLIR_LIB_EDSC_TYPES_H_
#include "mlir-c/Core.h"
#include "mlir/IR/Types.h"
#include "mlir/Support/LLVM.h"
@@ -171,6 +172,9 @@ public:
Expr() : storage(nullptr) {}
/* implicit */ Expr(ImplType *storage) : storage(storage) {}
explicit Expr(edsc_expr_t expr)
: storage(reinterpret_cast<ImplType *>(expr)) {}
operator edsc_expr_t() { return edsc_expr_t{storage}; }
Expr(const Expr &other) : storage(other.storage) {}
Expr &operator=(Expr other) {
@@ -227,6 +231,8 @@ struct Bindable : public Expr {
Bindable(Expr::ImplType *ptr) : Expr(ptr) {
assert(!ptr || isa<Bindable>() && "expected Bindable");
}
explicit Bindable(const edsc_expr_t &expr) : Expr(expr) {}
operator edsc_expr_t() { return edsc_expr_t{storage}; }
friend struct ScopedEDSCContext;
@@ -352,6 +358,9 @@ struct Stmt {
Stmt(const Bindable &lhs, const Expr &rhs,
llvm::ArrayRef<Stmt> stmts = llvm::ArrayRef<Stmt>());
Stmt &operator=(const Expr &expr);
explicit Stmt(edsc_stmt_t stmt)
: storage(reinterpret_cast<ImplType *>(stmt)) {}
operator edsc_stmt_t() { return edsc_stmt_t{storage}; }
operator Expr() const { return getLHS(); }
@@ -414,7 +423,7 @@ template <typename U> U Expr::dyn_cast() const {
if (isa<U>()) {
return U(storage);
}
return U(nullptr);
return U((Expr::ImplType *)(nullptr));
}
template <typename U> U Expr::cast() const {
assert(isa<U>());
@@ -497,6 +506,9 @@ Stmt For(llvm::MutableArrayRef<Bindable> indices, llvm::ArrayRef<Expr> lbs,
Stmt For(llvm::MutableArrayRef<Bindable> indices, llvm::ArrayRef<Bindable> lbs,
llvm::ArrayRef<Bindable> ubs, llvm::ArrayRef<Bindable> steps,
llvm::ArrayRef<Stmt> enclosedStmts);
Stmt For(llvm::MutableArrayRef<Bindable> indices, llvm::ArrayRef<Bindable> lbs,
llvm::ArrayRef<Expr> ubs, llvm::ArrayRef<Bindable> steps,
llvm::ArrayRef<Stmt> enclosedStmts);
/// This helper class exists purely for sugaring purposes and allows writing
/// expressions such as:
@@ -508,7 +520,8 @@ Stmt For(llvm::MutableArrayRef<Bindable> indices, llvm::ArrayRef<Bindable> lbs,
/// });
/// ```
struct Indexed {
Indexed(Bindable m) : base(m), indices() {}
Indexed(Bindable b) : base(b), indices() {}
Indexed(Expr e) : base(e), indices() {}
/// Returns a new `Indexed`. As a consequence, an Indexed with attached
/// indices can never be reused unless it is captured (e.g. via a Stmt).
@@ -534,7 +547,7 @@ struct Indexed {
Expr operator*(Expr e) const { return static_cast<Expr>(*this) * e; }
private:
Bindable base;
Expr base;
llvm::SmallVector<Expr, 4> indices;
};

View File

@@ -73,9 +73,11 @@ public:
llvm::Error invoke(StringRef name, MutableArrayRef<void *> args);
private:
// Ordering of llvmContext and jit is important for destruction purposes: the
// jit must be destroyed before the context.
llvm::LLVMContext llvmContext;
// Private implementation of the JIT (PIMPL)
std::unique_ptr<impl::OrcJIT> jit;
llvm::LLVMContext llvmContext;
};
template <typename... Args>

View File

@@ -20,6 +20,7 @@
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir-c/Core.h"
#include "mlir/Analysis/AffineAnalysis.h"
#include "mlir/EDSC/MLIREmitter.h"
#include "mlir/EDSC/Types.h"
@@ -38,8 +39,9 @@ using llvm::errs;
#define DEBUG_TYPE "edsc"
namespace mlir {
namespace edsc {
using namespace mlir;
using namespace mlir::edsc;
using namespace mlir::edsc::detail;
// Factors out the boilerplate that is needed to build and answer the
// following simple question:
@@ -89,7 +91,7 @@ static bool isFloatElement(const Value &v) {
return getElementType(v).isa<FloatType>();
}
Value *add(FuncBuilder *builder, Location location, Value *a, Value *b) {
static Value *add(FuncBuilder *builder, Location location, Value *a, Value *b) {
if (isIndexElement(*a)) {
auto *context = builder->getContext();
auto d0 = getAffineDimExpr(0, context);
@@ -103,7 +105,7 @@ Value *add(FuncBuilder *builder, Location location, Value *a, Value *b) {
return builder->create<AddFOp>(location, a, b)->getResult();
}
Value *sub(FuncBuilder *builder, Location location, Value *a, Value *b) {
static Value *sub(FuncBuilder *builder, Location location, Value *a, Value *b) {
if (isIndexElement(*a)) {
auto *context = builder->getContext();
auto d0 = getAffineDimExpr(0, context);
@@ -117,7 +119,7 @@ Value *sub(FuncBuilder *builder, Location location, Value *a, Value *b) {
return builder->create<SubFOp>(location, a, b)->getResult();
}
Value *mul(FuncBuilder *builder, Location location, Value *a, Value *b) {
static Value *mul(FuncBuilder *builder, Location location, Value *a, Value *b) {
if (!isFloatElement(*a)) {
return builder->create<MulIOp>(location, a, b)->getResult();
}
@@ -138,7 +140,7 @@ static void printDefininingStatement(llvm::raw_ostream &os, const Value &v) {
}
}
MLIREmitter &MLIREmitter::bind(Bindable e, Value *v) {
MLIREmitter &mlir::edsc::MLIREmitter::bind(Bindable e, Value *v) {
LLVM_DEBUG(printDefininingStatement(llvm::dbgs() << "\nBinding " << e << " @"
<< e.getStoragePtr() << ": ",
*v));
@@ -151,7 +153,7 @@ MLIREmitter &MLIREmitter::bind(Bindable e, Value *v) {
return *this;
}
Value *MLIREmitter::emit(Expr e) {
Value *mlir::edsc::MLIREmitter::emit(Expr e) {
auto it = ssaBindings.find(e);
if (it != ssaBindings.end()) {
return it->second;
@@ -316,7 +318,7 @@ Value *MLIREmitter::emit(Expr e) {
return res;
}
SmallVector<Value *, 8> MLIREmitter::emit(ArrayRef<Expr> exprs) {
SmallVector<Value *, 8> mlir::edsc::MLIREmitter::emit(ArrayRef<Expr> exprs) {
return mlir::functional::map(
[this](Expr e) {
auto *res = this->emit(e);
@@ -327,7 +329,7 @@ SmallVector<Value *, 8> MLIREmitter::emit(ArrayRef<Expr> exprs) {
exprs);
}
void MLIREmitter::emitStmt(const Stmt &stmt) {
void mlir::edsc::MLIREmitter::emitStmt(const Stmt &stmt) {
auto *block = builder->getBlock();
auto ip = builder->getInsertionPoint();
// Blocks are just a containing abstraction, they do not emit their RHS.
@@ -351,7 +353,7 @@ void MLIREmitter::emitStmt(const Stmt &stmt) {
builder->setInsertionPoint(block, ip);
}
void MLIREmitter::emitStmts(ArrayRef<Stmt> stmts) {
void mlir::edsc::MLIREmitter::emitStmts(ArrayRef<Stmt> stmts) {
for (auto &stmt : stmts) {
emitStmt(stmt);
}
@@ -390,7 +392,8 @@ static bool isDynamicSize(int size) { return size < 0; }
/// and returns the vector with {%d0, %c3, %c4, %d3, %c5}.
static SmallVector<Value *, 8> getMemRefSizes(FuncBuilder *b, Location loc,
Value *memRef) {
auto memRefType = memRef->getType().template cast<MemRefType>();
assert(memRef->getType().isa<MemRefType>() && "Expected a MemRef value");
MemRefType memRefType = memRef->getType().cast<MemRefType>();
SmallVector<Value *, 8> res;
res.reserve(memRefType.getShape().size());
const auto &shape = memRefType.getShape();
@@ -404,15 +407,165 @@ static SmallVector<Value *, 8> getMemRefSizes(FuncBuilder *b, Location loc,
return res;
}
SmallVector<edsc::Bindable, 8> MLIREmitter::makeBoundSizes(Value *memRef) {
SmallVector<edsc::Expr, 8>
mlir::edsc::MLIREmitter::makeBoundSizes(Value *memRef) {
assert(memRef->getType().isa<MemRefType>() && "Expected a MemRef value");
MemRefType memRefType = memRef->getType().cast<MemRefType>();
auto memRefSizes = edsc::makeBindables(memRefType.getShape().size());
auto memrefSizeValues = getMemRefSizes(getBuilder(), getLocation(), memRef);
assert(memrefSizeValues.size() == memRefSizes.size());
bindZipRange(llvm::zip(memRefSizes, memrefSizeValues));
return memRefSizes;
SmallVector<edsc::Expr, 8> res(memRefSizes.begin(), memRefSizes.end());
return res;
}
} // namespace edsc
} // namespace mlir
edsc_expr_t bindConstantBF16(edsc_mlir_emitter_t emitter, double value) {
auto *e = reinterpret_cast<mlir::edsc::MLIREmitter *>(emitter);
Bindable b;
e->bindConstant<mlir::ConstantFloatOp>(b, mlir::APFloat(value),
e->getBuilder()->getBF16Type());
return b;
}
edsc_expr_t bindConstantF16(edsc_mlir_emitter_t emitter, float value) {
auto *e = reinterpret_cast<mlir::edsc::MLIREmitter *>(emitter);
Bindable b;
bool unused;
mlir::APFloat val(value);
val.convert(e->getBuilder()->getF16Type().getFloatSemantics(),
mlir::APFloat::rmNearestTiesToEven, &unused);
e->bindConstant<mlir::ConstantFloatOp>(b, val, e->getBuilder()->getF16Type());
return b;
}
edsc_expr_t bindConstantF32(edsc_mlir_emitter_t emitter, float value) {
auto *e = reinterpret_cast<mlir::edsc::MLIREmitter *>(emitter);
Bindable b;
e->bindConstant<mlir::ConstantFloatOp>(b, mlir::APFloat(value),
e->getBuilder()->getF32Type());
return b;
}
edsc_expr_t bindConstantF64(edsc_mlir_emitter_t emitter, double value) {
auto *e = reinterpret_cast<mlir::edsc::MLIREmitter *>(emitter);
Bindable b;
e->bindConstant<mlir::ConstantFloatOp>(b, mlir::APFloat(value),
e->getBuilder()->getF64Type());
return b;
}
edsc_expr_t bindConstantInt(edsc_mlir_emitter_t emitter, int64_t value,
unsigned bitwidth) {
auto *e = reinterpret_cast<mlir::edsc::MLIREmitter *>(emitter);
Bindable b;
e->bindConstant<mlir::ConstantIntOp>(
b, value, e->getBuilder()->getIntegerType(bitwidth));
return b;
}
edsc_expr_t bindConstantIndex(edsc_mlir_emitter_t emitter, int64_t value) {
auto *e = reinterpret_cast<mlir::edsc::MLIREmitter *>(emitter);
Bindable b;
e->bindConstant<mlir::ConstantIndexOp>(b, value);
return b;
}
unsigned getRankOfFunctionArgument(mlir_func_t function, unsigned pos) {
auto *f = reinterpret_cast<mlir::Function *>(function);
assert(pos < f->getNumArguments());
auto *arg = *(f->getArguments().begin() + pos);
if (auto memRefType = arg->getType().dyn_cast<mlir::MemRefType>()) {
return memRefType.getRank();
}
return 0;
}
mlir_type_t getTypeOfFunctionArgument(mlir_func_t function, unsigned pos) {
auto *f = reinterpret_cast<mlir::Function *>(function);
assert(pos < f->getNumArguments());
auto *arg = *(f->getArguments().begin() + pos);
return mlir_type_t{arg->getType().getAsOpaquePointer()};
}
edsc_expr_t bindFunctionArgument(edsc_mlir_emitter_t emitter,
mlir_func_t function, unsigned pos) {
auto *e = reinterpret_cast<mlir::edsc::MLIREmitter *>(emitter);
auto *f = reinterpret_cast<mlir::Function *>(function);
assert(pos < f->getNumArguments());
auto *arg = *(f->getArguments().begin() + pos);
Bindable b;
e->bind(b, arg);
return Expr(b);
}
void bindFunctionArguments(edsc_mlir_emitter_t emitter, mlir_func_t function,
edsc_expr_list_t *result) {
auto *e = reinterpret_cast<mlir::edsc::MLIREmitter *>(emitter);
auto *f = reinterpret_cast<mlir::Function *>(function);
assert(result->n == f->getNumArguments());
for (unsigned pos = 0; pos < result->n; ++pos) {
auto *arg = *(f->getArguments().begin() + pos);
Bindable b;
e->bind(b, arg);
result->exprs[pos] = Expr(b);
}
}
unsigned getBoundMemRefRank(edsc_mlir_emitter_t emitter,
edsc_expr_t boundMemRef) {
auto *e = reinterpret_cast<mlir::edsc::MLIREmitter *>(emitter);
auto *v = e->getValue(mlir::edsc::Expr(boundMemRef));
auto memRefType = v->getType().cast<mlir::MemRefType>();
return memRefType.getRank();
}
void bindMemRefShape(edsc_mlir_emitter_t emitter, edsc_expr_t boundMemRef,
edsc_expr_list_t *result) {
auto *e = reinterpret_cast<mlir::edsc::MLIREmitter *>(emitter);
auto *v = e->getValue(mlir::edsc::Expr(boundMemRef));
auto memRefType = v->getType().cast<mlir::MemRefType>();
auto rank = memRefType.getRank();
assert(result->n == rank && "Unexpected memref shape binding results count");
auto bindables = e->makeBoundSizes(v);
for (unsigned i = 0; i < rank; ++i) {
result->exprs[i] = bindables[i];
}
}
void bindMemRefView(edsc_mlir_emitter_t emitter, edsc_expr_t boundMemRef,
edsc_expr_list_t *resultLbs, edsc_expr_list_t *resultUbs,
edsc_expr_list_t *resultSteps) {
auto *e = reinterpret_cast<mlir::edsc::MLIREmitter *>(emitter);
auto *v = e->getValue(mlir::edsc::Expr(boundMemRef));
auto memRefType = v->getType().cast<mlir::MemRefType>();
auto rank = memRefType.getRank();
assert(resultLbs->n == rank && "Unexpected memref binding results count");
assert(resultUbs->n == rank && "Unexpected memref binding results count");
assert(resultSteps->n == rank && "Unexpected memref binding results count");
auto bindables = e->makeBoundSizes(v);
for (unsigned i = 0; i < rank; ++i) {
Bindable zero;
e->bindConstant<mlir::ConstantIndexOp>(zero, 0);
resultLbs->exprs[i] = zero;
resultUbs->exprs[i] = bindables[i];
Bindable one;
e->bindConstant<mlir::ConstantIndexOp>(one, 1);
resultSteps->exprs[i] = one;
}
}
#define DEFINE_EDSL_BINARY_OP(FUN_NAME, OP_SYMBOL) \
edsc_expr_t FUN_NAME(edsc_expr_t e1, edsc_expr_t e2) { \
return Expr(e1) OP_SYMBOL Expr(e2); \
}
DEFINE_EDSL_BINARY_OP(Add, +);
DEFINE_EDSL_BINARY_OP(Sub, -);
DEFINE_EDSL_BINARY_OP(Mul, *);
// DEFINE_EDSL_BINARY_OP(Div, /);
DEFINE_EDSL_BINARY_OP(LT, <);
DEFINE_EDSL_BINARY_OP(LE, <=);
DEFINE_EDSL_BINARY_OP(GT, >);
DEFINE_EDSL_BINARY_OP(GE, >=);
#undef DEFINE_EDSL_BINARY_OP

View File

@@ -16,9 +16,14 @@
// =============================================================================
#include "mlir/EDSC/Types.h"
#include "mlir-c/Core.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Support/STLExtras.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/raw_ostream.h"
@@ -77,53 +82,57 @@ struct StmtStorage {
};
} // namespace detail
} // namespace edsc
} // namespace mlir
ScopedEDSCContext::ScopedEDSCContext() {
mlir::edsc::ScopedEDSCContext::ScopedEDSCContext() {
Expr::globalAllocator() = &allocator;
Bindable::resetIds();
}
ScopedEDSCContext::~ScopedEDSCContext() { Expr::globalAllocator() = nullptr; }
mlir::edsc::ScopedEDSCContext::~ScopedEDSCContext() {
Expr::globalAllocator() = nullptr;
}
ExprKind Expr::getKind() const { return storage->kind; }
ExprKind mlir::edsc::Expr::getKind() const { return storage->kind; }
Expr Expr::operator+(Expr other) const {
Expr mlir::edsc::Expr::operator+(Expr other) const {
return BinaryExpr(ExprKind::Add, *this, other);
}
Expr Expr::operator-(Expr other) const {
Expr mlir::edsc::Expr::operator-(Expr other) const {
return BinaryExpr(ExprKind::Sub, *this, other);
}
Expr Expr::operator*(Expr other) const {
Expr mlir::edsc::Expr::operator*(Expr other) const {
return BinaryExpr(ExprKind::Mul, *this, other);
}
Expr Expr::operator==(Expr other) const {
Expr mlir::edsc::Expr::operator==(Expr other) const {
return BinaryExpr(ExprKind::EQ, *this, other);
}
Expr Expr::operator!=(Expr other) const {
Expr mlir::edsc::Expr::operator!=(Expr other) const {
return BinaryExpr(ExprKind::NE, *this, other);
}
Expr Expr::operator<(Expr other) const {
Expr mlir::edsc::Expr::operator<(Expr other) const {
return BinaryExpr(ExprKind::LT, *this, other);
}
Expr Expr::operator<=(Expr other) const {
Expr mlir::edsc::Expr::operator<=(Expr other) const {
return BinaryExpr(ExprKind::LE, *this, other);
}
Expr Expr::operator>(Expr other) const {
Expr mlir::edsc::Expr::operator>(Expr other) const {
return BinaryExpr(ExprKind::GT, *this, other);
}
Expr Expr::operator>=(Expr other) const {
Expr mlir::edsc::Expr::operator>=(Expr other) const {
return BinaryExpr(ExprKind::GE, *this, other);
}
Expr Expr::operator&&(Expr other) const {
Expr mlir::edsc::Expr::operator&&(Expr other) const {
return BinaryExpr(ExprKind::And, *this, other);
}
Expr Expr::operator||(Expr other) const {
Expr mlir::edsc::Expr::operator||(Expr other) const {
return BinaryExpr(ExprKind::Or, *this, other);
}
// Free functions.
llvm::SmallVector<Bindable, 8> makeBindables(unsigned n) {
llvm::SmallVector<Bindable, 8> mlir::edsc::makeBindables(unsigned n) {
llvm::SmallVector<Bindable, 8> res;
res.reserve(n);
for (auto i = 0; i < n; ++i) {
@@ -132,7 +141,16 @@ llvm::SmallVector<Bindable, 8> makeBindables(unsigned n) {
return res;
}
llvm::SmallVector<Expr, 8> makeExprs(unsigned n) {
static llvm::SmallVector<Bindable, 8> makeBindables(edsc_expr_list_t exprList) {
llvm::SmallVector<Bindable, 8> exprs;
exprs.reserve(exprList.n);
for (unsigned i = 0; i < exprList.n; ++i) {
exprs.push_back(Expr(exprList.exprs[i]).cast<Bindable>());
}
return exprs;
}
llvm::SmallVector<Expr, 8> mlir::edsc::makeExprs(unsigned n) {
llvm::SmallVector<Expr, 8> res;
res.reserve(n);
for (auto i = 0; i < n; ++i) {
@@ -141,7 +159,7 @@ llvm::SmallVector<Expr, 8> makeExprs(unsigned n) {
return res;
}
llvm::SmallVector<Expr, 8> makeExprs(ArrayRef<Bindable> bindables) {
llvm::SmallVector<Expr, 8> mlir::edsc::makeExprs(ArrayRef<Bindable> bindables) {
llvm::SmallVector<Expr, 8> res;
res.reserve(bindables.size());
for (auto b : bindables) {
@@ -150,29 +168,53 @@ llvm::SmallVector<Expr, 8> makeExprs(ArrayRef<Bindable> bindables) {
return res;
}
Expr alloc(llvm::ArrayRef<Expr> sizes, Type memrefType) {
static llvm::SmallVector<Expr, 8> makeExprs(edsc_expr_list_t exprList) {
llvm::SmallVector<Expr, 8> exprs;
exprs.reserve(exprList.n);
for (unsigned i = 0; i < exprList.n; ++i) {
exprs.push_back(Expr(exprList.exprs[i]));
}
return exprs;
}
static llvm::SmallVector<Stmt, 8> makeStmts(edsc_stmt_list_t enclosedStmts) {
llvm::SmallVector<Stmt, 8> stmts;
stmts.reserve(enclosedStmts.n);
for (unsigned i = 0; i < enclosedStmts.n; ++i) {
stmts.push_back(Stmt(enclosedStmts.stmts[i]));
}
return stmts;
}
Expr mlir::edsc::alloc(llvm::ArrayRef<Expr> sizes, Type memrefType) {
return VariadicExpr(ExprKind::Alloc, sizes, memrefType);
}
Stmt Block(ArrayRef<Stmt> stmts) {
Stmt mlir::edsc::Block(ArrayRef<Stmt> stmts) {
return Stmt(StmtBlockLikeExpr(ExprKind::Block, {}), stmts);
}
Expr dealloc(Expr memref) { return UnaryExpr(ExprKind::Dealloc, memref); }
edsc_stmt_t Block(edsc_stmt_list_t enclosedStmts) {
return Stmt(mlir::edsc::Block(makeStmts(enclosedStmts)));
}
Stmt For(Expr lb, Expr ub, Expr step, ArrayRef<Stmt> stmts) {
Expr mlir::edsc::dealloc(Expr memref) {
return UnaryExpr(ExprKind::Dealloc, memref);
}
Stmt mlir::edsc::For(Expr lb, Expr ub, Expr step, ArrayRef<Stmt> stmts) {
Bindable idx;
return For(idx, lb, ub, step, stmts);
}
Stmt For(const Bindable &idx, Expr lb, Expr ub, Expr step,
ArrayRef<Stmt> stmts) {
Stmt mlir::edsc::For(const Bindable &idx, Expr lb, Expr ub, Expr step,
ArrayRef<Stmt> stmts) {
return Stmt(idx, StmtBlockLikeExpr(ExprKind::For, {lb, ub, step}), stmts);
}
Stmt For(MutableArrayRef<Bindable> indices, ArrayRef<Expr> lbs,
ArrayRef<Expr> ubs, ArrayRef<Expr> steps,
ArrayRef<Stmt> enclosedStmts) {
Stmt mlir::edsc::For(MutableArrayRef<Bindable> indices, ArrayRef<Expr> lbs,
ArrayRef<Expr> ubs, ArrayRef<Expr> steps,
ArrayRef<Stmt> enclosedStmts) {
assert(!indices.empty());
assert(indices.size() == lbs.size());
assert(indices.size() == ubs.size());
@@ -185,14 +227,37 @@ Stmt For(MutableArrayRef<Bindable> indices, ArrayRef<Expr> lbs,
return curStmt;
}
Stmt For(llvm::MutableArrayRef<Bindable> indices, llvm::ArrayRef<Bindable> lbs,
llvm::ArrayRef<Bindable> ubs, llvm::ArrayRef<Bindable> steps,
llvm::ArrayRef<Stmt> enclosedStmts) {
Stmt mlir::edsc::For(llvm::MutableArrayRef<Bindable> indices,
llvm::ArrayRef<Bindable> lbs, llvm::ArrayRef<Bindable> ubs,
llvm::ArrayRef<Bindable> steps,
llvm::ArrayRef<Stmt> enclosedStmts) {
return For(indices, SmallVector<Expr, 8>{lbs.begin(), lbs.end()},
SmallVector<Expr, 8>{ubs.begin(), ubs.end()},
SmallVector<Expr, 8>{steps.begin(), steps.end()}, enclosedStmts);
}
Stmt mlir::edsc::For(llvm::MutableArrayRef<Bindable> indices,
llvm::ArrayRef<Bindable> lbs, llvm::ArrayRef<Expr> ubs,
llvm::ArrayRef<Bindable> steps,
llvm::ArrayRef<Stmt> enclosedStmts) {
return For(indices, SmallVector<Expr, 8>{lbs.begin(), lbs.end()}, ubs,
SmallVector<Expr, 8>{steps.begin(), steps.end()}, enclosedStmts);
}
edsc_stmt_t For(edsc_expr_t iv, edsc_expr_t lb, edsc_expr_t ub,
edsc_expr_t step, edsc_stmt_list_t enclosedStmts) {
return Stmt(For(Expr(iv).cast<Bindable>(), Expr(lb), Expr(ub), Expr(step),
makeStmts(enclosedStmts)));
}
edsc_stmt_t ForNest(edsc_expr_list_t ivs, edsc_expr_list_t lbs,
edsc_expr_list_t ubs, edsc_expr_list_t steps,
edsc_stmt_list_t enclosedStmts) {
auto bindables = makeBindables(ivs);
return Stmt(For(bindables, makeExprs(lbs), makeExprs(ubs), makeExprs(steps),
makeStmts(enclosedStmts)));
}
template <typename BindableOrExpr>
static Expr loadBuilder(Expr m, ArrayRef<BindableOrExpr> indices) {
SmallVector<Expr, 8> exprs;
@@ -200,12 +265,24 @@ static Expr loadBuilder(Expr m, ArrayRef<BindableOrExpr> indices) {
exprs.append(indices.begin(), indices.end());
return VariadicExpr(ExprKind::Load, exprs);
}
Expr load(Expr m, Expr index) { return loadBuilder<Expr>(m, {index}); }
Expr load(Expr m, Bindable index) { return loadBuilder<Bindable>(m, {index}); }
Expr load(Expr m, const llvm::SmallVectorImpl<Bindable> &indices) {
Expr mlir::edsc::load(Expr m, Expr index) {
return loadBuilder<Expr>(m, {index});
}
Expr mlir::edsc::load(Expr m, Bindable index) {
return loadBuilder<Bindable>(m, {index});
}
Expr mlir::edsc::load(Expr m, const llvm::SmallVectorImpl<Bindable> &indices) {
return loadBuilder(m, ArrayRef<Bindable>{indices.begin(), indices.end()});
}
Expr load(Expr m, ArrayRef<Expr> indices) { return loadBuilder(m, indices); }
Expr mlir::edsc::load(Expr m, ArrayRef<Expr> indices) {
return loadBuilder(m, indices);
}
edsc_expr_t Load(edsc_indexed_t indexed, edsc_expr_list_t indices) {
Indexed i(Expr(indexed.base).cast<Bindable>());
Expr res = i[makeExprs(indices)];
return res;
}
template <typename BindableOrExpr>
static Expr storeBuilder(Expr val, Expr m, ArrayRef<BindableOrExpr> indices) {
@@ -215,33 +292,49 @@ static Expr storeBuilder(Expr val, Expr m, ArrayRef<BindableOrExpr> indices) {
exprs.append(indices.begin(), indices.end());
return VariadicExpr(ExprKind::Store, exprs);
}
Expr store(Expr val, Expr m, Expr index) {
Expr mlir::edsc::store(Expr val, Expr m, Expr index) {
return storeBuilder<Expr>(val, m, {index});
}
Expr store(Expr val, Expr m, Bindable index) {
Expr mlir::edsc::store(Expr val, Expr m, Bindable index) {
return storeBuilder<Bindable>(val, m, {index});
}
Expr store(Expr val, Expr m, const llvm::SmallVectorImpl<Bindable> &indices) {
Expr mlir::edsc::store(Expr val, Expr m,
const llvm::SmallVectorImpl<Bindable> &indices) {
return storeBuilder(val, m,
ArrayRef<Bindable>{indices.begin(), indices.end()});
}
Expr store(Expr val, Expr m, ArrayRef<Expr> indices) {
Expr mlir::edsc::store(Expr val, Expr m, ArrayRef<Expr> indices) {
return storeBuilder(val, m, indices);
}
Expr select(Expr cond, Expr lhs, Expr rhs) {
edsc_stmt_t Store(edsc_expr_t value, edsc_indexed_t indexed,
edsc_expr_list_t indices) {
Indexed i(Expr(indexed.base).cast<Bindable>());
Indexed loc = i[makeExprs(indices)];
return Stmt(loc = Expr(value));
}
Expr mlir::edsc::select(Expr cond, Expr lhs, Expr rhs) {
return TernaryExpr(ExprKind::Select, cond, lhs, rhs);
}
Expr vector_type_cast(Expr memrefExpr, Type memrefType) {
edsc_expr_t Select(edsc_expr_t cond, edsc_expr_t lhs, edsc_expr_t rhs) {
return select(Expr(cond), Expr(lhs), Expr(rhs));
}
Expr mlir::edsc::vector_type_cast(Expr memrefExpr, Type memrefType) {
return VariadicExpr(ExprKind::VectorTypeCast, {memrefExpr}, {memrefType});
}
Stmt Return(ArrayRef<Expr> values) {
Stmt mlir::edsc::Return(ArrayRef<Expr> values) {
return VariadicExpr(ExprKind::Return, values);
}
void Expr::print(raw_ostream &os) const {
edsc_stmt_t Return(edsc_expr_list_t values) {
return Stmt(Return(makeExprs(values)));
}
void mlir::edsc::Expr::print(raw_ostream &os) const {
if (auto unbound = this->dyn_cast<Bindable>()) {
os << "$" << unbound.getId();
return;
@@ -322,73 +415,77 @@ void Expr::print(raw_ostream &os) const {
os << "unknown_kind(" << static_cast<int>(getKind()) << ")";
}
void Expr::dump() const { this->print(llvm::errs()); }
void mlir::edsc::Expr::dump() const { this->print(llvm::errs()); }
std::string Expr::str() const {
std::string mlir::edsc::Expr::str() const {
std::string res;
llvm::raw_string_ostream os(res);
this->print(os);
return res;
}
llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Expr &expr) {
llvm::raw_ostream &mlir::edsc::operator<<(llvm::raw_ostream &os,
const Expr &expr) {
expr.print(os);
return os;
}
Bindable::Bindable()
mlir::edsc::Bindable::Bindable()
: Expr(Expr::globalAllocator()->Allocate<detail::BindableStorage>()) {
// Initialize with placement new.
new (storage) detail::BindableStorage{Bindable::newId()};
}
unsigned Bindable::getId() const {
edsc_expr_t makeBindable() { return Bindable(); }
unsigned mlir::edsc::Bindable::getId() const {
return static_cast<ImplType *>(storage)->id;
}
unsigned &Bindable::newId() {
unsigned &mlir::edsc::Bindable::newId() {
static thread_local unsigned id = 0;
return ++id;
}
UnaryExpr::UnaryExpr(ExprKind kind, Expr expr)
mlir::edsc::UnaryExpr::UnaryExpr(ExprKind kind, Expr expr)
: Expr(Expr::globalAllocator()->Allocate<detail::UnaryExprStorage>()) {
// Initialize with placement new.
new (storage) detail::UnaryExprStorage{kind, expr};
}
Expr UnaryExpr::getExpr() const {
Expr mlir::edsc::UnaryExpr::getExpr() const {
return static_cast<ImplType *>(storage)->expr;
}
BinaryExpr::BinaryExpr(ExprKind kind, Expr lhs, Expr rhs)
mlir::edsc::BinaryExpr::BinaryExpr(ExprKind kind, Expr lhs, Expr rhs)
: Expr(Expr::globalAllocator()->Allocate<detail::BinaryExprStorage>()) {
// Initialize with placement new.
new (storage) detail::BinaryExprStorage{kind, lhs, rhs};
}
Expr BinaryExpr::getLHS() const {
Expr mlir::edsc::BinaryExpr::getLHS() const {
return static_cast<ImplType *>(storage)->lhs;
}
Expr BinaryExpr::getRHS() const {
Expr mlir::edsc::BinaryExpr::getRHS() const {
return static_cast<ImplType *>(storage)->rhs;
}
TernaryExpr::TernaryExpr(ExprKind kind, Expr cond, Expr lhs, Expr rhs)
mlir::edsc::TernaryExpr::TernaryExpr(ExprKind kind, Expr cond, Expr lhs,
Expr rhs)
: Expr(Expr::globalAllocator()->Allocate<detail::TernaryExprStorage>()) {
// Initialize with placement new.
new (storage) detail::TernaryExprStorage{kind, cond, lhs, rhs};
}
Expr TernaryExpr::getCond() const {
Expr mlir::edsc::TernaryExpr::getCond() const {
return static_cast<ImplType *>(storage)->cond;
}
Expr TernaryExpr::getLHS() const {
Expr mlir::edsc::TernaryExpr::getLHS() const {
return static_cast<ImplType *>(storage)->lhs;
}
Expr TernaryExpr::getRHS() const {
Expr mlir::edsc::TernaryExpr::getRHS() const {
return static_cast<ImplType *>(storage)->rhs;
}
VariadicExpr::VariadicExpr(ExprKind kind, ArrayRef<Expr> exprs,
ArrayRef<Type> types)
mlir::edsc::VariadicExpr::VariadicExpr(ExprKind kind, ArrayRef<Expr> exprs,
ArrayRef<Type> types)
: Expr(Expr::globalAllocator()->Allocate<detail::VariadicExprStorage>()) {
// Initialize with placement new.
auto exprStorage = Expr::globalAllocator()->Allocate<Expr>(exprs.size());
@@ -399,15 +496,16 @@ VariadicExpr::VariadicExpr(ExprKind kind, ArrayRef<Expr> exprs,
kind, ArrayRef<Expr>(exprStorage, exprs.size()),
ArrayRef<Type>(typeStorage, types.size())};
}
ArrayRef<Expr> VariadicExpr::getExprs() const {
ArrayRef<Expr> mlir::edsc::VariadicExpr::getExprs() const {
return static_cast<ImplType *>(storage)->exprs;
}
ArrayRef<Type> VariadicExpr::getTypes() const {
ArrayRef<Type> mlir::edsc::VariadicExpr::getTypes() const {
return static_cast<ImplType *>(storage)->types;
}
StmtBlockLikeExpr::StmtBlockLikeExpr(ExprKind kind, ArrayRef<Expr> exprs,
ArrayRef<Type> types)
mlir::edsc::StmtBlockLikeExpr::StmtBlockLikeExpr(ExprKind kind,
ArrayRef<Expr> exprs,
ArrayRef<Type> types)
: Expr(Expr::globalAllocator()->Allocate<detail::VariadicExprStorage>()) {
// Initialize with placement new.
auto exprStorage = Expr::globalAllocator()->Allocate<Expr>(exprs.size());
@@ -418,15 +516,15 @@ StmtBlockLikeExpr::StmtBlockLikeExpr(ExprKind kind, ArrayRef<Expr> exprs,
kind, ArrayRef<Expr>(exprStorage, exprs.size()),
ArrayRef<Type>(typeStorage, types.size())};
}
ArrayRef<Expr> StmtBlockLikeExpr::getExprs() const {
ArrayRef<Expr> mlir::edsc::StmtBlockLikeExpr::getExprs() const {
return static_cast<ImplType *>(storage)->exprs;
}
ArrayRef<Type> StmtBlockLikeExpr::getTypes() const {
ArrayRef<Type> mlir::edsc::StmtBlockLikeExpr::getTypes() const {
return static_cast<ImplType *>(storage)->types;
}
Stmt::Stmt(const Bindable &lhs, const Expr &rhs,
llvm::ArrayRef<Stmt> enclosedStmts) {
mlir::edsc::Stmt::Stmt(const Bindable &lhs, const Expr &rhs,
llvm::ArrayRef<Stmt> enclosedStmts) {
storage = Expr::globalAllocator()->Allocate<detail::StmtStorage>();
// Initialize with placement new.
auto enclosedStmtStorage =
@@ -437,24 +535,33 @@ Stmt::Stmt(const Bindable &lhs, const Expr &rhs,
lhs, rhs, ArrayRef<Stmt>(enclosedStmtStorage, enclosedStmts.size())};
}
Stmt::Stmt(const Expr &rhs, llvm::ArrayRef<Stmt> enclosedStmts)
mlir::edsc::Stmt::Stmt(const Expr &rhs, llvm::ArrayRef<Stmt> enclosedStmts)
: Stmt(Bindable(), rhs, enclosedStmts) {}
Stmt &Stmt::operator=(const Expr &expr) {
edsc_stmt_t makeStmt(edsc_expr_t e) {
assert(e && "unexpected empty expression");
return Stmt(Expr(e));
}
Stmt &mlir::edsc::Stmt::operator=(const Expr &expr) {
Stmt res(Bindable(), expr, {});
std::swap(res.storage, this->storage);
return *this;
}
Bindable Stmt::getLHS() const { return static_cast<ImplType *>(storage)->lhs; }
Bindable mlir::edsc::Stmt::getLHS() const {
return static_cast<ImplType *>(storage)->lhs;
}
Expr Stmt::getRHS() const { return static_cast<ImplType *>(storage)->rhs; }
Expr mlir::edsc::Stmt::getRHS() const {
return static_cast<ImplType *>(storage)->rhs;
}
llvm::ArrayRef<Stmt> Stmt::getEnclosedStmts() const {
llvm::ArrayRef<Stmt> mlir::edsc::Stmt::getEnclosedStmts() const {
return storage->enclosedStmts;
}
void Stmt::print(raw_ostream &os, Twine indent) const {
void mlir::edsc::Stmt::print(raw_ostream &os, Twine indent) const {
assert(storage && "Unexpected null storage,stmt must be bound to print");
auto lhs = getLHS();
auto rhs = getRHS();
@@ -491,36 +598,97 @@ void Stmt::print(raw_ostream &os, Twine indent) const {
}
}
void Stmt::dump() const { this->print(llvm::errs()); }
void mlir::edsc::Stmt::dump() const { this->print(llvm::errs()); }
std::string Stmt::str() const {
std::string mlir::edsc::Stmt::str() const {
std::string res;
llvm::raw_string_ostream os(res);
this->print(os);
return res;
}
llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Stmt &stmt) {
llvm::raw_ostream &mlir::edsc::operator<<(llvm::raw_ostream &os,
const Stmt &stmt) {
stmt.print(os);
return os;
}
Indexed Indexed::operator[](llvm::ArrayRef<Expr> indices) const {
Indexed mlir::edsc::Indexed::operator[](llvm::ArrayRef<Expr> indices) const {
Indexed res(base);
res.indices = llvm::SmallVector<Expr, 4>(indices.begin(), indices.end());
return res;
}
Indexed Indexed::operator[](llvm::ArrayRef<Bindable> indices) const {
Indexed mlir::edsc::Indexed::
operator[](llvm::ArrayRef<Bindable> indices) const {
return (*this)[llvm::ArrayRef<Expr>{indices.begin(), indices.end()}];
}
Stmt Indexed::operator=(Expr expr) { // NOLINT: unconventional-assing-operator
// NOLINTNEXTLINE: unconventional-assign-operator
Stmt mlir::edsc::Indexed::operator=(Expr expr) {
assert(!indices.empty() && "Expected attached indices to Indexed");
assert(base);
Stmt stmt(store(expr, base, indices));
indices.clear();
return stmt;
}
} // namespace edsc
} // namespace mlir
edsc_indexed_t makeIndexed(edsc_expr_t expr) {
return edsc_indexed_t{expr, edsc_expr_list_t{nullptr, 0}};
}
edsc_indexed_t index(edsc_indexed_t indexed, edsc_expr_list_t indices) {
return edsc_indexed_t{indexed.base, indices};
}
mlir_type_t makeScalarType(mlir_context_t context, const char *name,
unsigned bitwidth) {
mlir::MLIRContext *c = reinterpret_cast<mlir::MLIRContext *>(context);
mlir_type_t res =
llvm::StringSwitch<mlir_type_t>(name)
.Case("bf16",
mlir_type_t{mlir::Type::getBF16(c).getAsOpaquePointer()})
.Case("f16", mlir_type_t{mlir::Type::getF16(c).getAsOpaquePointer()})
.Case("f32", mlir_type_t{mlir::Type::getF32(c).getAsOpaquePointer()})
.Case("f64", mlir_type_t{mlir::Type::getF64(c).getAsOpaquePointer()})
.Case("index",
mlir_type_t{mlir::Type::getIndex(c).getAsOpaquePointer()})
.Case("i",
mlir_type_t{
mlir::Type::getInteger(bitwidth, c).getAsOpaquePointer()})
.Default(mlir_type_t{nullptr});
if (!res) {
llvm_unreachable("Invalid type specifier");
}
return res;
}
mlir_type_t makeMemRefType(mlir_context_t context, mlir_type_t elemType,
int64_list_t sizes) {
auto t = mlir::MemRefType::get(
llvm::ArrayRef<int64_t>(sizes.values, sizes.n),
mlir::Type::getFromOpaquePointer(elemType),
{mlir::AffineMap::getMultiDimIdentityMap(
sizes.n, reinterpret_cast<mlir::MLIRContext *>(context))},
0);
return mlir_type_t{t.getAsOpaquePointer()};
}
mlir_type_t makeFunctionType(mlir_context_t context, mlir_type_list_t inputs,
mlir_type_list_t outputs) {
llvm::SmallVector<mlir::Type, 8> ins(inputs.n), outs(outputs.n);
for (unsigned i = 0; i < inputs.n; ++i) {
ins[i] = mlir::Type::getFromOpaquePointer(inputs.types[i]);
}
for (unsigned i = 0; i < outputs.n; ++i) {
ins[i] = mlir::Type::getFromOpaquePointer(outputs.types[i]);
}
auto ft = mlir::FunctionType::get(
ins, outs, reinterpret_cast<mlir::MLIRContext *>(context));
return mlir_type_t{ft.getAsOpaquePointer()};
}
unsigned getFunctionArity(mlir_func_t function) {
auto *f = reinterpret_cast<mlir::Function *>(function);
return f->getNumArguments();
}

View File

@@ -278,9 +278,9 @@ Expected<std::unique_ptr<ExecutionEngine>> ExecutionEngine::create(Module *m) {
setupTargetTriple(llvmModule.get());
packFunctionArguments(llvmModule.get());
engine->jit = std::move(*expectedJIT);
if (auto err = engine->jit->addModule(std::move(llvmModule)))
if (auto err = (*expectedJIT)->addModule(std::move(llvmModule)))
return std::move(err);
engine->jit = std::move(*expectedJIT);
return engine;
}