From cacf05892e58b20f135e5c326466535e67114bce Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Mon, 28 Jan 2019 14:32:00 -0800 Subject: [PATCH] Add a C API for EDSCs in other languages + python This CL adds support for calling EDSCs from other languages than C++. Following the LLVM convention this CL: 1. declares simple opaque types and a C API in mlir-c/Core.h; 2. defines the implementation directly in lib/EDSC/Types.cpp and lib/EDSC/MLIREmitter.cpp. Unlike LLVM however the nomenclature for these types and API functions is not well-defined, naming suggestions are most welcome. To avoid the need for conversion functions, Types.h and MLIREmitter.h include mlir-c/Core.h and provide constructors and conversion operators between the mlir::edsc type and the corresponding C type. In this first commit, mlir-c/Core.h only contains the types for the C API to allow EDSCs to work from Python. This includes both a minimal set of core MLIR types (mlir_context_t, mlir_type_t, mlir_func_t) as well as the EDSC types (edsc_mlir_emitter_t, edsc_expr_t, edsc_stmt_t, edsc_indexed_t). This can be restructured in the future as concrete needs arise. For now, the API only supports: 1. scalar types; 2. memrefs of scalar types with static or symbolic shapes; 3. functions with input and output of these types. The C API is not complete wrt ownership semantics. This is in large part due to the fact that python bindings are written with Pybind11 which allows very idiomatic C++ bindings. An effort is made to write a large chunk of these bindings using the C API but some C++isms are used where the design benefits from this simplication. A fully isolated C API will make more sense once we also integrate with another language like Swift and have enough use cases to drive the design. Lastly, this CL also fixes a bug in mlir::ExecutionEngine were the order of declaration of llvmContext and the JIT result in an improper order of destructors (which used to crash before the fix). PiperOrigin-RevId: 231290250 --- mlir/include/mlir-c/Core.h | 248 +++++++++++++ mlir/include/mlir/EDSC/MLIREmitter.h | 2 +- mlir/include/mlir/EDSC/Types.h | 19 +- .../mlir/ExecutionEngine/ExecutionEngine.h | 4 +- mlir/lib/EDSC/MLIREmitter.cpp | 183 +++++++++- mlir/lib/EDSC/Types.cpp | 328 +++++++++++++----- mlir/lib/ExecutionEngine/ExecutionEngine.cpp | 4 +- 7 files changed, 686 insertions(+), 102 deletions(-) create mode 100644 mlir/include/mlir-c/Core.h diff --git a/mlir/include/mlir-c/Core.h b/mlir/include/mlir-c/Core.h new file mode 100644 index 000000000000..5992ec9eb628 --- /dev/null +++ b/mlir/include/mlir-c/Core.h @@ -0,0 +1,248 @@ +/*===-- mlir-c/Core.h - Core Library C Interface ------------------*- C -*-===*\ +|* *| +|* Copyright 2019 The MLIR Authors. *| +|* *| +|* Licensed under the Apache License, Version 2.0 (the "License"); *| +|* you may not use this file except in compliance with the License. *| +|* You may obtain a copy of the License at *| +|* *| +|* http://www.apache.org/licenses/LICENSE-2.0 *| +|* *| +|* Unless required by applicable law or agreed to in writing, software *| +|* distributed under the License is distributed on an "AS IS" BASIS, *| +|* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. *| +|* See the License for the specific language governing permissions and *| +|* limitations under the License. *| +|* *| +|*===----------------------------------------------------------------------===*| +|* *| +|* This header declares the C interface to MLIR. *| +|* *| +\*===----------------------------------------------------------------------===*/ +#ifndef MLIR_C_CORE_H +#define MLIR_C_CORE_H + +#ifdef __cplusplus +#include +extern "C" { +#else +#include +#endif + +/// Opaque MLIR types. +/// Opaque C type for mlir::MLIRContext*. +typedef void *mlir_context_t; +/// Opaque C type for mlir::Type. +typedef const void *mlir_type_t; +/// Opaque C type for mlir::Function*. +typedef void *mlir_func_t; +/// Opaque C type for mlir::edsc::MLIREmiter. +typedef void *edsc_mlir_emitter_t; +/// Opaque C type for mlir::edsc::Expr. +typedef void *edsc_expr_t; +/// Opaque C type for mlir::edsc::Stmt. +typedef void *edsc_stmt_t; + +/// Simple C lists for non-owning mlir Opaque C types. +/// Recommended usage is construction from the `data()` and `size()` of a scoped +/// owning SmallVectorImpl<...> and passing to one of the C functions declared +/// later in this file. +/// Once the function returns and the proper EDSC has been constructed, +/// resources are freed by exiting the scope. +typedef struct { + int64_t *values; + uint64_t n; +} int64_list_t; + +typedef struct { + mlir_type_t *types; + uint64_t n; +} mlir_type_list_t; + +typedef struct { + edsc_expr_t *exprs; + uint64_t n; +} edsc_expr_list_t; + +typedef struct { + edsc_stmt_t *stmts; + uint64_t n; +} edsc_stmt_list_t; + +typedef struct { + edsc_expr_t base; + edsc_expr_list_t indices; +} edsc_indexed_t; + +typedef struct { + edsc_indexed_t *list; + uint64_t n; +} edsc_indexed_list_t; + +/// Minimal C API for exposing EDSCs to Swift, Python and other languages. + +/// Returns a simple scalar mlir::Type using the following convention: +/// - makeScalarType(c, "bf16") return an `mlir::Type::getBF16` +/// - makeScalarType(c, "f16") return an `mlir::Type::getF16` +/// - makeScalarType(c, "f32") return an `mlir::Type::getF32` +/// - makeScalarType(c, "f64") return an `mlir::Type::getF64` +/// - makeScalarType(c, "index") return an `mlir::Type::getIndex` +/// - makeScalarType(c, "i", bitwidth) return an +/// `mlir::Type::getInteger(bitwidth)` +/// +/// No other combinations are currently supported. +mlir_type_t makeScalarType(mlir_context_t context, const char *name, + unsigned bitwidth); + +/// Returns an `mlir::MemRefType` of the element type `elemType` and shape +/// `sizes`. +mlir_type_t makeMemRefType(mlir_context_t context, mlir_type_t elemType, + int64_list_t sizes); + +/// Returns an `mlir::FunctionType` of the element type `elemType` and shape +/// `sizes`. +mlir_type_t makeFunctionType(mlir_context_t context, mlir_type_list_t inputs, + mlir_type_list_t outputs); + +/// Returns the arity of `function`. +unsigned getFunctionArity(mlir_func_t function); + +/// Returns a new opaque mlir::edsc::Expr that is bound into `emitter` with a +/// constant of the specified type. +edsc_expr_t bindConstantBF16(edsc_mlir_emitter_t emitter, double value); +edsc_expr_t bindConstantF16(edsc_mlir_emitter_t emitter, float value); +edsc_expr_t bindConstantF32(edsc_mlir_emitter_t emitter, float value); +edsc_expr_t bindConstantF64(edsc_mlir_emitter_t emitter, double value); +edsc_expr_t bindConstantInt(edsc_mlir_emitter_t emitter, int64_t value, + unsigned bitwidth); +edsc_expr_t bindConstantIndex(edsc_mlir_emitter_t emitter, int64_t value); + +/// Returns the rank of the `function` argument at position `pos`. +/// If the argument is of MemRefType, this returns the rank of the MemRef. +/// Otherwise returns `0`. +/// TODO(ntv): support more than MemRefType and scalar Type. +unsigned getRankOfFunctionArgument(mlir_func_t function, unsigned pos); + +/// Returns an opaque mlir::Type of the `function` argument at position `pos`. +mlir_type_t getTypeOfFunctionArgument(mlir_func_t function, unsigned pos); + +/// Returns an opaque mlir::edsc::Expr that has been bound to the `pos` argument +/// of `function`. +edsc_expr_t bindFunctionArgument(edsc_mlir_emitter_t emitter, + mlir_func_t function, unsigned pos); + +/// Fills the preallocated list `result` with opaque mlir::edsc::Expr that have +/// been bound to each argument of `function`. +/// +/// Prerequisites: +/// - `result` must have been preallocated with space for exactly the number +/// of arguments of `function`. +void bindFunctionArguments(edsc_mlir_emitter_t emitter, mlir_func_t function, + edsc_expr_list_t *result); + +/// Returns the rank of `boundMemRef`. This API function is provided to more +/// easily compose with `bindFunctionArgument`. A similar function could be +/// provided for an mlir_type_t of type MemRefType but it is expected that users +/// of this API either: +/// 1. construct the MemRefType explicitly, in which case they already have +/// access to the rank and shape of the MemRefType; +/// 2. access MemRefs via mlir_function_t *values* in which case they would +/// pass edsc_expr_t bound to an edsc_emitter_t. +/// +/// Prerequisites: +/// - `boundMemRef` must be an opaque edsc_expr_t that has alreay been bound +/// in `emitter`. +unsigned getBoundMemRefRank(edsc_mlir_emitter_t emitter, + edsc_expr_t boundMemRef); + +/// Fills the preallocated list `result` with opaque mlir::edsc::Expr that have +/// been bound to each dimension of `boundMemRef`. +/// +/// Prerequisites: +/// - `result` must have been preallocated with space for exactly the rank of +/// `boundMemRef`; +/// - `boundMemRef` must be an opaque edsc_expr_t that has alreay been bound +/// in `emitter`. This is because symbolic MemRef shapes require an SSAValue +/// that can only be recovered from `emitter`. +void bindMemRefShape(edsc_mlir_emitter_t emitter, edsc_expr_t boundMemRef, + edsc_expr_list_t *result); + +/// Fills the preallocated lists `resultLbs`, `resultUbs` and `resultSteps` with +/// opaque mlir::edsc::Expr that have been bound to proper values to traverse +/// each dimension of `memRefType`. +/// At the moment: +/// - `resultsLbs` are always bound to the constant index `0`; +/// - `resultsUbs` are always bound to the shape of `memRefType`; +/// - `resultsSteps` are always bound to the constant index `1`. +/// In the future, this will allow non-contiguous MemRef views. +/// +/// Prerequisites: +/// - `resultLbs`, `resultUbs` and `resultSteps` must have each been +/// preallocated with space for exactly the rank of `boundMemRef`; +/// - `boundMemRef` must be an opaque edsc_expr_t that has alreay been bound +/// in `emitter`. This is because symbolic MemRef shapes require an SSAValue +/// that can only be recovered from `emitter`. +void bindMemRefView(edsc_mlir_emitter_t emitter, edsc_expr_t boundMemRef, + edsc_expr_list_t *resultLbs, edsc_expr_list_t *resultUbs, + edsc_expr_list_t *resultSteps); + +/// Returns an opaque expression for an mlir::edsc::Expr. +edsc_expr_t makeBindable(); + +/// Returns an opaque expression for an mlir::edsc::Stmt. +edsc_stmt_t makeStmt(edsc_expr_t e); + +/// Returns an opaque expression for an mlir::edsc::Indexed. +edsc_indexed_t makeIndexed(edsc_expr_t expr); + +/// Returns an indexed opaque expression with indices bound in the structure +/// given an `indexed` and `indices`. +/// Prerequisite: +/// - `indexed` must not have been indexed previously. +edsc_indexed_t index(edsc_indexed_t indexed, edsc_expr_list_t indices); + +/// Returns an opaque expression that will emit an mlir::LoadOp. +edsc_expr_t Load(edsc_indexed_t indexed, edsc_expr_list_t indices); + +/// Returns an opaque statement for an mlir::StoreOp. +edsc_stmt_t Store(edsc_expr_t value, edsc_indexed_t indexed, + edsc_expr_list_t indices); + +/// Returns an opaque statement for an mlir::SelectOp. +edsc_expr_t Select(edsc_expr_t cond, edsc_expr_t lhs, edsc_expr_t rhs); + +/// Returns an opaque statement for an mlir::ReturnOp. +edsc_stmt_t Return(edsc_expr_list_t values); + +/// Returns a single opaque statement that acts as an mlir block. At the moment +/// this is pure syntactic sugar to allow lists of mlir::edsc::Stmt to be +/// specified and emitted. In particular, block arguments are not currently +/// supported. +edsc_stmt_t Block(edsc_stmt_list_t enclosedStmts); + +/// Returns an opaque statement for an mlir::ForInst with `enclosedStmts` nested +/// below it. +edsc_stmt_t For(edsc_expr_t iv, edsc_expr_t lb, edsc_expr_t ub, + edsc_expr_t step, edsc_stmt_list_t enclosedStmts); + +/// Returns an opaque statement for a perfectly nested set of mlir::ForInst with +/// `enclosedStmts` nested below it. +edsc_stmt_t ForNest(edsc_expr_list_t iv, edsc_expr_list_t lb, + edsc_expr_list_t ub, edsc_expr_list_t step, + edsc_stmt_list_t enclosedStmts); + +/// Returns an opaque expression for the corresponding Binary operation. +edsc_expr_t Add(edsc_expr_t e1, edsc_expr_t e2); +edsc_expr_t Sub(edsc_expr_t e1, edsc_expr_t e2); +edsc_expr_t Mul(edsc_expr_t e1, edsc_expr_t e2); +// edsc_expr_t Div(edsc_expr_t e1, edsc_expr_t e2); +edsc_expr_t LT(edsc_expr_t e1, edsc_expr_t e2); +edsc_expr_t LE(edsc_expr_t e1, edsc_expr_t e2); +edsc_expr_t GT(edsc_expr_t e1, edsc_expr_t e2); +edsc_expr_t GE(edsc_expr_t e1, edsc_expr_t e2); + +#ifdef __cplusplus +} // end extern "C" +#endif + +#endif // MLIR_C_CORE_H diff --git a/mlir/include/mlir/EDSC/MLIREmitter.h b/mlir/include/mlir/EDSC/MLIREmitter.h index fbd5a544a30d..f6e399ed8033 100644 --- a/mlir/include/mlir/EDSC/MLIREmitter.h +++ b/mlir/include/mlir/EDSC/MLIREmitter.h @@ -121,7 +121,7 @@ struct MLIREmitter { /// /// Prerequisite: /// `memRef` is a Value of type MemRefType. - SmallVector makeBoundSizes(Value *memRef); + SmallVector makeBoundSizes(Value *memRef); private: FuncBuilder *builder; diff --git a/mlir/include/mlir/EDSC/Types.h b/mlir/include/mlir/EDSC/Types.h index 57bdb6bc1c3d..96d257eca4cc 100644 --- a/mlir/include/mlir/EDSC/Types.h +++ b/mlir/include/mlir/EDSC/Types.h @@ -24,6 +24,7 @@ #ifndef MLIR_LIB_EDSC_TYPES_H_ #define MLIR_LIB_EDSC_TYPES_H_ +#include "mlir-c/Core.h" #include "mlir/IR/Types.h" #include "mlir/Support/LLVM.h" @@ -171,6 +172,9 @@ public: Expr() : storage(nullptr) {} /* implicit */ Expr(ImplType *storage) : storage(storage) {} + explicit Expr(edsc_expr_t expr) + : storage(reinterpret_cast(expr)) {} + operator edsc_expr_t() { return edsc_expr_t{storage}; } Expr(const Expr &other) : storage(other.storage) {} Expr &operator=(Expr other) { @@ -227,6 +231,8 @@ struct Bindable : public Expr { Bindable(Expr::ImplType *ptr) : Expr(ptr) { assert(!ptr || isa() && "expected Bindable"); } + explicit Bindable(const edsc_expr_t &expr) : Expr(expr) {} + operator edsc_expr_t() { return edsc_expr_t{storage}; } friend struct ScopedEDSCContext; @@ -352,6 +358,9 @@ struct Stmt { Stmt(const Bindable &lhs, const Expr &rhs, llvm::ArrayRef stmts = llvm::ArrayRef()); Stmt &operator=(const Expr &expr); + explicit Stmt(edsc_stmt_t stmt) + : storage(reinterpret_cast(stmt)) {} + operator edsc_stmt_t() { return edsc_stmt_t{storage}; } operator Expr() const { return getLHS(); } @@ -414,7 +423,7 @@ template U Expr::dyn_cast() const { if (isa()) { return U(storage); } - return U(nullptr); + return U((Expr::ImplType *)(nullptr)); } template U Expr::cast() const { assert(isa()); @@ -497,6 +506,9 @@ Stmt For(llvm::MutableArrayRef indices, llvm::ArrayRef lbs, Stmt For(llvm::MutableArrayRef indices, llvm::ArrayRef lbs, llvm::ArrayRef ubs, llvm::ArrayRef steps, llvm::ArrayRef enclosedStmts); +Stmt For(llvm::MutableArrayRef indices, llvm::ArrayRef lbs, + llvm::ArrayRef ubs, llvm::ArrayRef steps, + llvm::ArrayRef enclosedStmts); /// This helper class exists purely for sugaring purposes and allows writing /// expressions such as: @@ -508,7 +520,8 @@ Stmt For(llvm::MutableArrayRef indices, llvm::ArrayRef lbs, /// }); /// ``` struct Indexed { - Indexed(Bindable m) : base(m), indices() {} + Indexed(Bindable b) : base(b), indices() {} + Indexed(Expr e) : base(e), indices() {} /// Returns a new `Indexed`. As a consequence, an Indexed with attached /// indices can never be reused unless it is captured (e.g. via a Stmt). @@ -534,7 +547,7 @@ struct Indexed { Expr operator*(Expr e) const { return static_cast(*this) * e; } private: - Bindable base; + Expr base; llvm::SmallVector indices; }; diff --git a/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h b/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h index 9774825bda3c..c725fdb9a554 100644 --- a/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h +++ b/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h @@ -73,9 +73,11 @@ public: llvm::Error invoke(StringRef name, MutableArrayRef args); private: + // Ordering of llvmContext and jit is important for destruction purposes: the + // jit must be destroyed before the context. + llvm::LLVMContext llvmContext; // Private implementation of the JIT (PIMPL) std::unique_ptr jit; - llvm::LLVMContext llvmContext; }; template diff --git a/mlir/lib/EDSC/MLIREmitter.cpp b/mlir/lib/EDSC/MLIREmitter.cpp index 2becf90c9440..f0efdb5081a2 100644 --- a/mlir/lib/EDSC/MLIREmitter.cpp +++ b/mlir/lib/EDSC/MLIREmitter.cpp @@ -20,6 +20,7 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" +#include "mlir-c/Core.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/EDSC/MLIREmitter.h" #include "mlir/EDSC/Types.h" @@ -38,8 +39,9 @@ using llvm::errs; #define DEBUG_TYPE "edsc" -namespace mlir { -namespace edsc { +using namespace mlir; +using namespace mlir::edsc; +using namespace mlir::edsc::detail; // Factors out the boilerplate that is needed to build and answer the // following simple question: @@ -89,7 +91,7 @@ static bool isFloatElement(const Value &v) { return getElementType(v).isa(); } -Value *add(FuncBuilder *builder, Location location, Value *a, Value *b) { +static Value *add(FuncBuilder *builder, Location location, Value *a, Value *b) { if (isIndexElement(*a)) { auto *context = builder->getContext(); auto d0 = getAffineDimExpr(0, context); @@ -103,7 +105,7 @@ Value *add(FuncBuilder *builder, Location location, Value *a, Value *b) { return builder->create(location, a, b)->getResult(); } -Value *sub(FuncBuilder *builder, Location location, Value *a, Value *b) { +static Value *sub(FuncBuilder *builder, Location location, Value *a, Value *b) { if (isIndexElement(*a)) { auto *context = builder->getContext(); auto d0 = getAffineDimExpr(0, context); @@ -117,7 +119,7 @@ Value *sub(FuncBuilder *builder, Location location, Value *a, Value *b) { return builder->create(location, a, b)->getResult(); } -Value *mul(FuncBuilder *builder, Location location, Value *a, Value *b) { +static Value *mul(FuncBuilder *builder, Location location, Value *a, Value *b) { if (!isFloatElement(*a)) { return builder->create(location, a, b)->getResult(); } @@ -138,7 +140,7 @@ static void printDefininingStatement(llvm::raw_ostream &os, const Value &v) { } } -MLIREmitter &MLIREmitter::bind(Bindable e, Value *v) { +MLIREmitter &mlir::edsc::MLIREmitter::bind(Bindable e, Value *v) { LLVM_DEBUG(printDefininingStatement(llvm::dbgs() << "\nBinding " << e << " @" << e.getStoragePtr() << ": ", *v)); @@ -151,7 +153,7 @@ MLIREmitter &MLIREmitter::bind(Bindable e, Value *v) { return *this; } -Value *MLIREmitter::emit(Expr e) { +Value *mlir::edsc::MLIREmitter::emit(Expr e) { auto it = ssaBindings.find(e); if (it != ssaBindings.end()) { return it->second; @@ -316,7 +318,7 @@ Value *MLIREmitter::emit(Expr e) { return res; } -SmallVector MLIREmitter::emit(ArrayRef exprs) { +SmallVector mlir::edsc::MLIREmitter::emit(ArrayRef exprs) { return mlir::functional::map( [this](Expr e) { auto *res = this->emit(e); @@ -327,7 +329,7 @@ SmallVector MLIREmitter::emit(ArrayRef exprs) { exprs); } -void MLIREmitter::emitStmt(const Stmt &stmt) { +void mlir::edsc::MLIREmitter::emitStmt(const Stmt &stmt) { auto *block = builder->getBlock(); auto ip = builder->getInsertionPoint(); // Blocks are just a containing abstraction, they do not emit their RHS. @@ -351,7 +353,7 @@ void MLIREmitter::emitStmt(const Stmt &stmt) { builder->setInsertionPoint(block, ip); } -void MLIREmitter::emitStmts(ArrayRef stmts) { +void mlir::edsc::MLIREmitter::emitStmts(ArrayRef stmts) { for (auto &stmt : stmts) { emitStmt(stmt); } @@ -390,7 +392,8 @@ static bool isDynamicSize(int size) { return size < 0; } /// and returns the vector with {%d0, %c3, %c4, %d3, %c5}. static SmallVector getMemRefSizes(FuncBuilder *b, Location loc, Value *memRef) { - auto memRefType = memRef->getType().template cast(); + assert(memRef->getType().isa() && "Expected a MemRef value"); + MemRefType memRefType = memRef->getType().cast(); SmallVector res; res.reserve(memRefType.getShape().size()); const auto &shape = memRefType.getShape(); @@ -404,15 +407,165 @@ static SmallVector getMemRefSizes(FuncBuilder *b, Location loc, return res; } -SmallVector MLIREmitter::makeBoundSizes(Value *memRef) { +SmallVector +mlir::edsc::MLIREmitter::makeBoundSizes(Value *memRef) { assert(memRef->getType().isa() && "Expected a MemRef value"); MemRefType memRefType = memRef->getType().cast(); auto memRefSizes = edsc::makeBindables(memRefType.getShape().size()); auto memrefSizeValues = getMemRefSizes(getBuilder(), getLocation(), memRef); assert(memrefSizeValues.size() == memRefSizes.size()); bindZipRange(llvm::zip(memRefSizes, memrefSizeValues)); - return memRefSizes; + SmallVector res(memRefSizes.begin(), memRefSizes.end()); + return res; } -} // namespace edsc -} // namespace mlir +edsc_expr_t bindConstantBF16(edsc_mlir_emitter_t emitter, double value) { + auto *e = reinterpret_cast(emitter); + Bindable b; + e->bindConstant(b, mlir::APFloat(value), + e->getBuilder()->getBF16Type()); + return b; +} + +edsc_expr_t bindConstantF16(edsc_mlir_emitter_t emitter, float value) { + auto *e = reinterpret_cast(emitter); + Bindable b; + bool unused; + mlir::APFloat val(value); + val.convert(e->getBuilder()->getF16Type().getFloatSemantics(), + mlir::APFloat::rmNearestTiesToEven, &unused); + e->bindConstant(b, val, e->getBuilder()->getF16Type()); + return b; +} + +edsc_expr_t bindConstantF32(edsc_mlir_emitter_t emitter, float value) { + auto *e = reinterpret_cast(emitter); + Bindable b; + e->bindConstant(b, mlir::APFloat(value), + e->getBuilder()->getF32Type()); + return b; +} + +edsc_expr_t bindConstantF64(edsc_mlir_emitter_t emitter, double value) { + auto *e = reinterpret_cast(emitter); + Bindable b; + e->bindConstant(b, mlir::APFloat(value), + e->getBuilder()->getF64Type()); + return b; +} + +edsc_expr_t bindConstantInt(edsc_mlir_emitter_t emitter, int64_t value, + unsigned bitwidth) { + auto *e = reinterpret_cast(emitter); + Bindable b; + e->bindConstant( + b, value, e->getBuilder()->getIntegerType(bitwidth)); + return b; +} + +edsc_expr_t bindConstantIndex(edsc_mlir_emitter_t emitter, int64_t value) { + auto *e = reinterpret_cast(emitter); + Bindable b; + e->bindConstant(b, value); + return b; +} + +unsigned getRankOfFunctionArgument(mlir_func_t function, unsigned pos) { + auto *f = reinterpret_cast(function); + assert(pos < f->getNumArguments()); + auto *arg = *(f->getArguments().begin() + pos); + if (auto memRefType = arg->getType().dyn_cast()) { + return memRefType.getRank(); + } + return 0; +} + +mlir_type_t getTypeOfFunctionArgument(mlir_func_t function, unsigned pos) { + auto *f = reinterpret_cast(function); + assert(pos < f->getNumArguments()); + auto *arg = *(f->getArguments().begin() + pos); + return mlir_type_t{arg->getType().getAsOpaquePointer()}; +} + +edsc_expr_t bindFunctionArgument(edsc_mlir_emitter_t emitter, + mlir_func_t function, unsigned pos) { + auto *e = reinterpret_cast(emitter); + auto *f = reinterpret_cast(function); + assert(pos < f->getNumArguments()); + auto *arg = *(f->getArguments().begin() + pos); + Bindable b; + e->bind(b, arg); + return Expr(b); +} + +void bindFunctionArguments(edsc_mlir_emitter_t emitter, mlir_func_t function, + edsc_expr_list_t *result) { + auto *e = reinterpret_cast(emitter); + auto *f = reinterpret_cast(function); + assert(result->n == f->getNumArguments()); + for (unsigned pos = 0; pos < result->n; ++pos) { + auto *arg = *(f->getArguments().begin() + pos); + Bindable b; + e->bind(b, arg); + result->exprs[pos] = Expr(b); + } +} + +unsigned getBoundMemRefRank(edsc_mlir_emitter_t emitter, + edsc_expr_t boundMemRef) { + auto *e = reinterpret_cast(emitter); + auto *v = e->getValue(mlir::edsc::Expr(boundMemRef)); + auto memRefType = v->getType().cast(); + return memRefType.getRank(); +} + +void bindMemRefShape(edsc_mlir_emitter_t emitter, edsc_expr_t boundMemRef, + edsc_expr_list_t *result) { + auto *e = reinterpret_cast(emitter); + auto *v = e->getValue(mlir::edsc::Expr(boundMemRef)); + auto memRefType = v->getType().cast(); + auto rank = memRefType.getRank(); + assert(result->n == rank && "Unexpected memref shape binding results count"); + auto bindables = e->makeBoundSizes(v); + for (unsigned i = 0; i < rank; ++i) { + result->exprs[i] = bindables[i]; + } +} + +void bindMemRefView(edsc_mlir_emitter_t emitter, edsc_expr_t boundMemRef, + edsc_expr_list_t *resultLbs, edsc_expr_list_t *resultUbs, + edsc_expr_list_t *resultSteps) { + auto *e = reinterpret_cast(emitter); + auto *v = e->getValue(mlir::edsc::Expr(boundMemRef)); + auto memRefType = v->getType().cast(); + auto rank = memRefType.getRank(); + assert(resultLbs->n == rank && "Unexpected memref binding results count"); + assert(resultUbs->n == rank && "Unexpected memref binding results count"); + assert(resultSteps->n == rank && "Unexpected memref binding results count"); + auto bindables = e->makeBoundSizes(v); + for (unsigned i = 0; i < rank; ++i) { + Bindable zero; + e->bindConstant(zero, 0); + resultLbs->exprs[i] = zero; + resultUbs->exprs[i] = bindables[i]; + Bindable one; + e->bindConstant(one, 1); + resultSteps->exprs[i] = one; + } +} + +#define DEFINE_EDSL_BINARY_OP(FUN_NAME, OP_SYMBOL) \ + edsc_expr_t FUN_NAME(edsc_expr_t e1, edsc_expr_t e2) { \ + return Expr(e1) OP_SYMBOL Expr(e2); \ + } + +DEFINE_EDSL_BINARY_OP(Add, +); +DEFINE_EDSL_BINARY_OP(Sub, -); +DEFINE_EDSL_BINARY_OP(Mul, *); +// DEFINE_EDSL_BINARY_OP(Div, /); +DEFINE_EDSL_BINARY_OP(LT, <); +DEFINE_EDSL_BINARY_OP(LE, <=); +DEFINE_EDSL_BINARY_OP(GT, >); +DEFINE_EDSL_BINARY_OP(GE, >=); + +#undef DEFINE_EDSL_BINARY_OP diff --git a/mlir/lib/EDSC/Types.cpp b/mlir/lib/EDSC/Types.cpp index a6554b5a806a..e34609360454 100644 --- a/mlir/lib/EDSC/Types.cpp +++ b/mlir/lib/EDSC/Types.cpp @@ -16,9 +16,14 @@ // ============================================================================= #include "mlir/EDSC/Types.h" +#include "mlir-c/Core.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/StandardTypes.h" #include "mlir/Support/STLExtras.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringSwitch.h" #include "llvm/ADT/Twine.h" #include "llvm/Support/raw_ostream.h" @@ -77,53 +82,57 @@ struct StmtStorage { }; } // namespace detail +} // namespace edsc +} // namespace mlir -ScopedEDSCContext::ScopedEDSCContext() { +mlir::edsc::ScopedEDSCContext::ScopedEDSCContext() { Expr::globalAllocator() = &allocator; Bindable::resetIds(); } -ScopedEDSCContext::~ScopedEDSCContext() { Expr::globalAllocator() = nullptr; } +mlir::edsc::ScopedEDSCContext::~ScopedEDSCContext() { + Expr::globalAllocator() = nullptr; +} -ExprKind Expr::getKind() const { return storage->kind; } +ExprKind mlir::edsc::Expr::getKind() const { return storage->kind; } -Expr Expr::operator+(Expr other) const { +Expr mlir::edsc::Expr::operator+(Expr other) const { return BinaryExpr(ExprKind::Add, *this, other); } -Expr Expr::operator-(Expr other) const { +Expr mlir::edsc::Expr::operator-(Expr other) const { return BinaryExpr(ExprKind::Sub, *this, other); } -Expr Expr::operator*(Expr other) const { +Expr mlir::edsc::Expr::operator*(Expr other) const { return BinaryExpr(ExprKind::Mul, *this, other); } -Expr Expr::operator==(Expr other) const { +Expr mlir::edsc::Expr::operator==(Expr other) const { return BinaryExpr(ExprKind::EQ, *this, other); } -Expr Expr::operator!=(Expr other) const { +Expr mlir::edsc::Expr::operator!=(Expr other) const { return BinaryExpr(ExprKind::NE, *this, other); } -Expr Expr::operator<(Expr other) const { +Expr mlir::edsc::Expr::operator<(Expr other) const { return BinaryExpr(ExprKind::LT, *this, other); } -Expr Expr::operator<=(Expr other) const { +Expr mlir::edsc::Expr::operator<=(Expr other) const { return BinaryExpr(ExprKind::LE, *this, other); } -Expr Expr::operator>(Expr other) const { +Expr mlir::edsc::Expr::operator>(Expr other) const { return BinaryExpr(ExprKind::GT, *this, other); } -Expr Expr::operator>=(Expr other) const { +Expr mlir::edsc::Expr::operator>=(Expr other) const { return BinaryExpr(ExprKind::GE, *this, other); } -Expr Expr::operator&&(Expr other) const { +Expr mlir::edsc::Expr::operator&&(Expr other) const { return BinaryExpr(ExprKind::And, *this, other); } -Expr Expr::operator||(Expr other) const { +Expr mlir::edsc::Expr::operator||(Expr other) const { return BinaryExpr(ExprKind::Or, *this, other); } // Free functions. -llvm::SmallVector makeBindables(unsigned n) { +llvm::SmallVector mlir::edsc::makeBindables(unsigned n) { llvm::SmallVector res; res.reserve(n); for (auto i = 0; i < n; ++i) { @@ -132,7 +141,16 @@ llvm::SmallVector makeBindables(unsigned n) { return res; } -llvm::SmallVector makeExprs(unsigned n) { +static llvm::SmallVector makeBindables(edsc_expr_list_t exprList) { + llvm::SmallVector exprs; + exprs.reserve(exprList.n); + for (unsigned i = 0; i < exprList.n; ++i) { + exprs.push_back(Expr(exprList.exprs[i]).cast()); + } + return exprs; +} + +llvm::SmallVector mlir::edsc::makeExprs(unsigned n) { llvm::SmallVector res; res.reserve(n); for (auto i = 0; i < n; ++i) { @@ -141,7 +159,7 @@ llvm::SmallVector makeExprs(unsigned n) { return res; } -llvm::SmallVector makeExprs(ArrayRef bindables) { +llvm::SmallVector mlir::edsc::makeExprs(ArrayRef bindables) { llvm::SmallVector res; res.reserve(bindables.size()); for (auto b : bindables) { @@ -150,29 +168,53 @@ llvm::SmallVector makeExprs(ArrayRef bindables) { return res; } -Expr alloc(llvm::ArrayRef sizes, Type memrefType) { +static llvm::SmallVector makeExprs(edsc_expr_list_t exprList) { + llvm::SmallVector exprs; + exprs.reserve(exprList.n); + for (unsigned i = 0; i < exprList.n; ++i) { + exprs.push_back(Expr(exprList.exprs[i])); + } + return exprs; +} + +static llvm::SmallVector makeStmts(edsc_stmt_list_t enclosedStmts) { + llvm::SmallVector stmts; + stmts.reserve(enclosedStmts.n); + for (unsigned i = 0; i < enclosedStmts.n; ++i) { + stmts.push_back(Stmt(enclosedStmts.stmts[i])); + } + return stmts; +} + +Expr mlir::edsc::alloc(llvm::ArrayRef sizes, Type memrefType) { return VariadicExpr(ExprKind::Alloc, sizes, memrefType); } -Stmt Block(ArrayRef stmts) { +Stmt mlir::edsc::Block(ArrayRef stmts) { return Stmt(StmtBlockLikeExpr(ExprKind::Block, {}), stmts); } -Expr dealloc(Expr memref) { return UnaryExpr(ExprKind::Dealloc, memref); } +edsc_stmt_t Block(edsc_stmt_list_t enclosedStmts) { + return Stmt(mlir::edsc::Block(makeStmts(enclosedStmts))); +} -Stmt For(Expr lb, Expr ub, Expr step, ArrayRef stmts) { +Expr mlir::edsc::dealloc(Expr memref) { + return UnaryExpr(ExprKind::Dealloc, memref); +} + +Stmt mlir::edsc::For(Expr lb, Expr ub, Expr step, ArrayRef stmts) { Bindable idx; return For(idx, lb, ub, step, stmts); } -Stmt For(const Bindable &idx, Expr lb, Expr ub, Expr step, - ArrayRef stmts) { +Stmt mlir::edsc::For(const Bindable &idx, Expr lb, Expr ub, Expr step, + ArrayRef stmts) { return Stmt(idx, StmtBlockLikeExpr(ExprKind::For, {lb, ub, step}), stmts); } -Stmt For(MutableArrayRef indices, ArrayRef lbs, - ArrayRef ubs, ArrayRef steps, - ArrayRef enclosedStmts) { +Stmt mlir::edsc::For(MutableArrayRef indices, ArrayRef lbs, + ArrayRef ubs, ArrayRef steps, + ArrayRef enclosedStmts) { assert(!indices.empty()); assert(indices.size() == lbs.size()); assert(indices.size() == ubs.size()); @@ -185,14 +227,37 @@ Stmt For(MutableArrayRef indices, ArrayRef lbs, return curStmt; } -Stmt For(llvm::MutableArrayRef indices, llvm::ArrayRef lbs, - llvm::ArrayRef ubs, llvm::ArrayRef steps, - llvm::ArrayRef enclosedStmts) { +Stmt mlir::edsc::For(llvm::MutableArrayRef indices, + llvm::ArrayRef lbs, llvm::ArrayRef ubs, + llvm::ArrayRef steps, + llvm::ArrayRef enclosedStmts) { return For(indices, SmallVector{lbs.begin(), lbs.end()}, SmallVector{ubs.begin(), ubs.end()}, SmallVector{steps.begin(), steps.end()}, enclosedStmts); } +Stmt mlir::edsc::For(llvm::MutableArrayRef indices, + llvm::ArrayRef lbs, llvm::ArrayRef ubs, + llvm::ArrayRef steps, + llvm::ArrayRef enclosedStmts) { + return For(indices, SmallVector{lbs.begin(), lbs.end()}, ubs, + SmallVector{steps.begin(), steps.end()}, enclosedStmts); +} + +edsc_stmt_t For(edsc_expr_t iv, edsc_expr_t lb, edsc_expr_t ub, + edsc_expr_t step, edsc_stmt_list_t enclosedStmts) { + return Stmt(For(Expr(iv).cast(), Expr(lb), Expr(ub), Expr(step), + makeStmts(enclosedStmts))); +} + +edsc_stmt_t ForNest(edsc_expr_list_t ivs, edsc_expr_list_t lbs, + edsc_expr_list_t ubs, edsc_expr_list_t steps, + edsc_stmt_list_t enclosedStmts) { + auto bindables = makeBindables(ivs); + return Stmt(For(bindables, makeExprs(lbs), makeExprs(ubs), makeExprs(steps), + makeStmts(enclosedStmts))); +} + template static Expr loadBuilder(Expr m, ArrayRef indices) { SmallVector exprs; @@ -200,12 +265,24 @@ static Expr loadBuilder(Expr m, ArrayRef indices) { exprs.append(indices.begin(), indices.end()); return VariadicExpr(ExprKind::Load, exprs); } -Expr load(Expr m, Expr index) { return loadBuilder(m, {index}); } -Expr load(Expr m, Bindable index) { return loadBuilder(m, {index}); } -Expr load(Expr m, const llvm::SmallVectorImpl &indices) { +Expr mlir::edsc::load(Expr m, Expr index) { + return loadBuilder(m, {index}); +} +Expr mlir::edsc::load(Expr m, Bindable index) { + return loadBuilder(m, {index}); +} +Expr mlir::edsc::load(Expr m, const llvm::SmallVectorImpl &indices) { return loadBuilder(m, ArrayRef{indices.begin(), indices.end()}); } -Expr load(Expr m, ArrayRef indices) { return loadBuilder(m, indices); } +Expr mlir::edsc::load(Expr m, ArrayRef indices) { + return loadBuilder(m, indices); +} + +edsc_expr_t Load(edsc_indexed_t indexed, edsc_expr_list_t indices) { + Indexed i(Expr(indexed.base).cast()); + Expr res = i[makeExprs(indices)]; + return res; +} template static Expr storeBuilder(Expr val, Expr m, ArrayRef indices) { @@ -215,33 +292,49 @@ static Expr storeBuilder(Expr val, Expr m, ArrayRef indices) { exprs.append(indices.begin(), indices.end()); return VariadicExpr(ExprKind::Store, exprs); } -Expr store(Expr val, Expr m, Expr index) { +Expr mlir::edsc::store(Expr val, Expr m, Expr index) { return storeBuilder(val, m, {index}); } -Expr store(Expr val, Expr m, Bindable index) { +Expr mlir::edsc::store(Expr val, Expr m, Bindable index) { return storeBuilder(val, m, {index}); } -Expr store(Expr val, Expr m, const llvm::SmallVectorImpl &indices) { +Expr mlir::edsc::store(Expr val, Expr m, + const llvm::SmallVectorImpl &indices) { return storeBuilder(val, m, ArrayRef{indices.begin(), indices.end()}); } -Expr store(Expr val, Expr m, ArrayRef indices) { +Expr mlir::edsc::store(Expr val, Expr m, ArrayRef indices) { return storeBuilder(val, m, indices); } -Expr select(Expr cond, Expr lhs, Expr rhs) { +edsc_stmt_t Store(edsc_expr_t value, edsc_indexed_t indexed, + edsc_expr_list_t indices) { + Indexed i(Expr(indexed.base).cast()); + Indexed loc = i[makeExprs(indices)]; + return Stmt(loc = Expr(value)); +} + +Expr mlir::edsc::select(Expr cond, Expr lhs, Expr rhs) { return TernaryExpr(ExprKind::Select, cond, lhs, rhs); } -Expr vector_type_cast(Expr memrefExpr, Type memrefType) { +edsc_expr_t Select(edsc_expr_t cond, edsc_expr_t lhs, edsc_expr_t rhs) { + return select(Expr(cond), Expr(lhs), Expr(rhs)); +} + +Expr mlir::edsc::vector_type_cast(Expr memrefExpr, Type memrefType) { return VariadicExpr(ExprKind::VectorTypeCast, {memrefExpr}, {memrefType}); } -Stmt Return(ArrayRef values) { +Stmt mlir::edsc::Return(ArrayRef values) { return VariadicExpr(ExprKind::Return, values); } -void Expr::print(raw_ostream &os) const { +edsc_stmt_t Return(edsc_expr_list_t values) { + return Stmt(Return(makeExprs(values))); +} + +void mlir::edsc::Expr::print(raw_ostream &os) const { if (auto unbound = this->dyn_cast()) { os << "$" << unbound.getId(); return; @@ -322,73 +415,77 @@ void Expr::print(raw_ostream &os) const { os << "unknown_kind(" << static_cast(getKind()) << ")"; } -void Expr::dump() const { this->print(llvm::errs()); } +void mlir::edsc::Expr::dump() const { this->print(llvm::errs()); } -std::string Expr::str() const { +std::string mlir::edsc::Expr::str() const { std::string res; llvm::raw_string_ostream os(res); this->print(os); return res; } -llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Expr &expr) { +llvm::raw_ostream &mlir::edsc::operator<<(llvm::raw_ostream &os, + const Expr &expr) { expr.print(os); return os; } -Bindable::Bindable() +mlir::edsc::Bindable::Bindable() : Expr(Expr::globalAllocator()->Allocate()) { // Initialize with placement new. new (storage) detail::BindableStorage{Bindable::newId()}; } -unsigned Bindable::getId() const { +edsc_expr_t makeBindable() { return Bindable(); } + +unsigned mlir::edsc::Bindable::getId() const { return static_cast(storage)->id; } -unsigned &Bindable::newId() { +unsigned &mlir::edsc::Bindable::newId() { static thread_local unsigned id = 0; return ++id; } -UnaryExpr::UnaryExpr(ExprKind kind, Expr expr) +mlir::edsc::UnaryExpr::UnaryExpr(ExprKind kind, Expr expr) : Expr(Expr::globalAllocator()->Allocate()) { // Initialize with placement new. new (storage) detail::UnaryExprStorage{kind, expr}; } -Expr UnaryExpr::getExpr() const { +Expr mlir::edsc::UnaryExpr::getExpr() const { return static_cast(storage)->expr; } -BinaryExpr::BinaryExpr(ExprKind kind, Expr lhs, Expr rhs) +mlir::edsc::BinaryExpr::BinaryExpr(ExprKind kind, Expr lhs, Expr rhs) : Expr(Expr::globalAllocator()->Allocate()) { // Initialize with placement new. new (storage) detail::BinaryExprStorage{kind, lhs, rhs}; } -Expr BinaryExpr::getLHS() const { +Expr mlir::edsc::BinaryExpr::getLHS() const { return static_cast(storage)->lhs; } -Expr BinaryExpr::getRHS() const { +Expr mlir::edsc::BinaryExpr::getRHS() const { return static_cast(storage)->rhs; } -TernaryExpr::TernaryExpr(ExprKind kind, Expr cond, Expr lhs, Expr rhs) +mlir::edsc::TernaryExpr::TernaryExpr(ExprKind kind, Expr cond, Expr lhs, + Expr rhs) : Expr(Expr::globalAllocator()->Allocate()) { // Initialize with placement new. new (storage) detail::TernaryExprStorage{kind, cond, lhs, rhs}; } -Expr TernaryExpr::getCond() const { +Expr mlir::edsc::TernaryExpr::getCond() const { return static_cast(storage)->cond; } -Expr TernaryExpr::getLHS() const { +Expr mlir::edsc::TernaryExpr::getLHS() const { return static_cast(storage)->lhs; } -Expr TernaryExpr::getRHS() const { +Expr mlir::edsc::TernaryExpr::getRHS() const { return static_cast(storage)->rhs; } -VariadicExpr::VariadicExpr(ExprKind kind, ArrayRef exprs, - ArrayRef types) +mlir::edsc::VariadicExpr::VariadicExpr(ExprKind kind, ArrayRef exprs, + ArrayRef types) : Expr(Expr::globalAllocator()->Allocate()) { // Initialize with placement new. auto exprStorage = Expr::globalAllocator()->Allocate(exprs.size()); @@ -399,15 +496,16 @@ VariadicExpr::VariadicExpr(ExprKind kind, ArrayRef exprs, kind, ArrayRef(exprStorage, exprs.size()), ArrayRef(typeStorage, types.size())}; } -ArrayRef VariadicExpr::getExprs() const { +ArrayRef mlir::edsc::VariadicExpr::getExprs() const { return static_cast(storage)->exprs; } -ArrayRef VariadicExpr::getTypes() const { +ArrayRef mlir::edsc::VariadicExpr::getTypes() const { return static_cast(storage)->types; } -StmtBlockLikeExpr::StmtBlockLikeExpr(ExprKind kind, ArrayRef exprs, - ArrayRef types) +mlir::edsc::StmtBlockLikeExpr::StmtBlockLikeExpr(ExprKind kind, + ArrayRef exprs, + ArrayRef types) : Expr(Expr::globalAllocator()->Allocate()) { // Initialize with placement new. auto exprStorage = Expr::globalAllocator()->Allocate(exprs.size()); @@ -418,15 +516,15 @@ StmtBlockLikeExpr::StmtBlockLikeExpr(ExprKind kind, ArrayRef exprs, kind, ArrayRef(exprStorage, exprs.size()), ArrayRef(typeStorage, types.size())}; } -ArrayRef StmtBlockLikeExpr::getExprs() const { +ArrayRef mlir::edsc::StmtBlockLikeExpr::getExprs() const { return static_cast(storage)->exprs; } -ArrayRef StmtBlockLikeExpr::getTypes() const { +ArrayRef mlir::edsc::StmtBlockLikeExpr::getTypes() const { return static_cast(storage)->types; } -Stmt::Stmt(const Bindable &lhs, const Expr &rhs, - llvm::ArrayRef enclosedStmts) { +mlir::edsc::Stmt::Stmt(const Bindable &lhs, const Expr &rhs, + llvm::ArrayRef enclosedStmts) { storage = Expr::globalAllocator()->Allocate(); // Initialize with placement new. auto enclosedStmtStorage = @@ -437,24 +535,33 @@ Stmt::Stmt(const Bindable &lhs, const Expr &rhs, lhs, rhs, ArrayRef(enclosedStmtStorage, enclosedStmts.size())}; } -Stmt::Stmt(const Expr &rhs, llvm::ArrayRef enclosedStmts) +mlir::edsc::Stmt::Stmt(const Expr &rhs, llvm::ArrayRef enclosedStmts) : Stmt(Bindable(), rhs, enclosedStmts) {} -Stmt &Stmt::operator=(const Expr &expr) { +edsc_stmt_t makeStmt(edsc_expr_t e) { + assert(e && "unexpected empty expression"); + return Stmt(Expr(e)); +} + +Stmt &mlir::edsc::Stmt::operator=(const Expr &expr) { Stmt res(Bindable(), expr, {}); std::swap(res.storage, this->storage); return *this; } -Bindable Stmt::getLHS() const { return static_cast(storage)->lhs; } +Bindable mlir::edsc::Stmt::getLHS() const { + return static_cast(storage)->lhs; +} -Expr Stmt::getRHS() const { return static_cast(storage)->rhs; } +Expr mlir::edsc::Stmt::getRHS() const { + return static_cast(storage)->rhs; +} -llvm::ArrayRef Stmt::getEnclosedStmts() const { +llvm::ArrayRef mlir::edsc::Stmt::getEnclosedStmts() const { return storage->enclosedStmts; } -void Stmt::print(raw_ostream &os, Twine indent) const { +void mlir::edsc::Stmt::print(raw_ostream &os, Twine indent) const { assert(storage && "Unexpected null storage,stmt must be bound to print"); auto lhs = getLHS(); auto rhs = getRHS(); @@ -491,36 +598,97 @@ void Stmt::print(raw_ostream &os, Twine indent) const { } } -void Stmt::dump() const { this->print(llvm::errs()); } +void mlir::edsc::Stmt::dump() const { this->print(llvm::errs()); } -std::string Stmt::str() const { +std::string mlir::edsc::Stmt::str() const { std::string res; llvm::raw_string_ostream os(res); this->print(os); return res; } -llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Stmt &stmt) { +llvm::raw_ostream &mlir::edsc::operator<<(llvm::raw_ostream &os, + const Stmt &stmt) { stmt.print(os); return os; } -Indexed Indexed::operator[](llvm::ArrayRef indices) const { +Indexed mlir::edsc::Indexed::operator[](llvm::ArrayRef indices) const { Indexed res(base); res.indices = llvm::SmallVector(indices.begin(), indices.end()); return res; } -Indexed Indexed::operator[](llvm::ArrayRef indices) const { +Indexed mlir::edsc::Indexed:: +operator[](llvm::ArrayRef indices) const { return (*this)[llvm::ArrayRef{indices.begin(), indices.end()}]; } -Stmt Indexed::operator=(Expr expr) { // NOLINT: unconventional-assing-operator +// NOLINTNEXTLINE: unconventional-assign-operator +Stmt mlir::edsc::Indexed::operator=(Expr expr) { assert(!indices.empty() && "Expected attached indices to Indexed"); assert(base); Stmt stmt(store(expr, base, indices)); indices.clear(); return stmt; } -} // namespace edsc -} // namespace mlir + +edsc_indexed_t makeIndexed(edsc_expr_t expr) { + return edsc_indexed_t{expr, edsc_expr_list_t{nullptr, 0}}; +} + +edsc_indexed_t index(edsc_indexed_t indexed, edsc_expr_list_t indices) { + return edsc_indexed_t{indexed.base, indices}; +} + +mlir_type_t makeScalarType(mlir_context_t context, const char *name, + unsigned bitwidth) { + mlir::MLIRContext *c = reinterpret_cast(context); + mlir_type_t res = + llvm::StringSwitch(name) + .Case("bf16", + mlir_type_t{mlir::Type::getBF16(c).getAsOpaquePointer()}) + .Case("f16", mlir_type_t{mlir::Type::getF16(c).getAsOpaquePointer()}) + .Case("f32", mlir_type_t{mlir::Type::getF32(c).getAsOpaquePointer()}) + .Case("f64", mlir_type_t{mlir::Type::getF64(c).getAsOpaquePointer()}) + .Case("index", + mlir_type_t{mlir::Type::getIndex(c).getAsOpaquePointer()}) + .Case("i", + mlir_type_t{ + mlir::Type::getInteger(bitwidth, c).getAsOpaquePointer()}) + .Default(mlir_type_t{nullptr}); + if (!res) { + llvm_unreachable("Invalid type specifier"); + } + return res; +} + +mlir_type_t makeMemRefType(mlir_context_t context, mlir_type_t elemType, + int64_list_t sizes) { + auto t = mlir::MemRefType::get( + llvm::ArrayRef(sizes.values, sizes.n), + mlir::Type::getFromOpaquePointer(elemType), + {mlir::AffineMap::getMultiDimIdentityMap( + sizes.n, reinterpret_cast(context))}, + 0); + return mlir_type_t{t.getAsOpaquePointer()}; +} + +mlir_type_t makeFunctionType(mlir_context_t context, mlir_type_list_t inputs, + mlir_type_list_t outputs) { + llvm::SmallVector ins(inputs.n), outs(outputs.n); + for (unsigned i = 0; i < inputs.n; ++i) { + ins[i] = mlir::Type::getFromOpaquePointer(inputs.types[i]); + } + for (unsigned i = 0; i < outputs.n; ++i) { + ins[i] = mlir::Type::getFromOpaquePointer(outputs.types[i]); + } + auto ft = mlir::FunctionType::get( + ins, outs, reinterpret_cast(context)); + return mlir_type_t{ft.getAsOpaquePointer()}; +} + +unsigned getFunctionArity(mlir_func_t function) { + auto *f = reinterpret_cast(function); + return f->getNumArguments(); +} diff --git a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp index e344c86799ac..5aac31db94b2 100644 --- a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp +++ b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp @@ -278,9 +278,9 @@ Expected> ExecutionEngine::create(Module *m) { setupTargetTriple(llvmModule.get()); packFunctionArguments(llvmModule.get()); - engine->jit = std::move(*expectedJIT); - if (auto err = engine->jit->addModule(std::move(llvmModule))) + if (auto err = (*expectedJIT)->addModule(std::move(llvmModule))) return std::move(err); + engine->jit = std::move(*expectedJIT); return engine; }