mirror of
https://github.com/intel/llvm.git
synced 2026-01-26 12:26:52 +08:00
Add a C API for EDSCs in other languages + python
This CL adds support for calling EDSCs from other languages than C++. Following the LLVM convention this CL: 1. declares simple opaque types and a C API in mlir-c/Core.h; 2. defines the implementation directly in lib/EDSC/Types.cpp and lib/EDSC/MLIREmitter.cpp. Unlike LLVM however the nomenclature for these types and API functions is not well-defined, naming suggestions are most welcome. To avoid the need for conversion functions, Types.h and MLIREmitter.h include mlir-c/Core.h and provide constructors and conversion operators between the mlir::edsc type and the corresponding C type. In this first commit, mlir-c/Core.h only contains the types for the C API to allow EDSCs to work from Python. This includes both a minimal set of core MLIR types (mlir_context_t, mlir_type_t, mlir_func_t) as well as the EDSC types (edsc_mlir_emitter_t, edsc_expr_t, edsc_stmt_t, edsc_indexed_t). This can be restructured in the future as concrete needs arise. For now, the API only supports: 1. scalar types; 2. memrefs of scalar types with static or symbolic shapes; 3. functions with input and output of these types. The C API is not complete wrt ownership semantics. This is in large part due to the fact that python bindings are written with Pybind11 which allows very idiomatic C++ bindings. An effort is made to write a large chunk of these bindings using the C API but some C++isms are used where the design benefits from this simplication. A fully isolated C API will make more sense once we also integrate with another language like Swift and have enough use cases to drive the design. Lastly, this CL also fixes a bug in mlir::ExecutionEngine were the order of declaration of llvmContext and the JIT result in an improper order of destructors (which used to crash before the fix). PiperOrigin-RevId: 231290250
This commit is contained in:
committed by
jpienaar
parent
eb753f4aec
commit
cacf05892e
248
mlir/include/mlir-c/Core.h
Normal file
248
mlir/include/mlir-c/Core.h
Normal file
@@ -0,0 +1,248 @@
|
||||
/*===-- mlir-c/Core.h - Core Library C Interface ------------------*- C -*-===*\
|
||||
|* *|
|
||||
|* Copyright 2019 The MLIR Authors. *|
|
||||
|* *|
|
||||
|* Licensed under the Apache License, Version 2.0 (the "License"); *|
|
||||
|* you may not use this file except in compliance with the License. *|
|
||||
|* You may obtain a copy of the License at *|
|
||||
|* *|
|
||||
|* http://www.apache.org/licenses/LICENSE-2.0 *|
|
||||
|* *|
|
||||
|* Unless required by applicable law or agreed to in writing, software *|
|
||||
|* distributed under the License is distributed on an "AS IS" BASIS, *|
|
||||
|* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. *|
|
||||
|* See the License for the specific language governing permissions and *|
|
||||
|* limitations under the License. *|
|
||||
|* *|
|
||||
|*===----------------------------------------------------------------------===*|
|
||||
|* *|
|
||||
|* This header declares the C interface to MLIR. *|
|
||||
|* *|
|
||||
\*===----------------------------------------------------------------------===*/
|
||||
#ifndef MLIR_C_CORE_H
|
||||
#define MLIR_C_CORE_H
|
||||
|
||||
#ifdef __cplusplus
|
||||
#include <cstdint>
|
||||
extern "C" {
|
||||
#else
|
||||
#include <stdint.h>
|
||||
#endif
|
||||
|
||||
/// Opaque MLIR types.
|
||||
/// Opaque C type for mlir::MLIRContext*.
|
||||
typedef void *mlir_context_t;
|
||||
/// Opaque C type for mlir::Type.
|
||||
typedef const void *mlir_type_t;
|
||||
/// Opaque C type for mlir::Function*.
|
||||
typedef void *mlir_func_t;
|
||||
/// Opaque C type for mlir::edsc::MLIREmiter.
|
||||
typedef void *edsc_mlir_emitter_t;
|
||||
/// Opaque C type for mlir::edsc::Expr.
|
||||
typedef void *edsc_expr_t;
|
||||
/// Opaque C type for mlir::edsc::Stmt.
|
||||
typedef void *edsc_stmt_t;
|
||||
|
||||
/// Simple C lists for non-owning mlir Opaque C types.
|
||||
/// Recommended usage is construction from the `data()` and `size()` of a scoped
|
||||
/// owning SmallVectorImpl<...> and passing to one of the C functions declared
|
||||
/// later in this file.
|
||||
/// Once the function returns and the proper EDSC has been constructed,
|
||||
/// resources are freed by exiting the scope.
|
||||
typedef struct {
|
||||
int64_t *values;
|
||||
uint64_t n;
|
||||
} int64_list_t;
|
||||
|
||||
typedef struct {
|
||||
mlir_type_t *types;
|
||||
uint64_t n;
|
||||
} mlir_type_list_t;
|
||||
|
||||
typedef struct {
|
||||
edsc_expr_t *exprs;
|
||||
uint64_t n;
|
||||
} edsc_expr_list_t;
|
||||
|
||||
typedef struct {
|
||||
edsc_stmt_t *stmts;
|
||||
uint64_t n;
|
||||
} edsc_stmt_list_t;
|
||||
|
||||
typedef struct {
|
||||
edsc_expr_t base;
|
||||
edsc_expr_list_t indices;
|
||||
} edsc_indexed_t;
|
||||
|
||||
typedef struct {
|
||||
edsc_indexed_t *list;
|
||||
uint64_t n;
|
||||
} edsc_indexed_list_t;
|
||||
|
||||
/// Minimal C API for exposing EDSCs to Swift, Python and other languages.
|
||||
|
||||
/// Returns a simple scalar mlir::Type using the following convention:
|
||||
/// - makeScalarType(c, "bf16") return an `mlir::Type::getBF16`
|
||||
/// - makeScalarType(c, "f16") return an `mlir::Type::getF16`
|
||||
/// - makeScalarType(c, "f32") return an `mlir::Type::getF32`
|
||||
/// - makeScalarType(c, "f64") return an `mlir::Type::getF64`
|
||||
/// - makeScalarType(c, "index") return an `mlir::Type::getIndex`
|
||||
/// - makeScalarType(c, "i", bitwidth) return an
|
||||
/// `mlir::Type::getInteger(bitwidth)`
|
||||
///
|
||||
/// No other combinations are currently supported.
|
||||
mlir_type_t makeScalarType(mlir_context_t context, const char *name,
|
||||
unsigned bitwidth);
|
||||
|
||||
/// Returns an `mlir::MemRefType` of the element type `elemType` and shape
|
||||
/// `sizes`.
|
||||
mlir_type_t makeMemRefType(mlir_context_t context, mlir_type_t elemType,
|
||||
int64_list_t sizes);
|
||||
|
||||
/// Returns an `mlir::FunctionType` of the element type `elemType` and shape
|
||||
/// `sizes`.
|
||||
mlir_type_t makeFunctionType(mlir_context_t context, mlir_type_list_t inputs,
|
||||
mlir_type_list_t outputs);
|
||||
|
||||
/// Returns the arity of `function`.
|
||||
unsigned getFunctionArity(mlir_func_t function);
|
||||
|
||||
/// Returns a new opaque mlir::edsc::Expr that is bound into `emitter` with a
|
||||
/// constant of the specified type.
|
||||
edsc_expr_t bindConstantBF16(edsc_mlir_emitter_t emitter, double value);
|
||||
edsc_expr_t bindConstantF16(edsc_mlir_emitter_t emitter, float value);
|
||||
edsc_expr_t bindConstantF32(edsc_mlir_emitter_t emitter, float value);
|
||||
edsc_expr_t bindConstantF64(edsc_mlir_emitter_t emitter, double value);
|
||||
edsc_expr_t bindConstantInt(edsc_mlir_emitter_t emitter, int64_t value,
|
||||
unsigned bitwidth);
|
||||
edsc_expr_t bindConstantIndex(edsc_mlir_emitter_t emitter, int64_t value);
|
||||
|
||||
/// Returns the rank of the `function` argument at position `pos`.
|
||||
/// If the argument is of MemRefType, this returns the rank of the MemRef.
|
||||
/// Otherwise returns `0`.
|
||||
/// TODO(ntv): support more than MemRefType and scalar Type.
|
||||
unsigned getRankOfFunctionArgument(mlir_func_t function, unsigned pos);
|
||||
|
||||
/// Returns an opaque mlir::Type of the `function` argument at position `pos`.
|
||||
mlir_type_t getTypeOfFunctionArgument(mlir_func_t function, unsigned pos);
|
||||
|
||||
/// Returns an opaque mlir::edsc::Expr that has been bound to the `pos` argument
|
||||
/// of `function`.
|
||||
edsc_expr_t bindFunctionArgument(edsc_mlir_emitter_t emitter,
|
||||
mlir_func_t function, unsigned pos);
|
||||
|
||||
/// Fills the preallocated list `result` with opaque mlir::edsc::Expr that have
|
||||
/// been bound to each argument of `function`.
|
||||
///
|
||||
/// Prerequisites:
|
||||
/// - `result` must have been preallocated with space for exactly the number
|
||||
/// of arguments of `function`.
|
||||
void bindFunctionArguments(edsc_mlir_emitter_t emitter, mlir_func_t function,
|
||||
edsc_expr_list_t *result);
|
||||
|
||||
/// Returns the rank of `boundMemRef`. This API function is provided to more
|
||||
/// easily compose with `bindFunctionArgument`. A similar function could be
|
||||
/// provided for an mlir_type_t of type MemRefType but it is expected that users
|
||||
/// of this API either:
|
||||
/// 1. construct the MemRefType explicitly, in which case they already have
|
||||
/// access to the rank and shape of the MemRefType;
|
||||
/// 2. access MemRefs via mlir_function_t *values* in which case they would
|
||||
/// pass edsc_expr_t bound to an edsc_emitter_t.
|
||||
///
|
||||
/// Prerequisites:
|
||||
/// - `boundMemRef` must be an opaque edsc_expr_t that has alreay been bound
|
||||
/// in `emitter`.
|
||||
unsigned getBoundMemRefRank(edsc_mlir_emitter_t emitter,
|
||||
edsc_expr_t boundMemRef);
|
||||
|
||||
/// Fills the preallocated list `result` with opaque mlir::edsc::Expr that have
|
||||
/// been bound to each dimension of `boundMemRef`.
|
||||
///
|
||||
/// Prerequisites:
|
||||
/// - `result` must have been preallocated with space for exactly the rank of
|
||||
/// `boundMemRef`;
|
||||
/// - `boundMemRef` must be an opaque edsc_expr_t that has alreay been bound
|
||||
/// in `emitter`. This is because symbolic MemRef shapes require an SSAValue
|
||||
/// that can only be recovered from `emitter`.
|
||||
void bindMemRefShape(edsc_mlir_emitter_t emitter, edsc_expr_t boundMemRef,
|
||||
edsc_expr_list_t *result);
|
||||
|
||||
/// Fills the preallocated lists `resultLbs`, `resultUbs` and `resultSteps` with
|
||||
/// opaque mlir::edsc::Expr that have been bound to proper values to traverse
|
||||
/// each dimension of `memRefType`.
|
||||
/// At the moment:
|
||||
/// - `resultsLbs` are always bound to the constant index `0`;
|
||||
/// - `resultsUbs` are always bound to the shape of `memRefType`;
|
||||
/// - `resultsSteps` are always bound to the constant index `1`.
|
||||
/// In the future, this will allow non-contiguous MemRef views.
|
||||
///
|
||||
/// Prerequisites:
|
||||
/// - `resultLbs`, `resultUbs` and `resultSteps` must have each been
|
||||
/// preallocated with space for exactly the rank of `boundMemRef`;
|
||||
/// - `boundMemRef` must be an opaque edsc_expr_t that has alreay been bound
|
||||
/// in `emitter`. This is because symbolic MemRef shapes require an SSAValue
|
||||
/// that can only be recovered from `emitter`.
|
||||
void bindMemRefView(edsc_mlir_emitter_t emitter, edsc_expr_t boundMemRef,
|
||||
edsc_expr_list_t *resultLbs, edsc_expr_list_t *resultUbs,
|
||||
edsc_expr_list_t *resultSteps);
|
||||
|
||||
/// Returns an opaque expression for an mlir::edsc::Expr.
|
||||
edsc_expr_t makeBindable();
|
||||
|
||||
/// Returns an opaque expression for an mlir::edsc::Stmt.
|
||||
edsc_stmt_t makeStmt(edsc_expr_t e);
|
||||
|
||||
/// Returns an opaque expression for an mlir::edsc::Indexed.
|
||||
edsc_indexed_t makeIndexed(edsc_expr_t expr);
|
||||
|
||||
/// Returns an indexed opaque expression with indices bound in the structure
|
||||
/// given an `indexed` and `indices`.
|
||||
/// Prerequisite:
|
||||
/// - `indexed` must not have been indexed previously.
|
||||
edsc_indexed_t index(edsc_indexed_t indexed, edsc_expr_list_t indices);
|
||||
|
||||
/// Returns an opaque expression that will emit an mlir::LoadOp.
|
||||
edsc_expr_t Load(edsc_indexed_t indexed, edsc_expr_list_t indices);
|
||||
|
||||
/// Returns an opaque statement for an mlir::StoreOp.
|
||||
edsc_stmt_t Store(edsc_expr_t value, edsc_indexed_t indexed,
|
||||
edsc_expr_list_t indices);
|
||||
|
||||
/// Returns an opaque statement for an mlir::SelectOp.
|
||||
edsc_expr_t Select(edsc_expr_t cond, edsc_expr_t lhs, edsc_expr_t rhs);
|
||||
|
||||
/// Returns an opaque statement for an mlir::ReturnOp.
|
||||
edsc_stmt_t Return(edsc_expr_list_t values);
|
||||
|
||||
/// Returns a single opaque statement that acts as an mlir block. At the moment
|
||||
/// this is pure syntactic sugar to allow lists of mlir::edsc::Stmt to be
|
||||
/// specified and emitted. In particular, block arguments are not currently
|
||||
/// supported.
|
||||
edsc_stmt_t Block(edsc_stmt_list_t enclosedStmts);
|
||||
|
||||
/// Returns an opaque statement for an mlir::ForInst with `enclosedStmts` nested
|
||||
/// below it.
|
||||
edsc_stmt_t For(edsc_expr_t iv, edsc_expr_t lb, edsc_expr_t ub,
|
||||
edsc_expr_t step, edsc_stmt_list_t enclosedStmts);
|
||||
|
||||
/// Returns an opaque statement for a perfectly nested set of mlir::ForInst with
|
||||
/// `enclosedStmts` nested below it.
|
||||
edsc_stmt_t ForNest(edsc_expr_list_t iv, edsc_expr_list_t lb,
|
||||
edsc_expr_list_t ub, edsc_expr_list_t step,
|
||||
edsc_stmt_list_t enclosedStmts);
|
||||
|
||||
/// Returns an opaque expression for the corresponding Binary operation.
|
||||
edsc_expr_t Add(edsc_expr_t e1, edsc_expr_t e2);
|
||||
edsc_expr_t Sub(edsc_expr_t e1, edsc_expr_t e2);
|
||||
edsc_expr_t Mul(edsc_expr_t e1, edsc_expr_t e2);
|
||||
// edsc_expr_t Div(edsc_expr_t e1, edsc_expr_t e2);
|
||||
edsc_expr_t LT(edsc_expr_t e1, edsc_expr_t e2);
|
||||
edsc_expr_t LE(edsc_expr_t e1, edsc_expr_t e2);
|
||||
edsc_expr_t GT(edsc_expr_t e1, edsc_expr_t e2);
|
||||
edsc_expr_t GE(edsc_expr_t e1, edsc_expr_t e2);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // end extern "C"
|
||||
#endif
|
||||
|
||||
#endif // MLIR_C_CORE_H
|
||||
@@ -121,7 +121,7 @@ struct MLIREmitter {
|
||||
///
|
||||
/// Prerequisite:
|
||||
/// `memRef` is a Value of type MemRefType.
|
||||
SmallVector<edsc::Bindable, 8> makeBoundSizes(Value *memRef);
|
||||
SmallVector<edsc::Expr, 8> makeBoundSizes(Value *memRef);
|
||||
|
||||
private:
|
||||
FuncBuilder *builder;
|
||||
|
||||
@@ -24,6 +24,7 @@
|
||||
#ifndef MLIR_LIB_EDSC_TYPES_H_
|
||||
#define MLIR_LIB_EDSC_TYPES_H_
|
||||
|
||||
#include "mlir-c/Core.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
||||
@@ -171,6 +172,9 @@ public:
|
||||
|
||||
Expr() : storage(nullptr) {}
|
||||
/* implicit */ Expr(ImplType *storage) : storage(storage) {}
|
||||
explicit Expr(edsc_expr_t expr)
|
||||
: storage(reinterpret_cast<ImplType *>(expr)) {}
|
||||
operator edsc_expr_t() { return edsc_expr_t{storage}; }
|
||||
|
||||
Expr(const Expr &other) : storage(other.storage) {}
|
||||
Expr &operator=(Expr other) {
|
||||
@@ -227,6 +231,8 @@ struct Bindable : public Expr {
|
||||
Bindable(Expr::ImplType *ptr) : Expr(ptr) {
|
||||
assert(!ptr || isa<Bindable>() && "expected Bindable");
|
||||
}
|
||||
explicit Bindable(const edsc_expr_t &expr) : Expr(expr) {}
|
||||
operator edsc_expr_t() { return edsc_expr_t{storage}; }
|
||||
|
||||
friend struct ScopedEDSCContext;
|
||||
|
||||
@@ -352,6 +358,9 @@ struct Stmt {
|
||||
Stmt(const Bindable &lhs, const Expr &rhs,
|
||||
llvm::ArrayRef<Stmt> stmts = llvm::ArrayRef<Stmt>());
|
||||
Stmt &operator=(const Expr &expr);
|
||||
explicit Stmt(edsc_stmt_t stmt)
|
||||
: storage(reinterpret_cast<ImplType *>(stmt)) {}
|
||||
operator edsc_stmt_t() { return edsc_stmt_t{storage}; }
|
||||
|
||||
operator Expr() const { return getLHS(); }
|
||||
|
||||
@@ -414,7 +423,7 @@ template <typename U> U Expr::dyn_cast() const {
|
||||
if (isa<U>()) {
|
||||
return U(storage);
|
||||
}
|
||||
return U(nullptr);
|
||||
return U((Expr::ImplType *)(nullptr));
|
||||
}
|
||||
template <typename U> U Expr::cast() const {
|
||||
assert(isa<U>());
|
||||
@@ -497,6 +506,9 @@ Stmt For(llvm::MutableArrayRef<Bindable> indices, llvm::ArrayRef<Expr> lbs,
|
||||
Stmt For(llvm::MutableArrayRef<Bindable> indices, llvm::ArrayRef<Bindable> lbs,
|
||||
llvm::ArrayRef<Bindable> ubs, llvm::ArrayRef<Bindable> steps,
|
||||
llvm::ArrayRef<Stmt> enclosedStmts);
|
||||
Stmt For(llvm::MutableArrayRef<Bindable> indices, llvm::ArrayRef<Bindable> lbs,
|
||||
llvm::ArrayRef<Expr> ubs, llvm::ArrayRef<Bindable> steps,
|
||||
llvm::ArrayRef<Stmt> enclosedStmts);
|
||||
|
||||
/// This helper class exists purely for sugaring purposes and allows writing
|
||||
/// expressions such as:
|
||||
@@ -508,7 +520,8 @@ Stmt For(llvm::MutableArrayRef<Bindable> indices, llvm::ArrayRef<Bindable> lbs,
|
||||
/// });
|
||||
/// ```
|
||||
struct Indexed {
|
||||
Indexed(Bindable m) : base(m), indices() {}
|
||||
Indexed(Bindable b) : base(b), indices() {}
|
||||
Indexed(Expr e) : base(e), indices() {}
|
||||
|
||||
/// Returns a new `Indexed`. As a consequence, an Indexed with attached
|
||||
/// indices can never be reused unless it is captured (e.g. via a Stmt).
|
||||
@@ -534,7 +547,7 @@ struct Indexed {
|
||||
Expr operator*(Expr e) const { return static_cast<Expr>(*this) * e; }
|
||||
|
||||
private:
|
||||
Bindable base;
|
||||
Expr base;
|
||||
llvm::SmallVector<Expr, 4> indices;
|
||||
};
|
||||
|
||||
|
||||
@@ -73,9 +73,11 @@ public:
|
||||
llvm::Error invoke(StringRef name, MutableArrayRef<void *> args);
|
||||
|
||||
private:
|
||||
// Ordering of llvmContext and jit is important for destruction purposes: the
|
||||
// jit must be destroyed before the context.
|
||||
llvm::LLVMContext llvmContext;
|
||||
// Private implementation of the JIT (PIMPL)
|
||||
std::unique_ptr<impl::OrcJIT> jit;
|
||||
llvm::LLVMContext llvmContext;
|
||||
};
|
||||
|
||||
template <typename... Args>
|
||||
|
||||
@@ -20,6 +20,7 @@
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include "mlir-c/Core.h"
|
||||
#include "mlir/Analysis/AffineAnalysis.h"
|
||||
#include "mlir/EDSC/MLIREmitter.h"
|
||||
#include "mlir/EDSC/Types.h"
|
||||
@@ -38,8 +39,9 @@ using llvm::errs;
|
||||
|
||||
#define DEBUG_TYPE "edsc"
|
||||
|
||||
namespace mlir {
|
||||
namespace edsc {
|
||||
using namespace mlir;
|
||||
using namespace mlir::edsc;
|
||||
using namespace mlir::edsc::detail;
|
||||
|
||||
// Factors out the boilerplate that is needed to build and answer the
|
||||
// following simple question:
|
||||
@@ -89,7 +91,7 @@ static bool isFloatElement(const Value &v) {
|
||||
return getElementType(v).isa<FloatType>();
|
||||
}
|
||||
|
||||
Value *add(FuncBuilder *builder, Location location, Value *a, Value *b) {
|
||||
static Value *add(FuncBuilder *builder, Location location, Value *a, Value *b) {
|
||||
if (isIndexElement(*a)) {
|
||||
auto *context = builder->getContext();
|
||||
auto d0 = getAffineDimExpr(0, context);
|
||||
@@ -103,7 +105,7 @@ Value *add(FuncBuilder *builder, Location location, Value *a, Value *b) {
|
||||
return builder->create<AddFOp>(location, a, b)->getResult();
|
||||
}
|
||||
|
||||
Value *sub(FuncBuilder *builder, Location location, Value *a, Value *b) {
|
||||
static Value *sub(FuncBuilder *builder, Location location, Value *a, Value *b) {
|
||||
if (isIndexElement(*a)) {
|
||||
auto *context = builder->getContext();
|
||||
auto d0 = getAffineDimExpr(0, context);
|
||||
@@ -117,7 +119,7 @@ Value *sub(FuncBuilder *builder, Location location, Value *a, Value *b) {
|
||||
return builder->create<SubFOp>(location, a, b)->getResult();
|
||||
}
|
||||
|
||||
Value *mul(FuncBuilder *builder, Location location, Value *a, Value *b) {
|
||||
static Value *mul(FuncBuilder *builder, Location location, Value *a, Value *b) {
|
||||
if (!isFloatElement(*a)) {
|
||||
return builder->create<MulIOp>(location, a, b)->getResult();
|
||||
}
|
||||
@@ -138,7 +140,7 @@ static void printDefininingStatement(llvm::raw_ostream &os, const Value &v) {
|
||||
}
|
||||
}
|
||||
|
||||
MLIREmitter &MLIREmitter::bind(Bindable e, Value *v) {
|
||||
MLIREmitter &mlir::edsc::MLIREmitter::bind(Bindable e, Value *v) {
|
||||
LLVM_DEBUG(printDefininingStatement(llvm::dbgs() << "\nBinding " << e << " @"
|
||||
<< e.getStoragePtr() << ": ",
|
||||
*v));
|
||||
@@ -151,7 +153,7 @@ MLIREmitter &MLIREmitter::bind(Bindable e, Value *v) {
|
||||
return *this;
|
||||
}
|
||||
|
||||
Value *MLIREmitter::emit(Expr e) {
|
||||
Value *mlir::edsc::MLIREmitter::emit(Expr e) {
|
||||
auto it = ssaBindings.find(e);
|
||||
if (it != ssaBindings.end()) {
|
||||
return it->second;
|
||||
@@ -316,7 +318,7 @@ Value *MLIREmitter::emit(Expr e) {
|
||||
return res;
|
||||
}
|
||||
|
||||
SmallVector<Value *, 8> MLIREmitter::emit(ArrayRef<Expr> exprs) {
|
||||
SmallVector<Value *, 8> mlir::edsc::MLIREmitter::emit(ArrayRef<Expr> exprs) {
|
||||
return mlir::functional::map(
|
||||
[this](Expr e) {
|
||||
auto *res = this->emit(e);
|
||||
@@ -327,7 +329,7 @@ SmallVector<Value *, 8> MLIREmitter::emit(ArrayRef<Expr> exprs) {
|
||||
exprs);
|
||||
}
|
||||
|
||||
void MLIREmitter::emitStmt(const Stmt &stmt) {
|
||||
void mlir::edsc::MLIREmitter::emitStmt(const Stmt &stmt) {
|
||||
auto *block = builder->getBlock();
|
||||
auto ip = builder->getInsertionPoint();
|
||||
// Blocks are just a containing abstraction, they do not emit their RHS.
|
||||
@@ -351,7 +353,7 @@ void MLIREmitter::emitStmt(const Stmt &stmt) {
|
||||
builder->setInsertionPoint(block, ip);
|
||||
}
|
||||
|
||||
void MLIREmitter::emitStmts(ArrayRef<Stmt> stmts) {
|
||||
void mlir::edsc::MLIREmitter::emitStmts(ArrayRef<Stmt> stmts) {
|
||||
for (auto &stmt : stmts) {
|
||||
emitStmt(stmt);
|
||||
}
|
||||
@@ -390,7 +392,8 @@ static bool isDynamicSize(int size) { return size < 0; }
|
||||
/// and returns the vector with {%d0, %c3, %c4, %d3, %c5}.
|
||||
static SmallVector<Value *, 8> getMemRefSizes(FuncBuilder *b, Location loc,
|
||||
Value *memRef) {
|
||||
auto memRefType = memRef->getType().template cast<MemRefType>();
|
||||
assert(memRef->getType().isa<MemRefType>() && "Expected a MemRef value");
|
||||
MemRefType memRefType = memRef->getType().cast<MemRefType>();
|
||||
SmallVector<Value *, 8> res;
|
||||
res.reserve(memRefType.getShape().size());
|
||||
const auto &shape = memRefType.getShape();
|
||||
@@ -404,15 +407,165 @@ static SmallVector<Value *, 8> getMemRefSizes(FuncBuilder *b, Location loc,
|
||||
return res;
|
||||
}
|
||||
|
||||
SmallVector<edsc::Bindable, 8> MLIREmitter::makeBoundSizes(Value *memRef) {
|
||||
SmallVector<edsc::Expr, 8>
|
||||
mlir::edsc::MLIREmitter::makeBoundSizes(Value *memRef) {
|
||||
assert(memRef->getType().isa<MemRefType>() && "Expected a MemRef value");
|
||||
MemRefType memRefType = memRef->getType().cast<MemRefType>();
|
||||
auto memRefSizes = edsc::makeBindables(memRefType.getShape().size());
|
||||
auto memrefSizeValues = getMemRefSizes(getBuilder(), getLocation(), memRef);
|
||||
assert(memrefSizeValues.size() == memRefSizes.size());
|
||||
bindZipRange(llvm::zip(memRefSizes, memrefSizeValues));
|
||||
return memRefSizes;
|
||||
SmallVector<edsc::Expr, 8> res(memRefSizes.begin(), memRefSizes.end());
|
||||
return res;
|
||||
}
|
||||
|
||||
} // namespace edsc
|
||||
} // namespace mlir
|
||||
edsc_expr_t bindConstantBF16(edsc_mlir_emitter_t emitter, double value) {
|
||||
auto *e = reinterpret_cast<mlir::edsc::MLIREmitter *>(emitter);
|
||||
Bindable b;
|
||||
e->bindConstant<mlir::ConstantFloatOp>(b, mlir::APFloat(value),
|
||||
e->getBuilder()->getBF16Type());
|
||||
return b;
|
||||
}
|
||||
|
||||
edsc_expr_t bindConstantF16(edsc_mlir_emitter_t emitter, float value) {
|
||||
auto *e = reinterpret_cast<mlir::edsc::MLIREmitter *>(emitter);
|
||||
Bindable b;
|
||||
bool unused;
|
||||
mlir::APFloat val(value);
|
||||
val.convert(e->getBuilder()->getF16Type().getFloatSemantics(),
|
||||
mlir::APFloat::rmNearestTiesToEven, &unused);
|
||||
e->bindConstant<mlir::ConstantFloatOp>(b, val, e->getBuilder()->getF16Type());
|
||||
return b;
|
||||
}
|
||||
|
||||
edsc_expr_t bindConstantF32(edsc_mlir_emitter_t emitter, float value) {
|
||||
auto *e = reinterpret_cast<mlir::edsc::MLIREmitter *>(emitter);
|
||||
Bindable b;
|
||||
e->bindConstant<mlir::ConstantFloatOp>(b, mlir::APFloat(value),
|
||||
e->getBuilder()->getF32Type());
|
||||
return b;
|
||||
}
|
||||
|
||||
edsc_expr_t bindConstantF64(edsc_mlir_emitter_t emitter, double value) {
|
||||
auto *e = reinterpret_cast<mlir::edsc::MLIREmitter *>(emitter);
|
||||
Bindable b;
|
||||
e->bindConstant<mlir::ConstantFloatOp>(b, mlir::APFloat(value),
|
||||
e->getBuilder()->getF64Type());
|
||||
return b;
|
||||
}
|
||||
|
||||
edsc_expr_t bindConstantInt(edsc_mlir_emitter_t emitter, int64_t value,
|
||||
unsigned bitwidth) {
|
||||
auto *e = reinterpret_cast<mlir::edsc::MLIREmitter *>(emitter);
|
||||
Bindable b;
|
||||
e->bindConstant<mlir::ConstantIntOp>(
|
||||
b, value, e->getBuilder()->getIntegerType(bitwidth));
|
||||
return b;
|
||||
}
|
||||
|
||||
edsc_expr_t bindConstantIndex(edsc_mlir_emitter_t emitter, int64_t value) {
|
||||
auto *e = reinterpret_cast<mlir::edsc::MLIREmitter *>(emitter);
|
||||
Bindable b;
|
||||
e->bindConstant<mlir::ConstantIndexOp>(b, value);
|
||||
return b;
|
||||
}
|
||||
|
||||
unsigned getRankOfFunctionArgument(mlir_func_t function, unsigned pos) {
|
||||
auto *f = reinterpret_cast<mlir::Function *>(function);
|
||||
assert(pos < f->getNumArguments());
|
||||
auto *arg = *(f->getArguments().begin() + pos);
|
||||
if (auto memRefType = arg->getType().dyn_cast<mlir::MemRefType>()) {
|
||||
return memRefType.getRank();
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
mlir_type_t getTypeOfFunctionArgument(mlir_func_t function, unsigned pos) {
|
||||
auto *f = reinterpret_cast<mlir::Function *>(function);
|
||||
assert(pos < f->getNumArguments());
|
||||
auto *arg = *(f->getArguments().begin() + pos);
|
||||
return mlir_type_t{arg->getType().getAsOpaquePointer()};
|
||||
}
|
||||
|
||||
edsc_expr_t bindFunctionArgument(edsc_mlir_emitter_t emitter,
|
||||
mlir_func_t function, unsigned pos) {
|
||||
auto *e = reinterpret_cast<mlir::edsc::MLIREmitter *>(emitter);
|
||||
auto *f = reinterpret_cast<mlir::Function *>(function);
|
||||
assert(pos < f->getNumArguments());
|
||||
auto *arg = *(f->getArguments().begin() + pos);
|
||||
Bindable b;
|
||||
e->bind(b, arg);
|
||||
return Expr(b);
|
||||
}
|
||||
|
||||
void bindFunctionArguments(edsc_mlir_emitter_t emitter, mlir_func_t function,
|
||||
edsc_expr_list_t *result) {
|
||||
auto *e = reinterpret_cast<mlir::edsc::MLIREmitter *>(emitter);
|
||||
auto *f = reinterpret_cast<mlir::Function *>(function);
|
||||
assert(result->n == f->getNumArguments());
|
||||
for (unsigned pos = 0; pos < result->n; ++pos) {
|
||||
auto *arg = *(f->getArguments().begin() + pos);
|
||||
Bindable b;
|
||||
e->bind(b, arg);
|
||||
result->exprs[pos] = Expr(b);
|
||||
}
|
||||
}
|
||||
|
||||
unsigned getBoundMemRefRank(edsc_mlir_emitter_t emitter,
|
||||
edsc_expr_t boundMemRef) {
|
||||
auto *e = reinterpret_cast<mlir::edsc::MLIREmitter *>(emitter);
|
||||
auto *v = e->getValue(mlir::edsc::Expr(boundMemRef));
|
||||
auto memRefType = v->getType().cast<mlir::MemRefType>();
|
||||
return memRefType.getRank();
|
||||
}
|
||||
|
||||
void bindMemRefShape(edsc_mlir_emitter_t emitter, edsc_expr_t boundMemRef,
|
||||
edsc_expr_list_t *result) {
|
||||
auto *e = reinterpret_cast<mlir::edsc::MLIREmitter *>(emitter);
|
||||
auto *v = e->getValue(mlir::edsc::Expr(boundMemRef));
|
||||
auto memRefType = v->getType().cast<mlir::MemRefType>();
|
||||
auto rank = memRefType.getRank();
|
||||
assert(result->n == rank && "Unexpected memref shape binding results count");
|
||||
auto bindables = e->makeBoundSizes(v);
|
||||
for (unsigned i = 0; i < rank; ++i) {
|
||||
result->exprs[i] = bindables[i];
|
||||
}
|
||||
}
|
||||
|
||||
void bindMemRefView(edsc_mlir_emitter_t emitter, edsc_expr_t boundMemRef,
|
||||
edsc_expr_list_t *resultLbs, edsc_expr_list_t *resultUbs,
|
||||
edsc_expr_list_t *resultSteps) {
|
||||
auto *e = reinterpret_cast<mlir::edsc::MLIREmitter *>(emitter);
|
||||
auto *v = e->getValue(mlir::edsc::Expr(boundMemRef));
|
||||
auto memRefType = v->getType().cast<mlir::MemRefType>();
|
||||
auto rank = memRefType.getRank();
|
||||
assert(resultLbs->n == rank && "Unexpected memref binding results count");
|
||||
assert(resultUbs->n == rank && "Unexpected memref binding results count");
|
||||
assert(resultSteps->n == rank && "Unexpected memref binding results count");
|
||||
auto bindables = e->makeBoundSizes(v);
|
||||
for (unsigned i = 0; i < rank; ++i) {
|
||||
Bindable zero;
|
||||
e->bindConstant<mlir::ConstantIndexOp>(zero, 0);
|
||||
resultLbs->exprs[i] = zero;
|
||||
resultUbs->exprs[i] = bindables[i];
|
||||
Bindable one;
|
||||
e->bindConstant<mlir::ConstantIndexOp>(one, 1);
|
||||
resultSteps->exprs[i] = one;
|
||||
}
|
||||
}
|
||||
|
||||
#define DEFINE_EDSL_BINARY_OP(FUN_NAME, OP_SYMBOL) \
|
||||
edsc_expr_t FUN_NAME(edsc_expr_t e1, edsc_expr_t e2) { \
|
||||
return Expr(e1) OP_SYMBOL Expr(e2); \
|
||||
}
|
||||
|
||||
DEFINE_EDSL_BINARY_OP(Add, +);
|
||||
DEFINE_EDSL_BINARY_OP(Sub, -);
|
||||
DEFINE_EDSL_BINARY_OP(Mul, *);
|
||||
// DEFINE_EDSL_BINARY_OP(Div, /);
|
||||
DEFINE_EDSL_BINARY_OP(LT, <);
|
||||
DEFINE_EDSL_BINARY_OP(LE, <=);
|
||||
DEFINE_EDSL_BINARY_OP(GT, >);
|
||||
DEFINE_EDSL_BINARY_OP(GE, >=);
|
||||
|
||||
#undef DEFINE_EDSL_BINARY_OP
|
||||
|
||||
@@ -16,9 +16,14 @@
|
||||
// =============================================================================
|
||||
|
||||
#include "mlir/EDSC/Types.h"
|
||||
#include "mlir-c/Core.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/Support/STLExtras.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "llvm/ADT/Twine.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
@@ -77,53 +82,57 @@ struct StmtStorage {
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
} // namespace edsc
|
||||
} // namespace mlir
|
||||
|
||||
ScopedEDSCContext::ScopedEDSCContext() {
|
||||
mlir::edsc::ScopedEDSCContext::ScopedEDSCContext() {
|
||||
Expr::globalAllocator() = &allocator;
|
||||
Bindable::resetIds();
|
||||
}
|
||||
|
||||
ScopedEDSCContext::~ScopedEDSCContext() { Expr::globalAllocator() = nullptr; }
|
||||
mlir::edsc::ScopedEDSCContext::~ScopedEDSCContext() {
|
||||
Expr::globalAllocator() = nullptr;
|
||||
}
|
||||
|
||||
ExprKind Expr::getKind() const { return storage->kind; }
|
||||
ExprKind mlir::edsc::Expr::getKind() const { return storage->kind; }
|
||||
|
||||
Expr Expr::operator+(Expr other) const {
|
||||
Expr mlir::edsc::Expr::operator+(Expr other) const {
|
||||
return BinaryExpr(ExprKind::Add, *this, other);
|
||||
}
|
||||
Expr Expr::operator-(Expr other) const {
|
||||
Expr mlir::edsc::Expr::operator-(Expr other) const {
|
||||
return BinaryExpr(ExprKind::Sub, *this, other);
|
||||
}
|
||||
Expr Expr::operator*(Expr other) const {
|
||||
Expr mlir::edsc::Expr::operator*(Expr other) const {
|
||||
return BinaryExpr(ExprKind::Mul, *this, other);
|
||||
}
|
||||
|
||||
Expr Expr::operator==(Expr other) const {
|
||||
Expr mlir::edsc::Expr::operator==(Expr other) const {
|
||||
return BinaryExpr(ExprKind::EQ, *this, other);
|
||||
}
|
||||
Expr Expr::operator!=(Expr other) const {
|
||||
Expr mlir::edsc::Expr::operator!=(Expr other) const {
|
||||
return BinaryExpr(ExprKind::NE, *this, other);
|
||||
}
|
||||
Expr Expr::operator<(Expr other) const {
|
||||
Expr mlir::edsc::Expr::operator<(Expr other) const {
|
||||
return BinaryExpr(ExprKind::LT, *this, other);
|
||||
}
|
||||
Expr Expr::operator<=(Expr other) const {
|
||||
Expr mlir::edsc::Expr::operator<=(Expr other) const {
|
||||
return BinaryExpr(ExprKind::LE, *this, other);
|
||||
}
|
||||
Expr Expr::operator>(Expr other) const {
|
||||
Expr mlir::edsc::Expr::operator>(Expr other) const {
|
||||
return BinaryExpr(ExprKind::GT, *this, other);
|
||||
}
|
||||
Expr Expr::operator>=(Expr other) const {
|
||||
Expr mlir::edsc::Expr::operator>=(Expr other) const {
|
||||
return BinaryExpr(ExprKind::GE, *this, other);
|
||||
}
|
||||
Expr Expr::operator&&(Expr other) const {
|
||||
Expr mlir::edsc::Expr::operator&&(Expr other) const {
|
||||
return BinaryExpr(ExprKind::And, *this, other);
|
||||
}
|
||||
Expr Expr::operator||(Expr other) const {
|
||||
Expr mlir::edsc::Expr::operator||(Expr other) const {
|
||||
return BinaryExpr(ExprKind::Or, *this, other);
|
||||
}
|
||||
|
||||
// Free functions.
|
||||
llvm::SmallVector<Bindable, 8> makeBindables(unsigned n) {
|
||||
llvm::SmallVector<Bindable, 8> mlir::edsc::makeBindables(unsigned n) {
|
||||
llvm::SmallVector<Bindable, 8> res;
|
||||
res.reserve(n);
|
||||
for (auto i = 0; i < n; ++i) {
|
||||
@@ -132,7 +141,16 @@ llvm::SmallVector<Bindable, 8> makeBindables(unsigned n) {
|
||||
return res;
|
||||
}
|
||||
|
||||
llvm::SmallVector<Expr, 8> makeExprs(unsigned n) {
|
||||
static llvm::SmallVector<Bindable, 8> makeBindables(edsc_expr_list_t exprList) {
|
||||
llvm::SmallVector<Bindable, 8> exprs;
|
||||
exprs.reserve(exprList.n);
|
||||
for (unsigned i = 0; i < exprList.n; ++i) {
|
||||
exprs.push_back(Expr(exprList.exprs[i]).cast<Bindable>());
|
||||
}
|
||||
return exprs;
|
||||
}
|
||||
|
||||
llvm::SmallVector<Expr, 8> mlir::edsc::makeExprs(unsigned n) {
|
||||
llvm::SmallVector<Expr, 8> res;
|
||||
res.reserve(n);
|
||||
for (auto i = 0; i < n; ++i) {
|
||||
@@ -141,7 +159,7 @@ llvm::SmallVector<Expr, 8> makeExprs(unsigned n) {
|
||||
return res;
|
||||
}
|
||||
|
||||
llvm::SmallVector<Expr, 8> makeExprs(ArrayRef<Bindable> bindables) {
|
||||
llvm::SmallVector<Expr, 8> mlir::edsc::makeExprs(ArrayRef<Bindable> bindables) {
|
||||
llvm::SmallVector<Expr, 8> res;
|
||||
res.reserve(bindables.size());
|
||||
for (auto b : bindables) {
|
||||
@@ -150,29 +168,53 @@ llvm::SmallVector<Expr, 8> makeExprs(ArrayRef<Bindable> bindables) {
|
||||
return res;
|
||||
}
|
||||
|
||||
Expr alloc(llvm::ArrayRef<Expr> sizes, Type memrefType) {
|
||||
static llvm::SmallVector<Expr, 8> makeExprs(edsc_expr_list_t exprList) {
|
||||
llvm::SmallVector<Expr, 8> exprs;
|
||||
exprs.reserve(exprList.n);
|
||||
for (unsigned i = 0; i < exprList.n; ++i) {
|
||||
exprs.push_back(Expr(exprList.exprs[i]));
|
||||
}
|
||||
return exprs;
|
||||
}
|
||||
|
||||
static llvm::SmallVector<Stmt, 8> makeStmts(edsc_stmt_list_t enclosedStmts) {
|
||||
llvm::SmallVector<Stmt, 8> stmts;
|
||||
stmts.reserve(enclosedStmts.n);
|
||||
for (unsigned i = 0; i < enclosedStmts.n; ++i) {
|
||||
stmts.push_back(Stmt(enclosedStmts.stmts[i]));
|
||||
}
|
||||
return stmts;
|
||||
}
|
||||
|
||||
Expr mlir::edsc::alloc(llvm::ArrayRef<Expr> sizes, Type memrefType) {
|
||||
return VariadicExpr(ExprKind::Alloc, sizes, memrefType);
|
||||
}
|
||||
|
||||
Stmt Block(ArrayRef<Stmt> stmts) {
|
||||
Stmt mlir::edsc::Block(ArrayRef<Stmt> stmts) {
|
||||
return Stmt(StmtBlockLikeExpr(ExprKind::Block, {}), stmts);
|
||||
}
|
||||
|
||||
Expr dealloc(Expr memref) { return UnaryExpr(ExprKind::Dealloc, memref); }
|
||||
edsc_stmt_t Block(edsc_stmt_list_t enclosedStmts) {
|
||||
return Stmt(mlir::edsc::Block(makeStmts(enclosedStmts)));
|
||||
}
|
||||
|
||||
Stmt For(Expr lb, Expr ub, Expr step, ArrayRef<Stmt> stmts) {
|
||||
Expr mlir::edsc::dealloc(Expr memref) {
|
||||
return UnaryExpr(ExprKind::Dealloc, memref);
|
||||
}
|
||||
|
||||
Stmt mlir::edsc::For(Expr lb, Expr ub, Expr step, ArrayRef<Stmt> stmts) {
|
||||
Bindable idx;
|
||||
return For(idx, lb, ub, step, stmts);
|
||||
}
|
||||
|
||||
Stmt For(const Bindable &idx, Expr lb, Expr ub, Expr step,
|
||||
ArrayRef<Stmt> stmts) {
|
||||
Stmt mlir::edsc::For(const Bindable &idx, Expr lb, Expr ub, Expr step,
|
||||
ArrayRef<Stmt> stmts) {
|
||||
return Stmt(idx, StmtBlockLikeExpr(ExprKind::For, {lb, ub, step}), stmts);
|
||||
}
|
||||
|
||||
Stmt For(MutableArrayRef<Bindable> indices, ArrayRef<Expr> lbs,
|
||||
ArrayRef<Expr> ubs, ArrayRef<Expr> steps,
|
||||
ArrayRef<Stmt> enclosedStmts) {
|
||||
Stmt mlir::edsc::For(MutableArrayRef<Bindable> indices, ArrayRef<Expr> lbs,
|
||||
ArrayRef<Expr> ubs, ArrayRef<Expr> steps,
|
||||
ArrayRef<Stmt> enclosedStmts) {
|
||||
assert(!indices.empty());
|
||||
assert(indices.size() == lbs.size());
|
||||
assert(indices.size() == ubs.size());
|
||||
@@ -185,14 +227,37 @@ Stmt For(MutableArrayRef<Bindable> indices, ArrayRef<Expr> lbs,
|
||||
return curStmt;
|
||||
}
|
||||
|
||||
Stmt For(llvm::MutableArrayRef<Bindable> indices, llvm::ArrayRef<Bindable> lbs,
|
||||
llvm::ArrayRef<Bindable> ubs, llvm::ArrayRef<Bindable> steps,
|
||||
llvm::ArrayRef<Stmt> enclosedStmts) {
|
||||
Stmt mlir::edsc::For(llvm::MutableArrayRef<Bindable> indices,
|
||||
llvm::ArrayRef<Bindable> lbs, llvm::ArrayRef<Bindable> ubs,
|
||||
llvm::ArrayRef<Bindable> steps,
|
||||
llvm::ArrayRef<Stmt> enclosedStmts) {
|
||||
return For(indices, SmallVector<Expr, 8>{lbs.begin(), lbs.end()},
|
||||
SmallVector<Expr, 8>{ubs.begin(), ubs.end()},
|
||||
SmallVector<Expr, 8>{steps.begin(), steps.end()}, enclosedStmts);
|
||||
}
|
||||
|
||||
Stmt mlir::edsc::For(llvm::MutableArrayRef<Bindable> indices,
|
||||
llvm::ArrayRef<Bindable> lbs, llvm::ArrayRef<Expr> ubs,
|
||||
llvm::ArrayRef<Bindable> steps,
|
||||
llvm::ArrayRef<Stmt> enclosedStmts) {
|
||||
return For(indices, SmallVector<Expr, 8>{lbs.begin(), lbs.end()}, ubs,
|
||||
SmallVector<Expr, 8>{steps.begin(), steps.end()}, enclosedStmts);
|
||||
}
|
||||
|
||||
edsc_stmt_t For(edsc_expr_t iv, edsc_expr_t lb, edsc_expr_t ub,
|
||||
edsc_expr_t step, edsc_stmt_list_t enclosedStmts) {
|
||||
return Stmt(For(Expr(iv).cast<Bindable>(), Expr(lb), Expr(ub), Expr(step),
|
||||
makeStmts(enclosedStmts)));
|
||||
}
|
||||
|
||||
edsc_stmt_t ForNest(edsc_expr_list_t ivs, edsc_expr_list_t lbs,
|
||||
edsc_expr_list_t ubs, edsc_expr_list_t steps,
|
||||
edsc_stmt_list_t enclosedStmts) {
|
||||
auto bindables = makeBindables(ivs);
|
||||
return Stmt(For(bindables, makeExprs(lbs), makeExprs(ubs), makeExprs(steps),
|
||||
makeStmts(enclosedStmts)));
|
||||
}
|
||||
|
||||
template <typename BindableOrExpr>
|
||||
static Expr loadBuilder(Expr m, ArrayRef<BindableOrExpr> indices) {
|
||||
SmallVector<Expr, 8> exprs;
|
||||
@@ -200,12 +265,24 @@ static Expr loadBuilder(Expr m, ArrayRef<BindableOrExpr> indices) {
|
||||
exprs.append(indices.begin(), indices.end());
|
||||
return VariadicExpr(ExprKind::Load, exprs);
|
||||
}
|
||||
Expr load(Expr m, Expr index) { return loadBuilder<Expr>(m, {index}); }
|
||||
Expr load(Expr m, Bindable index) { return loadBuilder<Bindable>(m, {index}); }
|
||||
Expr load(Expr m, const llvm::SmallVectorImpl<Bindable> &indices) {
|
||||
Expr mlir::edsc::load(Expr m, Expr index) {
|
||||
return loadBuilder<Expr>(m, {index});
|
||||
}
|
||||
Expr mlir::edsc::load(Expr m, Bindable index) {
|
||||
return loadBuilder<Bindable>(m, {index});
|
||||
}
|
||||
Expr mlir::edsc::load(Expr m, const llvm::SmallVectorImpl<Bindable> &indices) {
|
||||
return loadBuilder(m, ArrayRef<Bindable>{indices.begin(), indices.end()});
|
||||
}
|
||||
Expr load(Expr m, ArrayRef<Expr> indices) { return loadBuilder(m, indices); }
|
||||
Expr mlir::edsc::load(Expr m, ArrayRef<Expr> indices) {
|
||||
return loadBuilder(m, indices);
|
||||
}
|
||||
|
||||
edsc_expr_t Load(edsc_indexed_t indexed, edsc_expr_list_t indices) {
|
||||
Indexed i(Expr(indexed.base).cast<Bindable>());
|
||||
Expr res = i[makeExprs(indices)];
|
||||
return res;
|
||||
}
|
||||
|
||||
template <typename BindableOrExpr>
|
||||
static Expr storeBuilder(Expr val, Expr m, ArrayRef<BindableOrExpr> indices) {
|
||||
@@ -215,33 +292,49 @@ static Expr storeBuilder(Expr val, Expr m, ArrayRef<BindableOrExpr> indices) {
|
||||
exprs.append(indices.begin(), indices.end());
|
||||
return VariadicExpr(ExprKind::Store, exprs);
|
||||
}
|
||||
Expr store(Expr val, Expr m, Expr index) {
|
||||
Expr mlir::edsc::store(Expr val, Expr m, Expr index) {
|
||||
return storeBuilder<Expr>(val, m, {index});
|
||||
}
|
||||
Expr store(Expr val, Expr m, Bindable index) {
|
||||
Expr mlir::edsc::store(Expr val, Expr m, Bindable index) {
|
||||
return storeBuilder<Bindable>(val, m, {index});
|
||||
}
|
||||
Expr store(Expr val, Expr m, const llvm::SmallVectorImpl<Bindable> &indices) {
|
||||
Expr mlir::edsc::store(Expr val, Expr m,
|
||||
const llvm::SmallVectorImpl<Bindable> &indices) {
|
||||
return storeBuilder(val, m,
|
||||
ArrayRef<Bindable>{indices.begin(), indices.end()});
|
||||
}
|
||||
Expr store(Expr val, Expr m, ArrayRef<Expr> indices) {
|
||||
Expr mlir::edsc::store(Expr val, Expr m, ArrayRef<Expr> indices) {
|
||||
return storeBuilder(val, m, indices);
|
||||
}
|
||||
|
||||
Expr select(Expr cond, Expr lhs, Expr rhs) {
|
||||
edsc_stmt_t Store(edsc_expr_t value, edsc_indexed_t indexed,
|
||||
edsc_expr_list_t indices) {
|
||||
Indexed i(Expr(indexed.base).cast<Bindable>());
|
||||
Indexed loc = i[makeExprs(indices)];
|
||||
return Stmt(loc = Expr(value));
|
||||
}
|
||||
|
||||
Expr mlir::edsc::select(Expr cond, Expr lhs, Expr rhs) {
|
||||
return TernaryExpr(ExprKind::Select, cond, lhs, rhs);
|
||||
}
|
||||
|
||||
Expr vector_type_cast(Expr memrefExpr, Type memrefType) {
|
||||
edsc_expr_t Select(edsc_expr_t cond, edsc_expr_t lhs, edsc_expr_t rhs) {
|
||||
return select(Expr(cond), Expr(lhs), Expr(rhs));
|
||||
}
|
||||
|
||||
Expr mlir::edsc::vector_type_cast(Expr memrefExpr, Type memrefType) {
|
||||
return VariadicExpr(ExprKind::VectorTypeCast, {memrefExpr}, {memrefType});
|
||||
}
|
||||
|
||||
Stmt Return(ArrayRef<Expr> values) {
|
||||
Stmt mlir::edsc::Return(ArrayRef<Expr> values) {
|
||||
return VariadicExpr(ExprKind::Return, values);
|
||||
}
|
||||
|
||||
void Expr::print(raw_ostream &os) const {
|
||||
edsc_stmt_t Return(edsc_expr_list_t values) {
|
||||
return Stmt(Return(makeExprs(values)));
|
||||
}
|
||||
|
||||
void mlir::edsc::Expr::print(raw_ostream &os) const {
|
||||
if (auto unbound = this->dyn_cast<Bindable>()) {
|
||||
os << "$" << unbound.getId();
|
||||
return;
|
||||
@@ -322,73 +415,77 @@ void Expr::print(raw_ostream &os) const {
|
||||
os << "unknown_kind(" << static_cast<int>(getKind()) << ")";
|
||||
}
|
||||
|
||||
void Expr::dump() const { this->print(llvm::errs()); }
|
||||
void mlir::edsc::Expr::dump() const { this->print(llvm::errs()); }
|
||||
|
||||
std::string Expr::str() const {
|
||||
std::string mlir::edsc::Expr::str() const {
|
||||
std::string res;
|
||||
llvm::raw_string_ostream os(res);
|
||||
this->print(os);
|
||||
return res;
|
||||
}
|
||||
|
||||
llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Expr &expr) {
|
||||
llvm::raw_ostream &mlir::edsc::operator<<(llvm::raw_ostream &os,
|
||||
const Expr &expr) {
|
||||
expr.print(os);
|
||||
return os;
|
||||
}
|
||||
|
||||
Bindable::Bindable()
|
||||
mlir::edsc::Bindable::Bindable()
|
||||
: Expr(Expr::globalAllocator()->Allocate<detail::BindableStorage>()) {
|
||||
// Initialize with placement new.
|
||||
new (storage) detail::BindableStorage{Bindable::newId()};
|
||||
}
|
||||
|
||||
unsigned Bindable::getId() const {
|
||||
edsc_expr_t makeBindable() { return Bindable(); }
|
||||
|
||||
unsigned mlir::edsc::Bindable::getId() const {
|
||||
return static_cast<ImplType *>(storage)->id;
|
||||
}
|
||||
|
||||
unsigned &Bindable::newId() {
|
||||
unsigned &mlir::edsc::Bindable::newId() {
|
||||
static thread_local unsigned id = 0;
|
||||
return ++id;
|
||||
}
|
||||
|
||||
UnaryExpr::UnaryExpr(ExprKind kind, Expr expr)
|
||||
mlir::edsc::UnaryExpr::UnaryExpr(ExprKind kind, Expr expr)
|
||||
: Expr(Expr::globalAllocator()->Allocate<detail::UnaryExprStorage>()) {
|
||||
// Initialize with placement new.
|
||||
new (storage) detail::UnaryExprStorage{kind, expr};
|
||||
}
|
||||
Expr UnaryExpr::getExpr() const {
|
||||
Expr mlir::edsc::UnaryExpr::getExpr() const {
|
||||
return static_cast<ImplType *>(storage)->expr;
|
||||
}
|
||||
|
||||
BinaryExpr::BinaryExpr(ExprKind kind, Expr lhs, Expr rhs)
|
||||
mlir::edsc::BinaryExpr::BinaryExpr(ExprKind kind, Expr lhs, Expr rhs)
|
||||
: Expr(Expr::globalAllocator()->Allocate<detail::BinaryExprStorage>()) {
|
||||
// Initialize with placement new.
|
||||
new (storage) detail::BinaryExprStorage{kind, lhs, rhs};
|
||||
}
|
||||
Expr BinaryExpr::getLHS() const {
|
||||
Expr mlir::edsc::BinaryExpr::getLHS() const {
|
||||
return static_cast<ImplType *>(storage)->lhs;
|
||||
}
|
||||
Expr BinaryExpr::getRHS() const {
|
||||
Expr mlir::edsc::BinaryExpr::getRHS() const {
|
||||
return static_cast<ImplType *>(storage)->rhs;
|
||||
}
|
||||
|
||||
TernaryExpr::TernaryExpr(ExprKind kind, Expr cond, Expr lhs, Expr rhs)
|
||||
mlir::edsc::TernaryExpr::TernaryExpr(ExprKind kind, Expr cond, Expr lhs,
|
||||
Expr rhs)
|
||||
: Expr(Expr::globalAllocator()->Allocate<detail::TernaryExprStorage>()) {
|
||||
// Initialize with placement new.
|
||||
new (storage) detail::TernaryExprStorage{kind, cond, lhs, rhs};
|
||||
}
|
||||
Expr TernaryExpr::getCond() const {
|
||||
Expr mlir::edsc::TernaryExpr::getCond() const {
|
||||
return static_cast<ImplType *>(storage)->cond;
|
||||
}
|
||||
Expr TernaryExpr::getLHS() const {
|
||||
Expr mlir::edsc::TernaryExpr::getLHS() const {
|
||||
return static_cast<ImplType *>(storage)->lhs;
|
||||
}
|
||||
Expr TernaryExpr::getRHS() const {
|
||||
Expr mlir::edsc::TernaryExpr::getRHS() const {
|
||||
return static_cast<ImplType *>(storage)->rhs;
|
||||
}
|
||||
|
||||
VariadicExpr::VariadicExpr(ExprKind kind, ArrayRef<Expr> exprs,
|
||||
ArrayRef<Type> types)
|
||||
mlir::edsc::VariadicExpr::VariadicExpr(ExprKind kind, ArrayRef<Expr> exprs,
|
||||
ArrayRef<Type> types)
|
||||
: Expr(Expr::globalAllocator()->Allocate<detail::VariadicExprStorage>()) {
|
||||
// Initialize with placement new.
|
||||
auto exprStorage = Expr::globalAllocator()->Allocate<Expr>(exprs.size());
|
||||
@@ -399,15 +496,16 @@ VariadicExpr::VariadicExpr(ExprKind kind, ArrayRef<Expr> exprs,
|
||||
kind, ArrayRef<Expr>(exprStorage, exprs.size()),
|
||||
ArrayRef<Type>(typeStorage, types.size())};
|
||||
}
|
||||
ArrayRef<Expr> VariadicExpr::getExprs() const {
|
||||
ArrayRef<Expr> mlir::edsc::VariadicExpr::getExprs() const {
|
||||
return static_cast<ImplType *>(storage)->exprs;
|
||||
}
|
||||
ArrayRef<Type> VariadicExpr::getTypes() const {
|
||||
ArrayRef<Type> mlir::edsc::VariadicExpr::getTypes() const {
|
||||
return static_cast<ImplType *>(storage)->types;
|
||||
}
|
||||
|
||||
StmtBlockLikeExpr::StmtBlockLikeExpr(ExprKind kind, ArrayRef<Expr> exprs,
|
||||
ArrayRef<Type> types)
|
||||
mlir::edsc::StmtBlockLikeExpr::StmtBlockLikeExpr(ExprKind kind,
|
||||
ArrayRef<Expr> exprs,
|
||||
ArrayRef<Type> types)
|
||||
: Expr(Expr::globalAllocator()->Allocate<detail::VariadicExprStorage>()) {
|
||||
// Initialize with placement new.
|
||||
auto exprStorage = Expr::globalAllocator()->Allocate<Expr>(exprs.size());
|
||||
@@ -418,15 +516,15 @@ StmtBlockLikeExpr::StmtBlockLikeExpr(ExprKind kind, ArrayRef<Expr> exprs,
|
||||
kind, ArrayRef<Expr>(exprStorage, exprs.size()),
|
||||
ArrayRef<Type>(typeStorage, types.size())};
|
||||
}
|
||||
ArrayRef<Expr> StmtBlockLikeExpr::getExprs() const {
|
||||
ArrayRef<Expr> mlir::edsc::StmtBlockLikeExpr::getExprs() const {
|
||||
return static_cast<ImplType *>(storage)->exprs;
|
||||
}
|
||||
ArrayRef<Type> StmtBlockLikeExpr::getTypes() const {
|
||||
ArrayRef<Type> mlir::edsc::StmtBlockLikeExpr::getTypes() const {
|
||||
return static_cast<ImplType *>(storage)->types;
|
||||
}
|
||||
|
||||
Stmt::Stmt(const Bindable &lhs, const Expr &rhs,
|
||||
llvm::ArrayRef<Stmt> enclosedStmts) {
|
||||
mlir::edsc::Stmt::Stmt(const Bindable &lhs, const Expr &rhs,
|
||||
llvm::ArrayRef<Stmt> enclosedStmts) {
|
||||
storage = Expr::globalAllocator()->Allocate<detail::StmtStorage>();
|
||||
// Initialize with placement new.
|
||||
auto enclosedStmtStorage =
|
||||
@@ -437,24 +535,33 @@ Stmt::Stmt(const Bindable &lhs, const Expr &rhs,
|
||||
lhs, rhs, ArrayRef<Stmt>(enclosedStmtStorage, enclosedStmts.size())};
|
||||
}
|
||||
|
||||
Stmt::Stmt(const Expr &rhs, llvm::ArrayRef<Stmt> enclosedStmts)
|
||||
mlir::edsc::Stmt::Stmt(const Expr &rhs, llvm::ArrayRef<Stmt> enclosedStmts)
|
||||
: Stmt(Bindable(), rhs, enclosedStmts) {}
|
||||
|
||||
Stmt &Stmt::operator=(const Expr &expr) {
|
||||
edsc_stmt_t makeStmt(edsc_expr_t e) {
|
||||
assert(e && "unexpected empty expression");
|
||||
return Stmt(Expr(e));
|
||||
}
|
||||
|
||||
Stmt &mlir::edsc::Stmt::operator=(const Expr &expr) {
|
||||
Stmt res(Bindable(), expr, {});
|
||||
std::swap(res.storage, this->storage);
|
||||
return *this;
|
||||
}
|
||||
|
||||
Bindable Stmt::getLHS() const { return static_cast<ImplType *>(storage)->lhs; }
|
||||
Bindable mlir::edsc::Stmt::getLHS() const {
|
||||
return static_cast<ImplType *>(storage)->lhs;
|
||||
}
|
||||
|
||||
Expr Stmt::getRHS() const { return static_cast<ImplType *>(storage)->rhs; }
|
||||
Expr mlir::edsc::Stmt::getRHS() const {
|
||||
return static_cast<ImplType *>(storage)->rhs;
|
||||
}
|
||||
|
||||
llvm::ArrayRef<Stmt> Stmt::getEnclosedStmts() const {
|
||||
llvm::ArrayRef<Stmt> mlir::edsc::Stmt::getEnclosedStmts() const {
|
||||
return storage->enclosedStmts;
|
||||
}
|
||||
|
||||
void Stmt::print(raw_ostream &os, Twine indent) const {
|
||||
void mlir::edsc::Stmt::print(raw_ostream &os, Twine indent) const {
|
||||
assert(storage && "Unexpected null storage,stmt must be bound to print");
|
||||
auto lhs = getLHS();
|
||||
auto rhs = getRHS();
|
||||
@@ -491,36 +598,97 @@ void Stmt::print(raw_ostream &os, Twine indent) const {
|
||||
}
|
||||
}
|
||||
|
||||
void Stmt::dump() const { this->print(llvm::errs()); }
|
||||
void mlir::edsc::Stmt::dump() const { this->print(llvm::errs()); }
|
||||
|
||||
std::string Stmt::str() const {
|
||||
std::string mlir::edsc::Stmt::str() const {
|
||||
std::string res;
|
||||
llvm::raw_string_ostream os(res);
|
||||
this->print(os);
|
||||
return res;
|
||||
}
|
||||
|
||||
llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Stmt &stmt) {
|
||||
llvm::raw_ostream &mlir::edsc::operator<<(llvm::raw_ostream &os,
|
||||
const Stmt &stmt) {
|
||||
stmt.print(os);
|
||||
return os;
|
||||
}
|
||||
|
||||
Indexed Indexed::operator[](llvm::ArrayRef<Expr> indices) const {
|
||||
Indexed mlir::edsc::Indexed::operator[](llvm::ArrayRef<Expr> indices) const {
|
||||
Indexed res(base);
|
||||
res.indices = llvm::SmallVector<Expr, 4>(indices.begin(), indices.end());
|
||||
return res;
|
||||
}
|
||||
|
||||
Indexed Indexed::operator[](llvm::ArrayRef<Bindable> indices) const {
|
||||
Indexed mlir::edsc::Indexed::
|
||||
operator[](llvm::ArrayRef<Bindable> indices) const {
|
||||
return (*this)[llvm::ArrayRef<Expr>{indices.begin(), indices.end()}];
|
||||
}
|
||||
|
||||
Stmt Indexed::operator=(Expr expr) { // NOLINT: unconventional-assing-operator
|
||||
// NOLINTNEXTLINE: unconventional-assign-operator
|
||||
Stmt mlir::edsc::Indexed::operator=(Expr expr) {
|
||||
assert(!indices.empty() && "Expected attached indices to Indexed");
|
||||
assert(base);
|
||||
Stmt stmt(store(expr, base, indices));
|
||||
indices.clear();
|
||||
return stmt;
|
||||
}
|
||||
} // namespace edsc
|
||||
} // namespace mlir
|
||||
|
||||
edsc_indexed_t makeIndexed(edsc_expr_t expr) {
|
||||
return edsc_indexed_t{expr, edsc_expr_list_t{nullptr, 0}};
|
||||
}
|
||||
|
||||
edsc_indexed_t index(edsc_indexed_t indexed, edsc_expr_list_t indices) {
|
||||
return edsc_indexed_t{indexed.base, indices};
|
||||
}
|
||||
|
||||
mlir_type_t makeScalarType(mlir_context_t context, const char *name,
|
||||
unsigned bitwidth) {
|
||||
mlir::MLIRContext *c = reinterpret_cast<mlir::MLIRContext *>(context);
|
||||
mlir_type_t res =
|
||||
llvm::StringSwitch<mlir_type_t>(name)
|
||||
.Case("bf16",
|
||||
mlir_type_t{mlir::Type::getBF16(c).getAsOpaquePointer()})
|
||||
.Case("f16", mlir_type_t{mlir::Type::getF16(c).getAsOpaquePointer()})
|
||||
.Case("f32", mlir_type_t{mlir::Type::getF32(c).getAsOpaquePointer()})
|
||||
.Case("f64", mlir_type_t{mlir::Type::getF64(c).getAsOpaquePointer()})
|
||||
.Case("index",
|
||||
mlir_type_t{mlir::Type::getIndex(c).getAsOpaquePointer()})
|
||||
.Case("i",
|
||||
mlir_type_t{
|
||||
mlir::Type::getInteger(bitwidth, c).getAsOpaquePointer()})
|
||||
.Default(mlir_type_t{nullptr});
|
||||
if (!res) {
|
||||
llvm_unreachable("Invalid type specifier");
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
mlir_type_t makeMemRefType(mlir_context_t context, mlir_type_t elemType,
|
||||
int64_list_t sizes) {
|
||||
auto t = mlir::MemRefType::get(
|
||||
llvm::ArrayRef<int64_t>(sizes.values, sizes.n),
|
||||
mlir::Type::getFromOpaquePointer(elemType),
|
||||
{mlir::AffineMap::getMultiDimIdentityMap(
|
||||
sizes.n, reinterpret_cast<mlir::MLIRContext *>(context))},
|
||||
0);
|
||||
return mlir_type_t{t.getAsOpaquePointer()};
|
||||
}
|
||||
|
||||
mlir_type_t makeFunctionType(mlir_context_t context, mlir_type_list_t inputs,
|
||||
mlir_type_list_t outputs) {
|
||||
llvm::SmallVector<mlir::Type, 8> ins(inputs.n), outs(outputs.n);
|
||||
for (unsigned i = 0; i < inputs.n; ++i) {
|
||||
ins[i] = mlir::Type::getFromOpaquePointer(inputs.types[i]);
|
||||
}
|
||||
for (unsigned i = 0; i < outputs.n; ++i) {
|
||||
ins[i] = mlir::Type::getFromOpaquePointer(outputs.types[i]);
|
||||
}
|
||||
auto ft = mlir::FunctionType::get(
|
||||
ins, outs, reinterpret_cast<mlir::MLIRContext *>(context));
|
||||
return mlir_type_t{ft.getAsOpaquePointer()};
|
||||
}
|
||||
|
||||
unsigned getFunctionArity(mlir_func_t function) {
|
||||
auto *f = reinterpret_cast<mlir::Function *>(function);
|
||||
return f->getNumArguments();
|
||||
}
|
||||
|
||||
@@ -278,9 +278,9 @@ Expected<std::unique_ptr<ExecutionEngine>> ExecutionEngine::create(Module *m) {
|
||||
setupTargetTriple(llvmModule.get());
|
||||
packFunctionArguments(llvmModule.get());
|
||||
|
||||
engine->jit = std::move(*expectedJIT);
|
||||
if (auto err = engine->jit->addModule(std::move(llvmModule)))
|
||||
if (auto err = (*expectedJIT)->addModule(std::move(llvmModule)))
|
||||
return std::move(err);
|
||||
engine->jit = std::move(*expectedJIT);
|
||||
|
||||
return engine;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user